In [1]:
import os
import sys
sys.path.append(os.pardir)

import numpy as np
import torch as th
import torch.distributed as dist
import math
import importlib

from guided_diffusion import dist_util, logger
from guided_diffusion.image_datasets import load_data_sde
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
    create_dse,
)
from torchvision import utils
from easydict import EasyDict
import PIL.Image as Image
import torchvision
import lpips
from psnrssim import MetricI2I

to_pil = torchvision.transforms.ToPILImage()
device = 'cuda:0'
batch_size = 4

In [2]:
metric = MetricI2I(scope='EGSDE', device=device)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /opt/conda/lib/python3.8/site-packages/lpips/weights/v0.1/vgg.pth


In [3]:
data = load_data_sde(
        data_dir='/mnt/disk1/FFHQ_1000/',
        batch_size=batch_size,
        image_size=256,
        class_cond=False,
        deterministic=True
    )

In [4]:
#load domain-specific feature extractor
dse = create_dse(image_size=256,
                    num_class=2,
                    classifier_use_fp16=True,
                    classifier_width=128,
                    classifier_depth=2,
                    classifier_attention_resolutions='32,16,8',
                    classifier_use_scale_shift_norm=True,
                    classifier_resblock_updown=True,
                    classifier_pool='attention',
                    phase='test')
states = th.load('../models/face2portrait.pt')
dse.load_state_dict(states)
dse.to(device)
dse.convert_to_fp16()
dse.eval()
print('Done')

Done


In [5]:
from resizer import Resizer

down_N = 32
shape = (batch_size, 3, 256, 256)
shape_d = (batch_size, 3, int(256 / down_N), int(256 / down_N))
down = Resizer(shape, 1 / down_N).to(device)
up = Resizer(shape_d, down_N).to(device)
die = (down, up)

In [6]:
args_src = EasyDict({'attention_resolutions':'16', 
        'batch_size':batch_size, 
        'channel_mult':'', 
        'class_cond':False, 
        'clip_denoised':True, 
        'diffusion_steps':1000, 
        'dropout':0.0, 
        'image_size':256, 
        'learn_sigma':True, 
        'model_path':'../logs/metface_distill_14/model010000.pt', 
        'noise_schedule':'linear', 
        'num_channels':128, 
        'num_head_channels':64, 
        'num_heads':4, 
        'num_heads_upsample':-1, 
        'num_res_blocks':1, 
        'num_samples':10000, 
        'p2_gamma':0, 
        'p2_k':1, 
        'predict_xstart':False, 
        'resblock_updown':True, 
        'rescale_learned_sigmas':False,
        'rescale_timesteps':False, 
        'sample_dir':'samples/metface_transfer_step_40_40k',
        'timestep_respacing':'ddim40', 
        'use_checkpoint':False,
        'use_ddim':True, 
        'use_fp16':True, 
        'use_kl':False, 
        'use_new_attention_order':False, 
        'use_scale_shift_norm':True,
        'data_dir':''})
        
model_src, diffusion = create_model_and_diffusion(
    **args_to_dict(args_src, model_and_diffusion_defaults().keys())
)

model_src.load_state_dict(
    dist_util.load_state_dict(args_src.model_path, map_location="cpu")
)
model_src.to(device)
model_src.convert_to_fp16()
model_src.eval()
print('done')

model_kwargs = {}
sample_fn = (
    diffusion.p_sample_loop if not args_src.use_ddim else diffusion.ddim_sample_loop
)

done


In [7]:
for ls in [500, 400, 300, 200, 100]:
    metric.reset()
    metric.scope = f'EGSDE with ls:{ls}'
    for idx, batch in enumerate(data):
        img, cond = batch
        img = img.to(device)
        noise = th.randn_like(img)
        # t = th.tensor([t_end] * img.shape[0], device=device)
        model_kwargs = {}
        egsde_kwargs = {"dse":dse, "ls":ls, "li":2, "die":die}
        sample_src = sample_fn(
            model_src,
            (args_src.batch_size, 3, args_src.image_size, args_src.image_size),
            clip_denoised=args_src.clip_denoised,
            model_kwargs=model_kwargs,
            noise=noise,
            ref_img=img,
            egsde_kwargs=egsde_kwargs,
            range_t = 20
        )
        metric.update(img, sample_src)
        if idx >= :
            break
    print(metric.print_metrics())

EGSDE with ls:500 - psnr: 12.743446, ssim: 0.283273, lpips: 0.545026, 
EGSDE with ls:400 - psnr: 12.606478, ssim: 0.272000, lpips: 0.551730, 
EGSDE with ls:300 - psnr: 12.700075, ssim: 0.276487, lpips: 0.547047, 
EGSDE with ls:200 - psnr: 12.700248, ssim: 0.273622, lpips: 0.550184, 
EGSDE with ls:100 - psnr: 12.643587, ssim: 0.274843, lpips: 0.550943, 
