## Image Generation Using Diffusion Models

In [1]:
import os
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
import logging
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
import torchvision
from pathlib import Path
import argparse




logging.basicConfig(
    format="%(asctime)s - %(levelname)s: %(message)s",
    level=logging.INFO,
    datefmt="%I:%M:%S"
)

In [2]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=32, device="cuda"):
        self.noise_steps = noise_steps  # Number of sampling steps (e.g., 1000 as proposed in DDPM paper)
        self.beta_start = beta_start    # Lower bound for beta
        self.beta_end = beta_end        # Upper bound for beta
        self.img_size = img_size
        self.device = device

        # Prepare noise schedule (beta), alpha, and cumulative product of alphas
        self.beta = self._prepare_noise_schedule().to(self.device)  
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)  # Cumulative product for reverse process

    def _prepare_noise_schedule(self):
        """Returns a 1D tensor of evenly spaced values for noise schedule (beta)"""
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps) 

    def noise_images(self, x, t):
        """Adds noise to the image x at time step t, creating a noisy image"""
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]  # Expand dimensions for broadcasting
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        epsilon = torch.randn_like(x)  # Random noise tensor with the same shape as x

        # Return the noisy image and the noise added to it
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon

    def sample_timesteps(self, n):
        """Randomly samples timesteps (used during training)"""
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n):
        """Reverse process: Generates n new images by sampling the diffusion model"""
        logging.info(f"Sampling {n} new images....")
        model.eval()  # Set model to evaluation mode

        with torch.no_grad():  # Disable gradient computation
            # Initialize noisy images with Gaussian noise
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)

            # Reverse process: iterate through timesteps in reverse
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = torch.full((n,), i, dtype=torch.long).to(self.device)  # Time step tensor

                # Predict noise at timestep t
                predicted_noise = model(x, t)
                
                # Retrieve alpha and beta values for timestep t
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]

                # Add noise for all steps except the last
                noise = torch.randn_like(x) if i > 1 else torch.zeros_like(x)

                # Update x based on the predicted noise and noise schedule
                x = (x - (1 - alpha) / torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha) 
                x += torch.sqrt(beta) * noise # adding a bit of noise to denoised image to preserve diversity and probabilistic behavior in generated images (like in the paper)

        model.train()  # Set model back to training mode

        # Post-process the output to match image format (scale pixel values to [0, 255])
        x = (x.clamp(-1, 1) + 1) / 2  # Clamp to [-1, 1] and rescale to [0, 1]
        x = (x * 255).type(torch.uint8)  # Convert to uint8 format for RGB images
        return x


In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2), # reduce the size by two
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        # mostl blocks have different hidden dimensions from the timestep embedding
        # so, we use a linear projection to bring the time embedding to the proper dimension 
        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear( # time embedding to hidden dimensions
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) # projecting time embedding
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1) # concatenating skip connections with x
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb
    

class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
    

class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda", img_size=32):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        
        # encoder part
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.sa1 = SelfAttention(128, img_size // 2)  # Adjust size according to the image resolution
        self.down2 = Down(128, 256)
        self.sa2 = SelfAttention(256, img_size // 4)
        self.down3 = Down(256, 256)
        self.sa3 = SelfAttention(256, img_size // 8)

        # bottleneck part
        self.bot1 = DoubleConv(256, 512)
        self.bot2 = DoubleConv(512, 512)
        self.bot3 = DoubleConv(512, 256)

        # decoder part
        self.up1 = Up(512, 128)
        self.sa4 = SelfAttention(128, img_size // 4)
        self.up2 = Up(256, 64)
        self.sa5 = SelfAttention(64, img_size // 2)
        self.up3 = Up(128, 64)
        self.sa6 = SelfAttention(64, img_size)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)

    # positional encodings to inject information about time into the model
    # timestep, number of output channels (dimensions)
    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    # the noise images, and timesteps
    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t) # takes the skip connections from the encoder too
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        output = self.outc(x)
        return output

In [4]:
# some classes or functions that we use a lot
class Args:
    def __init__(self, run_name, epochs, batch_size, image_size, dataset_path, device, lr):
        self.run_name = run_name
        self.epochs = epochs
        self.batch_size = batch_size
        self.image_size = image_size
        self.dataset_path = dataset_path
        self.device = device
        self.lr = lr
        

def plot_images(images):
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)


def get_data(args):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(80),  # args.image_size + 1/4 *args.image_size
        torchvision.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # normalizing them to be between -1 and 1
    ])
    dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    return dataloader


def setup_logging(run_name):
    os.makedirs("models", exist_ok=True) # leave the directory unaltered if exists
    os.makedirs("results", exist_ok=True)
    os.makedirs(os.path.join("models", run_name), exist_ok=True)
    os.makedirs(os.path.join("results", run_name), exist_ok=True)

In [5]:
# the training loop
def train(args):
    setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)
    model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr) # we use AdamW optimizer
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)
    training_loss = []

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, _) in enumerate(pbar):
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device) # random timesteps
            x_t, noise = diffusion.noise_images(images, t) # adding noise to images
            predicted_noise = model(x_t, t) # predicting the noise with the model
            loss = mse(noise, predicted_noise) # loss between the actual and predicted noise

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
            training_loss.append(loss.item())

        sampled_images = diffusion.sample(model, n=images.shape[0])
        save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
        if epoch % 50 == 0:
            torch.save(
                {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss.item(),
                "epoch": epoch,
                "training_loss": training_loss,
                },
                os.path.join("models", args.run_name, f"epoch_{epoch}.pt")
            )


            
if __name__ == '__main__':
    args = Args(
        run_name="DDPM_Uncondtional",
        epochs=500,
        batch_size=24,
        image_size=32,
        dataset_path="/kaggle/input/natural-landscapes/Natural Landscapes",
        device="cuda",
        lr=3e-4,
    )
    train(args)

100%|██████████| 180/180 [01:16<00:00,  2.34it/s, MSE=0.0891]
999it [00:24, 41.53it/s]
100%|██████████| 180/180 [00:59<00:00,  3.00it/s, MSE=0.0588]
999it [00:23, 41.64it/s]
100%|██████████| 180/180 [01:00<00:00,  2.97it/s, MSE=0.0268]
999it [00:24, 41.61it/s]
100%|██████████| 180/180 [01:00<00:00,  2.96it/s, MSE=0.0533]
999it [00:24, 41.41it/s]
100%|██████████| 180/180 [01:00<00:00,  2.97it/s, MSE=0.0382]
999it [00:24, 41.56it/s]
100%|██████████| 180/180 [01:00<00:00,  3.00it/s, MSE=0.0328]
999it [00:24, 41.54it/s]
100%|██████████| 180/180 [01:00<00:00,  2.98it/s, MSE=0.0475]
999it [00:24, 41.58it/s]
100%|██████████| 180/180 [01:00<00:00,  3.00it/s, MSE=0.0233]
999it [00:24, 41.46it/s]
100%|██████████| 180/180 [00:59<00:00,  3.01it/s, MSE=0.0279]
999it [00:24, 41.46it/s]
100%|██████████| 180/180 [01:00<00:00,  3.00it/s, MSE=0.0478]
999it [00:24, 41.50it/s]
100%|██████████| 180/180 [01:00<00:00,  2.99it/s, MSE=0.0282]
999it [00:24, 41.53it/s]
100%|██████████| 180/180 [01:00<00:00,  2.9