In [None]:
# This code is based on https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-/tree/57e53510e9c5932add4c09b23fb3296a207fa104, 
# which is released under the MIT licesne.

# This code is also based on "DL-Generative-Model-Assignment" Templete.

# Paper reference: 
# Classifier-free diffusion guidance: https://arxiv.org/pdf/2207.12598.pdf
# DDPM: https://arxiv.org/pdf/2006.11239.pdf

# Library used: pytorch, matplotlib, tqdm(for process bar), typing, PIL

import os

import numpy as np

import torch
import torch.nn as nn
from torch.nn import init
from tqdm import tqdm
from torch import optim
import torch.nn.functional as F

from torchvision.datasets import CIFAR10
from torchvision.datasets import STL10
from torch.optim.lr_scheduler import _LRScheduler

import torchvision
from torchvision import transforms

from torchvision.utils import save_image

from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from typing import Dict
cvhm
import math
torch.cuda.empty_cache()

In [None]:
!pip install matplotlib

In [None]:
!pip install tqdm

In [2]:
!pip install typing

Collecting typing
  Downloading typing-3.7.4.3.tar.gz (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: typing
  Building wheel for typing (setup.py) ... [?25ldone
[?25h  Created wheel for typing: filename=typing-3.7.4.3-py3-none-any.whl size=26306 sha256=da038ae193dbc8db3e11ee7713b909e0a11d8e3e895cdc003660d3bc64b9ed68
  Stored in directory: /root/.cache/pip/wheels/e5/23/95/593222493ec6253200e94e4a5ee4361d12112000816d840434
Successfully built typing
Installing collected packages: typing
Successfully installed typing-3.7.4.3
[0m

In [None]:
# Pre-setting, ensure put correct device(cuba), batch_size(64), img_size(base on dataset)
# If you are using saving/loading model or saving images function or running any evaluation-related function
# Make sure you have folder with same name, and you have model inside. The way to get model is from programming running itself with save-function.

modelConfig = {
    "epoch": 300,                                     #  epoch
    "batch_size": 64,                                 #  batch_size(normally is 64)
    "T": 500,                                         #  Time step
    "channel": 128,                                   #  number of channel
    "channel_mult": [1, 2, 2, 2],                     #  channel_mult
    "num_res_blocks": 2,                              #  num of res block
    "dropout": 0.15,                                  #  drop out rate
    "lr": 1e-4,                                       #  learing rate
    "multiplier": 2.5,                                #  multiplier
    "beta_1": 1e-4,                                   #  beta_1
    "beta_T": 0.028,                                  #  beta_T
    "img_size": 32,                                   #  img_size(base on the dataset), Cifar -> 32, STL-10 -> 48/64, the code has out of memory issue when on 96
    "grad_clip": 1.,                                  #  grad_clip
    "device": "cuda",                                 #  device(cuda)
    "w": 1.8,                                         #  w
    "save_dir": "./CheckpointsCondition_cifar/",      #  path save model               
    "test_load_weight": "ckpt_502_.pt",               #  for evaluation use, name of the model ckpt_epoch_.pt
    "sampled_dir": "./SampledImgs/",                  #  sample img path
    "sampledImgName": None,                           #  sample img name
    "nrow": 8                                         #  nrow
}

In [None]:
def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    device = t.device
    out = torch.gather(v, index=t, dim=0).float().to(device)
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

In [None]:
# Add noise to the input image, 
# noise base on predefined sequence of noise levels determined by the betas. 

class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
        
    # For interpolation purpose 
    def GetAlphaOne(self, x_0):
            
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t =   extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + \
                extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise
        return x_t

    def forward(self, x_0, labels):
        """
        Algorithm 1.
        """
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t =   extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + \
                extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise
        loss = F.mse_loss(self.model(x_t, t, labels), noise, reduction='none')
        return loss

In [None]:
class GaussianDiffusionSampler(nn.Module):
    def __init__(self, model, beta_1, beta_T, T, w = 0.):
        super().__init__()

        self.model = model
        self.T = T
        ### In the classifier free guidence paper, w is the key to control the gudience.
        ### w = 0 and with label = 0 means no guidence.
        ### w > 0 and label > 0 means guidence. Guidence would be stronger if w is bigger.
        self.w = w

        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
        self.register_buffer('coeff1', torch.sqrt(1. / alphas))
        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return extract(self.coeff1, t, x_t.shape) * x_t - extract(self.coeff2, t, x_t.shape) * eps

    def p_mean_variance(self, x_t, t, labels):
        # below: only log_variance is used in the KL computations
        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
        var = extract(var, t, x_t.shape)
        eps = self.model(x_t, t, labels)
        nonEps = self.model(x_t, t, torch.zeros_like(labels).to(labels.device))
        eps = (1. + self.w) * eps - self.w * nonEps
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
        return xt_prev_mean, var

    def forward(self, x_T, labels):
        """
        Algorithm 2.
        """
        x_t = x_T
        for time_step in reversed(range(self.T)):
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
            mean, var= self.p_mean_variance(x_t=x_t, t=t, labels=labels)
            if time_step > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise
            assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        x_0 = x_t
        return torch.clip(x_0, -1, 1) 

In [None]:
# U-Net

def drop_connect(x, drop_ratio):
    keep_ratio = 1.0 - drop_ratio
    mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
    mask.bernoulli_(p=keep_ratio)
    x.div_(keep_ratio)
    x.mul_(mask)
    return x

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb, freeze=False),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, t):
        emb = self.timembedding(t)
        return emb

class ConditionalEmbedding(nn.Module):
    def __init__(self, num_labels, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        self.condEmbedding = nn.Sequential(
            nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model, padding_idx=0),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, t):
        emb = self.condEmbedding(t)
        return emb

class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.c1 = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.c2 = nn.Conv2d(in_ch, in_ch, 5, stride=2, padding=2)

    def forward(self, x, temb, cemb):
        x = self.c1(x) + self.c2(x)
        return x

class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.c = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.t = nn.ConvTranspose2d(in_ch, in_ch, 5, 2, 2, 1)

    def forward(self, x, temb, cemb):
        _, _, H, W = x.shape
        x = self.t(x)
        x = self.c(x)
        return x

class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=True):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.cond_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()


    def forward(self, x, temb, labels):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h += self.cond_proj(labels)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UNet(nn.Module):
    def __init__(self, T, num_labels, ch, ch_mult, num_res_blocks, dropout):
        super().__init__()
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)
        self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=False))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
 

    def forward(self, x, t, labels):
        # Timestep embedding
        temb = self.time_embedding(t)
        cemb = self.cond_embedding(labels)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb, cemb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb, cemb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb, cemb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

In [None]:
# A learning rate scheduler that gradually increases the learning rate during the initial training phase before reverting to a base scheduler.
# Help to allow model converge earlier

class GradualWarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None):
        self.multiplier = multiplier
        self.total_epoch = warm_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        self.last_epoch = None
        self.base_lrs = None
        super().__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

    def step(self, epoch=None, metrics=None):
        if self.finished and self.after_scheduler:
            if epoch is None:
                self.after_scheduler.step(None)
            else:
                self.after_scheduler.step(epoch - self.total_epoch)
        else:
            return super(GradualWarmupScheduler, self).step(epoch)

In [None]:
# Evaluation the model
# If you wish to use this, require model trained by this program -> see Train save module function
# If you wish not us this but still wish to see full pictures -> Uncomment specific fucntion in training process
def eval(modelConfig: Dict):
    device = torch.device(modelConfig["device"])
    # load model and evaluate
    with torch.no_grad():
        
        # make labels for CFG
        step = 8
        labelList = []
        k = 0
        for i in range(1, modelConfig["batch_size"] + 1):
            labelList.append(torch.ones(size=[1]).long() * k)
            if i % step == 0:
                if k < 9:
                    k += 1
                    
        labels = torch.cat(labelList, dim=0).long().to(device) + 1
        
        model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
        
        # Model path
        ckpt = torch.load(os.path.join(
            modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
        
        model.load_state_dict(ckpt)
        
        print("model load weight done.")
        model.eval()
        
        sampler = GaussianDiffusionSampler(
            model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)
        
        # Sampled from standard normal distribution
        noisyImage = torch.randn(
            size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
        
        sampledImgs = sampler(noisyImage, labels)
        sampledImgs = sampledImgs * 0.5 + 0.5  # [0 ~ 1]
        
        # Uncomment if you wish to save the image to folder, change path in the dicturary.
        # save_image(sampledImgs, os.path.join(
        #     modelConfig["sampled_dir"],  modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])
        
        # show first 8 images, comment it if you don't wish to see.
        plt.rcParams['figure.dpi'] = 100
        plt.grid(False)
        plt.imshow(torchvision.utils.make_grid(sampledImgs[:8]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
        plt.show()
        plt.pause(0.0001)
        
        fig = plt.figure(figsize=(10, 10))
        plt.rcParams['figure.dpi'] = 100
        plt.grid(False)
        plt.imshow(torchvision.utils.make_grid(sampledImgs).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
        plt.show()
        # Uncomment if you wish to save the image to folder, remember to create folder name called "submission".
        # fig.savefig(os.path.join("Submission", modelConfig["sampledImgName"]), dpi=fig.dpi)

In [None]:
# Training process

def train(modelConfig: Dict):
    device = torch.device(modelConfig["device"])
    # dataset
    
    # Uncomment the dataset if you wish, change the img_size in dicturary.
    # dataset = STL10(
    #     root='./dataset', split='train+unlabeled', download=True,
    #     transform=transforms.Compose([
    #         transforms.Resize((modelConfig["img_size"], modelConfig["img_size"])), -> img_size should be 64 or 48
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    #     ]))
    
    dataset = CIFAR10(
        root='./dataset', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]))
    
    dataloader = DataLoader(
        dataset, batch_size=modelConfig["batch_size"], shuffle=False, num_workers=8, drop_last=True, pin_memory=True)

    # Continue last training if you wish, remember change "save_dir" to correct one.
    # state_dict = torch.load(os.path.join(modelConfig["save_dir"], modelConfig["test_load_weight"]))
    
    # model setup
    net_model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    
    # Uncomment if you need to resume the training for now.
    # net_model.load_state_dict(state_dict)
    
    optimizer = torch.optim.AdamW(
        net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
    
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
    
    warmUpScheduler = GradualWarmupScheduler(optimizer=optimizer, multiplier=modelConfig["multiplier"],
                                             warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)
    
    trainer = GaussianDiffusionTrainer(
        net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)

    # start training
    for e in range(modelConfig["epoch"]):
        # add process bar
        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for images, labels in tqdmDataLoader:
                # train
                b = images.shape[0]
                optimizer.zero_grad()
                x_0 = images.to(device)
                labels = labels.to(device) + 1
                if np.random.rand() < 0.1:
                    labels = torch.zeros_like(labels).to(device)
                loss = trainer(x_0, labels).sum() / b ** 2.
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), modelConfig["grad_clip"])
                optimizer.step()
                # process bar setting
                tqdmDataLoader.set_postfix(ordered_dict={
                    "epoch": e,
                    "loss: ": loss.item(),
                    "img shape: ": x_0.shape,
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                })
                
        # learning rate warm up
        warmUpScheduler.step()
        
        # Uncomment to save model if you wish, remember to change "save_dir" to the path
        # torch.save(net_model.state_dict(), os.path.join(
        # modelConfig["save_dir"], 'ckpt_' + str(e+) + "_.pt"))
        
        # model name:
        # modelConfig["test_load_weight"] = 'ckpt_' + str(e) + "_.pt"
        # Update the sampleImgName
        # modelConfig["sampledImgName"] = "SampledGuidenceImgs_" + str(e) + ".png"
        
        # Evaluation:
        with torch.no_grad():
            step = 8
            labelList = []
            k = 0
            for i in range(1, modelConfig["batch_size"] + 1):
                labelList.append(torch.ones(size=[1]).long() * k)
                if i % step == 0:
                    if k < 9:
                        k += 1
            labels = torch.cat(labelList, dim=0).long().to(device) + 1
            
            sampler = GaussianDiffusionSampler(
                net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)

            noisyImage = torch.randn(
                size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)

            sampledImgs = sampler(noisyImage, labels)
            sampledImgs = sampledImgs * 0.5 + 0.5  # [0 ~ 1]

            # Uncomment if you wish to save the image to folder, change path in the dicturary.
            # save_image(sampledImgs, os.path.join(
            #     modelConfig["sampled_dir"],  modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])

            plt.rcParams['figure.dpi'] = 100
            plt.grid(False)
            plt.imshow(torchvision.utils.make_grid(sampledImgs[:8]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
            plt.show()
            plt.pause(0.0001)
            
            # Uncomment if you wish to see full picture 8x8
            plt.rcParams['figure.dpi'] = 100
            plt.grid(False)
            plt.imshow(torchvision.utils.make_grid(sampledImgs).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
            plt.show()
            plt.pause(0.0001)
            
            # Uncommment if you wish to see the interpolation result
            z = trainer.GetAlphaOne(sampledImgs)

            col_size = int(np.sqrt(modelConfig["batch_size"]))

            z0 = z[0:col_size].repeat(col_size,1,1,1) # z for top row
            z1 = z[modelConfig["batch_size"]-col_size:].repeat(col_size,1,1,1) # z for bottom row

            t = torch.linspace(0,1,col_size).unsqueeze(1).repeat(1,col_size).view(modelConfig["batch_size"],1,1,1).to(device)

            lerp_z = (1-t)*z0 + t*z1 # linearly interpolate between two points in the latent space
            lerp_g = sampler(lerp_z,labels) # sample the model at the resulting interpolated latents

            lerp_g = lerp_g * 0.5 + 0.5

            plt.rcParams['figure.dpi'] = 175
            plt.grid(False)
            plt.imshow(torchvision.utils.make_grid(lerp_g).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
            plt.show()
            

In [None]:
train(modelConfig)

In [None]:
# Interpolation process
# If you wish to run this function you need to have path for model and model as well.
# If you don't want to run this function but still want to see the interplation result, please uncomment Train function last part.

def eval_forInterpolation(modelConfig: Dict):
    device = torch.device(modelConfig["device"])

    with torch.no_grad():
        # labels for CFG
        step = 8
        labelList = []
        k = 0
        for i in range(1, modelConfig["batch_size"] + 1):
            labelList.append(torch.ones(size=[1]).long() * k)
            if i % step == 0:
                if k < 9:
                    k += 1

        labels = torch.cat(labelList, dim=0).long().to(device) + 1

        # Load model
        model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)

        ckpt = torch.load(os.path.join(
            modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)

        model.load_state_dict(ckpt)

        model.eval()
        
        # Load sampler and trainer
        sampler = GaussianDiffusionSampler(
            model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)

        trainer = GaussianDiffusionTrainer(
            model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)

        # Sample 8 pairs of random noise vectors
        noise_image = torch.randn(size=[modelConfig["batch_size"], 3, 32, 32], device=device)

        # Get sample image
        sample_image = sampler(noise_image,labels)

        z = trainer.GetAlphaOne(sample_image)

        # Follow Template
        col_size = int(np.sqrt(modelConfig["batch_size"]))

        z0 = z[0:col_size].repeat(col_size,1,1,1) # z for top row
        z1 = z[modelConfig["batch_size"]-col_size:].repeat(col_size,1,1,1) # z for bottom row

        t = torch.linspace(0,1,col_size).unsqueeze(1).repeat(1,col_size).view(modelConfig["batch_size"],1,1,1).to(device)

        lerp_z = (1-t)*z0 + t*z1 # linearly interpolate between two points in the latent space
        lerp_g = sampler(lerp_z,labels) # sample the model at the resulting interpolated latents

        lerp_g = lerp_g * 0.5 + 0.5

        # plot the result
        fig = plt.figure(figsize=(10, 10))
        plt.rcParams['figure.dpi'] = 175
        plt.grid(False)
        plt.imshow(torchvision.utils.make_grid(lerp_g).cpu().data.permute(0,2,1).contiguous().permute(2,1,0), cmap=plt.cm.binary)
        plt.show()
        
        # uncomment it if you wish to save
        # fig.savefig(os.path.join("Submission", "CifarInter.png"), dpi=fig.dpi)