<a href="https://colab.research.google.com/github/adityav1810/Implementing-Diffusion-Models/blob/main/Basic_Diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torch import optim
from tqdm import tqdm
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader

In [None]:
import logging
logging.basicConfig(format = "%(asctime)s - %(levelname)s: %(message)s",level = logging.INFO,datefmt="%I:%M:%S")

In [None]:
class Diffusion:
  '''
  Class to setup utils and necessary functions needed to build a basic diffusion model
  Params are the same as presented in the original paper(https://arxiv.org/abs/2006.11239)

  '''
  def __init(self,noise_steps = 1000,beta_start = 1e-4,beta_end = 0.02,img_size = 64,device="cuda"):
    '''
    Initialise variables generate noise
    beta is used to generate noise; alpha = 1-beta

    '''
    self.noise_steps = noise_steps
    self.beta_start = beta_start
    self.beta_end = beta_end
    self.img_size = img_size
    self.device = device

    self.beta = self.prepare_noise_schedule().to(device)
    self.alpha = 1-self.beta
    self.alpha_hat = torch.cumprod(self.alpha,dim =0)
  def prepare_noise_schedule(self):
    return torch.linspace(self.beta_start,self.beta_end,self.noise_step)

  def noise_images(self,x,t):
    '''
    Function which creates noise in images. [FORWARDS DIFFUSION PROCESS]
    Instead of adding noise at each timestep, we can directly reach at final timestep (t)

    Returns : sqrt(alpha_hat) * X + sqrt(1-alpha_hat) * noise and noise
    '''
    sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:,None,None,None]
    sqrt_one_minus_alpha_hat = torch.sqrt(1.-self.alpha_hat[t])[:,None,None,None]
    e = torch.rand_like(x)
    return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * e , e

  def sample(self,model,n):
    logging.info(f"Sampling {n} new images...")
    model.eval()
    with torch.no_grad():
      x = torch.randn((n,3,self.img_size,self.img_size)).to(self.device)
      for i in tqdm(reversed(range(1,self.noise_steps)),position = 0):
        t = (torch.ones(n) * i).long().to(self.device)
        predicted_noise = model(x,t)
        alpha = self.alpha[t][:,None,None,None]
        alpha_hat = self.alpha_hat[t][:,None,None,None]
        beta = self.beta[t][:,None,None,None]
        if(i>1):
          # Noise of each timestep ; final timestep will be clear image hence no noise
          noise  =torch.randn_like(x)
        else:
          noise = torch.zeros_like(x)
        x = 1/torch.sqrt(alpha) * (x - ((1-alpha) / (torch.sqrt(1-alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
    model.train()
    x = (x.clamp(-1,1) + 1) / 2
    x = (x * 255).type(torch.uint8)
    return x











In [None]:
class UNet(nn.Module):
  '''
  Constructs UNet Architecture which is used in the Original Paper
  Uses Attention  and Conv blocks to prepare the encoder - decoder type structure of the Diffusion Model


  '''

  def __init__(self,c_in = 3,c_out = 3,time_dim = 256,device = "cuda"):
    super().__init__()
    self.device = device
    self.time_dim  = time_dim
    self.inc = DoubleConv(c_in,64)
    self.down1 = Down(64,128)
    self.sa2 = SelfAttention(128,32)
    self.down2 = Down(128,256)
    self.sa2 = SelfAttention(256,16)
    self.down3 = Down(256,256)
    self.sa3 = SelfAttention(256,8)

    self.bot1 = DoubeConv(256,512)
    self.bot2 = DoubeConv(512,512)
    self.bot3 = DoubleConv(512,256)

    self.up1 = Up(512,128)
    self.sa4 = SelfAttention(128,16)
    self.up2 = Up(256,64)
    self.sa5 = SelfAttention(64,32)
    self.up3 = Up(128,64)
    self.sa6 = SelfAttention(64,64)
    self.outc = nn.Conv2d(64,c_out,kernel_size = 1)
  def pos_encoding(self,t,channels):
    '''
    9:28
    '''
    inv_freq = 1.0/(10000 ** (torch.arrange(0,channels,2,device = self.device).float()/channels))
    pos_enc_a = torch.sin(t.repeat(1,channels//2) * inv_freq)
    pos_enc_b = torch.cos(t.repeat(1,channels//2) * inv_freq)
    pos_enc = torch.cat([pos_enc_a,pos_enc_b],dim = 1)
    return pos_enc

  def forward(self,x,t):
    t = t.unsqueeze(-1).type(torch.float)
    t = self.pos_encoding(t,self.time_dim)

    x1 = self.inc1(x)
    x2 = self.down(x1,t)
    x2 = self.sa1(x2)

    x3 = self.down2(x2,t)
    x3 = self.sa2(x3)

    x4 = self.down3(x3,t)
    x4 = self.sa3(x4)

    x4 = self.bot1(x4)
    x4 = self.bot2(x4)
    x4 = self.bot3(x4)

    x = self.up1(x4,x3,t)
    x = self.sa4(x)
    x = self.up2(x,x2,t)
    x = self.sa5(x)
    x = self.up3(x,x1,t)
    x = self.sa6(x)
    output = self.outc(x)
    return output










In [None]:
class DoubleConv(nn.Module):
  def __init__(self,in_channels,out_channels,mid_channels= None, residual = False):
    super().__init__()
    self.residual = residual
    if not mid_channels :
      mid_channels = out_channels
    self.double_conv = nn.Sequential(
        nn.Conv2D(in_channels,mid_channels,kernel_size = 3, padding = 1, bias = False)
        nn.GroupNorm(1,mid_channels)
        nn.GELU()
        nn.Conv2D(mid_channels,out_channels,kernel_size = 3, padding = 1, bias = False)
        nn.GroupNorm(1,out_channels),
    )
    def forward(self,x):
      if self.residual:
        return F.gelu(x+ self.double_conv(x))
      else:
        return self.double_conv(x)

In [None]:
class Down(nn.Module):
  def __init__(self,in_channels,out_channels,emb_dim=256):
    super().__init__()
    self.maxpool_conv= nn.Sequential(
        nn.MaxPool2D(2)
        DoubleConv(in_channels,in_channels,residual = True)
        DoubleConv(in_channels,out_channels,residual = False)
    )
    self.emb_layer = nn.Sequential(
        nn.SiLU(),
        nn.Linear(emb_dim,out_channels),
    )


  def forward(self,x,t):
    x = self.maxpool_conv(x)
    emb = self.emb_layer(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
    return x+ emb



In [None]:
class Up(nn.Module):
  def __init__(self,in_channels,out_channels,emb_dim=256):
    super().__init__()
    self.up = nn.Upsample(scale_factor = 2,mode = 'bilinear',align_corners = True)

    self.conv= nn.Sequential(

        DoubleConv(in_channels,in_channels,residual = True)
        DoubleConv(in_channels,out_channels,,in_channels //2,residual = False)
    )
    self.emb_layer = nn.Sequential(
        nn.SiLU(),
        nn.Linear(emb_dim,out_channels),
    )


  def forward(self,x,skip_x,t):
    x = self.up(x)
    x = torch.cat([skip_x,x],dim = 1)
    x = self.conv(x)
    emb = self.emb_layer(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
    return x+ emb


In [None]:
class SelfAttention(nn.Module):
  '''
  Self Attention Module
  '''
  def __init__(self,channels,size):
    super(SelfAttention,self).__init__()
    self.channels = channels
    self.size = size
    self.mha = MultiheadAttention(channels,4,batch_first = True)
    self.ln = nn.LayerNorm([channels])
    self.ff_self = nn.Sequential(
        nn.LayerNorm([channels]),
        nn.Linear(channels,channels),
        nn.GELU()
        nn.Linear(channels,channels),
    )
  def forward(self,x):
    '''
    Forward Pass for the attention Layer
    Note: Attention works better if axes are swapped.
          For eg. its better to change an array [1,128,32,32]  to [1,1024,128]
                  swap it back after attention value has been computed
    '''
    x = x.view(-1,self.channels,self.size * self.size).swapaxes(1,2)
    x_ln = self.ln(x)
    attention_value,_=self.mha(x_ln,x_ln,x_ln)
    attention_value = attention_value + x
    attention_value = self.ff_self(attention_value) + attention_value
    return attention_value.swapaxes(2,1).view(-1,self.channels,self.size,self.size)



Lets create some Util Functions

In [None]:
def plot_images(images):
  '''
  Util Function to plot images.
  '''
  # define plot size
  plt.figure(figsize = (32,32))

  # concat a bunch of images
  torch.cat([
      torch.cat([i for i in images.cpu()],dim = -1)
  ],dim = -2).permute(1,2,0).cpu()
  plt.show()

def save_images(images, image_path,**kwargs):
  grid = torchvision.make_grid(images,**kwargs)
  ndarr = grid.permute(1,2,0).to('cpu').numpy()
  im = Image.fromarray(ndarr)
  im.save(path)

def get_data(args):
  '''
  Prepare data by applying transforms
  1. Resize images to 80%
  2. create random cropped images
  3. normalize images
  '''
  transforms = torchvision.transforms.Compose([
      torchvision.transforms.Resize(80),
      torchvision.transforms.RandomResizedCrop(args.img_size,scale = (0.8,1.0)),
      torchvision.tranforms.ToTensor(),
      torchvision.tranforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
  ])

  dataset = torchvision.datasets.ImageFolder(args.dataset_path , transform = transforms)
  dataloader = DataLoader(dataset,batch_size = args.batch_size,shuffle = True)
  return dataloader

def logger(run_name):
  os.makedirs("models",exist_ok = True)
  os.makedirs("results",exist_ok = True)
  os.makedirs(os.path.join("models",run_name),exist_ok = True)
  os.makedirs(os.path.join("results",run_name),exist_ok = True)

In [None]:
def train(args):
  logger(args.run_name)
  device = args.device
  dataloader=get_data(args)
  model = UNet().to(device)
  optimizer = optim.AdamW(model.parameters(),lr = args.lr)
  mse = nn.MSELoss()
  diffusion = Diffusion(img_size = args.image_size, device = device)
  logger = SummaryWriter()
  '''
  14:13
  https://www.youtube.com/watch?v=TBCRlnwJtZU
  '''