# PyTorch Tutorial: Generative Diffusion Models

Diffusion models (like Stable Diffusion and DALL-E) generate images by learning to reverse a gradual noise process. They start with pure noise and slowly refine it into an image.

In this notebook, we will implement the core mathematics of **DDPM (Denoising Diffusion Probabilistic Models)**.

## Learning Objectives
- Understand the Forward Diffusion Process (Adding Noise)
- Understand the Reverse Diffusion Process (Denoising)
- Implement the Noise Schedule
- Build a simplified U-Net for noise prediction


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(42)

## 1. The Forward Process (Adding Noise)

We take an image $x_0$ and add Gaussian noise over $T$ steps until it becomes pure noise $x_T$.

Formula:
$$ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I) $$

We can jump directly to any step $t$:
$$ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon $$
where $\epsilon \sim \mathcal{N}(0, I)$.

In [None]:
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

# Define schedule
T = 200
betas = linear_beta_schedule(timesteps=T)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)

def get_index_from_list(vals, t, x_shape):
    """Helper to get value at index t and reshape to match x"""
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """Takes an image and a timestep t and returns the noisy image at t"""
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(torch.sqrt(alphas_cumprod), t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(torch.sqrt(1. - alphas_cumprod), t, x_0.shape)
    
    # Mean + Variance
    return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise

# Visualize
image = torch.zeros((1, 3, 64, 64)) # Dummy black image
image[:, :, 16:48, 16:48] = 1.0 # White square in middle

plt.figure(figsize=(10, 3))
for idx, t_val in enumerate([0, 50, 100, 199]):
    t = torch.tensor([t_val])
    noisy_image, _ = forward_diffusion_sample(image, t)
    
    plt.subplot(1, 4, idx+1)
    plt.imshow(noisy_image[0].permute(1, 2, 0).clamp(0, 1))
    plt.title(f"t={t_val}")
    plt.axis('off')
plt.show()

## 2. The Reverse Process (The Model)

We need a neural network that takes a noisy image $x_t$ and the timestep $t$, and predicts the noise $\epsilon$ that was added.

We typically use a **U-Net**.

In [None]:
class SimpleUNet(nn.Module):
    """A very simplified U-Net for demonstration"""
    def __init__(self):
        super().__init__()
        # Downsample
        self.down1 = nn.Conv2d(3, 64, 3, padding=1)
        self.down2 = nn.Conv2d(64, 128, 3, padding=1)
        
        # Time embedding (simplified)
        self.time_mlp = nn.Linear(1, 128)
        
        # Upsample
        self.up1 = nn.ConvTranspose2d(128, 64, 3, padding=1)
        self.up2 = nn.ConvTranspose2d(64, 3, 3, padding=1)

    def forward(self, x, t):
        # Embed time
        t = t.float().view(-1, 1)
        t_emb = self.time_mlp(t).view(-1, 128, 1, 1)
        
        # Down
        x1 = F.relu(self.down1(x))
        x2 = F.relu(self.down2(x1))
        
        # Add time info
        x2 = x2 + t_emb
        
        # Up
        x = F.relu(self.up1(x2))
        x = self.up2(x)
        return x

model = SimpleUNet()
print("Model created!")

## 3. Training Loop

1. Sample a random image $x_0$.
2. Sample a random timestep $t$.
3. Add noise to get $x_t$.
4. Model predicts the noise: $\hat{\epsilon} = \text{Model}(x_t, t)$.
5. Loss is MSE between real noise $\epsilon$ and predicted noise $\hat{\epsilon}$.

In [None]:
# Dummy Training Step
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 1. Get data
x_0 = torch.randn(4, 3, 64, 64) # Batch of 4 images
t = torch.randint(0, T, (4,))

# 2. Forward diffusion
x_t, noise = forward_diffusion_sample(x_0, t)

# 3. Predict noise
noise_pred = model(x_t, t)

# 4. Loss
loss = F.mse_loss(noise, noise_pred)
print(f"Loss: {loss.item()}")

## 4. Sampling (Generation)

To generate an image:
1. Start with pure noise $x_T$.
2. Loop backwards from $T$ to $0$.
3. At each step, remove a bit of noise using the model's prediction.

*(Code omitted for brevity, but involves subtracting the predicted noise and adding a small amount of random noise back for stability)*

## 5. Full DDPM Sampling Algorithm

Let's implement the complete sampling (generation) loop.

In [None]:
@torch.no_grad()
def ddpm_sample(model, image_shape, T, betas, alphas_cumprod, device="cpu"):
    """
    DDPM Sampling: Start from noise, iteratively denoise.
    
    Algorithm:
    1. Sample x_T ~ N(0, I)
    2. For t = T-1, T-2, ..., 0:
       a. Predict noise: ε_θ(x_t, t)
       b. Compute x_{t-1} using the reverse process formula
    3. Return x_0
    """
    batch_size = image_shape[0]
    
    # Precompute coefficients
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)
    alphas = 1 - betas
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
    
    # Posterior variance: how much noise to add back
    posterior_variance = betas * (1 - alphas_cumprod[:-1]) / (1 - alphas_cumprod[1:])
    posterior_variance = torch.cat([posterior_variance[:1], posterior_variance])
    
    # Start from pure noise
    x_t = torch.randn(image_shape, device=device)
    
    # Iteratively denoise
    for t in reversed(range(T)):
        t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
        
        # Predict noise
        predicted_noise = model(x_t, t_batch)
        
        # Compute x_{t-1}
        # Formula: x_{t-1} = 1/sqrt(α_t) * (x_t - β_t/sqrt(1-α̅_t) * ε_θ)
        alpha_t = alphas[t]
        alpha_cumprod_t = alphas_cumprod[t]
        beta_t = betas[t]
        
        # Mean of p(x_{t-1} | x_t)
        model_mean = sqrt_recip_alphas[t] * (
            x_t - beta_t / sqrt_one_minus_alphas_cumprod[t] * predicted_noise
        )
        
        if t > 0:
            # Add noise (except for final step)
            noise = torch.randn_like(x_t)
            x_t = model_mean + torch.sqrt(posterior_variance[t]) * noise
        else:
            x_t = model_mean
    
    return x_t

# Demo sampling (with untrained model - will be noise, but shows the algorithm)
print("Sampling from untrained model (will be random, but shows algorithm)...")
sampled = ddpm_sample(model, (4, 3, 64, 64), T, betas, alphas_cumprod)
print(f"Sampled images shape: {sampled.shape}")

# Visualize
plt.figure(figsize=(12, 3))
for i in range(4):
    plt.subplot(1, 4, i+1)
    plt.imshow(sampled[i].permute(1, 2, 0).clamp(0, 1).numpy())
    plt.axis('off')
    plt.title(f'Sample {i+1}')
plt.suptitle('DDPM Samples (untrained model = noise)')
plt.show()

## 6. DDIM: Faster Sampling

DDPM requires ~1000 steps to generate an image. **DDIM (Denoising Diffusion Implicit Models)** achieves comparable quality in 20-50 steps!

Key insight: DDIM is a **deterministic** sampler (no random noise added during sampling).

In [None]:
@torch.no_grad()
def ddim_sample(model, image_shape, T, alphas_cumprod, num_inference_steps=50, device="cpu"):
    """
    DDIM Sampling: Deterministic, fewer steps needed.
    
    Key difference from DDPM:
    - Uses a subset of timesteps (e.g., every 20th step)
    - No stochastic noise during sampling
    - Same trained model, different inference algorithm
    """
    batch_size = image_shape[0]
    
    # Select a subset of timesteps
    step_size = T // num_inference_steps
    timesteps = list(range(0, T, step_size))
    timesteps = list(reversed(timesteps))
    
    # Start from pure noise
    x_t = torch.randn(image_shape, device=device)
    
    for i, t in enumerate(timesteps[:-1]):
        t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
        
        # Predict noise
        predicted_noise = model(x_t, t_batch)
        
        # Get alpha values
        alpha_t = alphas_cumprod[t]
        alpha_t_prev = alphas_cumprod[timesteps[i + 1]] if i + 1 < len(timesteps) else 1.0
        
        # Predict x_0 from x_t
        pred_x0 = (x_t - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)
        
        # Direction pointing to x_t
        direction = torch.sqrt(1 - alpha_t_prev) * predicted_noise
        
        # DDIM update (deterministic!)
        x_t = torch.sqrt(alpha_t_prev) * pred_x0 + direction
    
    return x_t

print("DDIM uses fewer steps:")
print(f"  DDPM: {T} steps required")
print(f"  DDIM: 50 steps (or fewer) for similar quality")

## 7. Classifier-Free Guidance (CFG)

CFG is the secret sauce behind Stable Diffusion's quality. It allows **trading off diversity for quality**.

The idea: Amplify the difference between conditional and unconditional predictions.

In [None]:
class ConditionalUNet(nn.Module):
    """U-Net that can accept a text/class condition."""
    def __init__(self, num_classes=10):
        super().__init__()
        self.down1 = nn.Conv2d(3, 64, 3, padding=1)
        self.down2 = nn.Conv2d(64, 128, 3, padding=1)
        
        # Conditioning: embed class into same dimension
        self.class_emb = nn.Embedding(num_classes + 1, 128)  # +1 for "no class" (unconditional)
        self.time_mlp = nn.Linear(1, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 3, padding=1)
        self.up2 = nn.ConvTranspose2d(64, 3, 3, padding=1)
    
    def forward(self, x, t, class_label=None):
        # Time embedding
        t = t.float().view(-1, 1)
        t_emb = self.time_mlp(t).view(-1, 128, 1, 1)
        
        # Class embedding (use num_classes as "null" class for unconditional)
        if class_label is None:
            class_label = torch.full((x.shape[0],), 10, device=x.device)  # Null class
        c_emb = self.class_emb(class_label).view(-1, 128, 1, 1)
        
        # Forward pass with conditioning
        x1 = F.relu(self.down1(x))
        x2 = F.relu(self.down2(x1))
        x2 = x2 + t_emb + c_emb  # Add both time and class info
        x = F.relu(self.up1(x2))
        x = self.up2(x)
        return x

def cfg_guided_noise(model, x_t, t, class_label, guidance_scale=7.5):
    """
    Classifier-Free Guidance: Blend conditional and unconditional predictions.
    
    Formula: ε_guided = ε_uncond + w * (ε_cond - ε_uncond)
    
    Where w is the guidance scale:
    - w = 1: No guidance (same as conditional)
    - w = 7.5: Typical value (strong adherence to prompt)
    - w > 10: Very strong guidance (may cause artifacts)
    """
    # Get unconditional prediction (null class)
    noise_uncond = model(x_t, t, class_label=None)
    
    # Get conditional prediction
    noise_cond = model(x_t, t, class_label=class_label)
    
    # CFG formula
    noise_guided = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
    
    return noise_guided

# Demo
cond_model = ConditionalUNet(num_classes=10)
x = torch.randn(2, 3, 64, 64)
t = torch.tensor([50, 100])
labels = torch.tensor([3, 7])  # Classes 3 and 7

noise_no_cfg = cond_model(x, t, labels)
noise_with_cfg = cfg_guided_noise(cond_model, x, t, labels, guidance_scale=7.5)

print(f"Without CFG (guidance_scale=1): noise range [{noise_no_cfg.min():.2f}, {noise_no_cfg.max():.2f}]")
print(f"With CFG (guidance_scale=7.5): noise range [{noise_with_cfg.min():.2f}, {noise_with_cfg.max():.2f}]")
print("\nHigher guidance = stronger adherence to condition, but may cause oversaturation")

## 8. Noise Schedules

The beta schedule significantly affects generation quality. Let's compare different schedules.

In [None]:
def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, timesteps)

def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine schedule from 'Improved DDPM' paper."""
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clamp(betas, 0.0001, 0.9999)

def quadratic_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

# Compare schedules
T = 1000
schedules = {
    'Linear': linear_beta_schedule(T),
    'Cosine': cosine_beta_schedule(T),
    'Quadratic': quadratic_beta_schedule(T)
}

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Plot betas
for name, betas in schedules.items():
    axes[0].plot(betas.numpy(), label=name)
axes[0].set_xlabel('Timestep')
axes[0].set_ylabel('β_t')
axes[0].set_title('Beta Schedule')
axes[0].legend()

# Plot cumulative alpha (signal remaining)
for name, betas in schedules.items():
    alphas_cumprod = torch.cumprod(1 - betas, dim=0)
    axes[1].plot(alphas_cumprod.numpy(), label=name)
axes[1].set_xlabel('Timestep')
axes[1].set_ylabel('α̅_t (signal remaining)')
axes[1].set_title('Signal Retention Over Time')
axes[1].legend()

plt.tight_layout()
plt.show()

print("Cosine schedule (used by Improved DDPM):")
print("  - Preserves signal longer at beginning")
print("  - Smoother transition to noise")
print("  - Better for high-resolution images")

## 9. FAANG Interview Questions

### Q1: Explain the forward and reverse diffusion processes. Why does this work for generation?

**Answer**:

**Forward Process (Fixed)**:
- Gradually adds Gaussian noise over T steps: $q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t I)$
- By t=T, the image is pure noise
- Key property: We can jump directly to any step: $x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon$

**Reverse Process (Learned)**:
- Neural network learns to predict the noise added at each step
- Start from noise, iteratively remove predicted noise
- $p_\theta(x_{t-1}|x_t)$ approximates the true posterior

**Why it works**:
1. Each denoising step is a small, learnable transformation
2. The objective (predict noise) is simple and stable
3. The model learns the data distribution implicitly through noise prediction

---

### Q2: What is Classifier-Free Guidance and why is it important?

**Answer**: CFG combines conditional and unconditional predictions to control generation quality.

**Formula**: $\epsilon_{guided} = \epsilon_{uncond} + w \cdot (\epsilon_{cond} - \epsilon_{uncond})$

**Training**:
- Randomly drop conditioning (e.g., set prompt to empty) with probability ~10%
- Model learns both conditional and unconditional generation

**Inference**:
- w=1: Pure conditional (diverse but may not match prompt)
- w=7.5: Typical (good balance)
- w>10: Strong adherence (may cause artifacts)

**Why important**: Allows users to control the diversity-quality tradeoff without retraining.

---

### Q3: Compare DDPM vs DDIM sampling. When would you use each?

**Answer**:

| Aspect | DDPM | DDIM |
|--------|------|------|
| **Steps** | 1000 (slow) | 20-50 (fast) |
| **Stochastic** | Yes (adds noise each step) | No (deterministic) |
| **Reproducible** | No (random each time) | Yes (same latent → same image) |
| **Quality** | Best | Slightly lower |
| **Use case** | Final production | Fast iteration, interpolation |

Key insight: DDIM uses the same trained model but a different inference algorithm.

---

### Q4: How does Stable Diffusion achieve high-resolution generation efficiently?

**Answer**: Stable Diffusion operates in **latent space**, not pixel space.

**Architecture**:
1. **VAE Encoder**: Compress 512x512 image → 64x64 latent (8x compression)
2. **U-Net**: Denoise in latent space (much cheaper!)
3. **VAE Decoder**: Decompress 64x64 latent → 512x512 image

**Benefits**:
- 8x smaller spatial dimensions = 64x fewer computations per attention layer
- Latent space captures semantic information (easier to learn)
- Same quality as pixel-space diffusion at fraction of compute

**Cost**: Slight loss in fine details (VAE reconstruction isn't perfect)

---

### Q5: What are the main failure modes of diffusion models and how do you address them?

**Answer**:

| Failure | Cause | Solution |
|---------|-------|----------|
| **Mode collapse** | Not enough diversity in training | More data, lower guidance scale |
| **Artifacts** | High CFG, few steps | Lower guidance (5-8), more steps |
| **Poor text following** | Weak text encoder | Better CLIP, T5 encoder |
| **Slow generation** | Many diffusion steps | DDIM, distillation, consistency models |
| **Out-of-distribution** | Unusual prompts | Fine-tune on domain data |
| **Composition failures** | "A red cube on blue sphere" | Better attention, layout guidance |

## Key Takeaways

1. **Diffusion**: A process of slowly destroying data with noise, then learning to reverse it.
2. **Forward Process**: Fixed mathematical formula (no learning) - adds noise.
3. **Reverse Process**: Learned by a U-Net - predicts and removes noise.
4. **Objective**: Predict the noise that was added (simple MSE loss).
5. **DDIM**: Faster, deterministic sampling using the same trained model.
6. **CFG**: Trade off diversity for quality by amplifying conditional predictions.
7. **Noise Schedules**: Cosine schedule preserves signal longer, better for high-res.
8. **Latent Diffusion**: Operate in compressed latent space for efficiency (Stable Diffusion).