In [None]:
import os
import sys
sys.path.append(os.pardir)

import numpy as np
import torch as th
import torch.distributed as dist
import math

from guided_diffusion import dist_util, logger
from guided_diffusion.image_datasets import load_data_sde
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)
from torchvision import utils
from easydict import EasyDict
import PIL.Image as Image
import torchvision
import lpips
from psnrssim import MetricI2I

to_pil = torchvision.transforms.ToPILImage()
device = 'cuda:5'
batch_size = 4

In [None]:
metric = MetricI2I(scope='', device=device)
# TODO
# SDEdit에서 input image (FFHQ)와 out image(MetFaces style image)의 L2, SSIM, PSNR재기 with last_step = 20

In [None]:
args_src = EasyDict({'attention_resolutions':'16', 
        'batch_size':batch_size, 
        'channel_mult':'', 
        'class_cond':False, 
        'clip_denoised':True, 
        'diffusion_steps':1000, 
        'dropout':0.0, 
        'image_size':256, 
        'learn_sigma':True, 
        'model_path':'../models/ffhq_p2.pt', 
        'noise_schedule':'linear', 
        'num_channels':128, 
        'num_head_channels':64, 
        'num_heads':4, 
        'num_heads_upsample':-1, 
        'num_res_blocks':1, 
        'num_samples':10000, 
        'p2_gamma':0, 
        'p2_k':1, 
        'predict_xstart':False, 
        'resblock_updown':True, 
        'rescale_learned_sigmas':False,
        'rescale_timesteps':False, 
        'sample_dir':'samples/metface_transfer_step_40_40k',
        'timestep_respacing':'ddim40', 
        'use_checkpoint':False,
        'use_ddim':True, 
        'use_fp16':True, 
        'use_kl':False, 
        'use_new_attention_order':False, 
        'use_scale_shift_norm':True,
        'data_dir':''})
        
model_src, diffusion = create_model_and_diffusion(
    **args_to_dict(args_src, model_and_diffusion_defaults().keys())
)

model_src.load_state_dict(
    dist_util.load_state_dict(args_src.model_path, map_location="cpu")
)
model_src.to(device)
model_src.convert_to_fp16()
model_src.eval()
print('done')

In [None]:
data = load_data_sde(
        data_dir='../FFHQ_partial/',
        batch_size=args_src.batch_size,
        image_size=args_src.image_size,
        class_cond=args_src.class_cond,
        deterministic=True
    )
        # data_dir='/mnt/raid/FFHQ_10k/',

In [None]:
f= open('log_SDEdit_AAHQ1400.txt', 'w')
models = ['../models/ffhq_p2.pt',
            '../logs/aahq_limited1400_distill_p2_0.2_3_aux_0.2_50/model030000.pt',
            '../logs/aahq_limited1400_distill_p2_0.1_3_aux_0.3_50/model020000.pt',
            '../logs/aahq_limited1400_distill_p2_0.1_3_aux_0.3_50/model010000.pt',
            ]
for model_path in models:
    args_src = EasyDict({'attention_resolutions':'16', 
            'batch_size':batch_size, 
            'channel_mult':'', 
            'class_cond':False, 
            'clip_denoised':True, 
            'diffusion_steps':1000, 
            'dropout':0.0, 
            'image_size':256, 
            'learn_sigma':True, 
            'model_path':model_path, 
            'noise_schedule':'linear', 
            'num_channels':128, 
            'num_head_channels':64, 
            'num_heads':4, 
            'num_heads_upsample':-1, 
            'num_res_blocks':1, 
            'num_samples':10000, 
            'p2_gamma':0, 
            'p2_k':1, 
            'predict_xstart':False, 
            'resblock_updown':True, 
            'rescale_learned_sigmas':False,
            'rescale_timesteps':False, 
            'sample_dir':'samples/metface_transfer_step_40_40k',
            'timestep_respacing':'ddim40', 
            'use_checkpoint':False,
            'use_ddim':True, 
            'use_fp16':True, 
            'use_kl':False, 
            'use_new_attention_order':False, 
            'use_scale_shift_norm':True,
            'data_dir':''})
            
    model_src, diffusion = create_model_and_diffusion(
        **args_to_dict(args_src, model_and_diffusion_defaults().keys())
    )

    model_src.load_state_dict(
        dist_util.load_state_dict(args_src.model_path, map_location="cpu")
    )
    model_src.to(device)
    model_src.convert_to_fp16()
    model_src.eval()
    print('done')

    model_kwargs = {}
    sample_fn = (
        diffusion.p_sample_loop if not args_src.use_ddim else diffusion.ddim_sample_loop
    )
    dist = []
    t_end = 20

    n_trial = 1
    for n in range(1,n_trial+1):
        metric.reset()
        metric.scope = f'SDEdit {n}th trial'
        with th.no_grad():
            for idx, batch in enumerate(data):
                img, cond = batch
                img = img.to(device)
                noise = th.randn_like(img)
                t = th.tensor([t_end] * img.shape[0], device=device)
                perturbed = diffusion.q_sample(img, t, noise=noise)
                sample_src = sample_fn(
                    model_src,
                    (args_src.batch_size, 3, args_src.image_size, args_src.image_size),
                    clip_denoised=args_src.clip_denoised,
                    model_kwargs=model_kwargs,
                    noise=perturbed,
                    t_start=0,
                    t_end=t_end
                )
                metric.update(img, sample_src)
        
        print(model_path)
        print(model_path, file=f)
        print(metric.print_metrics())
        print(metric.print_metrics(), file=f)
f.close()

In [None]:
pil_img = ((img + 1) * 127.5).clamp(0, 255).to(th.uint8)
pil_sam = ((sample_src + 1) * 127.5).clamp(0, 255).to(th.uint8)
display(to_pil(pil_img[0]))
display(to_pil(pil_sam[0]))

In [None]:
model_path='../logs/metface_distill_14/model010000.pt'
args_src = EasyDict({'attention_resolutions':'16', 
        'batch_size':batch_size, 
        'channel_mult':'', 
        'class_cond':False, 
        'clip_denoised':True, 
        'diffusion_steps':1000, 
        'dropout':0.0, 
        'image_size':256, 
        'learn_sigma':True, 
        'model_path':model_path, 
        'noise_schedule':'linear', 
        'num_channels':128, 
        'num_head_channels':64, 
        'num_heads':4, 
        'num_heads_upsample':-1, 
        'num_res_blocks':1, 
        'num_samples':10000, 
        'p2_gamma':0, 
        'p2_k':1, 
        'predict_xstart':False, 
        'resblock_updown':True, 
        'rescale_learned_sigmas':False,
        'rescale_timesteps':False, 
        'sample_dir':'samples/metface_transfer_step_40_40k',
        'timestep_respacing':'ddim40', 
        'use_checkpoint':False,
        'use_ddim':True, 
        'use_fp16':True, 
        'use_kl':False, 
        'use_new_attention_order':False, 
        'use_scale_shift_norm':True,
        'data_dir':''})
        
model_src, diffusion = create_model_and_diffusion(
    **args_to_dict(args_src, model_and_diffusion_defaults().keys())
)

model_src.load_state_dict(
    dist_util.load_state_dict(args_src.model_path, map_location="cpu")
)
model_src.to(device)
model_src.convert_to_fp16()
model_src.eval()
print('done')

model_kwargs = {}
sample_fn = (
    diffusion.p_sample_loop if not args_src.use_ddim else diffusion.ddim_sample_loop
)
dist = []
t_end = 20

n_trial = 3
for n in range(n_trial):
    metric.reset()
    metric.scope = f'SDEdit {n}th trial'
    with th.no_grad():
        for idx, batch in enumerate(data):
            img, cond = batch
            img = img.to(device)
            noise = th.randn_like(img)
            t = th.tensor([t_end] * img.shape[0], device=device)
            perturbed = diffusion.q_sample(img, t, noise=noise)
            sample_src = sample_fn(
                model_src,
                (args_src.batch_size, 3, args_src.image_size, args_src.image_size),
                clip_denoised=args_src.clip_denoised,
                model_kwargs=model_kwargs,
                noise=perturbed,
                t_start=0,
                t_end=t_end
            )
            metric.update(img, sample_src)
    
    print(model_path)
    print(metric.print_metrics())

ILVR

In [None]:
from resizer import Resizer

# ILVR
down_N = 16
shape = (batch_size, 3, 256, 256)
shape_d = (batch_size, 3, int(256 / down_N), int(256 / down_N))
down = Resizer(shape, 1 / down_N).to(device)
up = Resizer(shape_d, down_N).to(device)
resizers = (down, up)

sample_fn = (
    diffusion.p_sample_loop if not args_src.use_ddim else diffusion.ddim_sample_loop
)

In [None]:
from resizer import Resizer

f= open('log_ILVR.txt', 'w')
models = ['../models/ffhq_p2.pt',
            '../logs/metface_scratch_noglass/model060000.pt',
            '../logs/metface_transfer_noglass/model060000.pt',
            '../logs/metface_distill_3/model030000.pt',
            '../logs/metface_distill_14/model010000.pt'
            ]

# ILVR
down_N = 16
shape = (batch_size, 3, 256, 256)
shape_d = (batch_size, 3, int(256 / down_N), int(256 / down_N))
down = Resizer(shape, 1 / down_N).to(device)
up = Resizer(shape_d, down_N).to(device)
resizers = (down, up)

sample_fn = (
    diffusion.p_sample_loop if not args_src.use_ddim else diffusion.ddim_sample_loop
)

for model_path in models:
    args_src = EasyDict({'attention_resolutions':'16', 
            'batch_size':batch_size, 
            'channel_mult':'', 
            'class_cond':False, 
            'clip_denoised':True, 
            'diffusion_steps':1000, 
            'dropout':0.0, 
            'image_size':256, 
            'learn_sigma':True, 
            'model_path':model_path, 
            'noise_schedule':'linear', 
            'num_channels':128, 
            'num_head_channels':64, 
            'num_heads':4, 
            'num_heads_upsample':-1, 
            'num_res_blocks':1, 
            'num_samples':10000, 
            'p2_gamma':0, 
            'p2_k':1, 
            'predict_xstart':False, 
            'resblock_updown':True, 
            'rescale_learned_sigmas':False,
            'rescale_timesteps':False, 
            'sample_dir':'samples/metface_transfer_step_40_40k',
            'timestep_respacing':'ddim40', 
            'use_checkpoint':False,
            'use_ddim':True, 
            'use_fp16':True, 
            'use_kl':False, 
            'use_new_attention_order':False, 
            'use_scale_shift_norm':True,
            'data_dir':''})
            
    model_src, diffusion = create_model_and_diffusion(
        **args_to_dict(args_src, model_and_diffusion_defaults().keys())
    )

    model_src.load_state_dict(
        dist_util.load_state_dict(args_src.model_path, map_location="cpu")
    )
    model_src.to(device)
    model_src.convert_to_fp16()
    model_src.eval()
    print('done')

    model_kwargs = {}
    sample_fn = (
        diffusion.p_sample_loop if not args_src.use_ddim else diffusion.ddim_sample_loop
    )

    n_trial = 3
    for n in range(n_trial):
        metric.reset()
        metric.scope = f'SDEdit {n}th trial'
        with th.no_grad():
            for idx, batch in enumerate(data):
                img, cond = batch
                img = img.to(device)
                noise = th.randn_like(img)
                t = th.tensor([t_end] * img.shape[0], device=device)
                perturbed = diffusion.q_sample(img, t, noise=noise)
                sample_src = sample_fn(
                model_src,
                (args_src.batch_size, 3, args_src.image_size, args_src.image_size),
                clip_denoised=args_src.clip_denoised,
                model_kwargs=model_kwargs,
                noise=noise,
                resizers = (down, up),
                range_t=20,
                ref_img=img,
                )
                metric.update(img, sample_src)
        
        print(model_path)
        print(model_path, file=f)
        print(metric.print_metrics())
        print(metric.print_metrics(), file=f)
f.close()

In [None]:
with th.no_grad():
    for idx, batch in enumerate(data):
        img, cond = batch
        img = img.to(device)
        noise = th.randn_like(img)
        # t = th.tensor([t_end] * img.shape[0], device=device)
        model_kwargs = {}
        sample_src = sample_fn(
            model_src,
            (args_src.batch_size, 3, args_src.image_size, args_src.image_size),
            clip_denoised=args_src.clip_denoised,
            model_kwargs=model_kwargs,
            noise=noise,
            resizers = (down, up),
            range_t=20,
            ref_img=img,
        )
        break

In [None]:
pil_img = ((img + 1) * 127.5).clamp(0, 255).to(th.uint8)
pil_sam = ((sample_src + 1) * 127.5).clamp(0, 255).to(th.uint8)
display(to_pil(pil_img[0]))
display(to_pil(pil_sam[0]))