In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1rhbO-3WtvU0YjQYQrCebpE_IATSZyRCX", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ Latent Video Diffusion and the Diffusion Transformer (DiT)

*Part 3 of the Vizuara series on Diffusion Models for Video Generation*
*Estimated time: 50 minutes*

In Notebooks 1 and 2, we built video diffusion models that operate directly on pixel space. This works for our 32√ó32 Moving MNIST examples, but real videos are 256√ó256 or higher ‚Äî making pixel-space diffusion prohibitively expensive.

In this notebook, we tackle the two biggest ideas in modern video generation:
1. **Latent Video Diffusion** ‚Äî compress the video into a tiny latent space first, then run diffusion there
2. **Diffusion Transformers (DiT)** ‚Äî replace the U-Net entirely with a Transformer over spacetime patches

By the end of this notebook, you will:
- Build a simple video VAE (encoder + decoder)
- Run diffusion in latent space for massive efficiency gains
- Implement text conditioning via cross-attention
- Build a mini Diffusion Transformer with spacetime patches
- Generate text-conditioned videos from your DiT model

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://pods.vizuara.ai/courses/diffusion-models-video-generation/practice/3/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


In [None]:
#@title üéß Listen: Why It Matters
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/01_why_it_matters.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 1. Why Does This Matter?

Let us revisit the numbers from the article.

A 16-frame video at 256√ó256 resolution: $16 \times 256 \times 256 \times 3 = 3{,}145{,}728$ values.

After encoding with a VAE (spatial downsampling 8√ó, 4 latent channels): $16 \times 32 \times 32 \times 4 = 65{,}536$ values.

That is a **48√ó compression**. The diffusion model now operates on a tensor that is 48 times smaller. Training is faster, inference is faster, and memory usage drops dramatically.

On top of this, the Diffusion Transformer (DiT) architecture further simplifies the design. Instead of carefully engineering spatial convolutions, temporal convolutions, and separate attention blocks in a U-Net, we simply:
1. Cut the video into spacetime patches
2. Flatten them into a token sequence
3. Run a standard Transformer

This is the architecture behind Sora, and it scales predictably with model size ‚Äî a property inherited from the language modeling world.

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### The VAE Analogy

Think of the VAE as a skilled summarizer. Imagine you need to email a friend about a movie you watched. You would not send them the raw video file (too large). Instead, you would write a concise summary ‚Äî capturing the key plot points, visual style, and emotional beats in a much smaller representation. Your friend (the decoder) can then reconstruct a mental image of the movie from your summary.

The video VAE does the same thing with learned representations. The encoder compresses each frame from 256√ó256√ó3 to 32√ó32√ó4 ‚Äî keeping the essential visual information while discarding redundancy.

### The DiT Analogy

Think of the shift from U-Net to Transformer like the shift from specialized to general-purpose computing. A U-Net is like a custom-built circuit ‚Äî highly optimized for the task but rigid in structure. A Transformer is like a general-purpose processor ‚Äî you can make it bigger, train it on more data, and it just keeps getting better. This is why language models (which are Transformers) scale so well, and it is why the video generation field is moving in the same direction.

### ü§î Think About This

Before we proceed:
1. What information might the VAE lose during compression? How would this affect video quality?
2. Why do spacetime patches mix spatial and temporal information, while factorized attention keeps them separate?
3. If you had unlimited compute, would you still use a VAE, or run diffusion directly in pixel space?

*Take 2 minutes. Then scroll down.*

In [None]:
#@title üéß Listen: Math
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_math.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

### Video VAE

The encoder maps a video $\mathbf{v}$ to a latent representation $\mathbf{z}$:

$$\mathbf{z} = \mathcal{E}(\mathbf{v}), \quad \mathbf{v} \in \mathbb{R}^{T \times H \times W \times 3}, \quad \mathbf{z} \in \mathbb{R}^{T \times h \times w \times c}$$

where $h = H / f_s$, $w = W / f_s$, and $f_s$ is the spatial downsampling factor (typically 8).

Computationally: take each frame, run it through a series of strided convolutions that progressively halve the spatial resolution (3 times for factor 8: 256‚Üí128‚Üí64‚Üí32), while expanding the channel dimension.

The decoder inverts this:

$$\hat{\mathbf{v}} = \mathcal{D}(\mathbf{z})$$

The VAE is trained with a reconstruction loss plus a KL divergence regularizer:

$$\mathcal{L}_{\text{VAE}} = \|\mathbf{v} - \hat{\mathbf{v}}\|^2 + \beta \cdot D_{\text{KL}}(q(\mathbf{z}|\mathbf{v}) \| p(\mathbf{z}))$$

The first term ensures the reconstruction is accurate. The second term keeps the latent space well-organized (close to a standard Gaussian), which is crucial because the diffusion model will sample from this space.

Let us plug in numbers. If $\|\mathbf{v} - \hat{\mathbf{v}}\|^2 = 0.015$ (good reconstruction) and $D_{\text{KL}} = 3.2$ with $\beta = 0.001$:

$\mathcal{L}_{\text{VAE}} = 0.015 + 0.001 \times 3.2 = 0.015 + 0.0032 = 0.0182$

The reconstruction loss dominates ‚Äî we want high-quality decoding above all.

### Spacetime Patches for DiT

Given a video $\mathbf{v} \in \mathbb{R}^{T \times H \times W \times C}$, we divide it into non-overlapping 3D patches of size $t_p \times h_p \times w_p$:

$$\text{Number of tokens} = \frac{T}{t_p} \times \frac{H}{h_p} \times \frac{W}{w_p}$$

Each patch is flattened and linearly projected to the Transformer's hidden dimension $d$:

$$z_i = \text{Linear}(\text{flatten}(p_i)) \in \mathbb{R}^d$$

For a 16-frame video at 32√ó32 latent resolution with patches of size $2 \times 4 \times 4$:

$\frac{16}{2} \times \frac{32}{4} \times \frac{32}{4} = 8 \times 8 \times 8 = 512 \text{ tokens}$

Each token encodes a small 3D cube of spacetime ‚Äî containing information about both spatial content and temporal evolution within that cube.

### Cross-Attention for Text Conditioning

To condition on text, we add cross-attention layers where queries come from the video tokens and keys/values come from text embeddings:

$$\text{CrossAttn}(Q_{\text{video}}, K_{\text{text}}, V_{\text{text}}) = \text{softmax}\left(\frac{Q_{\text{video}} K_{\text{text}}^T}{\sqrt{d_k}}\right) \cdot V_{\text{text}}$$

Computationally: each video patch looks at the full text description and extracts the information relevant to that spacetime location. A patch showing a dog's legs would attend strongly to the word "running", while a patch showing the sky would attend to "sunny day".

In [None]:
#@title üéß Listen: Building Encoder
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_building_encoder.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

In [None]:
#@title üéß Listen: Encoder Code
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/05_encoder_code.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.1 Video VAE ‚Äî Encoder

We will build a simple video VAE that compresses spatial dimensions. For our 32√ó32 Moving MNIST, we will downsample by 4√ó to get 8√ó8 latents.

In [None]:
class VideoEncoder(nn.Module):
    """
    Encodes video frames into a latent representation.
    Processes each frame independently with 2D convolutions.

    Input: (B, C_in, T, H, W)
    Output: mean and log_var, each (B, C_latent, T, H/f, W/f)
    """
    def __init__(self, in_channels=1, latent_channels=4, base_dim=32):
        super().__init__()
        # Downsample 4x: 32 -> 16 -> 8
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, base_dim, 3, stride=2, padding=1),
            nn.GroupNorm(4, base_dim),
            nn.GELU(),
            nn.Conv2d(base_dim, base_dim * 2, 3, stride=2, padding=1),
            nn.GroupNorm(4, base_dim * 2),
            nn.GELU(),
            nn.Conv2d(base_dim * 2, base_dim * 2, 3, padding=1),
            nn.GELU(),
        )
        # Output mean and log_var
        self.to_mean = nn.Conv2d(base_dim * 2, latent_channels, 1)
        self.to_logvar = nn.Conv2d(base_dim * 2, latent_channels, 1)

    def forward(self, x):
        # x: (B, C, T, H, W)
        B, C, T, H, W = x.shape
        # Process each frame independently
        x_frames = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
        h = self.encoder(x_frames)
        mean = self.to_mean(h)
        logvar = self.to_logvar(h)
        # Reshape back: (B*T, C_lat, H/f, W/f) -> (B, C_lat, T, H/f, W/f)
        _, C_lat, Hf, Wf = mean.shape
        mean = mean.reshape(B, T, C_lat, Hf, Wf).permute(0, 2, 1, 3, 4)
        logvar = logvar.reshape(B, T, C_lat, Hf, Wf).permute(0, 2, 1, 3, 4)
        return mean, logvar

    def sample(self, mean, logvar):
        """Reparameterization trick."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + std * eps

In [None]:
# Test the encoder
encoder = VideoEncoder(in_channels=1, latent_channels=4).to(device)
test_video = torch.randn(2, 1, 8, 32, 32, device=device)
mean, logvar = encoder(test_video)
z = encoder.sample(mean, logvar)

print(f"Input video:  {test_video.shape}  (B, C, T, H, W)")
print(f"Latent mean:  {mean.shape}  (B, C_lat, T, H/4, W/4)")
print(f"Latent sample: {z.shape}")
compression = test_video.numel() / z.numel()
print(f"Compression ratio: {compression:.1f}x")
print("‚úÖ Encoder works!")

In [None]:
#@title üéß Listen: Decoder
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_decoder.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 Video VAE ‚Äî Decoder

In [None]:
class VideoDecoder(nn.Module):
    """
    Decodes latent representation back to video frames.

    Input: (B, C_latent, T, H/f, W/f)
    Output: (B, C_out, T, H, W)
    """
    def __init__(self, out_channels=1, latent_channels=4, base_dim=32):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d(latent_channels, base_dim * 2, 3, padding=1),
            nn.GroupNorm(4, base_dim * 2),
            nn.GELU(),
            nn.ConvTranspose2d(base_dim * 2, base_dim, 2, stride=2),
            nn.GroupNorm(4, base_dim),
            nn.GELU(),
            nn.ConvTranspose2d(base_dim, base_dim, 2, stride=2),
            nn.GELU(),
            nn.Conv2d(base_dim, out_channels, 1),
            nn.Sigmoid()  # Output in [0, 1]
        )

    def forward(self, z):
        # z: (B, C_lat, T, H/f, W/f)
        B, C_lat, T, Hf, Wf = z.shape
        z_frames = z.permute(0, 2, 1, 3, 4).reshape(B * T, C_lat, Hf, Wf)
        decoded = self.decoder(z_frames)
        _, C_out, H, W = decoded.shape
        return decoded.reshape(B, T, C_out, H, W).permute(0, 2, 1, 3, 4)

In [None]:
# Test encoder ‚Üí decoder roundtrip
decoder = VideoDecoder(out_channels=1, latent_channels=4).to(device)
reconstructed = decoder(z)

print(f"Original video:      {test_video.shape}")
print(f"Latent:              {z.shape}")
print(f"Reconstructed video: {reconstructed.shape}")
print("‚úÖ Full VAE roundtrip works!")

In [None]:
#@title üéß Listen: Vae Pipeline Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/07_vae_pipeline_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä Visualize the compression pipeline
fig, axes = plt.subplots(3, 8, figsize=(14, 5))

# Create a sample video
from collections import OrderedDict
sample_video = torch.randn(1, 1, 8, 32, 32, device=device).clamp(0, 1)

# Original
for col in range(8):
    axes[0, col].imshow(sample_video[0, 0, col].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axes[0, col].axis('off')
    if col == 0:
        axes[0, col].set_ylabel('Original\n32√ó32', fontsize=10, rotation=0, labelpad=50)

# Latent (show first channel)
mean_vis, logvar_vis = encoder(sample_video)
z_vis = encoder.sample(mean_vis, logvar_vis)
for col in range(8):
    axes[1, col].imshow(z_vis[0, 0, col].cpu().detach().numpy(), cmap='viridis')
    axes[1, col].axis('off')
    if col == 0:
        axes[1, col].set_ylabel('Latent\n8√ó8', fontsize=10, rotation=0, labelpad=50)

# Reconstructed
recon_vis = decoder(z_vis)
for col in range(8):
    axes[2, col].imshow(recon_vis[0, 0, col].cpu().detach().numpy(), cmap='gray', vmin=0, vmax=1)
    axes[2, col].axis('off')
    if col == 0:
        axes[2, col].set_ylabel('Decoded\n32√ó32', fontsize=10, rotation=0, labelpad=50)

fig.suptitle('Video VAE Pipeline: Encode ‚Üí Latent ‚Üí Decode', fontsize=14)
plt.tight_layout()
plt.show()
print(f"üí° Latent space is {z_vis.shape[-2]}√ó{z_vis.shape[-1]} ‚Äî 4x smaller in each spatial dimension")

In [None]:
#@title üéß Listen: Train Vae
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/08_train_vae.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 Training the VAE

In [None]:
def create_moving_mnist_video(num_frames=16, size=32, digit_size=12):
    """Create a single Moving MNIST video."""
    video = np.zeros((num_frames, size, size), dtype=np.float32)
    x = np.random.randint(0, size - digit_size)
    y = np.random.randint(0, size - digit_size)
    vx = np.random.choice([-2, -1, 1, 2])
    vy = np.random.choice([-2, -1, 1, 2])

    digit = np.zeros((digit_size, digit_size), dtype=np.float32)
    center = digit_size // 2
    for i in range(digit_size):
        for j in range(digit_size):
            if ((i - center)**2 + (j - center)**2) ** 0.5 < center:
                digit[i, j] = 1.0

    for t in range(num_frames):
        video[t, y:y+digit_size, x:x+digit_size] = digit
        x += vx; y += vy
        if x <= 0 or x >= size - digit_size: vx = -vx; x = max(0, min(size - digit_size, x))
        if y <= 0 or y >= size - digit_size: vy = -vy; y = max(0, min(size - digit_size, y))
    return video

def create_dataset(num_videos=512, num_frames=8, size=32):
    videos = np.stack([create_moving_mnist_video(num_frames, size) for _ in range(num_videos)])
    return torch.tensor(videos).unsqueeze(1)

dataset = create_dataset(512, 8, 32)
print(f"Dataset: {dataset.shape}")

In [None]:
# Train the VAE
encoder = VideoEncoder(in_channels=1, latent_channels=4).to(device)
decoder = VideoDecoder(out_channels=1, latent_channels=4).to(device)

vae_params = list(encoder.parameters()) + list(decoder.parameters())
vae_optimizer = torch.optim.Adam(vae_params, lr=3e-4)
beta_kl = 0.0001  # Small KL weight ‚Äî we prioritize reconstruction

vae_losses = []
print("Training Video VAE...")

for epoch in range(40):
    epoch_loss = 0
    perm = torch.randperm(len(dataset))

    for i in range(0, len(dataset) - 16 + 1, 16):
        batch = dataset[perm[i:i+16]].to(device)

        mean, logvar = encoder(batch)
        z = encoder.sample(mean, logvar)
        recon = decoder(z)

        # Reconstruction loss
        recon_loss = F.mse_loss(recon, batch)

        # KL divergence
        kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())

        loss = recon_loss + beta_kl * kl_loss

        vae_optimizer.zero_grad()
        loss.backward()
        vae_optimizer.step()
        epoch_loss += loss.item()

    avg_loss = epoch_loss / (len(dataset) // 16)
    vae_losses.append(avg_loss)

    if (epoch + 1) % 10 == 0:
        print(f"  Epoch {epoch+1}/40 ‚Äî Loss: {avg_loss:.4f} "
              f"(Recon: {recon_loss.item():.4f}, KL: {kl_loss.item():.2f})")

print("VAE training complete!")

In [None]:
# üìä VAE reconstruction quality
fig, axes = plt.subplots(2, 8, figsize=(14, 3.5))

sample = dataset[:1].to(device)
with torch.no_grad():
    mean, logvar = encoder(sample)
    z = encoder.sample(mean, logvar)
    recon = decoder(z)

for col in range(8):
    axes[0, col].imshow(sample[0, 0, col].cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    axes[0, col].axis('off')
    axes[1, col].imshow(recon[0, 0, col].cpu().detach().numpy(), cmap='gray', vmin=0, vmax=1)
    axes[1, col].axis('off')

axes[0, 0].set_ylabel('Original', fontsize=10, rotation=0, labelpad=45)
axes[1, 0].set_ylabel('Reconstructed', fontsize=10, rotation=0, labelpad=45)
fig.suptitle('VAE Reconstruction Quality', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Latent Diffusion Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/09_latent_diffusion_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. Latent Video Diffusion

Now that we have a working VAE, let us run diffusion in the latent space. The idea is simple:
1. Encode videos to latents using the frozen VAE encoder
2. Train a diffusion model to denoise in latent space
3. At inference: sample noise ‚Üí denoise in latent space ‚Üí decode to pixels

### 5.1 Pre-encode the Dataset

In [None]:
# Freeze the VAE and encode all training videos to latent space
encoder.eval()
decoder.eval()

with torch.no_grad():
    latent_dataset = []
    for i in range(0, len(dataset), 32):
        batch = dataset[i:i+32].to(device)
        mean, logvar = encoder(batch)
        z = encoder.sample(mean, logvar)
        latent_dataset.append(z.cpu())

latent_dataset = torch.cat(latent_dataset, dim=0)
print(f"Pixel dataset:  {dataset.shape}    ({dataset.numel():,} values)")
print(f"Latent dataset: {latent_dataset.shape}  ({latent_dataset.numel():,} values)")
print(f"Compression: {dataset.numel() / latent_dataset.numel():.1f}x")

In [None]:
#@title üéß Listen: Latent Unet
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/10_latent_unet.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 5.2 Simple Latent Diffusion Model

For the latent diffusion model, we will use a small 2D U-Net that processes each frame of the latent independently, plus temporal mixing layers.

In [None]:
class LatentDiffusionUNet(nn.Module):
    """
    A U-Net for denoising in latent space.
    Input: (B, C_lat, T, H_lat, W_lat) ‚Äî noisy latent video
    Output: (B, C_lat, T, H_lat, W_lat) ‚Äî predicted noise
    """
    def __init__(self, latent_channels=4, base_dim=64):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(base_dim, base_dim * 4),
            nn.GELU(),
            nn.Linear(base_dim * 4, base_dim)
        )
        self.time_dim = base_dim

        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(latent_channels, base_dim, 3, padding=1),
            nn.GroupNorm(4, base_dim), nn.GELU())
        self.enc2 = nn.Sequential(
            nn.Conv2d(base_dim, base_dim * 2, 3, stride=2, padding=1),
            nn.GroupNorm(4, base_dim * 2), nn.GELU())

        # Temporal mixing (1D conv across frames)
        self.temporal1 = nn.Conv1d(base_dim, base_dim, 3, padding=1)
        self.temporal2 = nn.Conv1d(base_dim * 2, base_dim * 2, 3, padding=1)

        # Decoder
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(base_dim * 4, base_dim, 2, stride=2),
            nn.GroupNorm(4, base_dim), nn.GELU())
        self.out = nn.Conv2d(base_dim * 2, latent_channels, 1)

        self.time_proj1 = nn.Linear(base_dim, base_dim)
        self.time_proj2 = nn.Linear(base_dim, base_dim * 2)

    def forward(self, x, t):
        B, C, T, H, W = x.shape

        # Time embedding
        half = self.time_dim // 2
        freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
        t_emb = torch.cat([
            (t[:, None].float() * freqs[None]).sin(),
            (t[:, None].float() * freqs[None]).cos()
        ], dim=-1)
        t_emb = self.time_mlp(t_emb)

        # Process frames with 2D convs
        x_frames = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)

        h1 = self.enc1(x_frames)  # (B*T, base, H, W)
        # Add time embedding
        te1 = self.time_proj1(t_emb).repeat_interleave(T, 0)[:, :, None, None]
        h1 = h1 + te1

        # Temporal mixing on h1
        _, C1, H1, W1 = h1.shape
        h1_t = h1.reshape(B, T, C1, H1, W1).permute(0, 3, 4, 2, 1)  # (B, H, W, C, T)
        h1_t = h1_t.reshape(B * H1 * W1, C1, T)
        h1_t = h1_t + self.temporal1(h1_t)
        h1_t = h1_t.reshape(B, H1, W1, C1, T).permute(0, 4, 3, 1, 2)  # (B, T, C, H, W)
        h1 = h1_t.reshape(B * T, C1, H1, W1)

        h2 = self.enc2(h1)  # (B*T, base*2, H/2, W/2)
        te2 = self.time_proj2(t_emb).repeat_interleave(T, 0)[:, :, None, None]
        h2 = h2 + te2

        # Temporal mixing on h2
        _, C2, H2, W2 = h2.shape
        h2_t = h2.reshape(B, T, C2, H2, W2).permute(0, 3, 4, 2, 1).reshape(B * H2 * W2, C2, T)
        h2_t = h2_t + self.temporal2(h2_t)
        h2_t = h2_t.reshape(B, H2, W2, C2, T).permute(0, 4, 3, 1, 2)
        h2 = h2_t.reshape(B * T, C2, H2, W2)

        # Decode with skip
        dec = self.dec2(torch.cat([h2, h2], dim=1))
        out = self.out(torch.cat([dec, h1], dim=1))

        return out.reshape(B, T, -1, H, W).permute(0, 2, 1, 3, 4)

In [None]:
# Test
latent_model = LatentDiffusionUNet(latent_channels=4, base_dim=64).to(device)
test_z = torch.randn(2, 4, 8, 8, 8, device=device)
test_t = torch.randint(0, 500, (2,), device=device)
test_out = latent_model(test_z, test_t)

print(f"Latent input:  {test_z.shape}")
print(f"Noise output:  {test_out.shape}")
params = sum(p.numel() for p in latent_model.parameters())
print(f"Parameters: {params:,}")
print("‚úÖ Latent diffusion model works!")

In [None]:
#@title üéß Listen: Train Latent
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/11_train_latent.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 5.3 Train Latent Diffusion

In [None]:
# Diffusion schedule
num_timesteps = 500
betas = torch.linspace(0.0001, 0.02, num_timesteps, device=device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

latent_model = LatentDiffusionUNet(latent_channels=4, base_dim=64).to(device)
optimizer = torch.optim.Adam(latent_model.parameters(), lr=2e-4)

latent_losses = []
print("Training Latent Diffusion Model...")

for epoch in range(30):
    epoch_loss = 0
    perm = torch.randperm(len(latent_dataset))

    for i in range(0, len(latent_dataset) - 16 + 1, 16):
        batch = latent_dataset[perm[i:i+16]].to(device)
        B = batch.shape[0]

        t = torch.randint(0, num_timesteps, (B,), device=device)
        noise = torch.randn_like(batch)

        # Forward diffusion in latent space
        sqrt_ac = sqrt_alphas_cumprod[t][:, None, None, None, None]
        sqrt_omac = sqrt_one_minus_alphas_cumprod[t][:, None, None, None, None]
        noisy_latent = sqrt_ac * batch + sqrt_omac * noise

        # Predict noise
        pred_noise = latent_model(noisy_latent, t)
        loss = F.mse_loss(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    avg = epoch_loss / (len(latent_dataset) // 16)
    latent_losses.append(avg)
    if (epoch + 1) % 5 == 0:
        print(f"  Epoch {epoch+1}/30 ‚Äî Loss: {avg:.4f}")

print("Latent diffusion training complete!")

In [None]:
# üìä Training curve
plt.figure(figsize=(10, 4))
plt.plot(latent_losses, 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss (in latent space)')
plt.title('Latent Video Diffusion ‚Äî Training Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Latent Sampling
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/12_latent_sampling.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 5.4 Sampling from Latent Diffusion

In [None]:
@torch.no_grad()
def sample_latent_diffusion(model, decoder, num_samples=8, num_frames=8,
                             latent_h=8, latent_w=8, latent_c=4):
    """Sample videos via latent diffusion + VAE decoding."""
    model.eval()

    # Start from noise in latent space (much smaller!)
    z = torch.randn(num_samples, latent_c, num_frames, latent_h, latent_w, device=device)

    for t_idx in reversed(range(num_timesteps)):
        t = torch.full((num_samples,), t_idx, device=device, dtype=torch.long)
        pred_noise = model(z, t)

        beta_t = betas[t_idx]
        alpha_t = alphas[t_idx]
        alpha_cumprod_t = alphas_cumprod[t_idx]

        coeff1 = 1.0 / torch.sqrt(alpha_t)
        coeff2 = beta_t / torch.sqrt(1.0 - alpha_cumprod_t)
        mean = coeff1 * (z - coeff2 * pred_noise)

        if t_idx > 0:
            z = mean + torch.sqrt(beta_t) * torch.randn_like(z)
        else:
            z = mean

    # Decode latents to pixels
    videos = decoder(z)
    model.train()
    return videos.clamp(0, 1)

print("Generating videos via latent diffusion...")
generated_latent = sample_latent_diffusion(latent_model, decoder)
print(f"Generated: {generated_latent.shape}")

In [None]:
# üìä Display latent diffusion results
fig, axes = plt.subplots(4, 8, figsize=(14, 7))
for row in range(4):
    for col in range(8):
        axes[row, col].imshow(generated_latent[row, 0, col].cpu().numpy(),
                             cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')
        if row == 0:
            axes[row, col].set_title(f'f={col}', fontsize=9)
    axes[row, 0].set_ylabel(f'Video {row+1}', fontsize=9, rotation=0, labelpad=40)

fig.suptitle('Generated Videos ‚Äî Latent Video Diffusion\n'
             'Diffusion runs in 8√ó8 latent space, decoded to 32√ó32 pixels', fontsize=13)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Dit Intro
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/13_dit_intro.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. The Diffusion Transformer (DiT)

Now let us build the other major innovation: replacing the U-Net with a Transformer. This is the architecture behind Sora.

### 6.1 Spacetime Patchification

In [None]:
class SpacetimePatchEmbed(nn.Module):
    """
    Convert a video into a sequence of spacetime patch tokens.

    Input: (B, C, T, H, W)
    Output: (B, num_patches, embed_dim)
    """
    def __init__(self, in_channels=4, embed_dim=128,
                 patch_t=2, patch_h=4, patch_w=4):
        super().__init__()
        self.patch_t = patch_t
        self.patch_h = patch_h
        self.patch_w = patch_w
        patch_dim = in_channels * patch_t * patch_h * patch_w
        self.proj = nn.Linear(patch_dim, embed_dim)

    def forward(self, x):
        B, C, T, H, W = x.shape
        pt, ph, pw = self.patch_t, self.patch_h, self.patch_w

        # Reshape into patches
        nt, nh, nw = T // pt, H // ph, W // pw
        x = x.reshape(B, C, nt, pt, nh, ph, nw, pw)
        x = x.permute(0, 2, 4, 6, 1, 3, 5, 7)  # (B, nt, nh, nw, C, pt, ph, pw)
        x = x.reshape(B, nt * nh * nw, C * pt * ph * pw)  # (B, num_patches, patch_dim)

        return self.proj(x), (nt, nh, nw)

In [None]:
# üìä Visualize patchification
patcher = SpacetimePatchEmbed(in_channels=4, embed_dim=128,
                               patch_t=2, patch_h=4, patch_w=4).to(device)
test_latent = torch.randn(1, 4, 8, 8, 8, device=device)
tokens, (nt, nh, nw) = patcher(test_latent)

print(f"Input latent: {test_latent.shape}  (B, C, T, H, W)")
print(f"Patches: {nt}t √ó {nh}h √ó {nw}w = {nt*nh*nw} tokens")
print(f"Token sequence: {tokens.shape}  (B, num_tokens, embed_dim)")
print(f"\nüí° Each token encodes a {patcher.patch_t}√ó{patcher.patch_h}√ó{patcher.patch_w} "
      f"cube of spacetime!")

In [None]:
#@title üéß Listen: Todo Unpatchify
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/14_todo_unpatchify.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### üîß TODO: Implement Spacetime Unpatchification

Patchification converts a video into tokens. After the Transformer processes these tokens, we need to convert them **back** into a video tensor. This is the inverse operation ‚Äî unpatchification.

In [None]:
def unpatchify(tokens, num_channels, nt, nh, nw, patch_t, patch_h, patch_w):
    """
    Convert a sequence of patch tokens back into a video tensor.

    Args:
        tokens: (B, num_patches, patch_dim) ‚Äî output from Transformer head
                where patch_dim = num_channels * patch_t * patch_h * patch_w
        num_channels: number of channels (e.g., 4 for latent space)
        nt, nh, nw: number of patches along time, height, width
        patch_t, patch_h, patch_w: patch dimensions

    Returns:
        video: (B, num_channels, T, H, W)
               where T = nt*patch_t, H = nh*patch_h, W = nw*patch_w
    """
    B = tokens.shape[0]

    # ============ TODO ============
    # Step 1: Reshape tokens from (B, nt*nh*nw, patch_dim)
    #         to (B, nt, nh, nw, C, patch_t, patch_h, patch_w)
    # Step 2: Permute to (B, C, nt, patch_t, nh, patch_h, nw, patch_w)
    #         This interleaves the grid indices with patch indices
    # Step 3: Reshape to (B, C, T, H, W) where T=nt*pt, H=nh*ph, W=nw*pw
    # ==============================

    video = ???  # YOUR CODE HERE

    return video

In [None]:
# ‚úÖ Verification
torch.manual_seed(42)
B, C, T, H, W = 2, 4, 8, 8, 8
pt, ph, pw = 2, 4, 4
_nt, _nh, _nw = T // pt, H // ph, W // pw
patch_dim = C * pt * ph * pw
test_tokens = torch.randn(B, _nt * _nh * _nw, patch_dim, device=device)

try:
    result = unpatchify(test_tokens, C, _nt, _nh, _nw, pt, ph, pw)
    assert result.shape == (B, C, T, H, W), f"Wrong shape: {result.shape}, expected ({B}, {C}, {T}, {H}, {W})"

    # Verify roundtrip: patchify ‚Üí unpatchify should recover original
    orig = torch.randn(B, C, T, H, W, device=device)
    tokens_rt, (_nt2, _nh2, _nw2) = patcher(orig)
    # Need to invert the linear projection for a true roundtrip, so just check shape
    print(f"‚úÖ Unpatchify works! Output shape: {result.shape}")
    print(f"   {_nt}√ó{_nh}√ó{_nw} patches ‚Üí {T}√ó{H}√ó{W} video")
except Exception as e:
    print(f"‚ùå Error: {e}")
    print("Hint: reshape to (B, nt, nh, nw, C, pt, ph, pw),")
    print("then permute to (B, C, nt, pt, nh, ph, nw, pw),")
    print("then reshape to (B, C, nt*pt, nh*ph, nw*pw)")

In [None]:
#@title üéß Listen: Dit Block
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/15_dit_block.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 6.2 DiT Block

In [None]:
class DiTBlock(nn.Module):
    """
    A single Transformer block for the Diffusion Transformer.
    Contains: self-attention + cross-attention (for conditioning) + feedforward.
    Timestep information is injected via adaptive layer norm (adaLN).
    """
    def __init__(self, dim, num_heads=4, mlp_ratio=4.0):
        super().__init__()
        # Self-attention
        self.norm1 = nn.LayerNorm(dim)
        self.self_attn_qkv = nn.Linear(dim, dim * 3)
        self.self_attn_proj = nn.Linear(dim, dim)
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        # Cross-attention (for text conditioning)
        self.norm2 = nn.LayerNorm(dim)
        self.cross_q = nn.Linear(dim, dim)
        self.cross_kv = nn.Linear(dim, dim * 2)
        self.cross_proj = nn.Linear(dim, dim)

        # Feedforward
        self.norm3 = nn.LayerNorm(dim)
        hidden = int(dim * mlp_ratio)
        self.ff = nn.Sequential(
            nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))

        # AdaLN modulation from timestep
        self.adaLN = nn.Sequential(
            nn.GELU(),
            nn.Linear(dim, dim * 6)  # 6 = scale+shift for 3 norms
        )

    def forward(self, x, t_emb, context=None):
        """
        x: (B, N, D) ‚Äî patch tokens
        t_emb: (B, D) ‚Äî timestep embedding
        context: (B, L, D) ‚Äî text embeddings (optional)
        """
        # AdaLN parameters from timestep
        ada = self.adaLN(t_emb)  # (B, 6*D)
        s1, b1, s2, b2, s3, b3 = ada.chunk(6, dim=-1)

        # Self-attention with adaLN
        h = self.norm1(x) * (1 + s1.unsqueeze(1)) + b1.unsqueeze(1)
        B, N, D = h.shape
        qkv = self.self_attn_qkv(h).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        h = (attn @ v).transpose(1, 2).reshape(B, N, D)
        h = self.self_attn_proj(h)
        x = x + h

        # Cross-attention (if conditioning context provided)
        if context is not None:
            h = self.norm2(x) * (1 + s2.unsqueeze(1)) + b2.unsqueeze(1)
            q = self.cross_q(h).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
            kv = self.cross_kv(context).reshape(B, -1, 2, self.num_heads, self.head_dim)
            kv = kv.permute(2, 0, 3, 1, 4)
            k_c, v_c = kv.unbind(0)
            attn = (q @ k_c.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            h = (attn @ v_c).transpose(1, 2).reshape(B, N, D)
            x = x + self.cross_proj(h)

        # Feedforward with adaLN
        h = self.norm3(x) * (1 + s3.unsqueeze(1)) + b3.unsqueeze(1)
        x = x + self.ff(h)

        return x

In [None]:
#@title üéß Listen: Full Dit
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/16_full_dit.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 6.3 Full Mini-DiT Model

In [None]:
class MiniDiT(nn.Module):
    """
    A minimal Diffusion Transformer for video generation.

    Architecture:
    1. Spacetime patchify the latent video
    2. Add positional embeddings
    3. Pass through N DiT blocks (self-attn + cross-attn + FFN)
    4. Unpatchify back to latent video shape
    """
    def __init__(self, latent_channels=4, embed_dim=128, num_heads=4,
                 depth=4, patch_t=2, patch_h=4, patch_w=4,
                 context_dim=64):
        super().__init__()
        self.patch_embed = SpacetimePatchEmbed(
            latent_channels, embed_dim, patch_t, patch_h, patch_w)
        self.patch_t = patch_t
        self.patch_h = patch_h
        self.patch_w = patch_w
        self.latent_channels = latent_channels

        # Positional embedding (learnable)
        self.pos_embed = nn.Parameter(torch.randn(1, 512, embed_dim) * 0.02)

        # Timestep embedding
        self.time_embed = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.time_dim = embed_dim

        # Context projection (for text conditioning)
        self.context_proj = nn.Linear(context_dim, embed_dim)

        # DiT blocks
        self.blocks = nn.ModuleList([
            DiTBlock(embed_dim, num_heads) for _ in range(depth)
        ])

        # Output: project back to patch dimension
        self.norm_out = nn.LayerNorm(embed_dim)
        patch_dim = latent_channels * patch_t * patch_h * patch_w
        self.head = nn.Linear(embed_dim, patch_dim)

    def forward(self, x, t, context=None):
        """
        x: (B, C, T, H, W) ‚Äî noisy latent video
        t: (B,) ‚Äî diffusion timestep
        context: (B, L, context_dim) ‚Äî text embeddings (optional)
        """
        B, C, T, H, W = x.shape

        # Patchify
        tokens, (nt, nh, nw) = self.patch_embed(x)
        num_tokens = tokens.shape[1]

        # Add positional embedding
        tokens = tokens + self.pos_embed[:, :num_tokens]

        # Timestep embedding
        half = self.time_dim // 2
        freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
        t_emb = torch.cat([
            (t[:, None].float() * freqs[None]).sin(),
            (t[:, None].float() * freqs[None]).cos()
        ], dim=-1)
        t_emb = self.time_embed(t_emb)

        # Project context if provided
        if context is not None:
            context = self.context_proj(context)

        # DiT blocks
        for block in self.blocks:
            tokens = block(tokens, t_emb, context)

        # Unpatchify
        tokens = self.head(self.norm_out(tokens))

        # Reshape back to (B, C, T, H, W)
        pt, ph, pw = self.patch_t, self.patch_h, self.patch_w
        tokens = tokens.reshape(B, nt, nh, nw, C, pt, ph, pw)
        tokens = tokens.permute(0, 4, 1, 5, 2, 6, 3, 7)  # (B, C, nt, pt, nh, ph, nw, pw)
        out = tokens.reshape(B, C, T, H, W)

        return out

In [None]:
# Test the DiT
dit = MiniDiT(latent_channels=4, embed_dim=128, num_heads=4,
              depth=4, context_dim=64).to(device)
test_z = torch.randn(2, 4, 8, 8, 8, device=device)
test_t = torch.randint(0, 500, (2,), device=device)
test_ctx = torch.randn(2, 10, 64, device=device)  # 10 "text" tokens

out_uncond = dit(test_z, test_t)
out_cond = dit(test_z, test_t, context=test_ctx)

print(f"Input latent:     {test_z.shape}")
print(f"Output (uncond):  {out_uncond.shape}")
print(f"Output (cond):    {out_cond.shape}")
params = sum(p.numel() for p in dit.parameters())
print(f"DiT parameters: {params:,}")
print("‚úÖ Mini DiT works ‚Äî with and without text conditioning!")

In [None]:
#@title üéß Listen: Todo Cross Attention
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/17_todo_cross_attention.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. üîß Your Turn ‚Äî Implement Cross-Attention

### TODO: Complete the cross-attention function

Cross-attention is how the video model "looks at" the text description. Complete the implementation below:

In [None]:
def cross_attention(query, key, value, num_heads=4):
    """
    Compute multi-head cross-attention.

    Args:
        query: (B, N_q, D) ‚Äî queries from video features
        key:   (B, N_kv, D) ‚Äî keys from text embeddings
        value: (B, N_kv, D) ‚Äî values from text embeddings
        num_heads: number of attention heads

    Returns:
        (B, N_q, D) ‚Äî video features enriched with text information
    """
    B, N_q, D = query.shape
    _, N_kv, _ = key.shape
    head_dim = D // num_heads

    # ============ TODO ============
    # Step 1: Reshape query to (B, num_heads, N_q, head_dim)
    # Step 2: Reshape key to (B, num_heads, N_kv, head_dim)
    # Step 3: Reshape value to (B, num_heads, N_kv, head_dim)
    # Step 4: Compute attention scores: (q @ k^T) / sqrt(head_dim)
    # Step 5: Apply softmax over the key dimension (last dim)
    # Step 6: Compute weighted sum: attn_weights @ value
    # Step 7: Reshape output back to (B, N_q, D)
    # ==============================

    q = ???  # YOUR CODE HERE (Step 1)
    k = ???  # YOUR CODE HERE (Step 2)
    v = ???  # YOUR CODE HERE (Step 3)

    scores = ???  # YOUR CODE HERE (Step 4)
    attn_weights = ???  # YOUR CODE HERE (Step 5)
    out = ???  # YOUR CODE HERE (Step 6)

    output = ???  # YOUR CODE HERE (Step 7)
    return output

In [None]:
# ‚úÖ Verification
torch.manual_seed(42)
B, N_q, N_kv, D = 2, 16, 8, 32
test_q = torch.randn(B, N_q, D, device=device)
test_k = torch.randn(B, N_kv, D, device=device)
test_v = torch.randn(B, N_kv, D, device=device)

try:
    result = cross_attention(test_q, test_k, test_v, num_heads=4)
    assert result.shape == (B, N_q, D), f"Wrong shape: {result.shape}, expected ({B}, {N_q}, {D})"

    # Verify attention is over key dimension
    # Each query should produce a different output even with same keys
    test_q2 = torch.randn(B, N_q, D, device=device)
    result2 = cross_attention(test_q2, test_k, test_v, num_heads=4)
    assert not torch.allclose(result, result2, atol=1e-3), "Different queries should give different outputs"

    print(f"‚úÖ Cross-attention works! Output shape: {result.shape}")
    print(f"   {N_q} video tokens attended to {N_kv} text tokens")
except Exception as e:
    print(f"‚ùå Error: {e}")
    print("Hint: reshape using .reshape(B, N, num_heads, head_dim).transpose(1, 2)")

In [None]:
#@title üéß Listen: Train Dit
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/18_train_dit.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. Training the DiT with "Text" Conditioning

For our Moving MNIST demo, we will create simple conditioning labels (e.g., velocity direction) as a stand-in for real text embeddings.

In [None]:
# Create labeled dataset: encode velocity direction as a simple "text" embedding
def create_labeled_dataset(num_videos=512, num_frames=8, size=32, embed_dim=64):
    """Create Moving MNIST with direction labels as conditioning."""
    videos = []
    labels = []

    for _ in range(num_videos):
        video = np.zeros((num_frames, size, size), dtype=np.float32)
        digit_size = 12
        x = np.random.randint(0, size - digit_size)
        y = np.random.randint(0, size - digit_size)

        # 4 directions: right, left, down, up
        direction = np.random.randint(0, 4)
        vx = [2, -2, 0, 0][direction]
        vy = [0, 0, 2, -2][direction]

        digit = np.zeros((digit_size, digit_size), dtype=np.float32)
        center = digit_size // 2
        for i in range(digit_size):
            for j in range(digit_size):
                if ((i - center)**2 + (j - center)**2) ** 0.5 < center:
                    digit[i, j] = 1.0

        for t in range(num_frames):
            video[t, y:y+digit_size, x:x+digit_size] = digit
            x += vx; y += vy
            if x <= 0 or x >= size - digit_size: vx = -vx; x = max(0, min(size - digit_size, x))
            if y <= 0 or y >= size - digit_size: vy = -vy; y = max(0, min(size - digit_size, y))

        videos.append(video)

        # Create a simple "text embedding" from direction
        label = np.zeros(embed_dim, dtype=np.float32)
        label[direction * (embed_dim // 4):(direction + 1) * (embed_dim // 4)] = 1.0
        labels.append(label)

    videos = torch.tensor(np.stack(videos)).unsqueeze(1)
    labels = torch.tensor(np.stack(labels))
    return videos, labels

cond_videos, cond_labels = create_labeled_dataset(512, 8, 32, 64)
print(f"Videos: {cond_videos.shape}, Labels: {cond_labels.shape}")
print("Directions: right=0, left=1, down=2, up=3")

In [None]:
# Encode to latent space
with torch.no_grad():
    cond_latents = []
    for i in range(0, len(cond_videos), 32):
        batch = cond_videos[i:i+32].to(device)
        mean, logvar = encoder(batch)
        z = encoder.sample(mean, logvar)
        cond_latents.append(z.cpu())
    cond_latents = torch.cat(cond_latents)
print(f"Latent dataset: {cond_latents.shape}")

In [None]:
# Train the DiT with conditioning
dit = MiniDiT(latent_channels=4, embed_dim=128, num_heads=4,
              depth=4, context_dim=64).to(device)
optimizer = torch.optim.Adam(dit.parameters(), lr=2e-4)

dit_losses = []
print("Training Mini DiT with direction conditioning...")

for epoch in range(30):
    epoch_loss = 0
    perm = torch.randperm(len(cond_latents))

    for i in range(0, len(cond_latents) - 16 + 1, 16):
        idx = perm[i:i+16]
        batch = cond_latents[idx].to(device)
        labels = cond_labels[idx].to(device)
        B = batch.shape[0]

        # Context: direction label reshaped as 1-token sequence
        context = labels.unsqueeze(1)  # (B, 1, 64)

        t = torch.randint(0, num_timesteps, (B,), device=device)
        noise = torch.randn_like(batch)
        sqrt_ac = sqrt_alphas_cumprod[t][:, None, None, None, None]
        sqrt_omac = sqrt_one_minus_alphas_cumprod[t][:, None, None, None, None]
        noisy = sqrt_ac * batch + sqrt_omac * noise

        pred = dit(noisy, t, context)
        loss = F.mse_loss(pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    avg = epoch_loss / (len(cond_latents) // 16)
    dit_losses.append(avg)
    if (epoch + 1) % 5 == 0:
        print(f"  Epoch {epoch+1}/30 ‚Äî Loss: {avg:.4f}")

print("DiT training complete!")

In [None]:
# üìä Training curve comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].plot(latent_losses, 'b-', linewidth=2, label='Latent U-Net')
axes[0].plot(dit_losses, 'r-', linewidth=2, label='DiT')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Architecture comparison
archs = ['Pixel U-Net\n(NB1)', 'Factorized\nU-Net (NB2)', 'Latent\nU-Net', 'DiT']
desc = ['3D conv in\npixel space', 'Spatial+Temporal\nattn in pixels', 'U-Net in\nlatent space', 'Transformer\nin latent space']

axes[1].barh(archs, [1, 2, 3, 4], color=['#dd8452', '#55a868', '#4c72b0', '#c44e52'])
axes[1].set_xlabel('Approach Sophistication ‚Üí')
axes[1].set_title('Video Diffusion Architecture Evolution')
for i, d in enumerate(desc):
    axes[1].text(0.5, i, d, va='center', fontsize=8, color='white', fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Conditioned Sampling
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/19_conditioned_sampling.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 8.1 Conditioned Sampling from DiT

In [None]:
@torch.no_grad()
def sample_dit(dit_model, decoder, direction, num_samples=4,
               num_frames=8, latent_h=8, latent_w=8, latent_c=4):
    """Sample videos conditioned on a direction label."""
    dit_model.eval()

    # Create conditioning
    context_dim = 64
    label = torch.zeros(num_samples, context_dim, device=device)
    label[:, direction * (context_dim // 4):(direction + 1) * (context_dim // 4)] = 1.0
    context = label.unsqueeze(1)  # (B, 1, 64)

    z = torch.randn(num_samples, latent_c, num_frames, latent_h, latent_w, device=device)

    for t_idx in reversed(range(num_timesteps)):
        t = torch.full((num_samples,), t_idx, device=device, dtype=torch.long)
        pred_noise = dit_model(z, t, context)

        beta_t = betas[t_idx]
        alpha_t = alphas[t_idx]
        alpha_cumprod_t = alphas_cumprod[t_idx]

        coeff1 = 1.0 / torch.sqrt(alpha_t)
        coeff2 = beta_t / torch.sqrt(1.0 - alpha_cumprod_t)
        mean = coeff1 * (z - coeff2 * pred_noise)

        if t_idx > 0:
            z = mean + torch.sqrt(beta_t) * torch.randn_like(z)
        else:
            z = mean

    videos = decoder(z)
    dit_model.train()
    return videos.clamp(0, 1)

# Generate for each direction
direction_names = ['Right ‚Üí', 'Left ‚Üê', 'Down ‚Üì', 'Up ‚Üë']
print("Generating conditioned videos for each direction...")
all_gen = []
for d in range(4):
    vids = sample_dit(dit, decoder, direction=d, num_samples=2)
    all_gen.append(vids)
    print(f"  Generated {direction_names[d]}")

In [None]:
#@title üéß Listen: Final Output
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/20_final_output.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 9. üéØ Final Output ‚Äî Text-Conditioned Video Generation

In [None]:
# Display conditioned generation results
fig, axes = plt.subplots(8, 8, figsize=(14, 14))

for d in range(4):
    for sample in range(2):
        row = d * 2 + sample
        for col in range(8):
            axes[row, col].imshow(
                all_gen[d][sample, 0, col].cpu().numpy(),
                cmap='gray', vmin=0, vmax=1)
            axes[row, col].axis('off')
            if row == 0:
                axes[row, col].set_title(f'Frame {col}', fontsize=9)
        axes[row, 0].set_ylabel(f'{direction_names[d]}',
                               fontsize=10, rotation=0, labelpad=50,
                               color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'][d])

fig.suptitle('üéØ Conditioned Video Generation with Mini DiT\n'
             'Each pair of rows shows a different motion direction', fontsize=14)
plt.tight_layout()
plt.show()

print("üéâ Congratulations! You've built:")
print("  1. A Video VAE (48x compression)")
print("  2. Latent Video Diffusion (efficient denoising)")
print("  3. A Diffusion Transformer with spacetime patches")
print("  4. Text-conditioned video generation via cross-attention")
print("\nThis is the same architecture family behind Sora!")

In [None]:
#@title üéß Listen: Reflection
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/21_reflection.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 10. Reflection and Next Steps

### ü§î Reflection Questions

1. **VAE quality tradeoff:** We used a small $\beta_{\text{KL}}$ to prioritize reconstruction. What happens if we increase it? How would this affect the diffusion model's job?

2. **Patch size matters:** We used $2 \times 4 \times 4$ patches. What happens with smaller patches (more tokens, slower but more detail) vs larger patches (fewer tokens, faster but less detail)?

3. **Factorized vs DiT:** The factorized U-Net separates spatial and temporal attention. The DiT mixes them via spacetime patches. When might one approach be better than the other?

4. **Scaling:** Sora reportedly uses a much larger DiT (billions of parameters). What changes when you scale up? Do you expect the same architecture to work, just bigger?

5. **Real text conditioning:** We used a simple one-hot direction label. How would you integrate a real text encoder like CLIP? What additional challenges arise?

### üèÜ Optional Challenges

1. **3D VAE:** Extend the VAE to compress temporally as well (reduce 8 frames to 4 latent frames). How does this affect reconstruction quality and diffusion efficiency?

2. **Classifier-free guidance:** Implement classifier-free guidance by randomly dropping the conditioning during training (replace with zeros 10% of the time). At inference, interpolate between conditional and unconditional predictions.

3. **Longer videos:** Try generating 16 or 32 frame videos. Does the DiT handle longer sequences gracefully?

### Series Summary

Across these three notebooks, you have built the complete toolkit for video diffusion:
- **Notebook 1:** Video diffusion basics ‚Äî forward/reverse process, 3D convolutions, DDPM sampling
- **Notebook 2:** Factorized attention ‚Äî spatial and temporal attention, computational efficiency
- **Notebook 3:** Modern architectures ‚Äî Video VAE, latent diffusion, DiT with spacetime patches, text conditioning

These are the core building blocks behind every modern video generation system, from Stable Video Diffusion to Sora.