# Diffusion Training and Experimentation

This notebook demonstrates training Diffusion model on the MNIST dataset using a modular training pipeline implemented in PyTorch Lightning.


In [1]:
!git clone https://github.com/Reennon/gen-ai-cv-2-3.git
%cd gen-ai-cv-2-3

fatal: destination path 'gen-ai-cv-2-3' already exists and is not an empty directory.
/content/gen-ai-cv-2-3


In [2]:
!pip install -r requirements.txt



In [3]:
!git pull

remote: Enumerating objects: 9, done.[K
remote: Counting objects:  11% (1/9)[Kremote: Counting objects:  22% (2/9)[Kremote: Counting objects:  33% (3/9)[Kremote: Counting objects:  44% (4/9)[Kremote: Counting objects:  55% (5/9)[Kremote: Counting objects:  66% (6/9)[Kremote: Counting objects:  77% (7/9)[Kremote: Counting objects:  88% (8/9)[Kremote: Counting objects: 100% (9/9)[Kremote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects:  33% (1/3)[Kremote: Compressing objects:  66% (2/3)[Kremote: Compressing objects: 100% (3/3)[Kremote: Compressing objects: 100% (3/3), done.[K
remote: Total 5 (delta 2), reused 5 (delta 2), pack-reused 0 (from 0)[K
Unpacking objects:  20% (1/5)Unpacking objects:  40% (2/5)Unpacking objects:  60% (3/5)Unpacking objects:  80% (4/5)Unpacking objects: 100% (5/5)Unpacking objects: 100% (5/5), 1.08 KiB | 1.08 MiB/s, done.
From https://github.com/Reennon/gen-ai-cv-2-3
   5a62562..b17cc25  main       -> origin/ma

In [4]:
import os
import dotenv
import wandb
import torch

from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from src.training.diffusion_module import (DiffusionModel)
from src.training.trainer import train_model

from google.colab import userdata

In [6]:
os.environ["WANDB_KEY"] = userdata.get("wandb_key")
!echo $WANDB_KEY >> .env

In [7]:
dotenv.load_dotenv()

True

In [8]:
parameters = OmegaConf.load("./params/diffusion.yml")
wandb.login(key=os.environ["WANDB_KEY"])

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrkovalch[0m ([33mrkovalchuk[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [9]:
wandb_project_name = "unconditional-unet-mnist"
device = "gpu"

In [10]:
# Prepare the MNIST dataset.
transform = transforms.Compose([
    transforms.ToTensor(),
])
mnist_train = MNIST(root='data', train=True, download=True, transform=transform)
mnist_val = MNIST(root='data', train=False, download=True, transform=transform)
train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
val_loader = DataLoader(mnist_val, batch_size=128)


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.optim import Adam
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image

# Assumed to be defined elsewhere:
# SinusoidalPosEmb, DownSample, UpSample

class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.conv_a = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.act = nn.ReLU(inplace=True)
        self.conv_b = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_ch)
        # Use a projection if channel dimensions differ
        self.shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, time_emb):
        h = self.conv_a(x)
        h = self.act(h)
        # Add time conditioning (spatially broadcasted)
        time_emb_proj = self.time_proj(time_emb).unsqueeze(-1).unsqueeze(-1)
        h = h + time_emb_proj
        h = self.conv_b(h)
        return self.act(h + self.shortcut(x))


class CrossAttnModule(nn.Module):
    def __init__(self, feat_dim, embed_dim):
        super().__init__()
        self.query_conv = nn.Conv2d(feat_dim, feat_dim, kernel_size=1)
        self.key_fc = nn.Linear(embed_dim, feat_dim)
        self.value_fc = nn.Linear(embed_dim, feat_dim)
        self.scale = feat_dim ** -0.5
        self.out_proj = nn.Conv2d(feat_dim, feat_dim, kernel_size=1)

    def forward(self, feat, emb):
        # feat: (B, feat_dim, H, W), emb: (B, embed_dim)
        B, C, H, W = feat.shape
        q = self.query_conv(feat).view(B, C, -1).permute(0, 2, 1)  # (B, H*W, C)
        k = self.key_fc(emb).unsqueeze(1)  # (B, 1, C)
        v = self.value_fc(emb).unsqueeze(1)  # (B, 1, C)
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # (B, H*W, 1)
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, v)  # (B, H*W, C)
        attn_output = attn_output.permute(0, 2, 1).view(B, C, H, W)
        return feat + self.out_proj(attn_output)


class CombinedUNet(pl.LightningModule):
    def __init__(self, in_channels=1, base_ch=64, time_dim=128,
                 num_classes=10, cls_embed_dim=32, attn_embed_dim=64):
        """
        in_channels: number of image channels (e.g., 1 for MNIST)
        base_ch: base channel count for U-Net
        time_dim: dimension for time embedding
        num_classes: number of classes for conditioning
        cls_embed_dim: dimension for input conditioning embedding
        attn_embed_dim: dimension for cross-attention embedding
        """
        super().__init__()
        self.num_classes = num_classes
        # Embedding layers for class conditioning
        self.input_embed = nn.Embedding(num_classes, cls_embed_dim)
        self.attn_embed = nn.Embedding(num_classes, attn_embed_dim)
        self.time_emb = SinusoidalPosEmb(time_dim)

        # Concatenate conditioning with image: channels become in_channels + cls_embed_dim
        self.init_conv = nn.Conv2d(in_channels + cls_embed_dim, base_ch, kernel_size=3, padding=1)

        # Downsampling pathway
        self.block1 = BasicBlock(base_ch, base_ch, time_dim)
        self.down1 = DownSample(base_ch, base_ch * 2)
        self.block2 = BasicBlock(base_ch * 2, base_ch * 2, time_dim)
        self.down2 = DownSample(base_ch * 2, base_ch * 4)
        self.block3 = BasicBlock(base_ch * 4, base_ch * 4, time_dim)

        # Bottleneck and cross-attention conditioning
        self.bottleneck = BasicBlock(base_ch * 4, base_ch * 4, time_dim)
        self.cross_attn = CrossAttnModule(feat_dim=base_ch * 4, embed_dim=attn_embed_dim)

        # Upsampling pathway with skip connections
        self.up1 = UpSample(base_ch * 4, base_ch * 2)
        self.block4 = BasicBlock(base_ch * 2, base_ch * 2, time_dim)
        self.up2 = UpSample(base_ch * 2, base_ch)
        self.block5 = BasicBlock(base_ch, base_ch, time_dim)

        self.final_conv = nn.Conv2d(base_ch, in_channels, kernel_size=3, padding=1)

    def forward(self, x, t, y):
        B, _, H, W = x.shape
        # Input-channel conditioning: if condition is dropped (y < 0), use zero embedding.
        condition_mask = (y >= 0).float().unsqueeze(1)
        y_indices = y.clone()
        y_indices[y_indices < 0] = 0  # Clamp for embedding lookup.
        class_cond = self.input_embed(y_indices) * condition_mask
        # Expand spatially and concatenate with image.
        class_cond = class_cond.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
        x = torch.cat([x, class_cond], dim=1)

        h = self.init_conv(x)
        time_features = self.time_emb(t)
        h = self.block1(h, time_features)
        skip1 = h.clone()
        h = self.down1(h)
        h = self.block2(h, time_features)
        skip2 = h.clone()
        h = self.down2(h)
        h = self.block3(h, time_features)

        # Bottleneck stage.
        h = self.bottleneck(h, time_features)

        # Cross-attention conditioning: use separate embedding; drop if condition is absent.
        cond_mask_attn = (y >= 0).float().unsqueeze(1)
        y_attn = y.clone()
        y_attn[y_attn < 0] = 0
        attn_condition = self.attn_embed(y_attn) * cond_mask_attn
        h = self.cross_attn(h, attn_condition)

        # Upsampling with skip connections.
        h = self.up1(h) + skip2
        h = self.block4(h, time_features)
        h = self.up2(h) + skip1
        h = self.block5(h, time_features)
        return self.final_conv(h)


class DiffusionModule(pl.LightningModule):
    def __init__(self, unet, total_steps=1000, beta_min=1e-4, beta_max=0.02,
                 lr=2e-4, cond_drop=0.1):
        """
        unet: the combined conditioning U-Net
        total_steps: total diffusion timesteps
        cond_drop: probability to drop conditioning during training
        """
        super().__init__()
        self.unet = unet
        self.total_steps = total_steps
        self.lr = lr
        self.cond_drop = cond_drop

        self.register_buffer('beta', torch.linspace(beta_min, beta_max, total_steps))
        self.register_buffer('alpha', 1 - self.beta)
        self.register_buffer('alpha_cumprod', torch.cumprod(self.alpha, dim=0))

    def forward(self, x, t, y):
        return self.unet(x, t, y)

    def q_sample(self, x0, t, noise):
        # q(x_t | x_0) = sqrt(alpha_bar[t])*x0 + sqrt(1 - alpha_bar[t])*noise
        sqrt_alpha = self.alpha_cumprod[t].sqrt().view(-1, 1, 1, 1)
        sqrt_beta = (1 - self.alpha_cumprod[t]).sqrt().view(-1, 1, 1, 1)
        return sqrt_alpha * x0 + sqrt_beta * noise

    def training_step(self, batch, batch_idx):
        x, y = batch
        B = x.size(0)
        t = torch.randint(0, self.total_steps, (B,), device=x.device)
        noise = torch.randn_like(x)
        x_noisy = self.q_sample(x, t, noise)
        # Drop condition with a given probability.
        cond_mask = torch.rand(B, device=x.device) < self.cond_drop
        y_mod = y.clone()
        y_mod[cond_mask] = -1
        noise_hat = self.unet(x_noisy, t, y_mod)
        loss = F.mse_loss(noise_hat, noise)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.lr)

    @torch.no_grad()
    def sample(self, num_imgs=16, ddim_steps=50, target_class=0, guidance=5.0):
        """
        Generate samples using DDIM with classifier-free guidance.
        target_class: desired class label (0-9)
        guidance: strength of conditioning
        """
        self.unet.eval()
        device = next(self.unet.parameters()).device
        img_res = 28
        x = torch.randn(num_imgs, 1, img_res, img_res, device=device)
        ts = torch.linspace(self.total_steps - 1, 0, steps=ddim_steps, dtype=torch.long, device=device)

        for i in range(len(ts) - 1):
            t_curr = ts[i]
            t_next = ts[i + 1]
            t_tensor = torch.full((num_imgs,), t_curr, device=device, dtype=torch.long)
            cond_labels = torch.full((num_imgs,), target_class, device=device, dtype=torch.long)
            uncond_labels = torch.full((num_imgs,), -1, device=device, dtype=torch.long)

            eps_cond = self.unet(x, t_tensor, cond_labels)
            eps_uncond = self.unet(x, t_tensor, uncond_labels)
            eps = eps_uncond + guidance * (eps_cond - eps_uncond)

            alpha_curr = self.alpha_cumprod[t_curr]
            sqrt_alpha_curr = alpha_curr.sqrt()
            sqrt_beta_curr = (1 - alpha_curr).sqrt()
            x0_pred = (x - sqrt_beta_curr * eps) / sqrt_alpha_curr

            alpha_next = self.alpha_cumprod[t_next]
            x = alpha_next.sqrt() * x0_pred + (1 - alpha_next).sqrt() * eps

        return x


def create_dataloader(batch_sz=128):
    tfms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda img: (img - 0.5) * 2)  # Normalize to [-1, 1]
    ])
    mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=tfms)
    return DataLoader(mnist_data, batch_size=batch_sz, shuffle=True)


# Instantiate the models.
unet_model = CombinedUNet(
    in_channels=1, base_ch=64, time_dim=128,
    num_classes=10, cls_embed_dim=32, attn_embed_dim=64
)
diffusion_model = DiffusionModule(
    unet_model, total_steps=1000, beta_min=1e-4, beta_max=0.02,
    lr=2e-4, cond_drop=0.1
)

# Prepare data and trainer.
data_loader = create_dataloader(batch_sz=128)
trainer = pl.Trainer(
    max_epochs=30,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1 if torch.cuda.is_available() else None
)

# Train the diffusion model.
trainer.fit(diffusion_model, data_loader)

# Sample images using DDIM with classifier-free guidance.
sampled_images = diffusion_model.sample(num_imgs=16, ddim_steps=50, target_class=7, guidance=5.0)
save_image(sampled_images, "combined_conditioning_ddim_samples.png", nrow=4, normalize=True)
print("Sampling complete. Check 'combined_conditioning_ddim_samples.png' for generated images.")


In [14]:
wandb.finish()

VBox(children=(Label(value='9.905 MB of 9.905 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇█
learning_rate_optimizer_0,▁▁▁▁▁▁▁▁▁▁
train_loss,█▂▅▃▂▂▅▃▃▃▃▂▂▂▃▁▁▃▂▃▂▂▁▁▂▂▁▁▁▁▂▂▂▁▁▁▂▂▂▁
trainer/global_step,▁▁▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇████
val_loss,█▃▄▃▃▁▂▂▁▁

0,1
epoch,9.0
learning_rate_optimizer_0,0.001
train_loss,0.03237
trainer/global_step,9379.0
val_loss,0.03204


In [15]:
# from google.colab import runtime
# runtime.unassign()