In [1]:
# Importing the libraries

import torch
import matplotlib.pyplot as plt
import matplotlib as mpl
import logging
from torch import tensor
import numpy as np
import torchvision 
import random,math
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from operator import attrgetter
from functools import partial
from torch.utils.data import DataLoader,default_collate,Dataset
from copy import copy
from torch.optim.lr_scheduler import ExponentialLR
from collections.abc import Mapping
from diffusers import UNet2DModel
from tqdm import tqdm
from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, EulerAncestralDiscreteScheduler
from scipy import linalg
from torch.optim.lr_scheduler import ExponentialLR
from pytorch_fid.inception import InceptionV3

In [2]:
# Get the device
def_device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Function to send data to device
def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)

def to_cpu(x):
    if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()}
    if isinstance(x, list): return [to_cpu(o) for o in x]
    if isinstance(x, tuple): return tuple(to_cpu(list(x)))
    res = x.detach().cpu()
    return res.float() if res.dtype==torch.float16 else res

In [3]:
# Load the datasets
# Remember to pad the images with 2 pixels on each side i.e. to make the image size 32x32

batch_size = 64

transforms = T.Compose([T.ToTensor(),T.Pad(2)])

train_ds = torchvision.datasets.FashionMNIST(root = './data/train',train = True,download = True,transform = transforms)
valid_ds = torchvision.datasets.FashionMNIST(root = './data/valid',train = False,download = True,transform = transforms)

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=4)

In [4]:
# Define the model

class DDPM_model(nn.Module):
    def __init__(self, model, beta_min = 0.0001, beta_max = 0.02, n_steps = 1000, cosine_schedule = False):
        super().__init__()
        self.model = model
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.n_steps = n_steps
        if cosine_schedule:
            def abar(t, T): return (t/T*math.pi/2).cos()**2
            timesteps = torch.linspace(0, self.n_steps -1 , self.n_steps)
            self.alphabar = abar(timesteps, self.n_steps)
            self.alpha = self.alphabar/abar(timesteps-1, self.n_steps)
            self.beta = 1 - self.alpha
            self.sigma = self.beta.sqrt()
        else:
            self.beta = torch.linspace(beta_min, beta_max, self.n_steps)
            self.alpha = 1. - self.beta
            self.alphabar = self.alpha.cumprod(dim=0)
            self.sigma = self.beta.sqrt()

    def add_noise(self, x_0):
        device = x_0.device
        n = len(x_0)
        timesteps = torch.randint(0, self.n_steps, (n,), device=device)
        alphabar_t = self.alphabar.to(device)[timesteps].reshape(-1, 1, 1, 1)
        noise = torch.randn_like(x_0, device=device)
        x_t = x_0 * alphabar_t.sqrt() + (1. - alphabar_t).sqrt()* noise
        return (x_t, timesteps.to(device)), noise

    def forward(self, x):
        return self.model.forward(*x).sample

In [5]:
# Define the model
# Take in a scheduler and sample from the model using the scheduler

def sample(sched,model,sz = (16,3,32,32)):
    preds = []
    device = next(model.parameters()).device
    x_t = torch.randn(sz).to(device)
    for t in tqdm(sched.timesteps,total=len(sched.timesteps)):
        with torch.no_grad(): noise = model((x_t, t))
        x_t = sched.step(noise, t, x_t).prev_sample
        preds.append(x_t.float().cpu())
    return preds

In [6]:
# Functions to calculate the inception score

def calc_stats(feats):
    feats = feats.squeeze()
    return feats.mean(0),feats.T.cov()

def calc_fid(m1,c1,m2,c2):
    csr = tensor(linalg.sqrtm(c1@c2, 256).real)
    return (((m1-m2)**2).sum() + c1.trace() + c2.trace() - 2*csr.trace()).item()

class Inception(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = InceptionV3(resize_input=True)
    def forward(self, x): return self.model(x.repeat(1,3,1,1))[0]

def get_fid(real_batch, fake_batch, device=def_device, model=Inception().to(def_device)):
    model.eval()
    with torch.no_grad():
        real_feats = model(to_device(real_batch, device))
        fake_feats = model(to_device(fake_batch, device))
        real_feats = to_cpu(real_feats)
        fake_feats = to_cpu(fake_feats)
        m1,c1 = calc_stats(real_feats)
        m2,c2 = calc_stats(fake_feats)
        return calc_fid(m1,c1,m2,c2)

In [7]:
# Get FID for two batches of real images

it = iter(train_dl)
real_batch,_ = next(it)
real_batch_2,_ = next(it)

fid = get_fid(real_batch, real_batch_2)
print(fid)

102.07628140364267


In [8]:
# Load the model

unet_model = UNet2DModel(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 128),norm_num_groups=8)
model = DDPM_model(unet_model, 0.0001, 0.02, 1000, cosine_schedule=False)
model.load_state_dict(torch.load("DDPM_state_dict.pth",map_location=torch.device('cpu')))

<All keys matched successfully>

In [9]:
# Get FID using DDPM scheduler for 1000 timesteps

sched = DDPMScheduler(beta_schedule="scaled_linear", num_train_timesteps=1000)
preds = sample(sched,model,(64,1,32,32))
fid = get_fid(real_batch, preds[-1])
print(fid)

100%|██████████| 1000/1000 [08:39<00:00,  1.93it/s]


187.42056070609073


In [10]:
# Get FID using DDIM scheduler for 500 timesteps

sched = DDIMScheduler(beta_schedule="scaled_linear", num_train_timesteps=1000)
sched.set_timesteps(500)
preds = sample(sched,model,(64,1,32,32))


fid = get_fid(real_batch, preds[-1])
print(fid)

100%|██████████| 500/500 [04:20<00:00,  1.92it/s]


341.5343243573738


In [11]:
# Get FID using DDIM scheduler for 333 timesteps

sched = DDIMScheduler(beta_schedule="scaled_linear", num_train_timesteps=1000)
sched.set_timesteps(333)
preds = sample(sched,model,(64,1,32,32))


fid = get_fid(real_batch, preds[-1])
print(fid)

100%|██████████| 333/333 [03:03<00:00,  1.81it/s]


350.20018457582114


In [12]:
# Get FID using DDIM scheduler for 50 timesteps

sched = DDIMScheduler(beta_schedule="scaled_linear", num_train_timesteps=1000)
sched.set_timesteps(50)
preds = sample(sched,model,(64,1,32,32))


fid = get_fid(real_batch, preds[-1])
print(fid)

100%|██████████| 50/50 [00:27<00:00,  1.82it/s]


315.70180505655634


In [13]:
# Get FID using PNDM scheduler for 50 timesteps

sched = PNDMScheduler(beta_schedule="scaled_linear", num_train_timesteps=1000)
sched.set_timesteps(50)
preds = sample(sched,model,(64,1,32,32))


fid = get_fid(real_batch, preds[-1])
print(fid)

100%|██████████| 59/59 [00:33<00:00,  1.75it/s]


288.7396716038522


In [14]:
# Load the datasets
# Remember to pad the images with 2 pixels on each side i.e. to make the image size 32x32
# There is an extra normalization transform here to make the images between -0.5 and 0.5

batch_size = 64

transforms = T.Compose([T.ToTensor(),T.Pad(2),lambda x: x- 0.5])

train_ds = torchvision.datasets.FashionMNIST(root = './data/train',train = True,download = True,transform = transforms)
valid_ds = torchvision.datasets.FashionMNIST(root = './data/valid',train = False,download = True,transform = transforms)

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=4)

In [15]:
# Get FID between two batches of real images

it = iter(train_dl)
real_batch,_ = next(it)
real_batch_2,_ = next(it)

fid = get_fid(real_batch, real_batch_2)
print(fid)

113.313574425312


In [17]:
# Load the model which was trained on normalized images and using cosine schedule

unet_model = UNet2DModel(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 256),norm_num_groups=8)
model = DDPM_model(unet_model, 0.0001, 0.02, 1000, cosine_schedule=True)
model.load_state_dict(torch.load("DDPM_cosine_state_dict.pth",map_location=torch.device('cpu')))

<All keys matched successfully>

In [18]:
# Get FID using DDPM scheduler for 1000 timesteps

sched = DDPMScheduler()
preds = sample(sched,model,(64,1,32,32))

fid = get_fid(real_batch, preds[-1])
print(fid)

100%|██████████| 1000/1000 [10:23<00:00,  1.60it/s]


193.21688921323997


In [19]:
# Get FID using DDIM scheduler for 500 timesteps

sched = DDIMScheduler(beta_schedule="scaled_linear", num_train_timesteps=1000)
sched.set_timesteps(333)
preds = sample(sched,model,(64,1,32,32))


fid = get_fid(real_batch, preds[-1])
print(fid)

100%|██████████| 333/333 [03:23<00:00,  1.64it/s]


146.48139774630687


In [20]:
# Get FID using DDIM scheduler for 333 timesteps

sched = DDIMScheduler(beta_schedule="scaled_linear", num_train_timesteps=1000)
sched.set_timesteps(50)
preds = sample(sched,model,(64,1,32,32))


fid = get_fid(real_batch, preds[-1])
print(fid)

100%|██████████| 50/50 [00:29<00:00,  1.69it/s]


120.99068507811174
