In [None]:
import os
import sys
sys.path.append(os.pardir)

import numpy as np
import torch as th
import torch.distributed as dist

from guided_diffusion import dist_util, logger
from guided_diffusion.image_datasets import load_data_sde, ImageDataset
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)
from torchvision import utils
from easydict import EasyDict
import PIL.Image as Image
import torchvision

to_pil = torchvision.transforms.ToPILImage()
device = 'cuda:2'
batch_size = 4

In [None]:
args_src = EasyDict({'attention_resolutions':'16', 
        'batch_size':batch_size, 
        'channel_mult':'', 
        'class_cond':False, 
        'clip_denoised':True, 
        'diffusion_steps':1000, 
        'dropout':0.0, 
        'image_size':256, 
        'learn_sigma':True, 
        'model_path':'../models/metface_distill.pt', 
        'noise_schedule':'linear', 
        'num_channels':128, 
        'num_head_channels':64, 
        'num_heads':4, 
        'num_heads_upsample':-1, 
        'num_res_blocks':1, 
        'num_samples':10000, 
        'p2_gamma':0, 
        'p2_k':1, 
        'predict_xstart':False, 
        'resblock_updown':True, 
        'rescale_learned_sigmas':False,
        'rescale_timesteps':False, 
        'sample_dir':'samples/metface_distill',
        'timestep_respacing':'ddim40', 
        'use_checkpoint':False,
        'use_ddim':True, 
        'use_fp16':True, 
        'use_kl':False, 
        'use_new_attention_order':False, 
        'use_scale_shift_norm':True,
        'data_dir':'/mnt/disk1/metface/no_glass/'})
model_src, diffusion = create_model_and_diffusion(
    **args_to_dict(args_src, model_and_diffusion_defaults().keys())
)

model_src.load_state_dict(
    dist_util.load_state_dict(args_src.model_path, map_location="cpu")
)
model_src.to(device)
model_src.convert_to_fp16()
model_src.eval()
print('done')

In [None]:
data = load_data_sde(
        data_dir='../FFHQ_partial/',
        batch_size=args_src.batch_size,
        image_size=args_src.image_size,
        class_cond=args_src.class_cond,
        deterministic=True
    )

### Sample images

In [None]:
t_end = 20
sample_fn = (
    diffusion.p_sample_loop if not args_src.use_ddim else diffusion.ddim_sample_loop
)
model_kwargs = {}
save_idx = 0
while save_idx<10000:
    for idx, batch in enumerate(data):
        img, cond = batch
        img = img.to(device)
        noise = th.randn_like(img)
        t = th.tensor([t_end] * img.shape[0], device=device)
        
        sample_src = sample_fn(
            model_src,
            (args_src.batch_size, 3, args_src.image_size, args_src.image_size),
            clip_denoised=args_src.clip_denoised,
            model_kwargs=model_kwargs,
            noise=noise,
            t_start=0,
            t_end=t_end
        )
        
        for i in range(img.shape[0]):
            out_path = f'/sampling/root/sdeit_20_0/{save_idx}.png'
            torchvision.utils.save_image(
                sample_src[i].unsqueeze(0),
                out_path,
                nrow=1,
                normalize=True,
                range=(-1, 1),
            )
            save_idx += 1
        if save_idx > 10000:
            break
            

### Calculate metric

In [4]:
metric = MetricI2I(scope='SDEdit', device=device)
metric.reset()
t_end = 20
sample_fn = (
    diffusion.p_sample_loop if not args_src.use_ddim else diffusion.ddim_sample_loop
)
model_kwargs = {}
save_idx = 0
while save_idx<10000:
    for idx, batch in enumerate(data):
        img, cond = batch
        img = img.to(device)
        noise = th.randn_like(img)
        t = th.tensor([t_end] * img.shape[0], device=device)
        
        sample_src = sample_fn(
            model_src,
            (args_src.batch_size, 3, args_src.image_size, args_src.image_size),
            clip_denoised=args_src.clip_denoised,
            model_kwargs=model_kwargs,
            noise=noise,
            t_start=0,
            t_end=t_end
        )
        metric.update(img, sample_src)
        
        save_idx += args_src.batch_size
        
        if save_idx > 10000:
            break
print(metric.print_metrics())
            