# Denoising Diffusion Probabilistic Models (DDPMs) 

### Imports and Definitions

In [None]:
# Import Libraries 
import random
import imageio
import numpy as np

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import einops
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.transforms import Compose, ToTensor, Lambda, Resize, Grayscale
from torchvision.datasets.mnist import MNIST, FashionMNIST
from torchvision.datasets import Places365, Flowers102, Food101, CIFAR10

# Definitions
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Path for storing the model
#STORE_PATH_MNIST = f"ddpm_model_mnist.pt"
#STORE_PATH_FASHION = f"ddpm_model_fashion.pt"
#STORE_PATH_PLACES = f"ddpm_model_places365.pt"
#STORE_PATH_FLOWERS = f"ddpm_model_flowers102.pt"
#STORE_PATH_FOOD = f"ddpm_model_food101.pt"
#STORE_PATH_CIFAR = f"ddpm_model_cifar10.pt"

In [None]:
# Training Hyper-parameters
no_train = True
dataset_loaded = 'flowers'
batch_size = 28
epochs = 10
learning_rate = 0.001

In [None]:
# Setting up store path, generated file name and dataset to load
if dataset_loaded == 'fashion':
    store_path = "ddpm_model_fashion.pt"
    gif_name="fashion.gif"
    ds_fn = FashionMNIST
elif dataset_loaded == 'places':
    store_path = "ddpm_model_places365.pt"
    gif_name="places.gif"
    ds_fn = Places365
elif dataset_loaded == 'flowers':
    store_path = "ddpm_model_flowers102.pt"
    gif_name="flowers.gif"
    ds_fn = Flowers102
elif dataset_loaded == 'food':
    store_path = "ddpm_model_food101.pt"
    gif_name="food.gif"
    ds_fn = Food101
elif dataset_loaded == 'cifar':
    store_path = "ddpm_model_cifar10.pt"
    gif_name="cifar.gif"
    ds_fn = CIFAR10    
else: 
    store_path = "ddpm_mnist.pt"
    gif_name="mnist.gif"
    ds_fn = MNIST

### Loading Dataset

In [None]:
# Function to show images
def show_images(images, title=""):
    """Shows the provided images as sub-pictures in a square"""

    # Converting images to CPU numpy arrays
    if type(images) is torch.Tensor:
        images = images.detach().cpu().numpy()

    # Defining number of rows and columns
    fig = plt.figure(figsize=(8, 8))
    rows = int(len(images) ** (1 / 2))
    cols = round(len(images) / rows)

    # Populating figure with sub-plots
    idx = 0
    for r in range(rows):
        for c in range(cols):
            fig.add_subplot(rows, cols, idx + 1)

            if idx < len(images):
                plt.imshow(images[idx][0], cmap="gray")
                idx += 1
    fig.suptitle(title, fontsize=30)

    # Showing the figure
    plt.show()       

In [None]:
# Function to show first batch of images
def show_first_batch(loader):
    for batch in loader:
        show_images(batch[0], "Images in the first batch")
        break

In [None]:
# Loading the data (converting each image into a tensor and normalizing between [-1, 1])
transform = Compose([Resize([32,32]),
    ToTensor(),                
    Lambda(lambda x: (x - 0.5) * 2)]
)
dataset = ds_fn("./datasets", download=True, transform=transform)
loader = DataLoader(dataset, batch_size, shuffle=True)

In [None]:
# Display the loaded first batch
show_first_batch(loader)

### DDPM Model

In [None]:
# Getting device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\t" + (f"{torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "CPU"))

In [None]:
# DDPM class
class MyDDPM(nn.Module):
    def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None, image_chw=(3, 32, 32)):
        super(MyDDPM, self).__init__()
        self.n_steps = n_steps
        self.device = device
        self.image_chw = image_chw
        self.network = network.to(device)
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(device)  # Number of steps is typically in the order of thousands
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device)

    def forward(self, x0, t, eta=None):
        # Make input image more noisy (we can directly skip to the desired step)
        n, c, h, w = x0.shape
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device)

        noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        return noisy

    def backward(self, x, t):
        # Run each image through the network for each timestep t in the vector t.
        # The network returns its estimation of the noise that was added.
        return self.network(x, t)

In [None]:
# Showing the forward process
def show_forward(ddpm, loader, device):    
    for batch in loader:
        imgs = batch[0]

        show_images(imgs, "Original images")

        for percent in [0.25, 0.5, 0.75, 1]:
            show_images(
                ddpm(imgs.to(device),
                     [int(percent * ddpm.n_steps) - 1 for _ in range(len(imgs))]),
                f"DDPM Noisy images {int(percent * 100)}%"
            )
        break

In [None]:
# Generate new image samples for a given DDPM model, a given number of samples to be generated and a given device
def generate_new_images(ddpm, n_samples=16, device=None, frames_per_gif=100, gif_name="sampling.gif", c=3, h=32, w=32):
    frame_idxs = np.linspace(0, ddpm.n_steps, frames_per_gif).astype(np.uint)
    frames = []
    
    with torch.no_grad():
        if device is None:
            device = ddpm.device
            
        # Starting from random noise
        x = torch.randn(n_samples, c, h, w).to(device)
        
        for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]):
            # Estimating noise to be removed
            time_tensor = (torch.ones(n_samples,1)*t).to(device).long()
            eta_theta = ddpm.backward(x, time_tensor)
            alpha_t = ddpm.alphas[t]
            alpha_t_bar = ddpm.alpha_bars[t]
            
            # Partially denoising the image
            x = (1/alpha_t.sqrt())*(x-(1-alpha_t)/(1-alpha_t_bar).sqrt()*eta_theta)
            
            if t>0:
                z = torch.randn(n_samples, c, h, w).to(device)
                
                # sigma_t squared = beta_t
                beta_t = ddpm.betas[t]
                sigma_t = beta_t.sqrt()
                
                # Adding some more noise like in Langevin Dynamics fashion
                x = x + sigma_t * z
                
            # Adding frames to the GIF
            if idx in frame_idxs or t==0:
                # Putting digits in range [0,255]
                normalized = x.clone()
                for i in range(len(normalized)):
                    normalized[i] -= torch.min(normalized[i])
                    normalized[i] *= 255/torch.max(normalized[i])
            
                # Reshaping batch (n, c, h, w) to be a square frame
                frame = einops.rearrange(normalized, "(b1 b2) c h w -> (b1 h) (b2 w) c", b1=int(n_samples ** 0.5))
                frame = frame.cpu().numpy().astype(np.uint8)
                
                # Rendering frame
                frames.append(frame)
                
    # Storing the gif
    with imageio.get_writer(gif_name, mode='I') as writer:
        for idx, frame in enumerate(frames):
            writer.append_data(frame)
            if idx == len(frames) - 1:
                for _ in range(frames_per_gif//3):
                    writer.append_data(frames[-1])
                    
    return x

In [None]:
# Function to return the standard positional embedding
def sinusoidal_embedding(n, d):
    embedding = torch.tensor([[i / 10000 ** (2*j/d) for j in range(d)] for i in range(n)])
    sin_mask = torch.arange(0, n, 2)
    embedding[sin_mask] = torch.sin(embedding[sin_mask])
    embedding[1 - sin_mask] = torch.cos(embedding[sin_mask])
    
    return embedding

In [None]:
# Block class for convolution
class MyBlock(nn.Module):
    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):
        super(MyBlock, self).__init__()
        self.ln = nn.LayerNorm(shape)
        self.gn1 = nn.GroupNorm(num_groups=1, num_channels=out_c)
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)
        self.gn2 = nn.GroupNorm(num_groups=1, num_channels=out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)
        self.activation = nn.GELU() if activation is None else activation
        self.normalize = normalize
    
    def forward(self, x):
        #out = self.ln(x) if self.normalize else x        
        out = self.conv1(x)
        out = self.gn1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.gn2(out)
        out = self.activation(out)
        return out

In [None]:
# Self-attestation class for UNet
class MyTESA(nn.Module):
    def __init__(self, in_c, size, num_heads=4, hidden_dim=1024, dropout=0.0):
        super(MyTESA, self).__init__()
        self.in_c = in_c
        self.size = size
        self.mha = nn.MultiheadAttention(embed_dim=in_c, num_heads=num_heads, batch_first=True)
        self.ln1 = nn.LayerNorm([in_c])
        self.ln2 = nn.LayerNorm([in_c])
        self.mlp = nn.Sequential(
            nn.Linear(in_features=in_c, out_features=hidden_dim),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=hidden_dim, out_features=in_c),
            nn.GELU()
        )
        
    def forward(self, x):
        x = x.reshape(-1, self.in_c, self.size*self.size).permute(0,2,1)
        out = self.ln1(x)
        attention_value, _ = self.mha(query=out, key=out, value=out)
        out = attention_value + x
        out = self.mlp(self.ln2(out))+out
        return out.permute(0,2,1).reshape(-1, self.in_c, self.size, self.size)

In [None]:
# Unet class
class MyUNet(nn.Module):
    def __init__(self, n_channels=3, n_steps=500, time_emb_dim=256):
        super(MyUNet, self).__init__()
        
        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)
        
        # First half
        self.input_conv = MyBlock((n_channels,32,32),n_channels,64)
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            MyBlock((64,16,16),64,64),
            MyBlock((64,16,16),64,128)
        )
        self.te1 = self._make_te(time_emb_dim,64)
        self.sa1 = MyTESA(128,16)
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            MyBlock((128,8,8),128,128),
            MyBlock((128,8,8),128,256)
        )
        self.te2 = self._make_te(time_emb_dim,128)
        self.sa2 = MyTESA(256,8)
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            MyBlock((256,4,4),256,256),
            MyBlock((256,4,4),256,256)
        )
        
        # Bottleneck
        self.te3 = self._make_te(time_emb_dim,256)
        self.sa3 = MyTESA(256,4)
        self.bottleneck = nn.Sequential(
            MyBlock((256,4,4),256,512),
            MyBlock((512,4,4),512,512),
            MyBlock((512,4,4),512,256)
        )
        
        # Second half
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up1_conv = nn.Sequential(
            MyBlock((512,8,8),512,256),
            MyBlock((256,8,8),256,128)
        )
        self.te4 = self._make_te(time_emb_dim,512)
        self.sa4 = MyTESA(128,8)
        self.up2_conv = nn.Sequential(
            MyBlock((256,16,16),256,128),
            MyBlock((128,16,16),128,64)
        )
        self.te5 = self._make_te(time_emb_dim,256)
        self.sa5 = MyTESA(64,16)
        self.up3_conv = nn.Sequential(
            MyBlock((128,32,32),128,64),
            MyBlock((64,32,32),64,64)
        )
        self.te6 = self._make_te(time_emb_dim,128)
        self.sa6 = MyTESA(64,32)
        self.out_conv = nn.Conv2d(64, n_channels, (1, 1))
        
    def forward(self, x, t):
        t = self.time_embed(t)
        n = len(x)
        out1 = self.input_conv(x)
        out2 = self.down1(out1 + self.te1(t).reshape(n, -1, 1, 1))
        out2 = self.sa1(out2)
        out3 = self.down2(out2 + self.te2(t).reshape(n, -1, 1, 1))
        out3 = self.sa2(out3)
        out4 = self.down3(out3 + self.te3(t).reshape(n, -1, 1, 1))
        out4 = self.sa3(out4)
        out5 = self.bottleneck(out4)
        out6 = self.up1_conv(torch.cat((out3, self.up(out5)), dim=1) + self.te4(t).reshape(n, -1, 1, 1))
        out6 = self.sa4(out6)
        out7 = self.up2_conv(torch.cat((out2, self.up(out6)), dim=1) + self.te5(t).reshape(n, -1, 1, 1))
        out7 = self.sa5(out7)
        out8 = self.up3_conv(torch.cat((out1, self.up(out7)), dim=1) + self.te6(t).reshape(n, -1, 1, 1))
        out8 = self.sa6(out8)
        return self.out_conv(out8)
    
    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim_in, dim_out)
        )

In [None]:
# Defining the DDPM model
n_channels, n_steps, min_beta, max_beta = 1, 500, 10**-4, 0.02
ddpm = MyDDPM(MyUNet(n_channels, n_steps), n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)

In [None]:
show_forward(ddpm, loader, device)

### Training the model

In [None]:
# Funtion for training loop
def training_loop(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"):
    mse = nn.MSELoss()
    smoothL1 = nn.SmoothL1Loss()
    best_loss = float("inf")
    n_steps = ddpm.n_steps
    for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
        epoch_loss = 0.
        for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch+1}/{n_epochs}", colour="#005500")):
            # Loading data
            x0 = batch[0].to(device)
            n = len(x0)
            
            # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
            eta = torch.randn_like(x0).to(device)
            t = torch.randint(0, n_steps, (n,)).to(device)
            
            # Computing the noisy image based on x0 and the time-step (forward process)
            noisy_imgs = ddpm(x0, t, eta)

            # Getting model estimation of noise based on the images and the time-step
            eta_theta = ddpm.backward(noisy_imgs, t.reshape(n,-1))
            
            # Optimizing the Smooth L1 between the noise plugged and the predicted noise
            loss = smoothL1(eta_theta, eta)
            optim.zero_grad()
            loss.backward()
            optim.step()
            
            epoch_loss += loss.item()*len(x0)/len(loader.dataset)
            
        # Display images generated at this epoch
        if display:
            show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch+1}")
            
        log_string = f"Loss at epoch {epoch+1}: {epoch_loss:.3f}"
        
        # Storing the model
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(ddpm.state_dict(), store_path)
            log_string += "--> Best model ever (stored)"
            
        print(log_string)

In [None]:
# Training
if not no_train:
    training_loop(ddpm, loader, epochs, optim=Adam(ddpm.parameters(), learning_rate), device=device, store_path=store_path)

### Generating New Images

In [None]:
# Loading the trained model
best_model = MyDDPM(MyUNet(n_channels), n_steps=n_steps, device=device)
best_model.load_state_dict(torch.load(store_path, map_location=device))
best_model.eval()
print("Model loaded: Generating new images")

In [None]:
# Generating new images
generated = generate_new_images(
        best_model,
        n_samples=49,
        device=device,
        gif_name= gif_name,
        c = n_channels
    )
show_images(generated, "Final result")

In [None]:
from IPython.display import Image

Image(open(gif_name,'rb').read())

### Empty the GPU cache for training

In [None]:
# torch.cuda.empty_cache()

### Number of model parameters

In [None]:
#total_params = sum(
#	param.numel() for param in best_model.parameters()
#)
#total_params

In [None]:
#trainable_params = sum(
#	p.numel() for p in best_model.parameters() if p.requires_grad
#)
#trainable_params