In [1]:
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader 
import matplotlib.pyplot as plt
from tqdm import tqdm 
import torch.nn as nn 
import os 
import time 
import math
import torchvision
from PIL import Image
import imageio
import shutil

if os.path.exists('results'):
    shutil.rmtree('results')
else:
    os.mkdir('results')


BATCH_SIZE = 32
NUM_STEPS = 100
EPOCHS = 100
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
MNIST_DATA = MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
MNIST_LOADER = DataLoader(MNIST_DATA, batch_size=BATCH_SIZE, shuffle=True)

# Making a Unet Network for 28x28 image

In [3]:
class Conv(nn.Module):
    def __init__(self, inc,ouc):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=inc,out_channels=ouc,kernel_size=3,padding=1),
            nn.GroupNorm(1,ouc),
            nn.GELU(),
            nn.Conv2d(in_channels=ouc,out_channels=ouc,kernel_size=3,padding=1),
            nn.GELU(),
        )
    def forward(self,x):
        return self.conv1(x)
    
class Down(nn.Module):
    def __init__(self, inc, ouc,emb_dim=512):
        super().__init__()
        self.pool = nn.MaxPool2d(2)
        self.conv = Conv(inc,ouc)
        self.embedding_layer = nn.Sequential(
            nn.SELU(),
            nn.Linear(emb_dim,ouc)
            )
    def forward(self,x,time_step):
        x = self.pool(x)
        x = self.conv(x)
        embedding = self.embedding_layer(time_step).squeeze(1)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
        return x+embedding
    
class Up(nn.Module):
    def __init__(self, inc,ouc,emb_dim=512):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2,mode='bilinear')
        self.conv = nn.Sequential(
            Conv(inc,ouc)
        )
        self.embedding_layer = nn.Sequential(
            nn.SELU(),
            nn.Linear(emb_dim,ouc)
            )
    def forward(self,x,skip,t):
        x = self.up(x)
        x = torch.cat([skip,x],dim=1)
        x = self.conv(x)
        emb = self.embedding_layer(t).squeeze(1)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
        return x+emb
        

class Unet(nn.Module):
    def __init__(self, inc=1,ouc=1,emb_dim=512,device='cpu',steps=NUM_STEPS):
        super().__init__()
        self.time_emb_layer = nn.Embedding(steps,embedding_dim=emb_dim)

        self.c1 = Conv(inc,64)
        self.down1 = Down(64,128)
        self.c2 = Conv(128,256)
        self.down2 = Down(256,256)

        self.b1 = Conv(256,512)
        self.b2 = Conv(512,512)
        self.b3 = Conv(512,256)

        self.up1 = Up(512,128)
        self.c3 = Conv(128,64)
        self.up2 = Up(128,64)

        self.ouc = nn.Conv2d(64,ouc,kernel_size=1)

    def forward(self,x,t):
        t = t.unsqueeze(-1)
        t = self.time_emb_layer(t)

        x1 = self.c1(x)
        x2 = self.down1(x1,t)
        x2 = self.c2(x2)
        x3 = self.down2(x2,t)

        x4 = self.b1(x3)
        x4 = self.b2(x4)
        x4 = self.b3(x4)    

        x = self.up1(x4,x2,t)
        x = self.c3(x)
        x = self.up2(x,x1,t)
        x = self.ouc(x)
        return x

   



In [4]:
class DiffusionModel:
    def __init__(self,num_steps=NUM_STEPS,device=DEVICE,beta_start=1e-4,beta_end=0.02,image_size=28):
        self.num_steps = num_steps
        self.device = device
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.image_size = image_size
        self.beta  = self.noise_scheduler().to(device)
        self.alpha = 1. - self.beta 
        self.alpha_hat = torch.cumprod(self.alpha , dim=0)
        self.lr = 3e-4

    def noise_scheduler(self, s: float = 0.008):
        steps = self.num_steps + 1
        x = torch.linspace(0, self.num_steps, steps, device=self.device)
        alphas_cumprod = torch.cos(
            ((x / self.num_steps + s) / (1 + s)) * math.pi * 0.5
        ) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        betas = torch.clip(betas, 1e-8, 0.999)  
        return betas

    def noise_images(self,x,t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:,None,None,None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1-self.alpha_hat[t])[:,None,None,None]
        E = torch.randn_like(x)
        return sqrt_alpha_hat *x + sqrt_one_minus_alpha_hat* E , E

    def sample_timestemps(self,n):
        return torch.randint(low=1,high=self.num_steps,size=(n,))
    
    def save_images(self,images, path, **kwargs):
        grid = torchvision.utils.make_grid(images, **kwargs)
        ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
        im = Image.fromarray(ndarr)
        im.save(path)

    def train(self,unet_model,dataloader,epochs):
        unet_model = unet_model.to(self.device)
        optimizer = torch.optim.AdamW(unet_model.parameters(),lr=self.lr)
        criterion = nn.MSELoss()
        l = len(dataloader)

        for epoch in range(epochs):
            p_bar = tqdm(dataloader)
            for i,(images,_) in enumerate(p_bar):
                images = images.to(self.device)
                t = self.sample_timestemps(images.shape[0]).to(self.device)
                x_t,noise = self.noise_images(images,t)
                predicted_noise = unet_model(x_t,t)
                loss = criterion(noise,predicted_noise)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                p_bar.set_postfix(MSE=loss.item())
            sampled_images = self.generate(unet_model,n=images.shape[0])
            self.save_images(sampled_images,os.path.join("results",f"{epoch}.jpg"))
            torch.save(unet_model.state_dict(),f"ckpt.pt")


    def generate(self,unet_model,n):
        unet_model.eval()
        with torch.no_grad():
            x = torch.randn((n,1,self.image_size,self.image_size)).to(self.device)
            for i in tqdm(reversed(range(1,self.num_steps))):
                t = (torch.ones(n)*i).long().to(self.device)
                predicted_noise = unet_model(x,t)
                alpha = self.alpha[t][:,None,None,None]
                alpha_hat = self.alpha_hat[t][:,None,None,None]
                beta = self.beta[t][:,None,None,None]
                if i>1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1/ torch.sqrt(alpha) * (x-((1-alpha)/(torch.sqrt(1-alpha_hat)))* predicted_noise) + torch.sqrt(beta)*noise 
            unet_model.train()
            unet_model.train()
            x = (x.clamp(-1,1)+1)/2
            x = (x * 255).type(torch.uint8)
            return x

In [5]:

DDPM = DiffusionModel()
# DDPM.train(
#     Unet(),
#     MNIST_LOADER,
#     EPOCHS,
# )

In [6]:
alphas = DDPM.alpha 
alpha_hats = DDPM.alpha_hat
betas = DDPM.beta

# Gif for posting

In [8]:
if os.path.exists('final'):
    shutil.rmtree('final')
os.mkdir('final')

num_images = 32
model = Unet()
model.load_state_dict(torch.load("ckpt.pt",map_location=DEVICE))
x = torch.randn((num_images,1,28,28)).to(DEVICE)
for i in tqdm(reversed(range(1,NUM_STEPS))):
            t = (torch.ones(1)*i).long().to(DEVICE)
            predicted_noise = model(x,t)
            alpha = alphas[t][:,None,None,None]
            alpha_hat = alpha_hats[t][:,None,None,None]
            beta = betas[t][:,None,None,None]
            if i>1:
                    noise = torch.randn_like(x)
            else:
                    noise = torch.zeros_like(x)
            x = 1/ torch.sqrt(alpha) * (x-((1-alpha)/(torch.sqrt(1-alpha_hat)))* predicted_noise) + torch.sqrt(beta)*noise
            img = (x.clamp(-1,1)+1)/2
            img = (img * 255).type(torch.uint8)
            DDPM.save_images(img,os.path.join("final",f"{i}.jpg"))


frames = []
for i in range(NUM_STEPS-1, 0, -1): 
    filename = f"final/{i}.jpg"
    if os.path.exists(filename):
        frames.append(imageio.imread(filename))
imageio.mimsave("final.gif", frames, duration=0.01)

99it [00:26,  3.77it/s]
  frames.append(imageio.imread(filename))
