# Physics Video Prediction - Training

Train a modern transformer to predict physics simulation frames.

**Architecture**: LLaMA-style transformer (RMSNorm, SwiGLU, RoPE)
- Each 32x32 frame is flattened to a 1024-dim vector (one "token" per frame)
- Transformer predicts next frame from variable-length context
- Always starts from frame 1 (teacher forcing)

**Goal**: Model learns physics (position, velocity) from raw pixels.

In [1]:
# === COLAB SETUP ===
# Clone repo (run once)
!rm -rf MNIST_AI 2>/dev/null
!git clone https://github.com/Caleb-Briggs/MNIST_AI.git
%cd MNIST_AI/experiments/physics_prediction

# Create results directory
!mkdir -p results

# === STORAGE OPTIONS ===
# Option 1: Google Drive (browser-based Colab only)
# from google.colab import drive
# drive.mount('/content/drive')
# !mkdir -p /content/drive/MyDrive/physics_prediction_results
# !ln -sf /content/drive/MyDrive/physics_prediction_results results

# Option 2: For VS Code - checkpoints save locally to results/
# Download them manually or use the sync cell below

print("Setup complete! Checkpoints save to results/")

Cloning into 'MNIST_AI'...
remote: Enumerating objects: 170, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (72/72), done.[K
remote: Total 170 (delta 60), reused 62 (delta 24), pack-reused 72 (from 1)[K
Receiving objects: 100% (170/170), 22.98 MiB | 27.65 MiB/s, done.B/s
Resolving deltas: 100% (79/79), done.
/content/MNIST_AI/experiments/physics_prediction
Setup complete! Checkpoints save to results/


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from IPython.display import HTML
import matplotlib.animation as animation

from physics_sim import (
    Ball, Barrier, PhysicsSimulation,
    generate_trajectory, create_random_simulation, generate_dataset
)
from model import VideoPredictor, count_parameters

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Device: cuda
GPU: NVIDIA A100-SXM4-40GB


## 1. Configuration

In [None]:
# Data config
NUM_TRAJECTORIES = 5000
FRAMES_PER_TRAJECTORY = 25  # Shorter trajectories, max context ~24
NUM_BARRIERS = 2
WITH_GRAVITY = False
RESOLUTION = 32  # 32x32 frames
DT = 6.0  # Time step - high enough that ball moves visibly

# Model config
N_HEADS = 8
N_LAYERS = 6
DROPOUT = 0.0

# Training config
BATCH_SIZE = 128
LEARNING_RATE = 3e-4
NUM_EPOCHS = 1000
SEED = 421

# Variable context: sample context lengths from this range during training
MIN_CONTEXT = 3   # At least 3 frames to see motion
MAX_CONTEXT = 20  # Up to 20 frames of context

# Data diversity: regenerate trajectories every N epochs
REGEN_DATA_EVERY = 25

# Visualization
VIS_EVERY = 25

torch.manual_seed(SEED)
np.random.seed(SEED)

# Enable TF32 for faster matmuls on A100
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

print(f"Model: {N_LAYERS} layers, {N_HEADS} heads, frame_dim=1024")
print(f"Training: batch={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={NUM_EPOCHS}")
print(f"Context length: {MIN_CONTEXT}-{MAX_CONTEXT} frames")
print(f"Data regeneration every {REGEN_DATA_EVERY} epochs")

## 2. Generate Training Data

In [None]:
print("Generating training data...")
data = generate_dataset(
    num_trajectories=NUM_TRAJECTORIES,
    num_frames=FRAMES_PER_TRAJECTORY,
    num_barriers=NUM_BARRIERS,
    with_gravity=WITH_GRAVITY,
    resolution=RESOLUTION,
    base_seed=SEED,
    dt=DT
)

print(f"Dataset shape: {data.shape}")
print(f"Memory: {data.nbytes / 1024 / 1024:.1f} MB")

# Convert to torch tensor (no channel dim needed now)
data_tensor = torch.from_numpy(data).float()  # (N, T, H, W)
print(f"Tensor shape: {data_tensor.shape}")

In [None]:
# Visualize a few trajectories
fig, axes = plt.subplots(3, 8, figsize=(16, 6))
for row in range(3):
    for col in range(8):
        t = col * 3  # 0, 3, 6, 9, 12, 15, 18, 21
        axes[row, col].imshow(data[row, t], cmap='gray', vmin=0, vmax=1)
        if row == 0:
            axes[row, col].set_title(f't={t}')
        axes[row, col].axis('off')
plt.suptitle('Sample Trajectories (32x32)')
plt.tight_layout()
plt.show()

## 3. Create Dataset and DataLoader

In [None]:
class PhysicsDataset(torch.utils.data.Dataset):
    """
    Dataset that samples (context, target) pairs from trajectories.
    
    Each sample:
    - context: frames 0 to context_len-1 (always starts from frame 0)
    - target: frame at context_len
    
    Context length is sampled randomly per-item, but we use a fixed
    max context and simply use the actual frames (no padding needed
    since we batch samples with the same context length).
    """
    
    def __init__(self, trajectories: torch.Tensor, min_context: int, max_context: int):
        self.trajectories = trajectories  # (N, T, H, W)
        self.min_context = min_context
        self.max_context = min(max_context, trajectories.size(1) - 1)
        self.num_traj = trajectories.size(0)
        self.context_range = self.max_context - self.min_context + 1
        
    def __len__(self):
        # Each trajectory gives us (max_context - min_context + 1) samples
        return self.num_traj * self.context_range
    
    def __getitem__(self, idx):
        traj_idx = idx // self.context_range
        context_idx = idx % self.context_range
        context_len = self.min_context + context_idx
        
        # Always start from frame 0, get context_len frames
        context = self.trajectories[traj_idx, :context_len]  # (context_len, H, W)
        target = self.trajectories[traj_idx, context_len]     # (H, W)
        
        return context, target, context_len


def collate_by_context_length(batch):
    """
    Group samples by context length and pad to max in batch.
    Returns contexts, targets, and a mask for valid positions.
    """
    contexts, targets, lengths = zip(*batch)
    
    max_len = max(lengths)
    batch_size = len(contexts)
    H, W = contexts[0].shape[1], contexts[0].shape[2]
    
    # Create padded tensor and attention mask
    padded_contexts = torch.zeros(batch_size, max_len, H, W)
    # Mask: True = ignore this position (will be masked in attention)
    attention_mask = torch.ones(batch_size, max_len, dtype=torch.bool)
    
    for i, (ctx, length) in enumerate(zip(contexts, lengths)):
        # Put actual frames at the END (right-aligned) so causal attention works
        start_idx = max_len - length
        padded_contexts[i, start_idx:] = ctx
        attention_mask[i, start_idx:] = False  # False = attend to this position
    
    targets_tensor = torch.stack(targets)
    
    return padded_contexts, targets_tensor, attention_mask


# Split into train/val
train_size = int(0.9 * NUM_TRAJECTORIES)
train_data = data_tensor[:train_size]
val_data = data_tensor[train_size:]

train_dataset = PhysicsDataset(train_data, MIN_CONTEXT, MAX_CONTEXT)
val_dataset = PhysicsDataset(val_data, MIN_CONTEXT, MAX_CONTEXT)

# Dataloaders with custom collate function
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_by_context_length
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=collate_by_context_length
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")

# Verify no NaN in data
print(f"\nData check:")
print(f"  Train data has NaN: {torch.isnan(train_data).any().item()}")
print(f"  Train data range: [{train_data.min():.3f}, {train_data.max():.3f}]")

## 4. Create Model

In [None]:
model = VideoPredictor(
    frame_size=RESOLUTION,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    dropout=DROPOUT
).to(device)

print(f"Model parameters: {count_parameters(model):,}")

# Test forward pass with variable context
sample_context, sample_target = next(iter(train_loader))
sample_context = sample_context.to(device)

with torch.no_grad():
    sample_pred = model(sample_context)

print(f"Context shape: {sample_context.shape}")
print(f"Target shape: {sample_target.shape}")
print(f"Prediction shape: {sample_pred.shape}")

## 5. Training Loop

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-5)

# Loss function: MSE on pixels
criterion = nn.MSELoss()

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'lr': []
}

In [None]:
# Mixed precision training
scaler = torch.amp.GradScaler('cuda')

def train_epoch(model, loader, optimizer, criterion, device, scaler):
    """Train for one epoch with mixed precision."""
    model.train()
    total_loss = 0
    num_batches = 0
    
    for context, target, mask in loader:
        context = context.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        mask = mask.to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        
        with torch.amp.autocast('cuda'):
            pred = model(context, padding_mask=mask)
            loss = criterion(pred, target)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches


@torch.no_grad()
def eval_epoch(model, loader, criterion, device):
    """Evaluate for one epoch."""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    for context, target, mask in loader:
        context = context.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        mask = mask.to(device, non_blocking=True)
        
        with torch.amp.autocast('cuda'):
            pred = model(context, padding_mask=mask)
            loss = criterion(pred, target)
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches


@torch.no_grad()
def quick_visualize(model, val_data, device, epoch):
    """Quick visualization during training."""
    model.eval()
    
    fig, axes = plt.subplots(3, 6, figsize=(14, 7))
    
    for row in range(3):
        traj_idx = row * 50
        context_len = 5 + row * 5  # 5, 10, 15 frames of context
        
        context = val_data[traj_idx, :context_len].unsqueeze(0).to(device)
        target = val_data[traj_idx, context_len]
        pred = model(context).cpu().squeeze(0)  # No padding mask needed for single sample
        
        # Show last 3 context frames
        for i in range(3):
            axes[row, i].imshow(context[0, -(3-i), :, :].cpu(), cmap='gray', vmin=0, vmax=1)
            if row == 0:
                axes[row, i].set_title(f'Ctx t-{2-i}')
            axes[row, i].axis('off')
        
        # Prediction
        axes[row, 3].imshow(pred.clamp(0, 1), cmap='gray', vmin=0, vmax=1)
        if row == 0:
            axes[row, 3].set_title('Predicted')
        axes[row, 3].axis('off')
        
        # Ground truth
        axes[row, 4].imshow(target, cmap='gray', vmin=0, vmax=1)
        if row == 0:
            axes[row, 4].set_title('Ground Truth')
        axes[row, 4].axis('off')
        
        # Error
        error = torch.abs(pred - target)
        axes[row, 5].imshow(error, cmap='hot', vmin=0, vmax=0.5)
        if row == 0:
            axes[row, 5].set_title('Error')
        axes[row, 5].axis('off')
        
        # Label row with context length
        axes[row, 0].set_ylabel(f'ctx={context_len}', fontsize=10)
    
    plt.suptitle(f'Epoch {epoch}')
    plt.tight_layout()


def generate_fresh_data(seed_offset):
    """Generate fresh training data with new terrains."""
    print(f"  Generating {NUM_TRAJECTORIES} fresh trajectories...")
    data = generate_dataset(
        num_trajectories=NUM_TRAJECTORIES,
        num_frames=FRAMES_PER_TRAJECTORY,
        num_barriers=NUM_BARRIERS,
        with_gravity=WITH_GRAVITY,
        resolution=RESOLUTION,
        base_seed=seed_offset,
        dt=DT
    )
    return torch.from_numpy(data).float()

In [None]:
# Main training loop
print(f"Starting training for {NUM_EPOCHS} epochs...")
print(f"Data regeneration every {REGEN_DATA_EVERY} epochs")
print(f"Visualizations every {VIS_EVERY} epochs")
print("="*60)

best_val_loss = float('inf')
data_seed_offset = SEED

# Current data references
current_train_data = train_data
current_val_data = val_data
current_train_loader = train_loader
current_val_loader = val_loader

for epoch in range(NUM_EPOCHS):
    # Regenerate data periodically
    if epoch > 0 and epoch % REGEN_DATA_EVERY == 0:
        print(f"\n{'='*60}")
        print(f"Regenerating training data at epoch {epoch}")
        data_seed_offset += NUM_TRAJECTORIES
        
        fresh_data = generate_fresh_data(data_seed_offset)
        
        train_size = int(0.9 * NUM_TRAJECTORIES)
        current_train_data = fresh_data[:train_size]
        current_val_data = fresh_data[train_size:]
        
        current_train_loader = torch.utils.data.DataLoader(
            PhysicsDataset(current_train_data, MIN_CONTEXT, MAX_CONTEXT),
            batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
            pin_memory=True, collate_fn=collate_by_context_length
        )
        current_val_loader = torch.utils.data.DataLoader(
            PhysicsDataset(current_val_data, MIN_CONTEXT, MAX_CONTEXT),
            batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
            pin_memory=True, collate_fn=collate_by_context_length
        )
        print(f"{'='*60}\n")
    
    # Train and evaluate
    train_loss = train_epoch(model, current_train_loader, optimizer, criterion, device, scaler)
    val_loss = eval_epoch(model, current_val_loader, criterion, device)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['lr'].append(optimizer.param_groups[0]['lr'])
    
    scheduler.step()
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
    
    lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch+1:4d}/{NUM_EPOCHS}: train={train_loss:.6f}, val={val_loss:.6f}, best={best_val_loss:.6f}, lr={lr:.2e}")
    
    # Periodic visualization
    if (epoch + 1) % VIS_EVERY == 0:
        print(f"\n{'='*60}")
        print(f"Visualization at epoch {epoch+1}")
        print(f"{'='*60}\n")
        
        quick_visualize(model, current_val_data, device, epoch + 1)
        plt.show()
        
        # Show rollout
        print("\nRollout comparison (8 steps):")
        model.eval()
        with torch.no_grad():
            traj_idx = 0
            context = current_val_data[traj_idx, :5].unsqueeze(0).to(device)
            gt_future = current_val_data[traj_idx, 5:13].cpu().numpy()
            predicted = model.rollout(context, 8).cpu().squeeze(0).numpy()
        
        fig, axes = plt.subplots(2, 8, figsize=(14, 3.5))
        for t in range(8):
            axes[0, t].imshow(np.clip(predicted[t], 0, 1), cmap='gray', vmin=0, vmax=1)
            axes[0, t].axis('off')
            if t == 0:
                axes[0, t].set_ylabel('Pred', fontsize=10)
            axes[0, t].set_title(f't+{t+1}', fontsize=8)
            
            axes[1, t].imshow(gt_future[t], cmap='gray', vmin=0, vmax=1)
            axes[1, t].axis('off')
            if t == 0:
                axes[1, t].set_ylabel('GT', fontsize=10)
        
        plt.suptitle(f'Rollout at Epoch {epoch+1}')
        plt.tight_layout()
        plt.show()
        
        # Loss curve
        if len(history['train_loss']) > 1:
            fig, ax = plt.subplots(figsize=(10, 3))
            ax.plot(history['train_loss'], label='Train', alpha=0.8)
            ax.plot(history['val_loss'], label='Val', alpha=0.8)
            ax.set_xlabel('Epoch')
            ax.set_ylabel('MSE Loss')
            ax.set_title('Training Progress')
            ax.legend()
            ax.grid(alpha=0.3)
            ax.set_yscale('log')
            plt.tight_layout()
            plt.show()
        
        print(f"\n{'='*60}\n")

print(f"\nTraining complete! Best val_loss: {best_val_loss:.6f}")

## 6. Visualizations

### Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curves
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss')
axes[0].set_title('Training Curves')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Learning rate
axes[1].plot(history['lr'])
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

### Single-Step Predictions vs Ground Truth

In [None]:
@torch.no_grad()
def visualize_predictions(model, val_data, device, num_samples=5):
    """Show context, prediction, ground truth, and error."""
    model.eval()
    
    fig, axes = plt.subplots(num_samples, 6, figsize=(14, 2.5*num_samples))
    
    indices = np.random.choice(len(val_data), num_samples, replace=False)
    
    for row, traj_idx in enumerate(indices):
        context_len = np.random.randint(MIN_CONTEXT, MAX_CONTEXT)
        
        context = val_data[traj_idx, :context_len].unsqueeze(0).to(device)
        target = val_data[traj_idx, context_len]
        pred = model(context).cpu().squeeze(0)
        
        # Show last 3 context frames
        for i in range(3):
            axes[row, i].imshow(context[0, -(3-i)].cpu(), cmap='gray', vmin=0, vmax=1)
            if row == 0:
                axes[row, i].set_title(f'Ctx t-{2-i}')
            axes[row, i].axis('off')
        
        axes[row, 3].imshow(pred.clamp(0, 1), cmap='gray', vmin=0, vmax=1)
        if row == 0:
            axes[row, 3].set_title('Predicted')
        axes[row, 3].axis('off')
        
        axes[row, 4].imshow(target, cmap='gray', vmin=0, vmax=1)
        if row == 0:
            axes[row, 4].set_title('Ground Truth')
        axes[row, 4].axis('off')
        
        error = torch.abs(pred - target)
        axes[row, 5].imshow(error, cmap='hot', vmin=0, vmax=0.5)
        if row == 0:
            axes[row, 5].set_title('Error')
        axes[row, 5].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_predictions(model, val_data, device, num_samples=5)

### Autoregressive Rollouts

In [None]:
@torch.no_grad()
def visualize_rollout(model, trajectories, device, traj_idx=0, context_len=5, rollout_steps=15):
    """Compare autoregressive rollout with ground truth."""
    model.eval()
    
    context = trajectories[traj_idx, :context_len].unsqueeze(0).to(device)
    gt_future = trajectories[traj_idx, context_len:context_len+rollout_steps].cpu().numpy()
    predicted = model.rollout(context, rollout_steps).cpu().squeeze(0).numpy()
    
    num_show = min(8, rollout_steps)
    step_indices = np.linspace(0, rollout_steps-1, num_show, dtype=int)
    
    fig, axes = plt.subplots(3, num_show, figsize=(2*num_show, 6))
    
    for col, t in enumerate(step_indices):
        axes[0, col].imshow(np.clip(predicted[t], 0, 1), cmap='gray', vmin=0, vmax=1)
        if col == 0:
            axes[0, col].set_ylabel('Predicted')
        axes[0, col].set_title(f't+{t+1}')
        axes[0, col].axis('off')
        
        axes[1, col].imshow(gt_future[t], cmap='gray', vmin=0, vmax=1)
        if col == 0:
            axes[1, col].set_ylabel('Ground Truth')
        axes[1, col].axis('off')
        
        error = np.abs(np.clip(predicted[t], 0, 1) - gt_future[t])
        axes[2, col].imshow(error, cmap='hot', vmin=0, vmax=0.5)
        if col == 0:
            axes[2, col].set_ylabel('Error')
        axes[2, col].axis('off')
    
    plt.suptitle(f'Autoregressive Rollout ({rollout_steps} steps, context={context_len})')
    plt.tight_layout()
    plt.show()
    
    # MSE over time
    mse_over_time = np.mean((np.clip(predicted, 0, 1) - gt_future)**2, axis=(1, 2))
    
    plt.figure(figsize=(8, 4))
    plt.plot(range(1, rollout_steps+1), mse_over_time)
    plt.xlabel('Rollout Step')
    plt.ylabel('MSE')
    plt.title('Prediction Error vs Rollout Length')
    plt.grid(alpha=0.3)
    plt.show()

# Test rollout with different context lengths
for ctx_len in [3, 8, 15]:
    print(f"\nContext length: {ctx_len}")
    visualize_rollout(model, val_data, device, traj_idx=0, context_len=ctx_len, rollout_steps=15)

### Animated Comparison

In [None]:
@torch.no_grad()
def animate_rollout_comparison(model, trajectories, device, traj_idx=0, context_len=5, rollout_steps=20):
    """Create side-by-side animation of prediction vs ground truth."""
    model.eval()
    
    context = trajectories[traj_idx, :context_len].unsqueeze(0).to(device)
    gt_future = trajectories[traj_idx, context_len:context_len+rollout_steps].cpu().numpy()
    predicted = model.rollout(context, rollout_steps).cpu().squeeze(0).numpy()
    
    fig, axes = plt.subplots(1, 3, figsize=(10, 3.5))
    
    im_pred = axes[0].imshow(np.clip(predicted[0], 0, 1), cmap='gray', vmin=0, vmax=1)
    axes[0].set_title('Predicted')
    axes[0].axis('off')
    
    im_gt = axes[1].imshow(gt_future[0], cmap='gray', vmin=0, vmax=1)
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')
    
    error = np.abs(np.clip(predicted[0], 0, 1) - gt_future[0])
    im_err = axes[2].imshow(error, cmap='hot', vmin=0, vmax=0.5)
    axes[2].set_title('Error')
    axes[2].axis('off')
    
    title = fig.suptitle('t+1')
    
    def update(frame):
        im_pred.set_array(np.clip(predicted[frame], 0, 1))
        im_gt.set_array(gt_future[frame])
        error = np.abs(np.clip(predicted[frame], 0, 1) - gt_future[frame])
        im_err.set_array(error)
        title.set_text(f't+{frame+1}')
        return [im_pred, im_gt, im_err, title]
    
    anim = animation.FuncAnimation(fig, update, frames=rollout_steps, interval=150, blit=False)
    plt.close()
    return HTML(anim.to_jshtml())

animate_rollout_comparison(model, val_data, device, traj_idx=0, context_len=5, rollout_steps=18)

## 7. Analysis: Does the Model Learn Physics?

### Latent Space Visualization

In [None]:
@torch.no_grad()
def analyze_internal_representations(model, trajectories, device, num_traj=5):
    """Analyze what the transformer learns by looking at attention patterns."""
    model.eval()
    
    # We can't easily extract latents since raw pixels go directly to transformer
    # Instead, analyze the transformer's internal activations on sample trajectories
    
    fig, axes = plt.subplots(num_traj, 4, figsize=(12, 3*num_traj))
    
    for i in range(num_traj):
        # Get a trajectory and run predictions with different context lengths
        traj = trajectories[i]  # (T, H, W)
        
        mses = []
        context_lengths = [3, 5, 10, 15, 20]
        
        for ctx_len in context_lengths:
            if ctx_len >= len(traj) - 1:
                continue
            context = traj[:ctx_len].unsqueeze(0).to(device)
            target = traj[ctx_len]
            pred = model(context).cpu().squeeze(0)
            mse = ((pred - target) ** 2).mean().item()
            mses.append((ctx_len, mse))
        
        # Plot: trajectory snapshots
        for j, t in enumerate([0, 5, 10, 15]):
            if t < len(traj):
                axes[i, j].imshow(traj[t].cpu(), cmap='gray', vmin=0, vmax=1)
                axes[i, j].set_title(f't={t}')
            axes[i, j].axis('off')
    
    plt.suptitle('Sample Trajectories')
    plt.tight_layout()
    plt.show()
    
    # Plot MSE vs context length (aggregate)
    all_mses = {ctx: [] for ctx in [3, 5, 10, 15, 20]}
    
    for i in range(min(50, len(trajectories))):
        traj = trajectories[i]
        for ctx_len in all_mses.keys():
            if ctx_len >= len(traj) - 1:
                continue
            context = traj[:ctx_len].unsqueeze(0).to(device)
            target = traj[ctx_len]
            pred = model(context).cpu().squeeze(0)
            mse = ((pred - target) ** 2).mean().item()
            all_mses[ctx_len].append(mse)
    
    plt.figure(figsize=(8, 4))
    ctx_lens = sorted(all_mses.keys())
    means = [np.mean(all_mses[c]) if all_mses[c] else 0 for c in ctx_lens]
    stds = [np.std(all_mses[c]) if all_mses[c] else 0 for c in ctx_lens]
    
    plt.errorbar(ctx_lens, means, yerr=stds, marker='o', capsize=5)
    plt.xlabel('Context Length')
    plt.ylabel('MSE')
    plt.title('Prediction Error vs Context Length')
    plt.grid(alpha=0.3)
    plt.show()

analyze_internal_representations(model, val_data, device)

## 8. Save Model

In [None]:
# Save model checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'history': history,
    'config': {
        'frame_size': RESOLUTION,
        'n_heads': N_HEADS,
        'n_layers': N_LAYERS,
        'dropout': DROPOUT,
        'min_context': MIN_CONTEXT,
        'max_context': MAX_CONTEXT,
    }
}
torch.save(checkpoint, 'results/model_checkpoint.pt')
print("Model saved to results/model_checkpoint.pt")

## Download Checkpoints (VS Code users)

Run this cell to download checkpoints to your local machine via the VS Code file browser, or zip them for easy download.

In [None]:
# Zip all checkpoints for easy download
import shutil
import os
from datetime import datetime

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
zip_name = f'checkpoints_{timestamp}'

# Create zip of results folder
shutil.make_archive(zip_name, 'zip', 'results')
print(f"Created {zip_name}.zip")
print(f"Size: {os.path.getsize(f'{zip_name}.zip') / 1024 / 1024:.1f} MB")
print("\nDownload via VS Code file browser (left panel) or run:")
print(f"  !cp {zip_name}.zip /content/")

# List what's in results
print("\nCheckpoints in results/:")
!ls -lh results/*.pt 2>/dev/null || echo "No .pt files yet"

## Summary

Key metrics to track:
- **Single-step MSE**: How well does model predict 1 frame ahead?
- **Rollout degradation**: How fast does error grow with longer rollouts?
- **Visual quality**: Do predictions look like valid physics?
- **Latent structure**: Does latent space encode position/velocity?