# Specific Test IV. Diffusion Models 


Task: Develop a generative model to simulate realistic strong gravitational lensing images. Train a diffusion model (DDPM) to generate lensing images. You are encouraged to explore various architectures and implementations within the diffusion model framework. Please implement your approach in PyTorch or Keras and discuss your strategy.

## Step 1: Imports
We start by importing everything we will need to work with and visualize the data. I am using PyTorch to create my final solution.

In [2]:
import torchvision

torchvision.disable_beta_transforms_warning() # this is to disable warnings with torchvision v2 transforms

import math
import os
import sys
from math import pi

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from ema_pytorch import EMA
from torch import einsum
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision.transforms import v2
from tqdm.auto import tqdm

## Creating the dataset

This dataset consists of 10,000 images for training the diffusion model which is a bit small but still works. For the transforms, I resize the image from 155 -> 160 (nearest multiple of 32), apply random horizontal flipping, and ensure the dataset follows a standard normal distribution.

In [3]:
class StrongLensingDataset(Dataset):
    def __init__(self, imgs, train=True, transform=None):
        self.imgs = imgs
        self.len = sum(1 for _, _, files in os.walk(self.imgs) for f in files)
        self.transform = transform

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        idx += 1
        path = os.path.join(self.imgs, f"sample{idx}.npy")
        img = torch.from_numpy(np.load(path))
        img = img.to(torch.float32)

        if self.transform:
            img = self.transform(img)

        return img


def get_dataloader(file_path, image_size=160, train=True, batch_size=16, num_workers=2):
    transform = v2.Compose(
        [
            v2.Resize((image_size, image_size), antialias=True), # closest multiple of 32 to 155
            v2.RandomHorizontalFlip(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.0615], std=[0.1164]), # mean = 0, variance = 1
        ]
    )
    transform2 = v2.Compose(
        [
            v2.Resize((image_size, image_size), antialias=True), # closest multiple of 32 to 155
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.0615], std=[0.1164]), # mean = 0, variance = 1
        ]
    )
    if train:
        dataset = StrongLensingDataset(file_path, train=train, transform=transform)
    else:
        dataset = StrongLensingDataset(file_path, train=train, transform=transform2)
    dataloader = DataLoader(
        dataset, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    return dataloader

## The Model

For the model, I closely followed openAI's guided diffusion Unet using Phil Wang's implementation for reference(specifically for the linear and normal attention). The model uses Groupnorm and SiLU, alongside a unet that passes 2 residual skip connections for every residual block. The model also takes time embeddings which are inserted into the model at every single residual block using the Adaptive Group Normalization technqiue.

The specific improvements from the paper *Diffusion models beat GANs on image synthesis* include AdaGN implemented using a scale and shift, attention layers at every single layer, and BigGan residual blocks

In [4]:
class Upsample(nn.Module):
    def __init__(self, channels, out_channels=None, end=False):
        super().__init__()
        self.end = end
        if not end:
            self.upsample = nn.Upsample(scale_factor=2)
        if out_channels != None:
            self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, padding=1)
        else:
            self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        if not self.end:
            x = self.upsample(x)
        x = self.conv(x)
        return x


class Downsample(nn.Module):
    def __init__(self, channels, out_channels=None, end=False):
        super().__init__()
        self.end = end
        if not end:
            self.downsample = nn.AvgPool2d(kernel_size=2)
        if out_channels != None:
            self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, padding=1)
        else:
            self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        if not self.end:
            x = self.downsample(x)
        x = self.conv(x)
        return x


# use same as BigGan according to paper
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, emb_size, groups=32, dropout=0.2):
        super().__init__()

        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_size, out_channels * 2
            ),  # times 2 to split later for the AdaGN
        )
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(groups, out_channels)
        self.act1 = nn.SiLU()
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(groups, out_channels)
        self.act2 = nn.SiLU()

        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x, t):
        t = self.time_mlp(t)[..., None, None]
        scale, shift = t.chunk(2, dim=1)

        # First pass
        out = self.conv1(x)
        out = self.norm1(out)
        out = out * (scale + 1) + shift  # adagn
        out = self.act1(out)
        
        # Second pass
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.act2(out)
        #out = self.dropout(out)

        return out + self.shortcut(x)

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

# huggingface blog
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        return F.normalize(x, dim=1) * self.g * (x.shape[-1] ** 0.5)

#huggingface blog
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = RMSNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

# hugging face blog -> phil wang
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale

        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

# huggingface blog -> phil wang
class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        attn = sim.softmax(dim=-1)
        out = einsum("b h i j, b h d j -> b h i d", attn, v)

        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


class Unet(nn.Module):
    def __init__(
        self,
        channels,
        input_channels=3,
        output_channels=3,
        mid_blocks=1,
        channel_mults=[1, 2, 4, 8],
        T=1000,
        groups=32,
    ):
        super().__init__()
        self.channels = channels
        self.channel_mults = channel_mults
        self.channel_mults.insert(0, 1)
        self.in_out_channels = []

        for i in range(len(channel_mults) - 1):
            self.in_out_channels.append(
                (int(channels * channel_mults[i]), int(channels * channel_mults[i + 1]))
            )

        self.downs = nn.ModuleList([])
        self.middles = nn.ModuleList([])
        self.ups = nn.ModuleList([])

        self.initial_conv = nn.Conv2d(
            input_channels, channels, kernel_size=3, padding=1
        )
        time_emb_size = channels * 4  # author papers did this

        self.time_embed = nn.Sequential(
            nn.Embedding(T, channels),
            nn.Linear(channels, time_emb_size),
            nn.SiLU(),
            nn.Linear(time_emb_size, time_emb_size),
        )
        current_channel = channels
        for i, (in_channel, out_channel) in enumerate(self.in_out_channels):
            end = i + 1 == len(self.in_out_channels)
            self.downs.append(
                nn.ModuleList(
                    [
                        ResidualBlock(in_channel, in_channel, time_emb_size, groups),
                        ResidualBlock(in_channel, in_channel, time_emb_size, groups),
                        Residual(
                            PreNorm(in_channel, LinearAttention(in_channel))
                        ),  # huggingface blog with rms insead of group
                        Downsample(in_channel, out_channel, end),
                    ]
                )
            )
            current_channel = out_channel

        for i in range(mid_blocks):
            self.middles.append(
                nn.ModuleList(
                    [
                        ResidualBlock(
                            current_channel, current_channel, time_emb_size, groups
                        ),
                        Residual(PreNorm(current_channel, Attention(current_channel))),
                        ResidualBlock(
                            current_channel, current_channel, time_emb_size, groups
                        ),
                    ]
                )
            )

        for i, (in_channel, out_channel) in enumerate(reversed(self.in_out_channels)):
            end = i + 1 == len(self.in_out_channels)
            self.ups.append(
                nn.ModuleList(
                    [
                        ResidualBlock(
                            out_channel + in_channel, out_channel, time_emb_size, groups
                        ),
                        ResidualBlock(
                            out_channel + in_channel, out_channel, time_emb_size, groups
                        ),
                        Residual(PreNorm(out_channel, LinearAttention(out_channel))),
                        Upsample(out_channel, in_channel, end),
                    ]
                )
            )
            current_channel = in_channel

            self.out1 = ResidualBlock(
                current_channel, current_channel, time_emb_size, groups=8
            )
            self.final = nn.Sequential(
                nn.Conv2d(current_channel, current_channel, kernel_size=3, padding=1),
                nn.BatchNorm2d(current_channel),
                nn.SiLU(),
                nn.Conv2d(current_channel, output_channels, kernel_size=1),
            )

    def forward(self, x, t):
        x = self.initial_conv(x)
        t = self.time_embed(t)
        skips = []
        for res1, res2, att, ds in self.downs:
            x = res1(x, t)
            skips.append(x)

            x = res2(x, t)
            x = att(x)
            skips.append(x)  # phil wang sends 2 skip connections

            x = ds(x)

        for res1, att, res2 in self.middles:
            x = res1(x, t)
            x = att(x)
            x = res2(x, t)

        scale = 2 ** -0.5
        for res1, res2, att, ds in self.ups:
            x = torch.cat([x, skips.pop() * scale], dim=1)
            x = res1(x, t)

            x = torch.cat([x, skips.pop() * scale], dim=1)
            x = res2(x, t)

            x = att(x)
            x = ds(x)

        x = self.out1(x, t)
        return self.final(x)

We quickly define the following function that extracts values from a tensor using the given timestep and reshapes it to a specified size which will make later code much easier to read.

In [5]:
def extract(a, t, x_shape):
    b, *_ = t.shape
    assert x_shape[0] == b
    out = torch.gather(a, dim=0, index=t)
    assert out.shape == torch.Size([b])
    return out.view(b, *(len(x_shape) - 1) * [1])

## The Diffusion Class

We define the diffusion class which implements algorithms algorithm 1 for training and algorithm 2 for sampling from the original DDPM paper. Be default we use a cosine schedule due to the relatively smaller image size and channels present in the Strong Gravitation Lensing images and their relative simpliciy(a linear schedule would noise the images way too quickly).

I used the default timesteps of 1000 as it seemed to work rather well and this can always be made faster by implementing DDIM sampling.

In [6]:
class Diffusion(nn.Module):
    def __init__(
        self,
        model,
        num_channels=1,
        schedule="cosine",
        image_size=160,
        T=1000,
        device=torch.device("cuda"),
    ):
        super().__init__()
        self.device = device
        self.model = model
        self.T = T
        self.num_channels = num_channels

        if schedule == "linear":
            betas = self.linear_betas()
        elif schedule == "cosine":
            betas = self.cosine_betas()
        else:
            print(f"Unsupported diffusion schedule: {schedule}")
            sys.exit()
        betas = betas.to(torch.float32)
        self.register_buffer("betas", betas)
        self.register_buffer("alphas", 1 - self.betas)
        self.register_buffer("alphas_cp", torch.cumprod(self.alphas, dim=0))
        self.register_buffer(
            "alphas_cp_prev", torch.cat([torch.tensor([1]), self.alphas_cp[:-1]])
        )
        self.register_buffer("sqrt_alphas_cp", torch.sqrt(self.alphas_cp))
        self.register_buffer("sqrt_m1_alphas_cp", torch.sqrt(1 - self.alphas_cp))
        self.register_buffer("rec_sqrt_alphas", 1 / torch.sqrt(self.alphas))
        self.register_buffer("rec_sqrt_m1_alphas_cp", 1 / self.sqrt_m1_alphas_cp)
        self.register_buffer("coeffs", self.betas / self.sqrt_m1_alphas_cp)

        self.register_buffer(
            "posterior_var",
            self.betas * (1 - self.alphas_cp_prev) / (1 - self.alphas_cp),
        )
        self.register_buffer(
            "log_posterior_var", torch.log(self.posterior_var.clamp(min=1e-15))
        )  # first value will be 0 to raise it
        self.register_buffer(
            "sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / self.alphas_cp)
        )
        self.register_buffer(
            "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / self.alphas_cp - 1)
        )
        self.device = device
        self.image_size = image_size

    def linear_betas(self, start=1e-4, end=0.02):
        betas = torch.linspace(start, end, self.T)
        return betas

    # improved ddpm
    def cosine_betas(self, s=0.008):
        t = torch.arange(0, self.T + 1, dtype=torch.float64) / self.T
        alpha_hats = torch.cos((t + s) / (1 + s) * (pi / 2)).square()
        alpha_hats = alpha_hats / alpha_hats[0]
        betas = 1 - (alpha_hats[1:] / alpha_hats[:-1])
        return torch.clamp(betas, max=0.999)  # prevent singularities

    def get_loss(self, x):
        b, c, w, h = x.shape

        noise = torch.randn_like(x)
        times = torch.randint(low=0, high=self.T, size=(b,), device=self.device)
        image_mean = x * extract(self.sqrt_alphas_cp, times, x.shape)
        image_variance = noise * extract(self.sqrt_m1_alphas_cp, times, x.shape)
        image = image_mean + image_variance
        output = self.model(image, times)
        loss = F.mse_loss(output, noise)
        return loss

    def sample(self, num_samples):
        self.model.eval()
        with torch.no_grad():
            xt = torch.randn((num_samples, self.num_channels, self.image_size, self.image_size)).to(
                self.device
            )
            b, c, w, h = xt.shape
            for t in tqdm(reversed(range(self.T)), total=self.T):

                t_b = torch.full((b,), t).to(self.device)
                z = torch.randn_like(xt)
                output = self.model(xt, t_b)
                output = xt - output * extract(self.coeffs, t_b, output.shape)
                mean = output * extract(self.rec_sqrt_alphas, t_b, output.shape)
                if t > 0:
                    log_var = extract(self.log_posterior_var, t_b, output.shape)
                    variance = z * torch.exp(
                        0.5 * log_var
                    )  # only have to multiply by 0.5 instead of sqrt
                else:
                    variance = 0
                xt = (mean + variance).to(torch.float32)
        self.model.train()
        return xt

## Training
For training, I use mixed precision training due to the high number of convolutions in the network alongside GradScaler which helps deal with any problems that might arise from fp16 training. We also incorporate and EMA that starts after 1000 steps and has a scale of 0.9999 and save the model every 30 epochs.

In [7]:
def train(diffuser, ema, optimizer, scaler, epochs, dataloader, device):
    fd = open("output-log.txt", "w")
    steps = 0
    for epoch in range(epochs):
        running = 0
        total = 0
        for image in tqdm(dataloader):
            optimizer.zero_grad()
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                loss = diffuser.get_loss(image.to(device))
            running += loss.item()
            total += 1
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            ema.update()
            steps += 1
        print(f"Epoch: {epoch}")
        print("Steps: ", steps)
        print(f"Epoch loss: {running / total}")
        fd.write(f"Epoch: {epoch:>3}, Steps: {steps:>5}, Epoch loss: {running / total}\n")
        if epoch % 100 == 0:
            torch.save(model.module.state_dict(), f"ddpm-deeplense{epoch}.pt")
            torch.save(ema.online_model.state_dict(), f"ema-ddpm-deeplense{epoch}.pt")
    fd.close()

Defining the training hyperparameters

In [9]:
device = torch.device("mps") # Training on 4 A40's
model = Unet(160, input_channels=1, output_channels=1, channel_mults=[1, 1, 2, 2, 3, 4]) # 
print(f"Model size: {sum(p.numel() for p in model.parameters())}")
ema = EMA(
    model,
    beta = 0.9999,
    update_after_step = 1500,
    update_every = 10,
) # EMA only starts after 1000 steps and updates every 10 steps so that it doesn't slow down training too much
ema = ema.to(device)
model = nn.DataParallel(model)
model = model.to(device)
diffuser = Diffusion(model, num_channels=1, schedule="cosine", device=device).to(device) # cosine schedule
trainloader = get_dataloader("dataset-ddpm/train", image_size=160, batch_size=155, train=True, num_workers=1) # num_workers doesn't have to be high
testloader = get_dataloader("dataset-ddpm/val", image_size=160, batch_size=155, train=False, num_workers=1) # num_workers doesn't have to be high
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) # same as used in guided diffusion paper
scaler = torch.cuda.amp.GradScaler()
epochs = 1000

Model size: 109324961


In [10]:
train(diffuser, ema, optimizer, scaler, epochs, dataloader, device)