# Lab 2.6.1 Solution: Diffusion Theory

This solution notebook contains completed exercises from Lab 2.6.1.

---

## Exercise: Experiment with Different Images

The exercise asked to visualize forward diffusion on different MNIST digits.

In [None]:
# Solution: Visualize different digits
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Load dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Get different digits
digit_indices = {}
for i, (_, label) in enumerate(train_dataset):
    if label not in digit_indices:
        digit_indices[label] = i
    if len(digit_indices) == 10:
        break

# Visualize each digit at different noise levels
timesteps = [0, 250, 500, 750, 999]

fig, axes = plt.subplots(10, len(timesteps), figsize=(12, 24))

for digit in range(10):
    image, _ = train_dataset[digit_indices[digit]]
    image = image.unsqueeze(0)
    noise = torch.randn_like(image)
    
    for j, t in enumerate(timesteps):
        # Add noise (simplified - use scheduler in actual code)
        alpha = 1.0 - (t / 1000)
        noisy = alpha**0.5 * image + (1 - alpha)**0.5 * noise
        
        img = noisy.squeeze().numpy()
        img = (img + 1) / 2
        
        axes[digit, j].imshow(img, cmap='gray', vmin=0, vmax=1)
        axes[digit, j].axis('off')
        if digit == 0:
            axes[digit, j].set_title(f't={t}')

plt.suptitle('Forward Diffusion on All MNIST Digits', fontsize=14, y=1.01)
plt.tight_layout()
plt.show()

## Challenge 1: Class-Conditional Generation

Add class conditioning to generate specific digits.

In [None]:
# Solution: Class-Conditional U-Net
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def get_timestep_embedding(timesteps, embedding_dim):
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
    emb = timesteps[:, None].float() * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    return emb

class ConditionalResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.emb_proj = nn.Linear(emb_dim, out_channels)
        
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.skip = nn.Identity()
    
    def forward(self, x, emb):
        h = self.conv1(x)
        h = self.norm1(h)
        h = F.silu(h)
        
        # Add time+class embedding
        h = h + self.emb_proj(emb)[:, :, None, None]
        
        h = self.conv2(h)
        h = self.norm2(h)
        h = F.silu(h)
        
        return h + self.skip(x)

class ConditionalUNet(nn.Module):
    """
    Class-conditional U-Net.
    
    Takes digit class as additional input to guide generation.
    """
    
    def __init__(self, in_ch=1, out_ch=1, base_ch=64, time_emb=128, num_classes=10):
        super().__init__()
        
        self.time_emb_dim = time_emb
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb, time_emb * 4),
            nn.SiLU(),
            nn.Linear(time_emb * 4, time_emb),
        )
        
        # Class embedding (learnable embedding for each digit)
        self.class_emb = nn.Embedding(num_classes, time_emb)
        
        # Encoder
        self.enc1 = ConditionalResidualBlock(in_ch, base_ch, time_emb)
        self.enc2 = ConditionalResidualBlock(base_ch, base_ch * 2, time_emb)
        self.enc3 = ConditionalResidualBlock(base_ch * 2, base_ch * 4, time_emb)
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = ConditionalResidualBlock(base_ch * 4, base_ch * 4, time_emb)
        
        # Decoder
        self.up3 = nn.ConvTranspose2d(base_ch * 4, base_ch * 4, 2, stride=2)
        self.dec3 = ConditionalResidualBlock(base_ch * 8, base_ch * 2, time_emb)
        self.up2 = nn.ConvTranspose2d(base_ch * 2, base_ch * 2, 2, stride=2)
        self.dec2 = ConditionalResidualBlock(base_ch * 4, base_ch, time_emb)
        self.up1 = nn.ConvTranspose2d(base_ch, base_ch, 2, stride=2)
        self.dec1 = ConditionalResidualBlock(base_ch * 2, base_ch, time_emb)
        
        self.out = nn.Conv2d(base_ch, out_ch, 1)
    
    def forward(self, x, t, c):
        """
        Args:
            x: Noisy image (B, 1, 28, 28)
            t: Timesteps (B,)
            c: Class labels (B,)
        """
        # Compute combined embedding
        t_emb = get_timestep_embedding(t, self.time_emb_dim)
        t_emb = self.time_mlp(t_emb)
        c_emb = self.class_emb(c)
        emb = t_emb + c_emb  # Combine time and class!
        
        # Encoder
        e1 = self.enc1(x, emb)
        e2 = self.enc2(self.pool(e1), emb)
        e3 = self.enc3(self.pool(e2), emb)
        
        # Bottleneck
        b = self.bottleneck(self.pool(e3), emb)
        
        # Decoder with skip connections
        d3 = self.up3(b)
        d3 = F.interpolate(d3, size=e3.shape[2:])
        d3 = self.dec3(torch.cat([d3, e3], dim=1), emb)
        
        d2 = self.up2(d3)
        d2 = F.interpolate(d2, size=e2.shape[2:])
        d2 = self.dec2(torch.cat([d2, e2], dim=1), emb)
        
        d1 = self.up1(d2)
        d1 = F.interpolate(d1, size=e1.shape[2:])
        d1 = self.dec1(torch.cat([d1, e1], dim=1), emb)
        
        return self.out(d1)

# Test
model = ConditionalUNet()
x = torch.randn(4, 1, 28, 28)
t = torch.randint(0, 1000, (4,))
c = torch.randint(0, 10, (4,))  # Class labels
out = model(x, t, c)
print(f"Input: {x.shape}")
print(f"Output: {out.shape}")
print(f"Class labels: {c.tolist()}")
print("\nâœ… Class-conditional U-Net works!")

## Challenge 2: Fashion-MNIST

Adapt the model for Fashion-MNIST.

In [None]:
# Solution: Load Fashion-MNIST
from torchvision import datasets

fashion_dataset = datasets.FashionMNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform,
)

# Class names
fashion_classes = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# Show samples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flatten()):
    img, label = fashion_dataset[i * 1000]
    ax.imshow(img.squeeze().numpy() * 0.5 + 0.5, cmap='gray')
    ax.set_title(fashion_classes[label])
    ax.axis('off')
plt.suptitle('Fashion-MNIST Samples')
plt.tight_layout()
plt.show()

print("\nTo train on Fashion-MNIST:")
print("1. Replace train_dataset with fashion_dataset")
print("2. Train with same hyperparameters")
print("3. May need more epochs (20-50) for good quality")

---

## Key Takeaways

1. **Forward diffusion** gradually adds Gaussian noise to images
2. **Cosine schedule** works better than linear by preserving signal longer
3. **U-Net** architecture uses skip connections for denoising
4. **Timestep embeddings** help the model know the noise level
5. **Class conditioning** enables controllable generation