# Denoising Diffusion Probablistic Model

In [1]:
import torch
import torch.nn as nn

In [None]:
device=["cuda:0" if torch.cuda.is_available() else "cpu"]

In [None]:
class LinearNoiseSchedular:
    
    def __init__(self,num_timesteps,beta_start,beta_end):
        self.num_timesteps=num_timesteps
        self.beta_start=beta_start
        self.beta_end=beta_end
        
        self.betas=torch.linspace(beta_start,beta_end,num_timesteps)
        self.alphas=1-self.betas
        self.alpha_cum_prod=torch.cumprod(self.alphas,dim=0)
        self.sqrt_alpha_cum_prod=torch.sqrt(self.alpha_cum_prod)
        self.sqrt_one_minus_alpha_cum_prod=torch.sqrt(1-self.alpha_cum_prod)

    
    def add_noise(self,original,noise,t):
        original_shape=original.shape
        batch_size=original_shape[0]

        sqrt_alpha_cum_prod=self.alpha_cum_prod[t].reshape(batch_size)
        sqrt_one_minus_alpha_cum_prod=self.sqrt_one_minus_alpha_cum_prod[t].reshape(batch_size)

        for _ in range(len(original_shape)-1):
            sqrt_alpha_cum_prod=sqrt_alpha_cum_prod.unsqueeze(-1)  #[batch_size]->[batch_size,1]
            sqrt_one_minus_alpha_cum_prod=sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)


        return sqrt_alpha_cum_prod*original+sqrt_one_minus_alpha_cum_prod*noise
    

    def sample_prev_timestep(self,xt,noise_pred,t):
        x0=(xt-(self.sqrt_one_minus_alpha_cum_prod*noise_pred))/self.sqrt_alpha_cum_prod

        x0=torch.clamp(x0,-1,1)

        mean=xt-((self.betas[t]*noise_pred)/(self.sqrt_one_minus_alpha_cum_prod))
        mean=mean/torch.sqrt(self.alphas[t])

        if t==0:
            return mean,x0
        else:
            variance=(1-self.alphas[t])*(1-self.alpha_cum_prod[t-1])
            variance=variance/1-self.alpha_cum_prod[t]

            sigma=variance**0.5

            z=torch.randn(xt.shape).to(xt.device)

            return mean+sigma*z,x0
        





In [1]:
# we also give the timestep we are at along with the image 

def get_time_embedding(time_steps,t_emb_dim):
    factor=1000**((torch.arange(
        start=0,end=t_emb_dim//2,device=time_steps.device) / (t_emb_dim // 2)
        ))
    
    t_emd=time_steps[:, None].repeat(1,t_emb_dim//2) /factor

    t_emb=torch.cat([torch.sin(t_emb),torch.cos(t_emb)],dim=-1)

    return t_emb


class DownBlock(nn.Module):
    def __init__(self,in_channels,out_channels,t_emb_dim,down_sample,num_heads):
        super().__init__()
        self.resnet_conv_first=nn.Sequential(
            nn.GroupNorm(8,in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
        )
        self.t_emb_layers=nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim,out_channels)
        )

        self.resenet_conv_second=nn.Sequential(
            nn.GroupNorm(8,out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
        )


        self.attention_norm=nn.GroupNorm(8,out_channels)
        self.attention=nn.MultiHeadAttention(out_channels,num_heads,batch_first=True)
        
        self.residual_input_conv=nn.Conv2d(in_channels,out_channels,kernel_size=2)

        self.down_sample_conv=nn.Conv2d(out_channels,out_channels,kernel_size=4,stride=2,padding=1) if self.down_sample else nn.Identity()

    
    def forward(self,x,t_emb):
        out=x

        # Resnet block 
        resnet_input=out
        out=self.resnet_conv_first(out)
        out=out+self.t_emb_layers(t_emb)[:,:,None,None]
        out=self.resenet_conv_second(out)
        out=out+self.residual_input_conv(resnet_input)

        # Attention Block
        batch_size,channels,h,w=out.shape
        in_attn=out.reshape(batch_size,channels,h*w)
        in_attn=self.attention_norm(in_attn)
        in_attn=in_attn.transpose(1,2) # to ensure the channels features are the last features


        out_attn,_=self.attention(in_attn,in_attn,in_attn)
        out_attn=out_attn.transpose(1,2).reshape(batch_size,channels,h,w)
        out=out+out_attn


        out=self.down_sample_conv(out)
        return out


class MidBlock(nn.Module):
    def __init__(self,in_channels,out_channels,t_emb_dim,num_heads):
        super().__init__()
        self.resnet_conv_first=nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8,in_channels),
                nn.SiLU(),
                nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,paddding=1)

            ),
            nn.Sequential(
                nn.GroupNorm(8,out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)

            )
        ])

        self.t_emb_layers=nn.ModuleList(
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim,out_channels)

            ),
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim,out_channels)
            )
        )

        self.resnet_conv_second=nn.ModuleList([
            nn.Sequential(
                nn.GroupNorm(8,out_channels),
                nn.SiLU(),
                nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
            )
        ])


        self.attention_norm=nn.GroupNorm(8,out_channels)
        self.attention=nn.MultiheadAttention(out_channels,num_heads,batch_first=True)


        self.residual_input_conv=nn.ModuleList([
            nn.Conv2d(in_channels,out_channels,kernel_size=1),
            nn.Conv2d(out_channels,out_channels,kernal_size=1)
        ])

    def forward(self,x,t_emb):
        out=x
        # first resnet block 
        resnet_input=out
        out=self.resnet_conv_first[0](out)
        out=out+self.t_emb_layers[0](t_emb)[:,:,None,None]
        out=self.resnet_conv_second[0](out)

        out=out+self.residual_input_conv[0](resnet_input)

        # attention block

        batch_size,channels,h,w=out.shape
        in_attn=out.reshape(batch_size,channels,h*w)
        in_attn=self.attention_norm(in_attn)
        in_attn=in_attn.transpose(1,2)
        out_attn=out_attn.tranpose(1,2).reshape(batch_size,channels,h,w)
        out=out+out_attn

        # second resnet block
        resnet_input=out
        out=self.resnet_conv_first[1](out)
        out=out+self.t_emb_layers[1](t_emb)[:,:,None,None]
        out=out+self.residual_input_conv[1](resnet_input)


        return out
    


class UpBlock(nn.Module):
    def __init__(self,in_channels,out_channels,t_emb_dim,up_sample,num_heads):
        super().__init__()
        self.up_sample=up_sample
        self.resnet_conv_first=nn.Sequential(
            nn.GroupNorm(8,in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
        )

        self.t_emb_layers=nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim,out_channels)
        )

        self.resnet_conv_second=nn.Sequential(
            nn.GroupNorm(8,out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
        )


        self.attention_norm=nn.GroupNorm(8,out_channels)
        self.attention=nn.MultiheadAttention(out_channels,num_heads,batch_first=True)
        self.residual_input_conv=nn.Conv2d(in_channels,out_channels,kernel_size=1)

        self.up_sample_conv=nn.ConvTranspose2d(in_channels//2,in_channels//2,kernel_size=4,stride=2,padding=1) if self.up_sample else nn.Identity()

    def forward(self,x,out_down,t_emb):
        
        x=self.up_sample_conv(x)
        x=torch.cat([x,out_down],dim=1)

        # Resnet_block
        out=x
        resnet_input=out
        out=self.resnet_conv_first(out)
        out=out+self.t_emb_layers(t_emb)[:,:,None,None]
        out=self.resnet_conv_second(out)
        out=out+self.residual_input_conv(resnet_input)

        # Attention_Block   
        batch_size,channels,h,w=out.shape
        in_attn=out.reshape(batch_size,channels,h*w)
        in_attn=self.attention_norm(in_attn)

        in_attn=in_attn.tranpose(1,2)
        out_attn,_=self.attention(in_attn,in_attn,in_attn)
        out_attn=out_attn.transpose(1,2).reshape(batch_size,channels,h,w)
        out=out+out_attn

        return out
        
        

class UNet(nn.Module):
    def __init__(self,im_channels):
        super().__init__()
        self.down_channels=[32,64,128,256]
        self.mid_channels=[256,256,128]
        # downsample argument
        self.t_emb_dim=128
        self.down_sample=[True,True,False]


        # Time Embedding block had position embedding followed by linear layer with activation in between (this is different from the timestep layers which we had for each resent block this can only be called once in an entire forward pass at start to get the intial time step represetation)

        self.t_proj=nn.Sequential(
            nn.Linear(self.t_emb_dim,self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim,self.t_emb_dim)
        )


        self.up_sample=list(reversed(self.down_sample))
        self.conv_in=nn.Conv2d(im_channels,self.down_channels[0],kernel_size=3,padding=1)


        self.downs=nn.ModuleList([])
        for i in range(len(self.down_channels)-1):
            self.downs.append(DownBlock(self.down_channels[i],self.down_channels[i+1],self.t_emb_dim,down_sample=self.down_sample[i],num_heads=4))
        
        self.mids=nn.ModuleList([])
        for i in range(len(self.mid_channels-1)):
            self.mids.append(MidBlock(self.mid_channels[i],self.mid_channels[i+1],self.t_emb_dim,num_heads=4))
        

        self.ups=nn.ModuleList([])

        for i in reversed(range(len(self.down_channels-1))):
            self.ups.append(UpBlock(self.down_channels[i]*2,self.down_channels[i-1] if i!=0 else 16, self.t_emb_dim,up_sample=self.down_sample[i],num_heads=4))



        self.norm_out=nn.GroupNorm(8,16)
        self.conv_out=nn.Conv2d(16,im_channels,kernel_size=3,padding=1)


    def forward(self,x,t):
        out=self.conv_in(x)
        t_emb=get_time_embedding(t,self.t_emb_dim)
        t_emb=self.t_proj(t_emb)

        down_outs=[]
        for down in self.downs:
            print(out.shape)
            down_outs.append(out)
            out=down(out,t_emb)


        for mid in self.mids:
            print(out.shape)
            out=mid(out,t_emb)

        for up in self.ups:
            down_out=down_outs.pop()
            print(out,down_outs.shape)
            out=up(out,down_out,t_emb)
        
        out=self.norm_out(out)
        out=nn.SiLU()(out)
        out=self.conv_out(out)

        return out

NameError: name 'nn' is not defined

## MNIST_Dataset

In [None]:
import os
from torch.utils.data import Dataset
import tqdm
import glob
import torchvision


class MnistDataset(Dataset):
    """ created datatset class rather than using torchvision to allow replacement with other image dataset"""

    def __init__(self,split,im_path,im_ext='png'):


        self.split=split
        self.im_ext=im_ext
        self.images,self.labels=self.load_images(im_path)

    def load_images(self,im_path):
        """Gets all images from the path specified and stacks them all up"""

        assert os.path.exists(im_path),"images path {} does not exists".format(im_path)
        ims=[]
        labels=[]
        for d_name in tqdm(os.listdir(im_path)):
            for fname in glob.glob(os.path.join(im_path,d_name,'*.{}'.format(self.im_ext))):
                ims.append(fname)
                labels.append(int(d_name))

        
        print('Found {} images for split {}'.format(len(ims),self.split))

        return ims,labels
    

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,index):
        im=Image.open(self.images[index])
        im_tensor=torchvision.transforms.ToTensor()(im)

        # Convert input to -1 to 1 range
        im_tensor=(2*im_tensor)-1
        return im_tensor





## Train


In [None]:
from torch.utils.data import DataLoader
import yaml
import numpy as np
from  torch.optim import Adam


def train(args):

    # Read the config_file #
    with open(args.config_path,'r') as file:
        try:
            config=yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    
    print(config)

    #######


    diffusion_config=config['diffusion_params']
    dataset_config=config['dataset_params']
    model_config=config['model_params']
    train_config=config['train_params']



    # Create a noise schedular

    schedular=LinearNoiseSchedular(num_timesteps=diffusion_config['num_timesteps'],
                                   beta_start=diffusion_config['beta_start'],
                                   beta_end=diffusion_config['beta_end'])
    
    # Create the dataset
    mnist=MnistDataset("train",im_path=dataset_config['im_path'])  
    mnist_loader=DataLoader(mnist,batch_size=train_config['batch_size'],shuffle=True,num_workers=4)

    device=["cuda:0" if torch.cuda.is_available() else "cpu"]

    # Instantiate the model
    model=UNet(model_config).to(device)
    model.train()

    # Create output directory
    if not os.path.exists(train_config['task_name']):
        os.mkdir(train_config['task_name'])

    
    # Load checkpoint if found
    if os.path.exists(os.path.join(train_config['task_name'],train_config['ckpt_name'])):
        print('Loading checkpoint as found one')
        model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
                                                      train_config['ckpt_name']),map_location=device))
        
    
    # Specify training parameters  
    num_epochs=train_config['num_epochs']
    optimizer=Adam(model.parameter(),lr=train_config['lr'])
    criterion=torch.nn.MSELoss()


    # Run training
    for epoch_idx in range(num_epochs):
        losses=[]
        for im in tqdm(mnist_loader):
            optimizer.zero_grad()
            im=im.float().to(device)

            # Sample random noise
            noise=torch.randn_like(im).to(device)

            # sample timestep
            t=torch.randint(0,diffusion_config['num_timestep'],(im.shape[0],)).to(device)

            # add noise to images according to timestep
            noisy_im=schedular.add_noise(im,noise,t)
            noise_pred=model(noisy_im,t)

            loss=criterion(noise_pred,noise)
            losses.append(loss.item())
            loss.backward()
            optimizer.step()

        print('Finished epoch:{} | Loss: {:.4f}'.format(
            epoch_idx+1,
            np.mean(losses)
        ))

        torch.save(model.state_dict(),os.path.join(train_config['task_name'],train_config['ckpt_name']))
        


    




In [None]:
import argParse

def infer(args):
    with open(args.config_path,'r') as file:
        try:
            config=yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
        
    
    print(config)


    #####################

    diffusion_config=config['diffusion_params'] 
    model_config=config['model_params']
    train_config=config["train_params"]


    # load model with checkpoints
    model=UNet(model_config).to(device)
    model.load_state_dict(torch.load(os.path.join(train_config['task_name'],
                                                  train_config['ckpt_name']),map_location=device))
    
    model.eval()


    # Create the noise scheduler
    scheduler=LinearNoiseSchedular(num_timesteps=diffusion_config['num_timesteps'],beta_start=diffusion_config['beta_start'],beta_end=diffusion_config['beta_end'])

    with torch.no_grad():
        sample(model,scheduler,train_config,model_config,diffusion_config)


def sample(model,scheduler,train_config,model_config,diffusion_config):

    # Samples stepwise by going backward one timestep at a time

    xt=torch.randn((train_config['num_samples'],model_config['im_channels'],model_config['im_size'],model_config['im_size'])).to(device)


    for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
        # Get prediction of noise
        noise_pred=model(xt,torch.as_tensor(i).unsqueeze.to(device))

        # Use scheduler to get x0 and xt-1
        xt,x0_pred=scheduler.sample_prev_timestep(xt,noise_pred,torch.as_tensor(i).to(device))


        # Save x0
        ims=torch.clamp(xt,-1.,1.).detach().cpu()
        ims=(ims+1)/2
        grid=make_grid(ims,nrow=train_config['num_grid_rows'])

        img=torchvision.transforms.ToPILImage()(grid)
        if not os.path.exists(os.path.join(train_config['task_name'],'samples')):
            os.mkdir(os.path.join(train_config['task_name'],'samples'))

        img.save(os.path.join(train_config['task_name'],'samples','x0_{}.png'.format(i)))

        img.close()




if __name__=='__main__':
    argparse=argParse()
    parser=argparse.ArgumentParser(description='Argument for ddpm image generation')