In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1ventuhdj998YNr_9KusKPNX2VFJg7Av1", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/00_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ Factorized Space-Time Attention for Video Diffusion

*Part 2 of the Vizuara series on Diffusion Models for Video Generation*
*Estimated time: 45 minutes*

In Notebook 1, we built a simple video diffusion model using 3D convolutions. It worked ‚Äî but it was slow and inflexible. In this notebook, we will build the **real** architecture behind modern video diffusion models: **factorized spatial-temporal attention**. This is the core architectural innovation used by Video Diffusion Models (VDM), Imagen Video, Stable Video Diffusion, and many others.

By the end of this notebook, you will:
- Understand why full 3D attention is impractical for video
- Build spatial self-attention (per-frame) and temporal self-attention (across-frames) from scratch
- See how factorized attention achieves **16x savings** over full 3D attention
- Implement a training strategy that leverages pretrained image models
- Train a factorized video U-Net on Moving MNIST and compare it to the 3D conv approach

In [None]:
#@title üéß Listen: Why It Matters
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/01_why_it_matters.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 1. Why Does This Matter?

Think about how a film editor works. They do not look at every pixel of every frame simultaneously ‚Äî that would be overwhelming. Instead, they work in two passes:

1. **Spatial pass:** Look at each frame individually ‚Äî is the composition good? Are the colors right? Is the lighting consistent?
2. **Temporal pass:** Play the frames in sequence ‚Äî is the motion smooth? Do objects move naturally? Are there any visual jumps?

This is exactly the idea behind **factorized attention** for video diffusion. Instead of trying to jointly attend over all space-time positions (which is absurdly expensive), we split the attention into two cheaper operations:
- **Spatial attention** within each frame (what does each frame look like?)
- **Temporal attention** across frames (how do things move over time?)

This simple factorization is what makes video diffusion models practical. Without it, the computational cost would be prohibitive for any reasonable resolution.

In [None]:
#@title üéß Listen: Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/02_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

### ü§î Think About This

Before we dive into the math, consider this thought experiment:

You have a 16-frame video at 32√ó32 resolution. If you wanted every pixel to "talk to" every other pixel across all frames (full 3D attention), how many attention pairs would that be?

- Total positions: 16 √ó 32 √ó 32 = **16,384**
- Attention pairs: 16,384¬≤ ‚âà **268 million**

Now consider the factorized approach:
- Spatial attention (per frame): 32 √ó 32 = 1,024 positions, repeated 16 times ‚Üí 16 √ó 1,024¬≤ ‚âà **16.8 million**
- Temporal attention (per position): 16 frames, repeated 1,024 times ‚Üí 1,024 √ó 16¬≤ ‚âà **262,000**

Total: ~**17 million** vs ~**268 million** ‚Äî that is a **16√ó reduction**!

The savings grow even more dramatically at higher resolutions. For a 256√ó256 video, the ratio is over **1000√ó**.

### ‚úã Stop and Think

Before scrolling down, ask yourself:
1. What information might we lose by factorizing? Can spatial-only attention capture an object moving across frames?
2. Why do we do spatial attention first and temporal attention second (and not the other way)?
3. Could there be scenarios where full 3D attention is actually better?

*Take 2 minutes. Then scroll down.*

In [None]:
#@title üéß Listen: Math
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/03_math.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

### Self-Attention Recap

Recall that self-attention computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \cdot V$$

where Q (queries), K (keys), and V (values) are linear projections of the input features.

Computationally, this says: for each query position, compute a weighted sum over all value positions, where the weights are determined by the similarity between the query and each key, scaled by $\sqrt{d_k}$ to prevent the dot products from becoming too large.

### Spatial Self-Attention

For spatial attention, we reshape the video features of shape $(B, T, H, W, C)$ so that the **time axis is merged with the batch axis**:

$$\mathbf{x}_{\text{spatial}} \in \mathbb{R}^{(B \cdot T) \times (H \cdot W) \times C}$$

Each frame is processed independently. The sequence length is $H \times W$ ‚Äî the number of spatial positions in one frame.

This means frame 1 attends only to itself, frame 2 attends only to itself, and so on. No cross-frame information flows during spatial attention.

### Temporal Self-Attention

For temporal attention, we reshape so that the **spatial axes are merged with the batch axis**:

$$\mathbf{x}_{\text{temporal}} \in \mathbb{R}^{(B \cdot H \cdot W) \times T \times C}$$

Now, for each spatial position $(h, w)$, we have a sequence of $T$ features across time. The attention lets each frame "look at" all other frames at that same spatial position.

Computationally, the pixel at position $(3, 7)$ in frame 5 can attend to position $(3, 7)$ in frames 1, 2, 3, ..., T ‚Äî but NOT to position $(10, 15)$ in any frame. This is the key restriction that makes it efficient.

### Complexity Comparison

| Approach | Sequence Length | Complexity |
|----------|----------------|------------|
| Full 3D | $T \cdot H \cdot W$ | $O((T \cdot H \cdot W)^2)$ |
| Spatial only | $H \cdot W$ (√óT) | $O(T \cdot (H \cdot W)^2)$ |
| Temporal only | $T$ (√óH¬∑W) | $O(H \cdot W \cdot T^2)$ |
| **Factorized** | Both | $O(T \cdot (H \cdot W)^2 + H \cdot W \cdot T^2)$ |

In [None]:
#@title üéß Listen: Building Components
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/04_building_components.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

In [None]:
#@title üéß Listen: Self Attention
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/05_self_attention.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.1 Multi-Head Self-Attention

First, let us build a general-purpose multi-head self-attention module. We will reuse this for both spatial and temporal attention ‚Äî the only difference is how we reshape the input.

In [None]:
class MultiHeadSelfAttention(nn.Module):
    """
    Standard multi-head self-attention.
    Input: (batch, seq_len, dim)
    Output: (batch, seq_len, dim)
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        assert dim % num_heads == 0, "dim must be divisible by num_heads"

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        B, N, C = x.shape
        # Project to Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv.unbind(0)

        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        # Weighted sum of values
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(out)

Let us verify it works with a small example.

In [None]:
# üìä Quick test
attn = MultiHeadSelfAttention(dim=32, num_heads=4).to(device)
test_input = torch.randn(2, 16, 32, device=device)  # batch=2, seq=16, dim=32
test_output = attn(test_input)
print(f"Input shape:  {test_input.shape}")
print(f"Output shape: {test_output.shape}")
assert test_input.shape == test_output.shape
print("‚úÖ Self-attention module works!")

In [None]:
#@title üéß Listen: Spatial Attention
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/06_spatial_attention.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 Spatial Attention Block

Now let us wrap the self-attention into a **spatial attention** block. The key idea: merge the time axis into the batch dimension, so each frame is processed independently.

In [None]:
class SpatialAttentionBlock(nn.Module):
    """
    Self-attention over spatial positions within each frame.
    Input: (B, T, H, W, C) -> reshape to (B*T, H*W, C) -> attend -> reshape back
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)

    def forward(self, x):
        # x: (B, T, H, W, C)
        B, T, H, W, C = x.shape

        # Merge time into batch: each frame is independent
        x_flat = x.reshape(B * T, H * W, C)

        # Self-attention with residual connection
        x_flat = x_flat + self.attn(self.norm(x_flat))

        # Reshape back
        return x_flat.reshape(B, T, H, W, C)

In [None]:
# üìä Visualize what spatial attention "sees"
# Each frame is processed independently ‚Äî no cross-frame information
B, T, H, W, C = 1, 4, 8, 8, 32
test_video = torch.randn(B, T, H, W, C, device=device)

spatial_block = SpatialAttentionBlock(dim=C).to(device)
output = spatial_block(test_video)

print(f"Input shape:  {test_video.shape}  (B, T, H, W, C)")
print(f"Output shape: {output.shape}")

# Verify frames are processed independently:
# If we change frame 3, only frame 3's output should change
test_video_modified = test_video.clone()
test_video_modified[:, 2] = torch.randn(B, H, W, C, device=device)

output_orig = spatial_block(test_video)
output_mod = spatial_block(test_video_modified)

# Frames 0, 1, 3 should be unchanged
for t in [0, 1, 3]:
    diff = (output_orig[:, t] - output_mod[:, t]).abs().max().item()
    print(f"  Frame {t} max diff: {diff:.6f} (should be ~0)")

diff_f2 = (output_orig[:, 2] - output_mod[:, 2]).abs().max().item()
print(f"  Frame 2 max diff: {diff_f2:.4f} (should be large)")
print("‚úÖ Spatial attention processes each frame independently!")

In [None]:
#@title üéß Listen: Temporal Attention
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/07_temporal_attention.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 Temporal Attention Block

Next, the **temporal attention** block. Here, we merge the spatial axes into the batch dimension, so each spatial position attends across all frames.

In [None]:
class TemporalAttentionBlock(nn.Module):
    """
    Self-attention across frames at each spatial position.
    Input: (B, T, H, W, C) -> reshape to (B*H*W, T, C) -> attend -> reshape back
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)

    def forward(self, x):
        # x: (B, T, H, W, C)
        B, T, H, W, C = x.shape

        # Merge spatial into batch: each position attends across time
        x_flat = x.permute(0, 2, 3, 1, 4).reshape(B * H * W, T, C)

        # Self-attention with residual connection
        x_flat = x_flat + self.attn(self.norm(x_flat))

        # Reshape back
        return x_flat.reshape(B, H, W, T, C).permute(0, 3, 1, 2, 4)

In [None]:
# üìä Visualize what temporal attention "sees"
temporal_block = TemporalAttentionBlock(dim=C).to(device)
output_temporal = temporal_block(test_video)

print(f"Input shape:  {test_video.shape}  (B, T, H, W, C)")
print(f"Output shape: {output_temporal.shape}")

# Verify: changing one spatial position should NOT affect other positions
test_video_2 = test_video.clone()
test_video_2[:, :, 5, 5] = torch.randn(B, T, C, device=device)

out_a = temporal_block(test_video)
out_b = temporal_block(test_video_2)

# Position (3,3) should be unchanged
diff_pos = (out_a[:, :, 3, 3] - out_b[:, :, 3, 3]).abs().max().item()
print(f"  Position (3,3) max diff: {diff_pos:.6f} (should be ~0)")

# Position (5,5) should change
diff_mod = (out_a[:, :, 5, 5] - out_b[:, :, 5, 5]).abs().max().item()
print(f"  Position (5,5) max diff: {diff_mod:.4f} (should be large)")
print("‚úÖ Temporal attention processes each spatial position independently across time!")

In [None]:
#@title üéß Listen: Factorized Block
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/08_factorized_block.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 Factorized Space-Time Block

Now we combine spatial and temporal attention into a single **factorized space-time block**. This is the fundamental building block of modern video diffusion architectures.

In [None]:
class FactorizedSpaceTimeBlock(nn.Module):
    """
    A single block of factorized spatial-temporal attention.
    1. Spatial self-attention (within each frame)
    2. Temporal self-attention (across frames at each position)
    3. Feedforward network
    """
    def __init__(self, dim, num_heads=4, mlp_ratio=4.0):
        super().__init__()
        self.spatial_attn = SpatialAttentionBlock(dim, num_heads)
        self.temporal_attn = TemporalAttentionBlock(dim, num_heads)

        # Feedforward network
        self.norm_ff = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.ff = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x):
        # x: (B, T, H, W, C)
        # Step 1: Spatial attention (per frame)
        x = self.spatial_attn(x)

        # Step 2: Temporal attention (across frames)
        x = self.temporal_attn(x)

        # Step 3: Feedforward with residual
        B, T, H, W, C = x.shape
        x_flat = x.reshape(B * T * H * W, C)
        x_flat = x_flat + self.ff(self.norm_ff(x_flat))
        return x_flat.reshape(B, T, H, W, C)

In [None]:
# üìä Test the full factorized block
block = FactorizedSpaceTimeBlock(dim=32, num_heads=4).to(device)
test_input = torch.randn(2, 4, 8, 8, 32, device=device)
test_output = block(test_input)

print(f"Input:  {test_input.shape}")
print(f"Output: {test_output.shape}")

num_params = sum(p.numel() for p in block.parameters())
print(f"Parameters: {num_params:,}")
print("‚úÖ Factorized space-time block works!")

In [None]:
#@title üéß Listen: Cost Comparison
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/09_cost_comparison.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üìä Computational Cost Comparison

Let us empirically verify the cost savings of factorized vs full 3D attention.

In [None]:
class Full3DAttention(nn.Module):
    """
    Full 3D self-attention over all space-time positions.
    Input: (B, T, H, W, C) -> reshape to (B, T*H*W, C) -> attend
    WARNING: Very expensive for large inputs!
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)

    def forward(self, x):
        B, T, H, W, C = x.shape
        x_flat = x.reshape(B, T * H * W, C)
        x_flat = x_flat + self.attn(self.norm(x_flat))
        return x_flat.reshape(B, T, H, W, C)

In [None]:
# Compare computation time: Factorized vs Full 3D
results = {}
resolutions = [(4, 8, 8), (4, 16, 16), (8, 16, 16), (8, 32, 32)]

for T, H, W in resolutions:
    label = f"T={T}, {H}x{W}"
    seq_len = T * H * W
    x = torch.randn(1, T, H, W, 32, device=device)

    # Factorized
    fact_block = FactorizedSpaceTimeBlock(dim=32).to(device)
    torch.cuda.synchronize() if device.type == 'cuda' else None
    start = time.time()
    for _ in range(10):
        _ = fact_block(x)
    torch.cuda.synchronize() if device.type == 'cuda' else None
    fact_time = (time.time() - start) / 10

    # Full 3D (skip if too large)
    if seq_len <= 4096:
        full_block = Full3DAttention(dim=32).to(device)
        torch.cuda.synchronize() if device.type == 'cuda' else None
        start = time.time()
        for _ in range(10):
            _ = full_block(x)
        torch.cuda.synchronize() if device.type == 'cuda' else None
        full_time = (time.time() - start) / 10
        speedup = full_time / fact_time
    else:
        full_time = float('nan')
        speedup = float('nan')

    results[label] = (fact_time, full_time, speedup)
    print(f"{label} (seq={seq_len:>5}): "
          f"Factorized={fact_time*1000:.1f}ms, "
          f"Full3D={full_time*1000:.1f}ms, "
          f"Speedup={speedup:.1f}x")

In [None]:
# üìä Visualize the scaling
labels = list(results.keys())
fact_times = [results[l][0]*1000 for l in labels]
full_times = [results[l][1]*1000 for l in labels]

fig, ax = plt.subplots(figsize=(10, 5))
x_pos = np.arange(len(labels))
width = 0.35

bars1 = ax.bar(x_pos - width/2, fact_times, width, label='Factorized', color='#4c72b0')
bars2 = ax.bar(x_pos + width/2, full_times, width, label='Full 3D', color='#dd8452')

ax.set_xlabel('Video Resolution')
ax.set_ylabel('Time per forward pass (ms)')
ax.set_title('Factorized vs Full 3D Attention ‚Äî Computation Time')
ax.set_xticks(x_pos)
ax.set_xticklabels(labels, rotation=15)
ax.legend()

# Add speedup annotations
for i, label in enumerate(labels):
    speedup = results[label][2]
    if not np.isnan(speedup):
        ax.annotate(f'{speedup:.1f}x faster',
                   xy=(x_pos[i], max(fact_times[i], full_times[i])),
                   xytext=(0, 10), textcoords='offset points',
                   ha='center', fontsize=9, color='green', fontweight='bold')

plt.tight_layout()
plt.show()
print("üí° The savings grow dramatically with resolution!")

In [None]:
#@title üéß Listen: Full Unet
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/10_full_unet.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Building a Factorized Video U-Net

Now let us put it all together into a complete U-Net architecture with factorized attention. We will compare it against the 3D conv model from Notebook 1.

### 6.1 Dataset (same Moving MNIST)

In [None]:
def create_moving_mnist_video(num_frames=16, size=32, digit_size=12):
    """Create a single Moving MNIST video: a white digit bouncing around."""
    video = np.zeros((num_frames, size, size), dtype=np.float32)

    # Random starting position and velocity
    x = np.random.randint(0, size - digit_size)
    y = np.random.randint(0, size - digit_size)
    vx = np.random.choice([-2, -1, 1, 2])
    vy = np.random.choice([-2, -1, 1, 2])

    # Simple digit pattern (circle-like blob)
    digit = np.zeros((digit_size, digit_size), dtype=np.float32)
    center = digit_size // 2
    for i in range(digit_size):
        for j in range(digit_size):
            dist = ((i - center)**2 + (j - center)**2) ** 0.5
            if dist < center:
                digit[i, j] = 1.0

    for t in range(num_frames):
        # Place digit
        video[t, y:y+digit_size, x:x+digit_size] = digit

        # Move with bouncing
        x += vx
        y += vy
        if x <= 0 or x >= size - digit_size:
            vx = -vx
            x = max(0, min(size - digit_size, x))
        if y <= 0 or y >= size - digit_size:
            vy = -vy
            y = max(0, min(size - digit_size, y))

    return video

def create_dataset(num_videos=512, num_frames=16, size=32):
    """Create a dataset of Moving MNIST videos."""
    videos = np.stack([create_moving_mnist_video(num_frames, size) for _ in range(num_videos)])
    # Shape: (N, T, H, W) -> (N, 1, T, H, W) for channel dim
    return torch.tensor(videos).unsqueeze(1)

# Create dataset
dataset = create_dataset(num_videos=512, num_frames=8, size=32)
print(f"Dataset shape: {dataset.shape}  (N, C, T, H, W)")

In [None]:
# üìä Visualize some training samples
fig, axes = plt.subplots(3, 8, figsize=(14, 5))
for row in range(3):
    for col in range(8):
        axes[row, col].imshow(dataset[row, 0, col].numpy(), cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')
        if row == 0:
            axes[row, col].set_title(f'f={col}', fontsize=9)
fig.suptitle('Moving MNIST Training Samples', fontsize=14)
plt.tight_layout()
plt.show()

### 6.2 Factorized Video U-Net

In [None]:
class SinusoidalTimeEmbed(nn.Module):
    """Sinusoidal timestep embedding."""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, t):
        half = self.dim // 2
        freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
        args = t[:, None].float() * freqs[None, :]
        emb = torch.cat([args.sin(), args.cos()], dim=-1)
        return self.mlp(emb)

In [None]:
class FactorizedVideoUNet(nn.Module):
    """
    A simple U-Net with factorized spatial-temporal attention.

    Architecture:
    - Encoder: 2D conv (spatial) ‚Üí factorized attention
    - Bottleneck: factorized attention
    - Decoder: factorized attention ‚Üí 2D conv (spatial)

    All convolutions are 2D (per frame). Temporal modeling is
    handled entirely by the temporal attention layers.
    """
    def __init__(self, in_channels=1, base_dim=32, num_heads=4):
        super().__init__()
        self.time_embed = SinusoidalTimeEmbed(base_dim)

        # Encoder
        self.enc1_conv = nn.Sequential(
            nn.Conv2d(in_channels, base_dim, 3, padding=1),
            nn.GroupNorm(4, base_dim),
            nn.GELU()
        )
        self.enc1_attn = FactorizedSpaceTimeBlock(base_dim, num_heads)

        self.enc2_conv = nn.Sequential(
            nn.Conv2d(base_dim, base_dim * 2, 3, stride=2, padding=1),
            nn.GroupNorm(4, base_dim * 2),
            nn.GELU()
        )
        self.enc2_attn = FactorizedSpaceTimeBlock(base_dim * 2, num_heads)

        # Bottleneck
        self.bottleneck = FactorizedSpaceTimeBlock(base_dim * 2, num_heads)

        # Decoder
        self.dec2_conv = nn.Sequential(
            nn.ConvTranspose2d(base_dim * 4, base_dim, 2, stride=2),
            nn.GroupNorm(4, base_dim),
            nn.GELU()
        )
        self.dec2_attn = FactorizedSpaceTimeBlock(base_dim, num_heads)

        self.out_conv = nn.Conv2d(base_dim * 2, in_channels, 1)

        # Time projection layers
        self.time_proj1 = nn.Linear(base_dim, base_dim)
        self.time_proj2 = nn.Linear(base_dim, base_dim * 2)

    def forward(self, x, t):
        """
        x: (B, C, T, H, W) - noisy video
        t: (B,) - diffusion timestep
        Returns: (B, C, T, H, W) - predicted noise
        """
        B, C, T, H, W = x.shape
        t_emb = self.time_embed(t)  # (B, base_dim)

        # Process each frame with 2D convs, then do factorized attention
        # Reshape: (B, C, T, H, W) -> (B*T, C, H, W) for 2D convs
        x_frames = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)

        # Encoder block 1
        h1 = self.enc1_conv(x_frames)  # (B*T, base_dim, H, W)
        # Add time embedding
        t1 = self.time_proj1(t_emb)[:, :, None, None]  # (B, base_dim, 1, 1)
        t1 = t1.repeat(1, 1, 1, 1).repeat_interleave(T, dim=0)  # (B*T, base_dim, 1, 1)
        h1 = h1 + t1
        # Reshape for attention: (B*T, C, H, W) -> (B, T, H, W, C)
        h1_attn = h1.reshape(B, T, -1, H, W).permute(0, 1, 3, 4, 2)
        h1_attn = self.enc1_attn(h1_attn)
        h1 = h1_attn.permute(0, 1, 4, 2, 3).reshape(B * T, -1, H, W)

        # Encoder block 2 (with downsampling)
        h2 = self.enc2_conv(h1)  # (B*T, base_dim*2, H/2, W/2)
        H2, W2 = H // 2, W // 2
        t2 = self.time_proj2(t_emb)[:, :, None, None].repeat_interleave(T, dim=0)
        h2 = h2 + t2
        h2_attn = h2.reshape(B, T, -1, H2, W2).permute(0, 1, 3, 4, 2)
        h2_attn = self.enc2_attn(h2_attn)
        h2 = h2_attn.permute(0, 1, 4, 2, 3).reshape(B * T, -1, H2, W2)

        # Bottleneck
        bot = h2.reshape(B, T, -1, H2, W2).permute(0, 1, 3, 4, 2)
        bot = self.bottleneck(bot)
        bot = bot.permute(0, 1, 4, 2, 3).reshape(B * T, -1, H2, W2)

        # Decoder with skip connection
        dec = torch.cat([bot, h2], dim=1)  # Skip connection
        dec = self.dec2_conv(dec)  # (B*T, base_dim, H, W)
        dec_attn = dec.reshape(B, T, -1, H, W).permute(0, 1, 3, 4, 2)
        dec_attn = self.dec2_attn(dec_attn)
        dec = dec_attn.permute(0, 1, 4, 2, 3).reshape(B * T, -1, H, W)

        # Final output with skip connection
        out = self.out_conv(torch.cat([dec, h1], dim=1))  # (B*T, C, H, W)
        return out.reshape(B, T, C, H, W).permute(0, 2, 1, 3, 4)

In [None]:
# Test the model
model = FactorizedVideoUNet(in_channels=1, base_dim=32).to(device)
test_x = torch.randn(2, 1, 8, 32, 32, device=device)
test_t = torch.randint(0, 1000, (2,), device=device)
test_out = model(test_x, test_t)

print(f"Input:  {test_x.shape}")
print(f"Output: {test_out.shape}")
num_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {num_params:,}")
print("‚úÖ Factorized Video U-Net works!")

In [None]:
#@title üéß Listen: Todo Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/11_todo_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. üîß Your Turn ‚Äî Implement the Training Step

### TODO: Complete the training function

This is the same diffusion training loop from Notebook 1, but now with our factorized attention model. Complete the missing pieces:

In [None]:
def cosine_beta_schedule(timesteps, s=0.008):
    """Cosine noise schedule (same as Notebook 1)."""
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clamp(betas, 0.0001, 0.999)

# Precompute schedule values
num_timesteps = 500
betas = cosine_beta_schedule(num_timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)


def train_one_step(model, optimizer, video_batch):
    """
    Perform one training step of the diffusion model.

    Args:
        model: the noise prediction network
        optimizer: Adam optimizer
        video_batch: clean videos, shape (B, 1, T, H, W)

    Returns:
        loss value (float)
    """
    optimizer.zero_grad()
    B = video_batch.shape[0]

    # ============ TODO ============
    # Step 1: Sample random timesteps for each video in the batch
    #         t should be integers in [0, num_timesteps)
    # Step 2: Sample random Gaussian noise (same shape as video_batch)
    # Step 3: Create the noisy video using the forward diffusion formula:
    #         noisy = sqrt_alphas_cumprod[t] * clean + sqrt_one_minus_alphas_cumprod[t] * noise
    #         (Hint: you need to reshape the schedule values for broadcasting)
    # Step 4: Predict the noise using the model
    # Step 5: Compute MSE loss between predicted and actual noise
    # Step 6: Backpropagate and update weights
    # ==============================

    t = ???  # YOUR CODE HERE (Step 1)
    noise = ???  # YOUR CODE HERE (Step 2)
    noisy_video = ???  # YOUR CODE HERE (Step 3)
    predicted_noise = ???  # YOUR CODE HERE (Step 4)
    loss = ???  # YOUR CODE HERE (Step 5)

    loss.backward()  # Step 6
    optimizer.step()

    return loss.item()

In [None]:
# ‚úÖ Verification: Test your implementation
_test_model = FactorizedVideoUNet(in_channels=1, base_dim=32).to(device)
_test_opt = torch.optim.Adam(_test_model.parameters(), lr=1e-3)
_test_batch = dataset[:4].to(device)

try:
    _test_loss = train_one_step(_test_model, _test_opt, _test_batch)
    assert isinstance(_test_loss, float), "Loss should be a float"
    assert 0 < _test_loss < 10, f"Loss {_test_loss} seems wrong (expected 0-10 range)"
    print(f"‚úÖ Training step works! Loss = {_test_loss:.4f}")
except Exception as e:
    print(f"‚ùå Error: {e}")
    print("Hint: Make sure t has shape (B,), noise has same shape as video_batch,")
    print("and schedule values are reshaped to (B, 1, 1, 1, 1) for broadcasting.")

del _test_model, _test_opt

In [None]:
#@title üéß Listen: Training Results
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/12_training_results.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. Training the Model

In [None]:
# Training loop
model = FactorizedVideoUNet(in_channels=1, base_dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

num_epochs = 30
batch_size = 16
losses = []

print("Training factorized video diffusion model...")
for epoch in range(num_epochs):
    epoch_losses = []

    # Shuffle dataset
    perm = torch.randperm(len(dataset))

    for i in range(0, len(dataset) - batch_size + 1, batch_size):
        batch = dataset[perm[i:i+batch_size]].to(device)
        loss = train_one_step(model, optimizer, batch)
        epoch_losses.append(loss)

    avg_loss = np.mean(epoch_losses)
    losses.append(avg_loss)

    if (epoch + 1) % 5 == 0:
        print(f"  Epoch {epoch+1}/{num_epochs} ‚Äî Loss: {avg_loss:.4f}")

print("Training complete!")

In [None]:
# üìä Training curve
plt.figure(figsize=(10, 4))
plt.plot(losses, 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Factorized Video U-Net ‚Äî Training Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### 8.1 Sampling

In [None]:
@torch.no_grad()
def sample_videos(model, num_samples=8, num_frames=8, size=32, num_channels=1):
    """Generate videos using DDPM sampling."""
    model.eval()

    # Start from pure noise
    x = torch.randn(num_samples, num_channels, num_frames, size, size, device=device)

    for t_idx in reversed(range(num_timesteps)):
        t = torch.full((num_samples,), t_idx, device=device, dtype=torch.long)

        # Predict noise
        predicted_noise = model(x, t)

        # DDPM update step
        beta_t = betas[t_idx]
        alpha_t = alphas[t_idx]
        alpha_cumprod_t = alphas_cumprod[t_idx]

        # Mean of p(x_{t-1} | x_t)
        coeff1 = 1.0 / torch.sqrt(alpha_t)
        coeff2 = beta_t / torch.sqrt(1.0 - alpha_cumprod_t)
        mean = coeff1 * (x - coeff2 * predicted_noise)

        # Add noise (except at t=0)
        if t_idx > 0:
            noise = torch.randn_like(x)
            sigma = torch.sqrt(beta_t)
            x = mean + sigma * noise
        else:
            x = mean

    model.train()
    return x.clamp(0, 1)

print("Generating videos...")
generated = sample_videos(model, num_samples=8)
print(f"Generated shape: {generated.shape}")

In [None]:
# üìä Display generated videos
fig, axes = plt.subplots(4, 8, figsize=(14, 7))
for row in range(4):
    for col in range(8):
        axes[row, col].imshow(generated[row, 0, col].cpu().numpy(),
                             cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')
        if row == 0:
            axes[row, col].set_title(f'f={col}', fontsize=9)
    axes[row, 0].set_ylabel(f'Video {row+1}', fontsize=10, rotation=0, labelpad=45)

fig.suptitle('Generated Videos ‚Äî Factorized Attention Model', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Todo Causal
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/13_todo_causal.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 9. üîß Your Turn ‚Äî Experiment with Attention Masking

### TODO: Implement causal temporal attention

In the current temporal attention, every frame can attend to every other frame (bidirectional). But what if we want **causal** temporal attention, where each frame can only attend to past frames? This is useful for autoregressive video generation.

In [None]:
class CausalTemporalAttention(nn.Module):
    """
    Temporal self-attention with a causal mask.
    Frame t can only attend to frames 0, 1, ..., t (not future frames).

    Input: (B, T, H, W, C)
    Output: (B, T, H, W, C)
    """
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.norm = nn.LayerNorm(dim)
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        B, T, H, W, C = x.shape

        # Merge spatial into batch
        x_flat = x.permute(0, 2, 3, 1, 4).reshape(B * H * W, T, C)
        x_normed = self.norm(x_flat)

        BHW, T, C = x_normed.shape
        qkv = self.qkv(x_normed).reshape(BHW, T, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale

        # ============ TODO ============
        # Create a causal mask that prevents attending to future frames.
        # The mask should be a (T, T) boolean tensor where:
        #   mask[i, j] = True if frame i should NOT attend to frame j
        #   (i.e., True for j > i ‚Äî future frames)
        # Apply the mask by setting masked positions to -inf before softmax
        # ==============================

        causal_mask = ???  # YOUR CODE HERE
        attn = ???  # YOUR CODE HERE: apply the mask

        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(BHW, T, C)
        out = self.proj(out)

        # Residual + reshape back
        x_flat = x_flat + out
        return x_flat.reshape(B, H, W, T, C).permute(0, 3, 1, 2, 4)

In [None]:
# ‚úÖ Verification
causal = CausalTemporalAttention(dim=32, num_heads=4).to(device)
test_in = torch.randn(1, 4, 4, 4, 32, device=device)

try:
    test_out = causal(test_in)
    assert test_out.shape == test_in.shape, f"Shape mismatch: {test_out.shape}"

    # Verify causality: changing frame 3 should NOT affect frames 0, 1, 2
    test_in_mod = test_in.clone()
    test_in_mod[:, 3] = torch.randn(1, 4, 4, 32, device=device)

    out_orig = causal(test_in)
    out_mod = causal(test_in_mod)

    for t in range(3):
        diff = (out_orig[:, t] - out_mod[:, t]).abs().max().item()
        assert diff < 1e-5, f"Frame {t} changed when it shouldn't have! diff={diff}"

    diff_3 = (out_orig[:, 3] - out_mod[:, 3]).abs().max().item()
    assert diff_3 > 0.01, "Frame 3 didn't change when it should have!"

    print("‚úÖ Causal temporal attention is correct!")
    print(f"   Frames 0-2 unaffected by changes to frame 3 ‚úì")
    print(f"   Frame 3 properly updated ‚úì")
except Exception as e:
    print(f"‚ùå Error: {e}")
    print("Hint: Use torch.triu to create an upper triangular mask,")
    print("then apply it with masked_fill(-inf) before softmax.")

In [None]:
#@title üéß Listen: Final Output
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/14_final_output.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 10. üéØ Final Output ‚Äî Side-by-Side Comparison

Let us compare the factorized attention model with a simpler baseline to appreciate the architectural differences.

In [None]:
# Generate a final gallery of videos
print("Generating final showcase...")
final_videos = sample_videos(model, num_samples=8)

fig, axes = plt.subplots(8, 8, figsize=(14, 14))
for row in range(8):
    for col in range(8):
        axes[row, col].imshow(final_videos[row, 0, col].cpu().numpy(),
                             cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')
        if row == 0:
            axes[row, col].set_title(f'Frame {col}', fontsize=9)
    axes[row, 0].set_ylabel(f'Video {row+1}', fontsize=9, rotation=0, labelpad=40)

fig.suptitle('üéØ Generated Videos ‚Äî Factorized Space-Time Attention\n'
             'Each row is a separate video showing temporal coherence', fontsize=14)
plt.tight_layout()
plt.show()
print("üéâ You've built factorized space-time attention from scratch!")

In [None]:
#@title üéß Listen: Reflection
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/15_reflection.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 11. Reflection and Next Steps

### ü§î Reflection Questions

1. **Why spatial first?** We do spatial attention before temporal attention. What would happen if we reversed the order? Would the model still work? Would training be different?

2. **Information bottleneck:** In factorized attention, a pixel at (3, 7) in frame 5 can only "communicate" with other frames through temporal attention at position (3, 7). What information might be lost? How could we mitigate this?

3. **Causal vs bidirectional:** We implemented both causal and bidirectional temporal attention. When would you use each? Think about autoregressive generation vs. denoising.

4. **Scaling behavior:** As video resolution increases, which component becomes the bottleneck ‚Äî spatial or temporal attention? What does this tell us about designing architectures for high-resolution video?

### üèÜ Optional Challenges

1. **Alternating order:** Modify the factorized block to alternate ‚Äî odd layers do spatial‚Üítemporal, even layers do temporal‚Üíspatial. Does this help?

2. **Local temporal attention:** Instead of attending to ALL frames, restrict temporal attention to a sliding window of ¬±3 frames. How does this affect quality vs speed?

3. **Cross-frame spatial attention:** Allow limited cross-frame communication in spatial attention by including features from adjacent frames as extra keys/values.

### What's Next?

In Notebook 3, we will tackle two final pieces of the puzzle:
- **Latent Video Diffusion:** Compressing videos with a VAE before running diffusion, for massive efficiency gains
- **Diffusion Transformers (DiT):** Replacing the U-Net entirely with a Transformer using spacetime patches

In [None]:
#@title üí¨ AI Teaching Assistant ‚Äî Click ‚ñ∂ to start
#@markdown This AI chatbot reads your notebook and can answer questions about any concept, code, or exercise.

import json as _json
import requests as _requests
from google.colab import output as _output
from IPython.display import display, HTML as _HTML, Markdown as _Markdown

# --- Read notebook content for context ---
def _get_notebook_context():
    try:
        from google.colab import _message
        nb = _message.blocking_request("get_ipynb", request="", timeout_sec=10)
        cells = nb.get("ipynb", {}).get("cells", [])
        parts = []
        for cell in cells:
            src = "".join(cell.get("source", []))
            tags = cell.get("metadata", {}).get("tags", [])
            if "chatbot" in tags:
                continue
            if src.strip():
                ct = cell.get("cell_type", "unknown")
                parts.append(f"[{ct.upper()}]\n{src}")
        return "\n\n---\n\n".join(parts)
    except Exception:
        return "Notebook content unavailable."

_NOTEBOOK_CONTEXT = _get_notebook_context()
_CHAT_HISTORY = []
_API_URL = "https://course-creator-brown.vercel.app/api/chat"

def _notebook_chat(question):
    global _CHAT_HISTORY
    try:
        resp = _requests.post(_API_URL, json={
            'question': question,
            'context': _NOTEBOOK_CONTEXT[:100000],
            'history': _CHAT_HISTORY[-10:],
        }, timeout=60)
        data = resp.json()
        answer = data.get('answer', 'Sorry, I could not generate a response.')
        _CHAT_HISTORY.append({'role': 'user', 'content': question})
        _CHAT_HISTORY.append({'role': 'assistant', 'content': answer})
        return answer
    except Exception as e:
        return f'Error connecting to teaching assistant: {str(e)}'

_output.register_callback('notebook_chat', _notebook_chat)

def ask(question):
    """Ask the AI teaching assistant a question about this notebook."""
    answer = _notebook_chat(question)
    display(_Markdown(answer))

print("\u2705 AI Teaching Assistant is ready!")
print("\U0001f4a1 Use the chat below, or call ask(\'your question\') in any cell.")

# --- Display chat widget ---
display(_HTML('''<style>
  .vc-wrap{font-family:-apple-system,BlinkMacSystemFont,'Segoe UI',Roboto,sans-serif;max-width:100%;border-radius:16px;overflow:hidden;box-shadow:0 4px 24px rgba(0,0,0,.12);background:#fff;border:1px solid #e5e7eb}
  .vc-hdr{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;padding:16px 20px;display:flex;align-items:center;gap:12px}
  .vc-avatar{width:42px;height:42px;background:rgba(255,255,255,.2);border-radius:50%;display:flex;align-items:center;justify-content:center;font-size:22px}
  .vc-hdr h3{font-size:16px;font-weight:600;margin:0}
  .vc-hdr p{font-size:12px;opacity:.85;margin:2px 0 0}
  .vc-msgs{height:420px;overflow-y:auto;padding:16px;background:#f8f9fb;display:flex;flex-direction:column;gap:10px}
  .vc-msg{display:flex;flex-direction:column;animation:vc-fade .25s ease}
  .vc-msg.user{align-items:flex-end}
  .vc-msg.bot{align-items:flex-start}
  .vc-bbl{max-width:85%;padding:10px 14px;border-radius:16px;font-size:14px;line-height:1.55;word-wrap:break-word}
  .vc-msg.user .vc-bbl{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border-bottom-right-radius:4px}
  .vc-msg.bot .vc-bbl{background:#fff;color:#1a1a2e;border:1px solid #e8e8e8;border-bottom-left-radius:4px}
  .vc-bbl code{background:rgba(0,0,0,.07);padding:2px 6px;border-radius:4px;font-size:13px;font-family:'Fira Code',monospace}
  .vc-bbl pre{background:#1e1e2e;color:#cdd6f4;padding:12px;border-radius:8px;overflow-x:auto;margin:8px 0;font-size:13px}
  .vc-bbl pre code{background:none;padding:0;color:inherit}
  .vc-bbl h3,.vc-bbl h4{margin:10px 0 4px;font-size:15px}
  .vc-bbl ul,.vc-bbl ol{margin:4px 0;padding-left:20px}
  .vc-bbl li{margin:2px 0}
  .vc-chips{display:flex;flex-wrap:wrap;gap:8px;padding:0 16px 12px;background:#f8f9fb}
  .vc-chip{background:#fff;border:1px solid #d1d5db;border-radius:20px;padding:6px 14px;font-size:12px;cursor:pointer;transition:all .15s;color:#4b5563}
  .vc-chip:hover{border-color:#667eea;color:#667eea;background:#f0f0ff}
  .vc-input{display:flex;padding:12px 16px;background:#fff;border-top:1px solid #eee;gap:8px}
  .vc-input input{flex:1;padding:10px 16px;border:2px solid #e8e8e8;border-radius:24px;font-size:14px;outline:none;transition:border-color .2s}
  .vc-input input:focus{border-color:#667eea}
  .vc-input button{background:linear-gradient(135deg,#667eea 0%,#764ba2 100%);color:#fff;border:none;border-radius:50%;width:42px;height:42px;cursor:pointer;display:flex;align-items:center;justify-content:center;font-size:18px;transition:transform .1s}
  .vc-input button:hover{transform:scale(1.05)}
  .vc-input button:disabled{opacity:.5;cursor:not-allowed;transform:none}
  .vc-typing{display:flex;gap:5px;padding:4px 0}
  .vc-typing span{width:8px;height:8px;background:#667eea;border-radius:50%;animation:vc-bounce 1.4s infinite ease-in-out}
  .vc-typing span:nth-child(2){animation-delay:.2s}
  .vc-typing span:nth-child(3){animation-delay:.4s}
  @keyframes vc-bounce{0%,80%,100%{transform:scale(0)}40%{transform:scale(1)}}
  @keyframes vc-fade{from{opacity:0;transform:translateY(8px)}to{opacity:1;transform:translateY(0)}}
  .vc-note{text-align:center;font-size:11px;color:#9ca3af;padding:8px 16px 12px;background:#fff}
</style>
<div class="vc-wrap">
  <div class="vc-hdr">
    <div class="vc-avatar">&#129302;</div>
    <div>
      <h3>Vizuara Teaching Assistant</h3>
      <p>Ask me anything about this notebook</p>
    </div>
  </div>
  <div class="vc-msgs" id="vcMsgs">
    <div class="vc-msg bot">
      <div class="vc-bbl">&#128075; Hi! I've read through this entire notebook. Ask me about any concept, code block, or exercise &mdash; I'm here to help you learn!</div>
    </div>
  </div>
  <div class="vc-chips" id="vcChips">
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Explain the main concept</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Help with the TODO exercise</span>
    <span class="vc-chip" onclick="vcAsk(this.textContent)">Summarize what I learned</span>
  </div>
  <div class="vc-input">
    <input type="text" id="vcIn" placeholder="Ask about concepts, code, exercises..." />
    <button id="vcSend" onclick="vcSendMsg()">&#10148;</button>
  </div>
  <div class="vc-note">AI-generated &middot; Verify important information &middot; <a href="#" onclick="vcClear();return false" style="color:#667eea">Clear chat</a></div>
</div>
<script>
(function(){
  var msgs=document.getElementById('vcMsgs'),inp=document.getElementById('vcIn'),
      btn=document.getElementById('vcSend'),chips=document.getElementById('vcChips');

  function esc(s){var d=document.createElement('div');d.textContent=s;return d.innerHTML}

  function md(t){
    return t
      .replace(/```(\w*)\n([\s\S]*?)```/g,function(_,l,c){return '<pre><code>'+esc(c)+'</code></pre>'})
      .replace(/`([^`]+)`/g,'<code>$1</code>')
      .replace(/\*\*([^*]+)\*\*/g,'<strong>$1</strong>')
      .replace(/\*([^*]+)\*/g,'<em>$1</em>')
      .replace(/^#### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^### (.+)$/gm,'<h4>$1</h4>')
      .replace(/^## (.+)$/gm,'<h3>$1</h3>')
      .replace(/^\d+\. (.+)$/gm,'<li>$1</li>')
      .replace(/^- (.+)$/gm,'<li>$1</li>')
      .replace(/\n\n/g,'<br><br>')
      .replace(/\n/g,'<br>');
  }

  function addMsg(text,isUser){
    var m=document.createElement('div');m.className='vc-msg '+(isUser?'user':'bot');
    var b=document.createElement('div');b.className='vc-bbl';
    b.innerHTML=isUser?esc(text):md(text);
    m.appendChild(b);msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function showTyping(){
    var m=document.createElement('div');m.className='vc-msg bot';m.id='vcTyping';
    m.innerHTML='<div class="vc-bbl"><div class="vc-typing"><span></span><span></span><span></span></div></div>';
    msgs.appendChild(m);msgs.scrollTop=msgs.scrollHeight;
  }

  function hideTyping(){var e=document.getElementById('vcTyping');if(e)e.remove()}

  window.vcSendMsg=function(){
    var q=inp.value.trim();if(!q)return;
    inp.value='';chips.style.display='none';
    addMsg(q,true);showTyping();btn.disabled=true;
    google.colab.kernel.invokeFunction('notebook_chat',[q],{})
      .then(function(r){
        hideTyping();
        var a=r.data['application/json'];
        addMsg(typeof a==='string'?a:JSON.stringify(a),false);
      })
      .catch(function(){
        hideTyping();
        addMsg('Sorry, I encountered an error. Please check your internet connection and try again.',false);
      })
      .finally(function(){btn.disabled=false;inp.focus()});
  };

  window.vcAsk=function(q){inp.value=q;vcSendMsg()};
  window.vcClear=function(){
    msgs.innerHTML='<div class="vc-msg bot"><div class="vc-bbl">&#128075; Chat cleared. Ask me anything!</div></div>';
    chips.style.display='flex';
  };

  inp.addEventListener('keypress',function(e){if(e.key==='Enter')vcSendMsg()});
  inp.focus();
})();
</script>'''))