# Implementing the original Diffusion model for the HAND-MNIST dataset

## Boring imports and datasets (Same for most of the Notebooks that use MNIST or HANDMNIST)

In [1]:
import torch 
import torch.nn as nn 
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import sklearn.model_selection as ms
import math
from tqdm import tqdm,trange
import albumentations as A
from torchvision import transforms
import torch.optim as optim
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
import os 

In [2]:
class InverseNormalize(transforms.Normalize):
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean).view(3,1,1)
        self.std = torch.tensor(std).view(3,1,1)
        
    def __call__(self, tensor):
        tensor = (tensor * self.std.to(tensor.device)) + self.mean.to(tensor.device)
        return tensor

In [57]:
class Hyperparams:
    num_epochs = 20
    lr = 1e-4
    batch_size_train = 64
    batch_size_valid = 64
    num_latent_features = 100
    embedding_dim = 256

    image_size = (64,64,3)

    T = 5
    
    discriminator_steps = 5
    grad_penalty_lambda = 10 
    
    normalise_transform = transforms.Compose([
        transforms.Normalize(mean=(0.5), std=(0.5))
        ])

    inverse_normalise_transform = InverseNormalize(mean=[227.8477, 229.4812, 222.2282], std=[21.2764, 14.5848, 29.2370])

In [4]:
class MNIST_Dataset(Dataset):
    def __init__(self, metadata_df, images, normalise_transform = Hyperparams.normalise_transform ):
        self.metadata_df = metadata_df
        self.images = images
        self.normalise_transform = normalise_transform

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self,idx):
        idx = int(idx)

        image = torch.tensor(images[idx]).unsqueeze(0)
        label = torch.tensor([1])

        if self.normalise_transform:
            image = self.normalise_transform(image.float())

        return image, label

In [50]:
metadata = pd.read_csv('data/HANDMNIST/metadata.csv')
train_metadata, valid_metadata = ms.train_test_split(metadata, test_size=0.2, train_size=0.8, random_state=19, shuffle=True, stratify=metadata['label'])

images = np.load('data/HANDMNIST/images.npy')

In [51]:
train_dataset = MNIST_Dataset(train_metadata.reset_index(), images[train_metadata.index])
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=Hyperparams.batch_size_train)

valid_dataset = MNIST_Dataset(valid_metadata.reset_index(), images[valid_metadata.index])
valid_loader = DataLoader(valid_dataset, shuffle=False, batch_size=Hyperparams.batch_size_valid)

## Fun part: models

In [52]:
class DoubleConv(nn.Module):
    def __init__(self,in_channels, out_channels, residual = False):
        super(DoubleConv, self).__init__()

        self.residual = residual
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False), 
            nn.GroupNorm(1,out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), 
            nn.GroupNorm(1,out_channels),
            nn.ReLU(inplace=True)
        ) 

        
    def forward(self,x):
        x1 = self.double_conv(x)
        if self.residual:
            return x1 + x
        else:
            return x1

class Down(nn.Module):
    def __init__(self, input_channels, output_channels, embedding_dim = Hyperparams.embedding_dim):
        super(Down, self).__init__()
        
        self.input_channels = input_channels
        self.output_channels = output_channels

        self.down_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(input_channels, input_channels, residual = True),
            DoubleConv(input_channels, output_channels, residual = False),
        )

        self.embedding_layer = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Linear(embedding_dim, output_channels),
        )

    def forward(self,x , t):
        x = self.down_conv(x)
        new_embedding = self.embedding_layer(t).view(t.shape[0],self.output_channels,1,1).repeat(1,1,x.shape[2], x.shape[3])
        return x + new_embedding

class Up(nn.Module):
    def __init__(self, input_channels, output_channels, embedding_dim = Hyperparams.embedding_dim):
        super(Up, self).__init__()
        
        self.input_channels = input_channels
        self.output_channels = output_channels

        self.up_conv = nn.Sequential(
            DoubleConv(input_channels, input_channels, residual = True),
            DoubleConv(input_channels, output_channels, residual = False),
            nn.ConvTranspose2d(output_channels, output_channels, kernel_size=4, stride=2, padding=1),
        )

        self.embedding_layer = nn.Sequential(
            nn.LeakyReLU(0.1),
            nn.Linear(embedding_dim, output_channels),
        )

    def forward(self,x, from_skip_x,t):
        x = torch.cat([x, from_skip_x], dim = 1)
        x = self.up_conv(x)
        new_embedding = self.embedding_layer(t).view(t.shape[0],self.output_channels,1,1).repeat(1,1,x.shape[2], x.shape[3])
        return x + new_embedding

    

class Unet(nn.Module):
    def __init__(self, input_channels, output_channels, img_size):
        super(Unet, self).__init__()

        self.input_channels = input_channels
        self.output_channels = output_channels

        self.img_size = img_size

        self.down1 = Down(input_channels, 64)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)

        self.latent = DoubleConv(512,512)

        self.up1 = Up(1024, 256)
        self.up2 = Up(512, 128)
        self.up3 = Up(256, 64)
        self.up4 = Up(128, output_channels)

    def pos_encoding(self,t, embedding_dim = Hyperparams.embedding_dim):
        
        inv_freq = 1 / (10_000 ** torch.arange(0,embedding_dim,2) / embedding_dim).float()

        pos_encoding_a = torch.sin(t * inv_freq)
        pos_encoding_b = torch.cos(t * inv_freq)

        pos_encoding = torch.cat([pos_encoding_a, pos_encoding_b], 0)
        return pos_encoding.unsqueeze(0)

    def forward(self,x, t):
        t = self.pos_encoding(t)
        
        x1 = self.down1(x, t)
        x2 = self.down2(x1, t)
        x3 = self.down3(x2, t)
        x4 = self.down4(x3, t)

        latent = self.latent(x4)
        
        x5 = self.up1(latent,x4,t)
        x6 = self.up2(x5,x3,t)
        x7 = self.up3(x6,x2,t)
        x8 = self.up4(x7,x1,t)

        return x8

## Some necessary functions for diffusion:

In [53]:
class Diffusion:
    def __init__(self,img_size, num_channels = 3, beta_start = 1e-4 , beta_end = 2e-2, T = 1000):
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.T = T
        self.img_size = img_size
        self.num_channels = num_channels


        self.betas = self.make_betas()
        self.alphas = self.make_alphas()
        self.alpha_bars = self.make_alpha_bars()
        

    def make_betas(self):
        return torch.linspace(self.beta_start,self.beta_end,self.T)

    def make_alphas(self):
        return 1 - self.make_betas()

    def make_alpha_bars(self):
        return torch.cumprod(self.make_alphas(), 0)

    def add_noise_to_img(self, image, t):
        alpha_bar = self.alpha_bars[t]

        sqrt_alpha_bar = torch.sqrt(alpha_bar).view(-1,1,1,1).repeat(1,1,image.shape[2],image.shape[3])
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bar).view(-1,1,1,1).repeat(1,1,image.shape[2],image.shape[3])
        noise = torch.randn(image.shape)

        noised_image = sqrt_alpha_bar * image + sqrt_one_minus_alpha_bar * noise
        return noised_image, noise

    def sample_timesteps(self, n):
        return torch.randint(0, self.T, (n,))

    def sample(self, model, fixed_noise, n): #sampling n new images for inference time
        with torch.no_grad():
            if fixed_noise == None:
                noise_to_clear = torch.randn((n, self.num_channels, self.img_size, self.img_size)).to(device)
            else:
                noise_to_clear = fixed_noise

            for t in range(self.T-1, -1, -1):
                t_tensor = torch.tensor([t])
                predicted_noise =  model(noise_to_clear, t_tensor)

                beta = self.betas[t]
                alpha = self.alphas[t]
                alpha_bar = self.alpha_bars[t]

                if t != 0:
                    helping_noise = torch.randn_like(noise_to_clear)
                else:
                    helping_noise = torch.zeros_like(noise_to_clear)

                noise_to_clear = ((1 / torch.sqrt(alpha))  * (noise_to_clear - ((beta / torch.sqrt(1 - alpha_bar)) * predicted_noise))) + torch.sqrt(beta) * helping_noise

            new_images = inverse_normalise_transform(noise_to_clear)
        return new_images

In [66]:
def train(trainloader, model, diffusion_tools, optimizer, criterion,device, epoch):
    model.train()

    for index, data, in tqdm(enumerate(trainloader), total = len(train_loader)):
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        batch_size = images.shape[0]
        timesteps = diffusion_tools.sample_timesteps(batch_size)

        noised_images, noise = diffusion_tools.add_noise_to_img(images, timesteps)
        predicted_noise = model(noised_images,timesteps)

        loss = criterion(predicted_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [67]:
def validate(model, fixed_input_noise, diffusion_tools, device, epoch):
    model.eval()
    with torch.no_grad():
        generated_images = diffusion_tools.sample(model, fixed_input_noise, fixed_input_noise.shape[0])
        generated_images = generated_images.cpu().numpy()

        batch_size = generated_images.shape[0]
        
        for i in range(batch_size):
            plt.subplot(1, batch_size, i + 1)
            image = generated_images[i]
            
            if image.shape[0] > 1:
                image = image.transpose(1, 2, 0)
                
            plt.imshow(image, cmap='gray' if image.shape[-1] != 3 else None)
            plt.axis('off')
            
        plt.show()

In [68]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

diffusion_tools = Diffusion(
    img_size = Hyperparams.image_size[0], 
    num_channels = Hyperparams.image_size[2],
    beta_start = 1e-4 , 
    beta_end = 2e-2, 
    T = Hyperparams.T)

model = Unet(
    input_channels = Hyperparams.image_size[2], 
    output_channels = Hyperparams.image_size[2], 
    img_size= Hyperparams.image_size[0]
).to(device)

fixed_input_noise = torch.randn((5,3,28,28)).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=Hyperparams.lr, betas=(0.5, 0.9))

In [69]:
for epoch in range(Hyperparams.num_epochs):
    print(f'epoch {epoch}/{Hyperparams.num_epochs}')
    train(train_loader, model, diffusion_tools, optimizer, criterion,device, epoch)
    validate(model, fixed_input_noise, diffusion_tools, device, epoch)

epoch 0/20


  0%|          | 0/433 [00:00<?, ?it/s]

torch.Size([64, 1, 28, 28])





RuntimeError: The size of tensor a (64) must match the size of tensor b (128) at non-singleton dimension 0