
---

# **CHAPTER 26: GENERATIVE AI & DIFFUSION MODELS**

*Engineering Systems for Creation and Synthesis*

## **Chapter Overview**

Generative AI represents a paradigm shift from discriminative models (classification, prediction) to creative systems that produce novel content. This chapter covers the full spectrum of generative architectures: autoregressive models for discrete sequences, Variational Autoencoders (VAEs) for structured latent spaces, Generative Adversarial Networks (GANs) for high-fidelity synthesis, and Diffusion Models that power modern image generation systems like Stable Diffusion and DALL-E 3.

**Estimated Time:** 45-55 hours (4 weeks)  
**Prerequisites:** Chapters 10-14 (Deep Learning, CNNs, Transformers), Chapter 25 (Advanced Transformers), strong PyTorch/TensorFlow skills

---

## **26.0 Learning Objectives**

By the end of this chapter, you will be able to:
1. Implement autoregressive generative models (PixelCNN, WaveNet) for discrete and continuous data
2. Engineer Variational Autoencoders with stable training objectives (ELBO, KL-annealing, β-VAE)
3. Architect and train GANs with spectral normalization, progressive growing, and mode collapse mitigation
4. Implement Denoising Diffusion Probabilistic Models (DDPM) from scratch with variance scheduling
5. Optimize diffusion inference with DDIM, classifier-free guidance, and latent diffusion architectures
6. Design conditional generation systems (ControlNet, inpainting) for controlled content creation

---

## **26.1 Autoregressive Models**

Autoregressive models factorize the joint distribution as a product of conditionals: $p(x) = \prod_{i} p(x_i | x_{<i})$

#### **26.1.1 PixelCNN (Image Generation)**

Uses masked convolutions to preserve autoregressive property (predict pixel given previous pixels).

```python
# pixelcnn.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskedConv2d(nn.Conv2d):
    """
    Type A mask: Current pixel not included (for first layer)
    Type B mask: Current pixel included (for subsequent layers)
    """
    def __init__(self, mask_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.register_buffer('mask', torch.zeros_like(self.weight))
        
        _, _, kH, kW = self.weight.size()
        self.mask[:, :, :kH//2, :] = 1
        self.mask[:, :, kH//2, :kW//2] = 1
        
        if mask_type == 'B':
            self.mask[:, :, kH//2, kW//2] = 1
    
    def forward(self, x):
        self.weight.data *= self.mask
        return super().forward(x)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = MaskedConv2d('B', in_channels, in_channels//2, 1)
        self.conv2 = MaskedConv2d('B', in_channels//2, in_channels//2, 3, padding=1)
        self.conv3 = MaskedConv2d('B', in_channels//2, in_channels, 1)
        
    def forward(self, x):
        residual = x
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return F.relu(x + residual)

class PixelCNN(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=128, num_classes=256):
        super().__init__()
        self.in_conv = MaskedConv2d('A', in_channels, hidden_channels, 7, padding=3)
        
        self.residual_blocks = nn.ModuleList([
            ResidualBlock(hidden_channels) for _ in range(15)
        ])
        
        # Output 256-way softmax per color channel (quantized to 8-bit)
        self.out_conv = nn.Sequential(
            MaskedConv2d('B', hidden_channels, hidden_channels, 1),
            nn.ReLU(),
            MaskedConv2d('B', hidden_channels, in_channels * num_classes, 1)
        )
        
    def forward(self, x):
        # x: (batch, 3, H, W) normalized to [0, 1]
        x = self.in_conv(x)
        
        for block in self.residual_blocks:
            x = block(x)
            
        x = self.out_conv(x)
        return x.view(x.size(0), 256, 3, x.size(2), x.size(3))  # (B, 256, 3, H, W)
    
    def sample(self, batch_size=16, image_size=28, device='cuda'):
        samples = torch.zeros(batch_size, 3, image_size, image_size).to(device)
        
        for i in range(image_size):
            for j in range(image_size):
                for c in range(3):  # RGB channels
                    logits = self.forward(samples)[:, :, c, i, j]
                    probs = F.softmax(logits, dim=1)
                    pixel = torch.multinomial(probs, 1).float() / 255.0
                    samples[:, c, i, j] = pixel.squeeze(1)
        
        return samples
```

**Limitations:** Slow sequential sampling (must generate pixel-by-pixel), difficulty modeling long-range dependencies compared to transformers.

#### **26.1.2 WaveNet (Audio Generation)**

Dilated causal convolutions for raw audio generation (24kHz samples).

```python
class CausalConv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.causal_padding = self.kernel_size[0] - 1
        
    def forward(self, x):
        x = F.pad(x, (self.causal_padding, 0))
        return super().forward(x)

class WaveNetBlock(nn.Module):
    def __init__(self, channels, kernel_size=2, dilation=1):
        super().__init__()
        self.filter_conv = CausalConv1d(channels, channels, kernel_size, dilation=dilation)
        self.gate_conv = CausalConv1d(channels, channels, kernel_size, dilation=dilation)
        self.residual_conv = nn.Conv1d(channels, channels, 1)
        self.skip_conv = nn.Conv1d(channels, channels, 1)
        
    def forward(self, x):
        filter_out = torch.tanh(self.filter_conv(x))
        gate_out = torch.sigmoid(self.gate_conv(x))
        z = filter_out * gate_out
        
        residual = self.residual_conv(z)
        skip = self.skip_conv(z)
        
        return x + residual, skip  # Residual for next layer, skip for output
```

---

## **26.2 Variational Autoencoders (VAEs)**

#### **26.2.1 The ELBO Objective**

Maximize the Evidence Lower Bound:

$$\mathcal{L}(\theta, \phi) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_{KL}(q_\phi(z|x) || p(z))$$

Reconstruction term + KL divergence regularization.

```python
# vae.py
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)  # Log variance for stability
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        """
        Reparameterization trick: z = mu + sigma * epsilon
        Allows backpropagation through stochastic node
        """
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_var)
        recon = self.decode(z)
        return recon, mu, log_var
    
    def loss_function(self, recon, x, mu, log_var, kl_weight=1.0):
        # Reconstruction loss (binary cross-entropy for MNIST)
        BCE = F.binary_cross_entropy(recon, x.view(-1, 784), reduction='sum')
        
        # KL Divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        return BCE + kl_weight * KLD, BCE, KLD
```

**Training Tips:**
- **KL Annealing:** Gradually increase KL weight from 0 to 1 over epochs (prevents posterior collapse early in training)
- **Free Bits:** Modify ELBO to only penalize KL if it drops below threshold (encourages use of latent dimensions)

#### **26.2.2 β-VAE (Disentangled Representations)**

Scale KL term by β > 1 to encourage factorized latent representations.

```python
def beta_vae_loss(recon, x, mu, log_var, beta=4.0):
    BCE = F.binary_cross_entropy(recon, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + beta * KLD
```

#### **26.2.3 VQ-VAE (Vector Quantized VAE)**

Discrete latent space using codebook quantization, used as foundation for DALL-E and Stable Diffusion autoencoders.

```python
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
        
    def forward(self, inputs):
        # inputs: (batch, height, width, embedding_dim)
        # Flatten to (B*H*W, D)
        flat_input = inputs.reshape(-1, self.embedding_dim)
        
        # Calculate distances to codebook vectors
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self.embeddings.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self.embeddings.weight.t()))
        
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.size(0), self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize
        quantized = torch.matmul(encodings, self.embeddings.weight).view(inputs.shape)
        
        # Loss: Straight-through estimator
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)  # Commitment loss
        q_latent_loss = F.mse_loss(quantized, inputs.detach())   # Codebook loss
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        
        # Straight-through estimator (gradient flows through quantized as if identity)
        quantized = inputs + (quantized - inputs).detach()
        
        return quantized, loss, encoding_indices.view(inputs.shape[:-1])
```

---

## **26.3 Generative Adversarial Networks (GANs)**

#### **26.3.1 DCGAN Architecture**

```python
class Generator(nn.Module):
    def __init__(self, latent_dim=100, ngf=64, nc=3):
        super().__init__()
        self.main = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # State: (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # State: (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # State: (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # State: ngf x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: nc x 64 x 64
        )
        
    def forward(self, z):
        return self.main(z.unsqueeze(-1).unsqueeze(-1))

class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.main(x).view(-1, 1).squeeze(1)
```

#### **26.3.2 Training Stability Improvements**

**WGAN-GP (Wasserstein GAN with Gradient Penalty):**
```python
def gradient_penalty(discriminator, real_data, fake_data, device):
    batch_size = real_data.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolates = alpha * real_data + (1 - alpha) * fake_data
    interpolates.requires_grad_(True)
    
    disc_interpolates = discriminator(interpolates)
    gradients = torch.autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True,
        retain_graph=True
    )[0]
    
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = ((gradient_norm - 1) ** 2).mean()
    return penalty

# Loss: D_loss = -mean(D(real)) + mean(D(fake)) + lambda * GP
# G_loss = -mean(D(fake))
```

**Spectral Normalization:** Constrain Lipschitz constant of discriminator by normalizing weights.

```python
from torch.nn.utils import spectral_norm

# Apply to discriminator layers
self.conv = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size))
```

---

## **26.4 Diffusion Models**

#### **26.4.1 DDPM (Denoising Diffusion Probabilistic Models)**

Forward process adds Gaussian noise over $T$ timesteps; reverse process learns to denoise.

```python
# ddpm.py
import math

class Diffusion:
    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02, device='cuda'):
        self.timesteps = timesteps
        self.device = device
        
        # Linear variance schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Precompute values for sampling
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        
        # Posterior variance
        self.posterior_variance = (self.betas * (1.0 - self.alphas_cumprod_prev) / 
                                  (1.0 - self.alphas_cumprod))
        
    def q_sample(self, x_start, t, noise=None):
        """
        Forward diffusion: q(x_t | x_0)
        """
        if noise is None:
            noise = torch.randn_like(x_start)
            
        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_losses(self, denoise_model, x_start, t, noise=None):
        """
        Training loss: MSE between predicted and actual noise
        """
        if noise is None:
            noise = torch.randn_like(x_start)
            
        x_noisy = self.q_sample(x_start, t, noise)
        predicted_noise = denoise_model(x_noisy, t)
        
        loss = F.mse_loss(predicted_noise, noise)
        return loss
    
    @torch.no_grad()
    def p_sample(self, model, x, t, t_index):
        """
        Single step of reverse diffusion
        """
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
        
        # Equation 11 in DDPM paper
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
        )
        
        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_variance_t) * noise
    
    @torch.no_grad()
    def sample(self, model, image_size, batch_size=16, channels=3):
        """
        Generate images by iteratively denoising
        """
        model.eval()
        shape = (batch_size, channels, image_size, image_size)
        img = torch.randn(shape, device=self.device)
        
        for i in tqdm(reversed(range(self.timesteps)), desc='Sampling'):
            t = torch.full((batch_size,), i, device=self.device, dtype=torch.long)
            img = self.p_sample(model, img, t, i)
        
        return img

def extract(a, t, shape):
    """Extract values at timesteps t and reshape"""
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(shape) - 1))).to(t.device)

# U-Net architecture for denoising (simplified)
class UNet(nn.Module):
    def __init__(self, dim=64, dim_mults=(1, 2, 4, 8), channels=3):
        super().__init__()
        # Time embedding
        time_dim = dim * 4
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )
        
        # Downsampling/upsampling with attention
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        # ... (standard U-Net with ResNet blocks and attention)
        
    def forward(self, x, time):
        t = self.time_mlp(time)
        # U-Net forward with skip connections
        return x
```

#### **26.4.2 DDIM (Denoising Diffusion Implicit Models)**

Deterministic sampling (DDPM is stochastic), faster inference (10-50 steps vs. 1000).

```python
class DDIM(Diffusion):
    def __init__(self, timesteps=1000, ddim_timesteps=50, **kwargs):
        super().__init__(timesteps, **kwargs)
        # Subsample timesteps for sampling
        self.ddim_timesteps = torch.linspace(0, timesteps-1, ddim_timesteps).long()
        
    @torch.no_grad()
    def sample(self, model, image_size, batch_size=16, channels=3, eta=0.0):
        """
        eta=0: deterministic (DDIM)
        eta=1: stochastic (DDPM)
        """
        model.eval()
        shape = (batch_size, channels, image_size, image_size)
        img = torch.randn(shape, device=self.device)
        
        for i in tqdm(reversed(range(len(self.ddim_timesteps))), desc='DDIM Sampling'):
            t = torch.full((batch_size,), self.ddim_timesteps[i], device=self.device, dtype=torch.long)
            prev_t = self.ddim_timesteps[i-1] if i > 0 else torch.tensor([0])
            
            # Predict x_0 from x_t
            alpha_t = self.alphas_cumprod[t]
            alpha_prev = self.alphas_cumprod[prev_t] if i > 0 else torch.tensor([1.0])
            
            predicted_noise = model(img, t)
            
            # Predict x_0
            pred_x0 = (img - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)
            
            # Direction pointing to x_t
            dir_xt = torch.sqrt(1 - alpha_prev - eta**2 * (1 - alpha_prev)/(1 - alpha_t) * (1 - alpha_t/alpha_prev)) * predicted_noise
            
            # Random noise (only if eta > 0)
            noise = torch.randn_like(img) if i > 0 else torch.zeros_like(img)
            sigma_t = eta * torch.sqrt((1 - alpha_prev)/(1 - alpha_t) * (1 - alpha_t/alpha_prev))
            
            img = torch.sqrt(alpha_prev) * pred_x0 + dir_xt + sigma_t * noise
        
        return img
```

#### **26.4.3 Classifier-Free Guidance (CFG)**

Trade-off between diversity and fidelity by combining conditional and unconditional predictions.

```python
def guided_denoise_step(model, x, t, context, guidance_scale=7.5):
    """
    context: text embedding or class label
    guidance_scale: 1 = no guidance, 7.5 = high guidance (standard)
    """
    # Conditional prediction
    noise_pred_cond = model(x, t, context=context)
    
    # Unconditional prediction (null embedding)
    noise_pred_uncond = model(x, t, context=torch.zeros_like(context))
    
    # Guided prediction: extrapolate in direction of conditional
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
    
    return noise_pred
```

#### **26.4.4 Latent Diffusion Models (Stable Diffusion)**

Perform diffusion in latent space of VQ-VAE rather than pixel space (computationally efficient).

```python
# Conceptual pipeline
class StableDiffusion:
    def __init__(self, vae, text_encoder, unet, scheduler):
        self.vae = vae  # Pretrained VQ-VAE
        self.text_encoder = text_encoder  # CLIP
        self.unet = unet  # Denoising U-Net
        self.scheduler = scheduler
        
    def encode_prompt(self, prompt):
        return self.text_encoder(prompt)
    
    @torch.no_grad()
    def __call__(self, prompt, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
        # 1. Encode text
        text_embeddings = self.encode_prompt(prompt)
        uncond_embeddings = self.encode_prompt("")  # Negative prompt
        
        # 2. Prepare latents (random noise in latent space)
        latents = torch.randn((1, 4, height//8, width//8))  # 64x64 for 512px image
        
        # 3. Denoising loop in latent space
        self.scheduler.set_timesteps(num_inference_steps)
        
        for t in self.scheduler.timesteps:
            # Expand latents for classifier-free guidance
            latent_model_input = torch.cat([latents] * 2)
            
            # Predict noise
            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=torch.cat([uncond_embeddings, text_embeddings]))
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
            
            # Compute previous noisy sample
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
        
        # 4. Decode latents to image
        image = self.vae.decode(latents)
        return image
```

#### **26.4.5 ControlNet**

Inject spatial conditioning (canny edges, depth maps, poses) into diffusion model via trainable copies of U-Net blocks.

```python
class ControlNet(nn.Module):
    def __init__(self, base_unet, control_channels=3):
        super().__init__()
        # Lock base model parameters
        for param in base_unet.parameters():
            param.requires_grad = False
            
        # Create trainable copies of encoding blocks
        self.controlnet_blocks = nn.ModuleList([
            copy_of_unet_block() for _ in range(num_encoding_blocks)
        ])
        
        # Zero convolution to connect control features
        self.zero_convs = nn.ModuleList([
            nn.Conv2d(block_out_channels, block_out_channels, 1) for _ in range(num_blocks)
        ])
        
    def forward(self, x, t, context, control_image):
        # Process control image through ControlNet blocks
        control_features = []
        h = control_image
        for block, zero_conv in zip(self.controlnet_blocks, self.zero_convs):
            h = block(h)
            control_features.append(zero_conv(h))  # Initially zeros, gradually learn
        
        # Add control features to base U-Net skip connections
        output = self.base_unet(x, t, context, control_features=control_features)
        return output
```

---

## **26.5 Flow Matching & Consistency Models**

#### **26.5.1 Flow Matching (Rectified Flow)**

Directly regress vector field of probability path, avoiding stochastic sampling.

```python
class FlowMatching:
    def __init__(self, sigma_min=0.0):
        self.sigma_min = sigma_min
        
    def sample_location_and_conditional_flow(self, x0, x1, t):
        """
        x0: noise (source), x1: data (target)
        """
        # Sample probability path
        xt = (1 - t) * x0 + t * x1 + sigma_min * torch.randn_like(x0)
        
        # Vector field: direction from x0 to x1
        ut = x1 - x0
        
        return xt, ut
    
    def loss(self, model, x1):
        """
        MSE between predicted and target vector field
        """
        x0 = torch.randn_like(x1)
        t = torch.rand(x1.size(0), device=x1.device)
        
        xt, ut = self.sample_location_and_conditional_flow(x0, x1, t.unsqueeze(-1))
        
        vt = model(xt, t)  # Predicted velocity
        return F.mse_loss(vt, ut)
```

#### **26.5.2 Consistency Models**

Learn to map any point on diffusion path directly to data (single-step generation).

```python
class ConsistencyModel(nn.Module):
    def __init__(self, num_distillation_steps=800000):
        super().__init__()
        self.student = UNetModel()
        self.teacher = copy.deepcopy(self.student)
        self.teacher.eval()
        
    def update_teacher(self, student_params, ema_decay=0.999):
        # EMA update of teacher from student
        for teacher_param, student_param in zip(self.teacher.parameters(), student_params):
            teacher_param.data = ema_decay * teacher_param.data + (1 - ema_decay) * student_param.data
    
    def loss(self, x, t):
        # Consistency loss: f(x_t, t) should equal f(x_{t-1}, t-1) approx x_0
        z = torch.randn_like(x)
        x_t = diffusion_sample(x, t, z)
        x_t_next = diffusion_sample(x, t-1, z) if t > 0 else x
        
        pred_x = self.student(x_t, t)
        with torch.no_grad():
            target_x = self.teacher(x_t_next, t-1) if t > 0 else x
        
        return F.mse_loss(pred_x, target_x)
```

---

## **26.6 Workbook Labs**

### **Lab 1: VAE Implementation**
Build a convolutional VAE for CelebA:

1. **Architecture:** Encoder (4 conv layers) → latent (256-dim) → Decoder (4 transposed conv)
2. **Training:** Implement KL-annealing schedule (0 to 1 over 20 epochs)
3. **Analysis:** Visualize latent space (interpolation between faces, clustering by attributes)
4. **Ablation:** Compare β=1 vs β=4 (disentanglement vs reconstruction quality)

**Deliverable:** Generated face samples, latent space visualization, disentanglement metrics.

### **Lab 2: GAN Training**
Train DCGAN on CIFAR-10:

1. **Baseline:** Standard GAN (vanilla loss) - document mode collapse
2. **Improvement:** Implement WGAN-GP with spectral normalization
3. **Metrics:** Calculate FID (Frechet Inception Distance) for both
4. **Conditional:** Extend to CGAN (class-conditional generation)

**Deliverable:** Comparative FID scores, generated image grids, training stability plots.

### **Lab 3: DDPM from Scratch**
Implement image diffusion on MNIST:

1. **Forward Process:** Implement noise schedule and q_sample
2. **Model:** U-Net with time embeddings and attention
3. **Training:** Train for 50 epochs, log noise prediction loss
4. **Sampling:** Implement both DDPM (1000 steps) and DDIM (50 steps) sampling
5. **Interpolation:** Sample two latents, interpolate, decode trajectory

**Deliverable:** Training curves, generated digit samples, interpolation GIFs.

### **Lab 4: Latent Diffusion**
Build simplified Stable Diffusion pipeline:

1. **Autoencoder:** Pretrain VQ-VAE on ImageNet subset (64x64 → 16x16 latent)
2. **UNet:** Denoising model with cross-attention for text conditioning
3. **Text Encoder:** Use pretrained CLIP or train small transformer
4. **Inference:** Implement classifier-free guidance, generate conditioned images

**Deliverable:** Text-to-image generation demo, guidance scale ablation study.

---

## **26.7 Common Pitfalls**

1. **Mode Collapse (GANs):** Generator finds single output that fools discriminator. **Solution:** WGAN-GP, spectral norm, unrolled GANs, or diversity-regularized losses.

2. **Posterior Collapse (VAEs):** Encoder ignores input, KL term drives latent to prior. **Solution:** KL annealing, free bits technique, or aggressive encoder training cycles.

3. **Training Instability (Diffusion):** Loss explosion due to variance schedule. **Solution:** Use cosine schedule instead of linear, gradient clipping, mixed precision careful tuning.

4. **Codebook Collapse (VQ-VAE):** Only subset of embeddings used. **Solution:** EMA updates for codebook, commitment loss weight tuning, random restarts of unused codes.

5. **CFG Ignoring Conditioning:** High guidance scale causes saturation. **Solution:** Scale appropriately (7.5 standard), use negative prompts to steer away from unwanted concepts.

---

## **26.8 Interview Questions**

**Q1:** Compare VAEs, GANs, and Diffusion Models. When would you use each?
*A: VAEs: Probabilistic, structured latent space good for representation learning and semi-supervised learning, but blurry samples. GANs: Sharp, high-quality samples, fast sampling (single forward pass), but training instability and mode collapse. Diffusion: State-of-the-art quality, stable training, good mode coverage, but slow sampling (iterative). Use VAEs for compression/feature extraction, GANs for real-time generation (games, video), Diffusion for highest quality image synthesis where latency less critical (art, design tools). Hybrid approaches (Latent Diffusion) combine VAE compression with diffusion quality.*

**Q2:** Explain why Diffusion Models require many sampling steps, and how DDIM improves this.
*A: DDPM samples by reversing Markov chain with small steps to ensure Gaussian transitions remain valid (approximation quality). Each step requires neural network evaluation → 1000 steps = slow. DDIM uses non-Markovian sampling: defines generative process as implicit model allowing deterministic sampling and jump steps. Can skip timesteps (e.g., every 20th) because it directly predicts x_0 from x_t, then reconstructs x_{t-1} without requiring intermediate steps to be valid diffusion states. Reduces steps 1000→50 with minimal quality loss.*

**Q3:** What is Classifier-Free Guidance and why does it work?
*A: CFG combines conditional (text prompt) and unconditional (empty prompt) predictions: pred = uncond + scale*(cond - uncond). It works by pushing samples toward high-likelihood regions of conditional distribution while using unconditional prediction as baseline. High scale (>1) sharpens distribution, increasing fidelity to prompt but reducing diversity (approaches mode of distribution). Mathematically equivalent to lowering temperature of conditional distribution. Scale 7.5 balances fidelity and diversity; scale 1.0 is pure conditional sampling.*

**Q4:** How does the reparameterization trick enable VAE training?
*A: Sampling z from N(μ, σ) is stochastic node, blocking backpropagation to encoder parameters. Trick: z = μ + σ*ε where ε ~ N(0,1). Now ε is external noise (stop gradient), μ and σ are deterministic outputs of encoder. Gradients flow through μ and σ to encoder, while randomness comes from fixed ε. Allows end-to-end training of encoder/decoder via backpropagation through sampling operation, maintaining stochastic node while keeping computation graph differentiable.*

**Q5:** Design a system to generate 1024x1024 images in real-time (30 FPS).
*A: Latent Diffusion is required (pixel space diffusion too slow). Architecture: (1) High-compression VAE (f=8 or f=16, 1024→64 or 128 latent), (2) Efficient U-Net (MobileNet-style blocks, attention only at lower resolutions), (3) Consistency model or distilled diffusion (LCM/SDXL Turbo) for 4-step sampling instead of 50, (4) Model quantization (INT8) and TensorRT optimization, (5) Tiled generation if memory constrained, (6) Potentially cascaded generation: 256→512→1024 with separate upscalers. Hardware: Multiple GPUs or TPUs with pipeline parallelism (different stages on different devices).*

---

## **26.9 Further Reading**

**Papers:**
- "Auto-Encoding Variational Bayes" (Kingma & Welling, 2014) - VAE foundation
- "Generative Adversarial Networks" (Goodfellow et al., 2014)
- "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
- "Denoising Diffusion Implicit Models" (Song et al., 2021)
- "High-Resolution Image Synthesis with Latent Diffusion Models" (Rombach et al., 2022) - Stable Diffusion
- "Adding Conditional Control to Text-to-Image Diffusion Models" (Zhang et al., 2023) - ControlNet

**Tools:**
- **diffusers:** Hugging Face diffusion model library
- **k-diffusion:** Advanced samplers (DPM++ 2M Karras, etc.)
- **ComfyUI:** Node-based diffusion workflow engine
- **LoRA:** Low-rank adaptation for efficient fine-tuning

---

## **26.10 Checkpoint Project: Production Generative System**

Build a commercial-grade text-to-image generation service.

**Requirements:**

1. **Model Architecture:**
   - Base: Fine-tuned Stable Diffusion XL or custom Latent Diffusion
   - Conditioning: Text (CLIP), Canny edges (ControlNet), Depth (MiDaS)
   - Personalization: LoRA adapters for custom styles (loadable at runtime)

2. **Inference Optimization:**
   - Implement DPM++ 2M Karras sampler (quality/speed trade-off)
   - Compile model with TensorRT or ONNX Runtime
   - Batch requests dynamically (up to 4 images per batch)
   - Memory-efficient attention (xFormers/FlashAttention)

3. **API Design:**
   - Async generation endpoints (return job ID, webhook on completion)
   - Progress streaming (SSE or WebSocket showing denoising steps)
   - Negative prompt support, seed control for reproducibility
   - Safety filter: NSFW detection and blur/filter pipeline

4. **Scaling:**
   - Queue-based architecture (Redis/RabbitMQ) with worker pools
   - Model sharding: Different GPUs for different resolutions or LoRAs
   - Shared KV-cache for batched text embeddings
   - Image upscaling pipeline (ESRGAN/Real-ESRGAN) for 4x resolution

5. **Evaluation:**
   - FID/CLIP score benchmarks on standard prompts
   - Human evaluation protocol (A/B testing)
   - Latency targets: 1024x1024 in <5 seconds on A100

**Deliverables:**
- `generative_service/` with FastAPI backend and diffusion pipeline
- Docker Compose with GPU worker setup
- Postman collection or Swagger docs for API
- Benchmark report: Latency vs. batch size, quality metrics

**Success Criteria:**
- Generate 1024x1024 images in <5 seconds end-to-end
- Support 10 concurrent users without queue overflow
- Successfully apply and switch LoRA weights without reloading base model
- NSFW filtering with <1% false positive rate on safe content

---

**End of Chapter 26**

*You now master generative AI architectures. Chapter 27 covers Multimodal AI—combining vision, language, and audio.*

<div style='width:100%; display:flex; justify-content:space-between; align-items:center; margin: 1em 0;'>
  <a href='25. transformer_architecture_deep_dive.ipynb' style='font-weight:bold; font-size:1.05em;'>&larr; Previous</a>
  <a href='../TOC.md' style='font-weight:bold; font-size:1.05em; text-align:center;'>Table of Contents</a>
  <a href='27. multimodal_ai.ipynb' style='font-weight:bold; font-size:1.05em;'>Next &rarr;</a>
</div>
