In [1]:
import os
os.chdir('/raid/jimyeong/nohlab_diffusion')

In [2]:
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import torch
from torch.utils.data import TensorDataset, DataLoader

%matplotlib inline
sns.set_style('darkgrid')

In [3]:
from modeling import get_model
from diffusion.sampler import get_sampler
from diffusion.scheduler import get_scheduler
from utils.fid import load_hidden_parameters, calc_hidden_parameters, calculate_frechet_distance

In [4]:
device = torch.device('cuda:2')
device

device(type='cuda', index=2)

In [5]:
ref_dataset = 'cifar10-test'
ref_mu, ref_sigma = load_hidden_parameters(ref_dataset)
image_shape = (3, 32, 32)
num_classes = 10
batch_size = 256        # doubles when using guidance scale not in [0, 1]
inception_batch_size = 512
num_examples = 20000
seed = 42

/home/jimyeong/datasets/fid_reference_files/cifar10-test-inception.npz


In [6]:
ckpt_dir = 'outputs/rf_cifar10_base'

with open(os.path.join(ckpt_dir, 'train_args.json'), 'r') as f:
    train_args = json.load(f)
train_args

{'output_dir': 'outputs/rf_cifar10_base',
 'wandb_run_name': 'rf_cifar10_base',
 'wandb_run_id': '50c7vwdy',
 'max_steps': 200000,
 'logging_steps': 50,
 'eval_steps': 1000,
 'save_steps': 10000,
 'eval_n_examples': 40,
 'guidance_scale': 1.0,
 'batch_size': 128,
 'lr': 0.0002,
 'optimizer': 'adamw',
 'adam_betas': [0.9, 0.99],
 'clip_grad_norm': 1.0,
 'use_ema': True,
 'ema_inv_gamma': 1.0,
 'ema_power': 0.75,
 'dataset': 'cifar10',
 'dataset_dir': '~/datasets',
 'augmentations': [],
 'dataloader_num_workers': 2,
 'dataloader_drop_last': True,
 'dataloader_pin_memory': True,
 'device': 'cuda:6',
 'seed': 42,
 'p_uncond': 0.2,
 'model_type': 'unet',
 'model_cfg': {'in_channels': 3,
  'out_channels': 3,
  'down_block_types': ['UNetDownBlock', 'UNetDownBlock', 'UNetDownBlock'],
  'up_block_types': ['UNetUpBlock', 'UNetUpBlock', 'UNetUpBlock'],
  'block_out_channels': [128, 256, 512],
  'n_blocks_per_layer': 2,
  'mid_attention': True,
  'n_dim_attention_head': 8,
  'norm_groups': 32,
  '

In [7]:
scheduler = get_scheduler(train_args['scheduler_type'], **train_args['scheduler_cfg'])
def load_model(ckpt_name, ema=True):
    ckpt_file = 'ema_model.pt' if ema else 'model.pt'
    ckpt_path = os.path.join(train_args['output_dir'], 'ckpts', ckpt_name, ckpt_file)

    model = get_model(train_args['model_type'], **train_args['model_cfg'])
    print(model.load_state_dict(torch.load(ckpt_path)))
    return model
def load_sampler(**kwargs):
    sampler_cfg = train_args['sampler_cfg'].copy()
    sampler_cfg.update(kwargs)
    sampler = get_sampler(train_args['sampler_type'], **sampler_cfg)
    return sampler

In [8]:
def generate_noise(num_examples, image_shape, num_classes, seed):
    gen = torch.Generator().manual_seed(seed)
    noise = torch.randn((num_examples, *image_shape), generator=gen)
    labels = torch.randint(0, num_classes, (num_examples,), generator=gen)
    print(f'noise: {noise.shape}, labels: {labels.shape}')
    return noise, labels

noise_dataset = TensorDataset(*generate_noise(num_examples, image_shape, num_classes, seed))
noise_dataset

noise: torch.Size([20000, 3, 32, 32]), labels: torch.Size([20000])


<torch.utils.data.dataset.TensorDataset at 0x7f4683d9ffb0>

In [9]:
def get_eps_pred_func(model, cls=None, guidance_scale=1.0):
    def pred_fn(z, t):
        if guidance_scale == 1.0 and cls is not None:
            eps_pred = model(z, t, cls=cls)
        elif guidance_scale != 0 and cls is not None:
            z = torch.cat([z, z], dim=0)
            t = torch.cat([t, t], dim=0)
            cls_cond = torch.cat([cls, cls], dim=0)
            uncond_mask = torch.cat([torch.ones_like(cls), torch.zeros_like(cls)], dim=0)

            eps_pred = model(z, t, cls=cls_cond, uncond_mask=uncond_mask)
            eps_uncond, eps_cond = torch.chunk(eps_pred, 2, dim=0)
            eps_pred = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
        else:
            eps_pred = model(z, t)
        return eps_pred
    return pred_fn

In [10]:
@torch.no_grad()
def generate_examples(
        model_ckpt_name,
        model_ema=True,
        n_sampling_steps=None,
        guidance_scale=1.0,
):
    model = load_model(model_ckpt_name, ema=model_ema)
    model.to(device)
    model.eval()
    
    sampler_cfg = {'pbar': True, 'pbar_kwargs': {'position': 1, 'leave': False, 'desc': 'sampling'}}
    if n_sampling_steps is not None:
        sampler_cfg['n_steps'] = n_sampling_steps
    sampler = load_sampler(**sampler_cfg)

    dataloader = DataLoader(noise_dataset, num_workers=1, batch_size=batch_size, shuffle=False, drop_last=False)
    generated = []
    for z, cls in tqdm(dataloader, leave=False, desc='generating'):
        z = z.to(device)
        cls = cls.to(device)

        pred_fn = get_eps_pred_func(model, cls=cls, guidance_scale=guidance_scale)
        gen = sampler.sample(z, scheduler, pred_fn).cpu()
        generated.append(gen)
    generated = torch.cat(generated, dim=0)
    return generated

In [11]:
def calc_fid_model(
        model_ckpt_name,
        model_ema=True,
        n_sampling_steps=None,
        guidance_scale=1.0,
): 
    gen = generate_examples(
        model_ckpt_name=model_ckpt_name,
        model_ema=model_ema,
        n_sampling_steps=n_sampling_steps,
        guidance_scale=guidance_scale
    )
    # print('gen shape:', gen.shape)
    mu, sigma = calc_hidden_parameters(
        gen, batch_size=inception_batch_size, device=device,
        pbar=True, pbar_kwargs={'leave': False, 'desc': 'inception'},
    )
    # print('mu:', mu[:5])
    # print('sigma:', sigma[:5, :5])
    fid = calculate_frechet_distance(mu, sigma, ref_mu, ref_sigma)
    # print(f'FID: {fid:.6f}')
    return fid

In [12]:
%%time
fid_ = calc_fid_model(
    model_ckpt_name='ckpt-200000',
    model_ema=True,
    n_sampling_steps=50,
    guidance_scale=1.0,
)
print(f'FID: {fid_:.6f}')

<All keys matched successfully>


generating:   0%|          | 0/79 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

sampling:   0%|          | 0/50 [00:00<?, ?it/s]

inception:   0%|          | 0/40 [00:00<?, ?it/s]

FID: 10.632688
CPU times: user 10min 17s, sys: 3.84 s, total: 10min 21s
Wall time: 9min 42s


In [None]:
# (name='ckpt-200000', ema=True, n_steps=10, guidance_scale=1.0) - FID: 19.797695
# (name='ckpt-200000', ema=True, n_steps=20, guidance_scale=1.0) - FID: 14.584010
# (name='ckpt-200000', ema=True, n_steps=30, guidance_scale=1.0) - FID: 13.101262
# (name='ckpt-200000', ema=True, n_steps=40, guidance_scale=1.0) - FID: 12.385719
# (name='ckpt-200000', ema=True, n_steps=50, guidance_scale=1.0) - FID: 11.989206
# (name='ckpt-200000', ema=True, n_steps=100, guidance_scale=1.0) - FID: 11.337518

# (name='ckpt-200000', ema=True, n_steps=10, guidance_scale=2.0) - FID: 18.849593
# (name='ckpt-200000', ema=True, n_steps=20, guidance_scale=2.0) - FID: 13.987997

# (name='ckpt-200000', ema=False, n_steps=20, guidance_scale=1.0) - FID: 23.878510
# (name='ckpt-200000', ema=False, n_steps=50, guidance_scale=1.0) - FID: 21.450654

# sample 5000 (name='ckpt-200000', ema=True, n_steps=50, guidance_scale=1.0) - FID: 14.466134
# sample 50000 (name='ckpt-200000', ema=True, n_steps=50, guidance_scale=1.0) - FID: 10.632688