# ðŸš€ MNIST Diffusion Model Training (Local)

This notebook is adapted for local execution. It covers:
- **Forward Diffusion Process** (Adding noise)
- **U-Net Architecture** (Predicting noise)
- **Training Loop** (Minimizing MSE loss)
- **Reverse Diffusion** (Generating new digits)

---

In [None]:
import math
import os
import time
import csv
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# ------------------------------------------------
# Hyperparameters
# ------------------------------------------------
T = 500              # Total diffusion steps
BATCH_SIZE = 128
LR = 2e-4
EPOCHS = 10          # Suggested 15, lowered for CPU speed
USE_CONDITIONAL = True # Model generates specific digits

# ------------------------------------------------
# Data Loading
# ------------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

print(f"Dataset loaded. {len(train_ds)} samples.")

## 1) Noise Schedule (Beta, Alpha, Alpha_bar)

Diffusion works by adding noise according to a schedule. 
- **Beta (Î²)**: Amount of noise added at each step.
- **Alpha (Î±)**: 1 - Î² (amount of Signal kept).
- **Alpha_bar (Î±Ì„)**: Cumulative product of alphas (total signal kept up to step t).

In [None]:
def make_beta_schedule(T, beta_start=1e-4, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, T)

betas = make_beta_schedule(T).to(device)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0).to(device)

def q_sample(x0, t, eps):
    """
    Forward Process: x_t = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * eps
    """
    a_bar = alpha_bars[t].view(-1, 1, 1, 1)
    signal = torch.sqrt(a_bar) * x0
    noise = torch.sqrt(1.0 - a_bar) * eps
    return signal + noise

## 2) Time Embedding

Since the U-Net needs to know which step `t` it is cleaning, we encode `t` into a vector using a sinusoidal function (similar to Transformers).

In [None]:
def sinusoidal_time_embedding(timesteps, dim):
    half = dim // 2
    freqs = torch.exp(
        -math.log(10000) * torch.arange(0, half, device=timesteps.device).float() / half
    )
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)
    return emb

## 3) U-Net Components

- **ResBlock**: Convolutional block that injects time information.
- **Down**: Downsamples the image (28x28 -> 14x14 -> 7x7).
- **Up**: Upsamples the image using Transposed Convolutions.

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_proj = nn.Linear(time_dim, out_ch)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb):
        h = F.silu(self.conv1(x))
        time_added = self.time_proj(t_emb).view(-1, h.size(1), 1, 1)
        h = h + time_added
        h = F.silu(self.conv2(h))
        return h + self.skip(x)

class Down(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.down = nn.Conv2d(ch, ch, 4, stride=2, padding=1)
    def forward(self, x):
        return self.down(x)

class Up(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1)
    def forward(self, x):
        return self.up(x)

## 4) U-Net Model Architecture

The model architecture that predicts noise from the noisy `x_t`.

In [None]:
class UNet(nn.Module):
    def __init__(self, in_ch=1, base=128, time_dim=128, num_classes=None):
        super().__init__()
        self.time_dim = time_dim
        self.num_classes = num_classes

        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_dim)

        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim)
        )

        self.in_conv = nn.Conv2d(in_ch, base, 3, padding=1)

        # Encoder
        self.rb1 = ResBlock(base, base, time_dim)
        self.down1 = Down(base)
        self.rb2 = ResBlock(base, base * 2, time_dim)
        self.down2 = Down(base * 2)
        self.rb3 = ResBlock(base * 2, base * 2, time_dim)

        # Bottleneck
        self.mid1 = ResBlock(base * 2, base * 4, time_dim)
        self.mid2 = ResBlock(base * 4, base * 4, time_dim)
        self.mid3 = ResBlock(base * 4, base * 2, time_dim)

        # Decoder
        self.up1 = Up(base * 2)
        self.rb4 = ResBlock(base * 4, base * 2, time_dim)
        self.up2 = Up(base * 2)
        self.rb5 = ResBlock(base * 3, base, time_dim)

        self.out_norm = nn.GroupNorm(8, base)
        self.out_conv = nn.Conv2d(base, 1, 3, padding=1)

    def forward(self, x, t, y=None):
        t_emb = sinusoidal_time_embedding(t, self.time_dim)
        t_emb = self.time_mlp(t_emb)

        if self.num_classes is not None:
            t_emb = t_emb + self.label_emb(y)

        x1 = self.rb1(self.in_conv(x), t_emb)
        x2 = self.rb2(self.down1(x1), t_emb)
        x3 = self.rb3(self.down2(x2), t_emb)

        h = self.mid1(x3, t_emb)
        h = self.mid2(h, t_emb)
        h = self.mid3(h, t_emb)

        h = self.up1(h)
        h = self.rb4(torch.cat([h, x2], dim=1), t_emb)
        h = self.up2(h)
        h = self.rb5(torch.cat([h, x1], dim=1), t_emb)

        return self.out_conv(F.silu(self.out_norm(h)))

## 5) Training Loop

The model learns to predict the noise $\epsilon$ added to the image $x_0$.

**Note:** Training on a CPU can be slow. Reduced epochs are set by default.

In [None]:
model = UNet(in_ch=1, base=128, time_dim=128, num_classes=(10 if USE_CONDITIONAL else None)).to(device)
opt = optim.AdamW(model.parameters(), lr=LR)
loss_fn = nn.MSELoss()

print("ðŸš€ Starting Training Loop...")
model.train()
loss_hist = []

for ep in range(1, EPOCHS + 1):
    epoch_losses = []
    t0 = time.time()
    for i, (x0, y) in enumerate(train_loader, start=1):
        x0, y = x0.to(device), y.to(device)
        t = torch.randint(0, T, (x0.size(0),), device=device)
        eps = torch.randn_like(x0)
        xt = q_sample(x0, t, eps)

        if USE_CONDITIONAL:
            eps_pred = model(xt, t, y=y)
        else:
            eps_pred = model(xt, t)

        loss = loss_fn(eps_pred, eps)
        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_hist.append(loss.item())
        epoch_losses.append(loss.item())
        
        if i % 100 == 0:
            print(f"Epoch [{ep}/{EPOCHS}] | Step {i} | Loss: {loss.item():.4f}")

    avg_loss = sum(epoch_losses) / len(epoch_losses)
    print(f"âœ… Epoch {ep} done! Average Loss: {avg_loss:.4f} | Time: {time.time()-t0:.2f}s")

# Save the model locally
torch.save(model.state_dict(), "diffusion_unet_mnist.pth")
print("ðŸ’¾ Model saved as diffusion_unet_mnist.pth")

## 6) Image Generation (Reverse Diffusion)

Now we use the trained model to generate images from noise.

In [None]:
@torch.no_grad()
def p_sample(x, t, y=None):
    eps_pred = model(x, t, y=y)
    beta_t = betas[t[0]]
    alpha_t = alphas[t[0]]
    alpha_bar_t = alpha_bars[t[0]]
    
    mean = (1.0 / torch.sqrt(alpha_t)) * (x - (beta_t / torch.sqrt(1.0 - alpha_bar_t)) * eps_pred)
    
    if t[0] > 0:
        z = torch.randn_like(x)
        return mean + torch.sqrt(beta_t) * z
    return mean

@torch.no_grad()
def sample(n=16, target_digit=None):
    model.eval()
    x = torch.randn(n, 1, 28, 28, device=device)
    y = torch.full((n,), target_digit, device=device, dtype=torch.long) if target_digit is not None else None
    
    for t_inv in range(T - 1, -1, -1):
        t_batch = torch.full((n,), t_inv, device=device, dtype=torch.long)
        x = p_sample(x, t_batch, y=y)
    return x

In [None]:
# Generate some digits
target = 3  # Try changing this digit
gen_imgs = sample(16, target_digit=target)

plt.figure(figsize=(6,6))
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(gen_imgs[i, 0].cpu(), cmap="gray")
    plt.axis("off")
plt.suptitle(f"Generated Digits: {target}")
plt.show()