In [None]:
from tqdm import tqdm
from PIL import Image

import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib
import importlib
import functools
import itertools
import random
import torch
import math
import os
import io

import sampling as sampling
import datasets
import models

from models.ema import ExponentialMovingAverage
from utils import restore_checkpoint
from losses import get_optimizer

from models import utils as mutils
from models import ncsnpp

from sde_lib import VESDE
from sampling import (ReverseDiffusionPredictor, 
                      LangevinCorrector, 
                      EulerMaruyamaPredictor,
                      AncestralSamplingPredictor,
                      NoneCorrector, 
                      NonePredictor,
                      AnnealedLangevinDynamics)

In [None]:
class SimpleCNN(nn.Module):
    
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        self.act = nn.ReLU()
        self.pool = nn.AvgPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 32, 3, 1, 1) # Size 16
        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1) # Size 8
        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1) # Size 4
        self.conv4 = nn.Conv2d(128, 256, 3, 1, 1) # Size 2
        self.fc = nn.Linear(1024, 1000)
    
    def forward(self, x):
        x = self.pool(self.act(self.conv1(x)))
        x = self.pool(self.act(self.conv2(x)))
        x = self.pool(self.act(self.conv3(x)))
        x = self.pool(self.act(self.conv4(x)))
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

class SimpleCNN256(nn.Module):
    
    def __init__(self):
        super(SimpleCNN256, self).__init__()
        
        self.act = nn.ReLU()
        self.pool2 = nn.AvgPool2d(2, 2)
        self.pool4 = nn.AvgPool2d(2, 4)
        self.conv1 = nn.Conv2d(3, 32, 3, 1, 1) # Size 64
        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1) # Size 16
        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1) # Size 4
        self.conv4 = nn.Conv2d(128, 256, 3, 1, 1) # Size 2
        self.fc = nn.Linear(1024, 1000)
    
    def forward(self, x):
        x = self.pool4(self.act(self.conv1(x)))
        x = self.pool4(self.act(self.conv2(x)))
        x = self.pool4(self.act(self.conv3(x)))
        x = self.pool2(self.act(self.conv4(x)))
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x

def get_sigma(y, sigma_max):
    sigma_min = torch.tensor([0.01]).to(y.device)
    sigma_max = torch.tensor([sigma_max]).to(y.device)
    ts = torch.linspace(1.0, 1e-3, 1000).to(y.device)
    ss = sigma_min * (sigma_max / sigma_min).to(y.device) ** ts
    return ss[y]

In [None]:
dataset = 'celebahq'
if dataset.lower() == 'cifar10':
    from configs.ve import cifar10_ncsnpp_continuous as configs
    ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
    config = configs.get_config() 
    sampling_eps = 1e-5
    twd_eps = 1e-02
    snr = 0.16
    batch_size = 200
    FID_N = 50000
elif dataset.lower() == 'celebahq':
    from configs.ve import celebahq_256_ncsnpp_continuous as configs
    ckpt_filename = "exp/ve/celebahq_256_ncsnpp_continuous/checkpoint_48.pth"
    config = configs.get_config() 
    sampling_eps = 1e-5
    twd_eps = 1e-3
    snr = 0.075
    batch_size = 20
    FID_N = 10000

config.training.batch_size = batch_size
config.eval.batch_size = batch_size

sigmas = mutils.get_sigmas(config)
scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)
score_model = mutils.create_model(config)

optimizer = get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(step=0, optimizer=optimizer, model=score_model, ema=ema)
state = restore_checkpoint(ckpt_filename, state, config.device)
ema.copy_to(score_model.parameters())

img_size = config.data.image_size
channels = config.data.num_channels
shape = (batch_size, channels, img_size, img_size)

if 'cifar10' in dataset.lower():
    net = SimpleCNN().cuda()
    ckpt = torch.load('./logs/CIFAR10/200.pt')
    net.load_state_dict(ckpt['net'])
elif 'celebahq' in dataset.lower():
    net = SimpleCNN256().cuda()
    ckpt = torch.load('./logs/CELEBAHQ/100.pt')
    net.load_state_dict(ckpt['net'])

In [None]:
get_ode_sampler = functools.partial(sampling.get_ode_sampler,
                                    shape=shape,
                                    inverse_scaler=inverse_scaler,
                                    denoise=False,
                                    rtol=1e-5,
                                    atol=1e-5,
                                    eps=sampling_eps,
                                    device=config.device)

get_pc_sampler = functools.partial(sampling.get_pc_sampler,
                                   shape=shape,
                                   predictor=ReverseDiffusionPredictor,
                                   # predictor=EulerMaruyamaPredictor,
                                   # predictor=AncestralSamplingPredictor,
                                   # corrector=LangevinCorrector,
                                   corrector=None,
                                   inverse_scaler=inverse_scaler,
                                   snr=snr,
                                   # n_steps=1,
                                   n_steps=0,
                                   probability_flow=False,
                                   continuous=config.training.continuous,
                                   eps=sampling_eps,
                                   device=config.device)

@torch.no_grad()
def karras_alg1(score_model, x_init, sigma_max, N):
    rho = 7
    sigma_min = torch.tensor([0.01]).float().to(x_init.device)
    sigma_max = sigma_max.to(x_init.device)
    ls = torch.linspace(0.0, 1.0, N-1).reshape(1,N-1).to(x_init.device)
    sigma_root = ls * (sigma_min.pow(1/rho) - sigma_max.pow(1/rho)).reshape(-1,1) + sigma_max.pow(1/rho).reshape(-1,1)
    sigma = sigma_root.pow(rho)
    sigma = torch.cat([sigma,torch.zeros(size=[x_init.shape[0],1]).to(x_init.device)], dim=1)
    
    score_model.eval()
    x = x_init.detach().clone()
    for i in range(N - 1):
        d1 = -(sigma[:,i]).reshape(-1,1,1,1) * score_model(x, sigma[:,i])
        x_hat = x + (sigma[:,i+1] - sigma[:,i]).reshape(-1,1,1,1) * d1
        if i < N-2:
            d2 = -(sigma[:,i+1]).reshape(-1,1,1,1) * score_model(x_hat, sigma[:,i+1])
            x.data = x + 0.5 * (sigma[:,i+1] - sigma[:,i]).reshape(-1,1,1,1) * (d1 + d2)
        else:
            x.data = x_hat
    return x.detach().clone(), 2*(N-1)-1

@torch.no_grad()
def karras_alg2(score_model, x_init, sigma_max, N, noise=0.007):
    rho = 7
    sigma_min = torch.tensor([0.01]).float().to(x_init.device)
    sigma_max = sigma_max.to(x_init.device)
    ls = torch.linspace(0.0, 1.0, N-1).reshape(1,N-1).to(x_init.device)
    sigma_root = ls * (sigma_min.pow(1/rho) - sigma_max.pow(1/rho)).reshape(-1,1) + sigma_max.pow(1/rho).reshape(-1,1)
    sigma = sigma_root.pow(rho)
    sigma = torch.cat([sigma,torch.zeros(size=[x_init.shape[0],1]).to(x_init.device)], dim=1)
    
    S_churn = 80
    S_tmin = 0.05
    S_tmax = 1.0
    S_noise = 1 + noise
    
    score_model.eval()
    x = x_init.detach().clone()
    for i in range(N-1):
        eps = S_noise * torch.randn_like(x).to(x_init.device)
        cond = (sigma[:,i] >= S_tmin) * (sigma[:,i] <= S_tmax)
        gamma = np.minimum(S_churn/N, np.sqrt(2)-1) * cond
        sigma_hat = (1 + gamma) * sigma[:,i]
        x_hat = x + (sigma_hat**2 - sigma[:,i]**2).sqrt().reshape(-1,1,1,1) * eps
        d1 = - sigma_hat.reshape(-1,1,1,1) * score_model(x_hat, sigma_hat)
        x_next = x_hat + (sigma[:,i+1] - sigma_hat).reshape(-1,1,1,1) * d1
        if i < N-2:
            d2 = - sigma[:,i+1].reshape(-1,1,1,1) * score_model(x_next, sigma[:,i+1])
            x.data = x_hat + 0.5 * (sigma[:,i+1] - sigma_hat).reshape(-1,1,1,1) * (d1 + d2)
        else:
            x.data = x_next
    return x.detach().clone(), 2*(N-1)-1

### DLG

In [None]:
for N, eta, skp in [(20, 0.8, 10)]:

    save_dir = './FID/{}/kar2_{}_dmcmc_{:.1e}_{}'.format(dataset.upper(), N, eta, skp)
    print('save dir : {}'.format(save_dir))
    contin = False
    
    score_model.eval()
    resume = contin and os.path.isdir(save_dir)
    if resume:
        S = len(os.listdir(save_dir)) - 1
        n_samples = math.ceil((FID_N - S) / batch_size)
        ckpt = torch.load(save_dir + '/ckpt.pt')
        x = ckpt['x'].cuda()
        s = ckpt['s'].cuda()
    else:
        if os.path.isdir(save_dir):
            for f in os.listdir(save_dir):
                os.remove(os.path.join(save_dir,f))
        else:
            os.makedirs(save_dir)
        
        S = 0
        n_samples = math.ceil(FID_N/batch_size)
        # Generating MCMC chain initialization points
        with torch.no_grad():
            x = config.model.sigma_max * torch.randn(size=shape).cuda()
            s = config.model.sigma_max * torch.ones(batch_size).cuda()

            if 'cifar10' in dataset:
                sampling_fn = functools.partial(karras_alg2, sigma_max=s.cpu(), N=20)
                x, _ = sampling_fn(score_model, x)
                x = x + 0.5 * torch.randn_like(x).cuda()
                s = get_sigma(net(x).argmax(dim=1), config.model.sigma_max)
                
                for _ in range(10):
                    x.data = x + 0.5 * 0.5 * score_model(x, s) + np.sqrt(0.5) * torch.randn_like(x)
                    s.data = get_sigma(net(x).argmax(dim=1), config.model.sigma_max)

                for _ in range(10):
                    x.data = x + 0.5 * eta * score_model(x, s) + np.sqrt(eta) * torch.randn_like(x)
                    s.data = get_sigma(net(x).argmax(dim=1), config.model.sigma_max)
            else:
                sampling_fn = functools.partial(karras_alg2, sigma_max=s.cpu(), N=20)
                x, _ = sampling_fn(score_model, x)
                x = x + 0.5 * torch.randn_like(x).cuda()
                s = get_sigma(net(x).argmax(dim=1), config.model.sigma_max)

                for _ in range(50):
                    x.data = x + 0.5 * 1.6 * score_model(x, s) + np.sqrt(1.6) * torch.randn_like(x)
                    s.data = get_sigma(net(x).argmax(dim=1), config.model.sigma_max)

                for _ in range(20):
                    x.data = x + 0.5 * eta * score_model(x, s) + np.sqrt(eta) * torch.randn_like(x)
                    s.data = get_sigma(net(x).argmax(dim=1), config.model.sigma_max)

    # Sampling
    with torch.no_grad():
        for i in tqdm(range(int(n_samples))):
            # Langevin Gibbs using n_skip NFE
            x_min = x.detach().clone()
            s_min = s.detach().clone()
            for _ in range(skp):
                x.data = x + 0.5 * eta * score_model(x, s) + np.sqrt(eta) * torch.randn_like(x) # x ~ p(x|s)
                s.data = get_sigma(net(x).argmax(dim=1), config.model.sigma_max) # s ~ p(s|x)
                
                # Choose MCMC samples of smallest noise levels
                cond = (s < s_min)
                x_min.data = cond.reshape(-1,1,1,1) * x + (~cond).reshape(-1,1,1,1) * x_min
                s_min.data = cond * s + (~cond) * s_min
            
            # Define reverse-SDE/ODE integrator
            sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=s_min.cpu(), N=N) # Define SDE
            # sampling_fn = get_ode_sampler(sde) # ODE integrator
            # sampling_fn = get_pc_sampler(sde) # PC integrator
            # sampling_fn = functools.partial(karras_alg1, sigma_max=s_min.cpu(), N=N) # KAR1 integrator
            sampling_fn = functools.partial(karras_alg2, sigma_max=s_min.cpu(), N=N) # KAR2 integrator
            
            # Run denoising using n_den NFE
            x_den, n = sampling_fn(score_model, x_min)
            
            # Tweedie's denoising formula
            x_den = x_den + twd_eps**2 * score_model(x_den, twd_eps * torch.ones_like(s).cuda())
            
            ckpt = {
                'x' : x.detach().clone(),
                's' : s.detach().clone(),
            }
            torch.save(ckpt, save_dir + '/ckpt.pt')

            for j in range(batch_size):
                if int(batch_size*i+j+S) >= FID_N:
                    break
                img = torch.clamp(x_den[j], 0.0, 1.0)
                img = Image.fromarray((255 * img.permute(1,2,0).detach().cpu().numpy()).astype(np.uint8))
                img.save(save_dir + '/{}.jpg'.format(int(batch_size*i+j+S)))

### Vanilla

In [None]:
for N in [201]:

    save_dir = './FID/{}/kar2_{}_vanilla'.format(dataset.upper(), N)
    print('save dir : {}'.format(save_dir))
    n_samples = math.ceil(FID_N/batch_size)
    contin = False

    if os.path.isdir(save_dir):
        if contin:
            S = len(os.listdir(save_dir))
        else:
            for f in os.listdir(save_dir):
                os.remove(os.path.join(save_dir,f))
            S = 0
    else:
        os.makedirs(save_dir)
        S = 0

    with torch.no_grad():
        for i in tqdm(range(n_samples)):
            x = (config.model.sigma_max * torch.randn(size=shape)).cuda()
            s = (config.model.sigma_max * torch.ones(batch_size)).cuda()
            # sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=s.cpu(), N=N)
            # sampling_fn = get_ode_sampler(sde)
            # sampling_fn = get_pc_sampler(sde)
            # sampling_fn = functools.partial(karras_alg1, sigma_max=s.cpu(), N=N)
            sampling_fn = functools.partial(karras_alg2, sigma_max=s.cpu(), N=N)
            x_den, n = sampling_fn(score_model, x)
            x_den = x_den + twd_eps**2 * score_model(x_den, twd_eps * torch.ones_like(s).cuda())

            for j in range(batch_size):
                if int(batch_size*i+j+S) >= FID_N:
                    break
                img = torch.clamp(x_den[j], 0.0, 1.0)
                img = Image.fromarray((255 * img.permute(1,2,0).detach().cpu().numpy()).astype(np.uint8))
                img.save(save_dir + '/{}.jpg'.format(int(batch_size*i+j+S)))