# Physics Video Prediction - Training

Train a transformer-based model to predict physics simulation frames.

**Architecture**: CNN Encoder → Transformer → CNN Decoder

**Goal**: Model learns to compress video to latent state that captures physics (position, velocity, etc.)

In [None]:
# Clone repo (for Colab)
!git clone https://github.com/Caleb-Briggs/MNIST_AI.git
%cd MNIST_AI/experiments/physics_prediction

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from IPython.display import HTML
import matplotlib.animation as animation

from physics_sim import (
    Ball, Barrier, PhysicsSimulation,
    generate_trajectory, create_random_simulation, generate_dataset
)
from model import VideoPredictor, count_parameters

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Configuration

In [None]:
# Data config
NUM_TRAJECTORIES = 2000  # More trajectories since each terrain = one context window
FRAMES_PER_TRAJECTORY = 50  # Shorter trajectories (context + rollout is enough)
NUM_BARRIERS = 3
WITH_GRAVITY = False  # Start simple

# Model config
LATENT_DIM = 256
N_HEADS = 4
N_LAYERS = 4
CONTEXT_LEN = 8  # Fixed context window (not shifted during rollout)
ROLLOUT_STEPS = 10  # Predict this many frames from fixed context

# Training config
BATCH_SIZE = 128  # Larger batch for A100
LEARNING_RATE = 3e-4  # Slightly higher LR for larger batch
NUM_EPOCHS = 50
SEED = 42

# Derived
MAX_ROLLOUT = FRAMES_PER_TRAJECTORY - CONTEXT_LEN  # Max frames we can predict

torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Each trajectory: {CONTEXT_LEN} context frames -> up to {MAX_ROLLOUT} rollout frames")

## 2. Generate Training Data

In [None]:
print("Generating training data...")
data = generate_dataset(
    num_trajectories=NUM_TRAJECTORIES,
    num_frames=FRAMES_PER_TRAJECTORY,
    num_barriers=NUM_BARRIERS,
    with_gravity=WITH_GRAVITY,
    base_seed=SEED
)

print(f"Dataset shape: {data.shape}")
print(f"Memory: {data.nbytes / 1024 / 1024:.1f} MB")

# Convert to torch tensor
data_tensor = torch.from_numpy(data).float().unsqueeze(2)  # Add channel dim: (N, T, 1, H, W)
print(f"Tensor shape: {data_tensor.shape}")

In [None]:
# Visualize a few trajectories
fig, axes = plt.subplots(3, 8, figsize=(16, 6))
for row in range(3):
    for col in range(8):
        t = col * 10
        axes[row, col].imshow(data[row, t], cmap='gray', vmin=0, vmax=1)
        if row == 0:
            axes[row, col].set_title(f't={t}')
        axes[row, col].axis('off')
plt.suptitle('Sample Trajectories')
plt.tight_layout()
plt.show()

## 3. Create Dataset and DataLoader

In [None]:
class PhysicsDataset(torch.utils.data.Dataset):
    """Dataset that returns (context_frames, all_future_frames) per trajectory.
    
    Each trajectory = one terrain configuration = one training sample.
    Context window is FIXED (not shifted during rollout).
    """
    
    def __init__(self, trajectories: torch.Tensor, context_len: int):
        """
        Args:
            trajectories: (num_traj, num_frames, 1, H, W)
            context_len: number of frames to use as fixed context
        """
        self.trajectories = trajectories
        self.context_len = context_len
        self.num_traj = trajectories.size(0)
        self.num_frames = trajectories.size(1)
    
    def __len__(self):
        return self.num_traj
    
    def __getitem__(self, idx):
        # Context: first context_len frames (FIXED, never shifts)
        # Targets: ALL remaining frames
        context = self.trajectories[idx, :self.context_len]  # (context_len, 1, 64, 64)
        targets = self.trajectories[idx, self.context_len:]  # (num_frames - context_len, 1, 64, 64)
        
        return context, targets

# Split into train/val
train_size = int(0.9 * NUM_TRAJECTORIES)
train_data = data_tensor[:train_size]
val_data = data_tensor[train_size:]

train_dataset = PhysicsDataset(train_data, CONTEXT_LEN)
val_dataset = PhysicsDataset(val_data, CONTEXT_LEN)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True
)

print(f"Train trajectories: {len(train_dataset)}")
print(f"Val trajectories: {len(val_dataset)}")
print(f"Each sample: {CONTEXT_LEN} fixed context frames -> {MAX_ROLLOUT} target frames")

## 4. Create Model

In [None]:
model = VideoPredictor(
    latent_dim=LATENT_DIM,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    dim_feedforward=LATENT_DIM * 2,
    dropout=0.1
).to(device)

print(f"Model parameters: {count_parameters(model):,}")

# Test forward pass
sample_context, sample_target = next(iter(train_loader))
sample_context = sample_context.to(device)
with torch.no_grad():
    sample_pred = model(sample_context)
print(f"Input shape: {sample_context.shape}")
print(f"Output shape: {sample_pred.shape}")

## 5. Training Loop

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

# Loss function: MSE on pixels
criterion = nn.MSELoss()

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'lr': []
}

In [None]:
def train_epoch(model, loader, optimizer, criterion, device, rollout_steps):
    """Train with parallel multi-step prediction (fixed context)."""
    model.train()
    total_loss = 0
    
    for context, targets in loader:
        context = context.to(device)  # (batch, context_len, 1, 64, 64)
        targets = targets.to(device)  # (batch, max_rollout, 1, 64, 64)
        
        optimizer.zero_grad()
        
        # Predict all future frames in parallel from fixed context
        # Only use first rollout_steps targets for training
        preds = model(context, n_future=rollout_steps)  # (batch, rollout_steps, 1, 64, 64)
        targets_subset = targets[:, :rollout_steps]  # (batch, rollout_steps, 1, 64, 64)
        
        # MSE loss over all predicted frames
        loss = criterion(preds, targets_subset)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * context.size(0)
    
    return total_loss / len(loader.dataset)

@torch.no_grad()
def eval_epoch(model, loader, criterion, device, rollout_steps):
    """Evaluate with parallel multi-step prediction."""
    model.eval()
    total_loss = 0
    
    for context, targets in loader:
        context = context.to(device)
        targets = targets.to(device)
        
        preds = model(context, n_future=rollout_steps)
        targets_subset = targets[:, :rollout_steps]
        
        loss = criterion(preds, targets_subset)
        total_loss += loss.item() * context.size(0)
    
    return total_loss / len(loader.dataset)

@torch.no_grad()
def quick_visualize(model, val_data, device, epoch, context_len=8):
    """Quick visualization during training - shows predictions and rollout."""
    model.eval()
    
    fig, axes = plt.subplots(2, 6, figsize=(15, 5))
    
    # Row 0: Single-step predictions (3 examples from different trajectories)
    for col in range(3):
        traj_idx = col * 10
        context = val_data[traj_idx, :context_len].unsqueeze(0).to(device)
        target = val_data[traj_idx, context_len]  # (1, 64, 64)
        pred = model.predict_next(context).cpu().squeeze(0)  # (1, 64, 64)
        
        # Show last context frame
        axes[0, col*2].imshow(context[0, -1, 0].cpu(), cmap='gray', vmin=0, vmax=1)
        axes[0, col*2].set_title(f'Context (traj {traj_idx})')
        axes[0, col*2].axis('off')
        
        # Show prediction vs target overlay
        axes[0, col*2+1].imshow(target[0].numpy(), cmap='gray', vmin=0, vmax=1, alpha=0.5)
        axes[0, col*2+1].imshow(pred[0].clamp(0, 1).numpy(), cmap='Blues', vmin=0, vmax=1, alpha=0.5)
        axes[0, col*2+1].set_title(f'Pred vs GT')
        axes[0, col*2+1].axis('off')
    
    # Row 1: Longer rollout from first trajectory
    context = val_data[0, :context_len].unsqueeze(0).to(device)
    rollout = model.rollout(context, 6).cpu().squeeze(0)  # (6, 1, 64, 64)
    gt = val_data[0, context_len:context_len+6]  # (6, 1, 64, 64)
    
    for col in range(6):
        axes[1, col].imshow(gt[col, 0].numpy(), cmap='gray', vmin=0, vmax=1, alpha=0.5)
        axes[1, col].imshow(rollout[col, 0].clamp(0, 1).numpy(), cmap='Blues', vmin=0, vmax=1, alpha=0.5)
        axes[1, col].set_title(f't+{col+1}')
        axes[1, col].axis('off')
    
    plt.suptitle(f'Epoch {epoch} - Blue=Predicted, Gray=Ground Truth')
    plt.tight_layout()
    plt.show()

In [None]:
# Visualization frequency
VIS_EVERY = 10  # Show visualizations every N epochs

print(f"Training for {NUM_EPOCHS} epochs...")
print(f"Multi-step rollout training: {ROLLOUT_STEPS} steps")
print(f"Visualizations every {VIS_EVERY} epochs")
print("="*60)

for epoch in range(NUM_EPOCHS):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device, ROLLOUT_STEPS)
    val_loss = eval_epoch(model, val_loader, criterion, device, ROLLOUT_STEPS)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['lr'].append(optimizer.param_groups[0]['lr'])
    
    scheduler.step()
    
    # Print progress
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}: train_loss={train_loss:.6f}, val_loss={val_loss:.6f}")
    
    # Periodic visualization
    if (epoch + 1) % VIS_EVERY == 0 or epoch == 0:
        quick_visualize(model, val_data, device, epoch + 1, CONTEXT_LEN)

print("="*60)
print("Training complete!")

## 6. Visualizations

### Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curves
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('MSE Loss')
axes[0].set_title('Training Curves')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Learning rate
axes[1].plot(history['lr'])
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

### Single-Step Predictions vs Ground Truth

In [None]:
@torch.no_grad()
def visualize_predictions(model, dataset, device, num_samples=5, num_future=3):
    """Show context, predictions at different horizons, and ground truth."""
    model.eval()
    
    # Columns: context frames + predicted frames + ground truth frames
    num_cols = CONTEXT_LEN + num_future * 2
    fig, axes = plt.subplots(num_samples, num_cols, figsize=(2*num_cols, 2.5*num_samples))
    
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    for row, idx in enumerate(indices):
        context, targets = dataset[idx]
        context = context.unsqueeze(0).to(device)  # Add batch dim
        preds = model(context, n_future=num_future).cpu().squeeze(0)  # (num_future, 1, 64, 64)
        targets = targets[:num_future]  # (num_future, 1, 64, 64)
        context = context.cpu().squeeze(0)
        
        col = 0
        # Show context frames
        for i in range(CONTEXT_LEN):
            axes[row, col].imshow(context[i, 0], cmap='gray', vmin=0, vmax=1)
            if row == 0:
                axes[row, col].set_title(f'Ctx {i+1}')
            axes[row, col].axis('off')
            col += 1
        
        # Show predicted and ground truth frames side by side
        for t in range(num_future):
            # Predicted
            axes[row, col].imshow(preds[t, 0].clamp(0, 1), cmap='gray', vmin=0, vmax=1)
            if row == 0:
                axes[row, col].set_title(f'Pred t+{t+1}')
            axes[row, col].axis('off')
            col += 1
            
            # Ground truth
            axes[row, col].imshow(targets[t, 0], cmap='gray', vmin=0, vmax=1)
            if row == 0:
                axes[row, col].set_title(f'GT t+{t+1}')
            axes[row, col].axis('off')
            col += 1
    
    plt.suptitle('Context → Predictions vs Ground Truth')
    plt.tight_layout()
    plt.show()

visualize_predictions(model, val_dataset, device, num_samples=5, num_future=5)

### Autoregressive Rollouts

In [None]:
@torch.no_grad()
def visualize_rollout(model, trajectories, device, traj_idx=0, rollout_steps=30):
    """Compare autoregressive rollout with ground truth."""
    model.eval()
    
    # Get initial context
    context = trajectories[traj_idx, :CONTEXT_LEN].unsqueeze(0).to(device)
    
    # Ground truth future
    gt_future = trajectories[traj_idx, CONTEXT_LEN:CONTEXT_LEN+rollout_steps].cpu().numpy()
    
    # Autoregressive rollout
    predicted = model.rollout(context, rollout_steps).cpu().squeeze().numpy()
    
    # Visualize
    num_show = min(10, rollout_steps)
    step_indices = np.linspace(0, rollout_steps-1, num_show, dtype=int)
    
    fig, axes = plt.subplots(3, num_show, figsize=(2*num_show, 6))
    
    for col, t in enumerate(step_indices):
        # Predicted
        axes[0, col].imshow(predicted[t, 0], cmap='gray', vmin=0, vmax=1)
        if col == 0:
            axes[0, col].set_ylabel('Predicted')
        axes[0, col].set_title(f't+{t+1}')
        axes[0, col].axis('off')
        
        # Ground truth
        axes[1, col].imshow(gt_future[t, 0], cmap='gray', vmin=0, vmax=1)
        if col == 0:
            axes[1, col].set_ylabel('Ground Truth')
        axes[1, col].axis('off')
        
        # Error
        error = np.abs(predicted[t, 0] - gt_future[t, 0])
        axes[2, col].imshow(error, cmap='hot', vmin=0, vmax=0.5)
        if col == 0:
            axes[2, col].set_ylabel('Error')
        axes[2, col].axis('off')
    
    plt.suptitle(f'Autoregressive Rollout ({rollout_steps} steps)')
    plt.tight_layout()
    plt.show()
    
    # Compute MSE over time
    mse_over_time = np.mean((predicted[:, 0] - gt_future[:, 0])**2, axis=(1, 2))
    
    plt.figure(figsize=(8, 4))
    plt.plot(range(1, rollout_steps+1), mse_over_time)
    plt.xlabel('Rollout Step')
    plt.ylabel('MSE')
    plt.title('Prediction Error vs Rollout Length')
    plt.grid(alpha=0.3)
    plt.show()

# Test on a few different trajectories
for traj_idx in [0, 10, 20]:
    print(f"\nTrajectory {traj_idx}:")
    visualize_rollout(model, val_data, device, traj_idx=traj_idx % len(val_data), rollout_steps=30)

### Animated Comparison

In [None]:
@torch.no_grad()
def animate_rollout_comparison(model, trajectories, device, traj_idx=0, rollout_steps=50):
    """Create side-by-side animation of prediction vs ground truth."""
    model.eval()
    
    # Get data
    context = trajectories[traj_idx, :CONTEXT_LEN].unsqueeze(0).to(device)
    gt_future = trajectories[traj_idx, CONTEXT_LEN:CONTEXT_LEN+rollout_steps].cpu().numpy()
    predicted = model.rollout(context, rollout_steps).cpu().squeeze().numpy()
    
    # Create animation
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    im_pred = axes[0].imshow(predicted[0, 0], cmap='gray', vmin=0, vmax=1)
    axes[0].set_title('Predicted')
    axes[0].axis('off')
    
    im_gt = axes[1].imshow(gt_future[0, 0], cmap='gray', vmin=0, vmax=1)
    axes[1].set_title('Ground Truth')
    axes[1].axis('off')
    
    error = np.abs(predicted[0, 0] - gt_future[0, 0])
    im_err = axes[2].imshow(error, cmap='hot', vmin=0, vmax=0.5)
    axes[2].set_title('Error')
    axes[2].axis('off')
    
    title = fig.suptitle('t=0')
    
    def update(frame):
        im_pred.set_array(predicted[frame, 0])
        im_gt.set_array(gt_future[frame, 0])
        error = np.abs(predicted[frame, 0] - gt_future[frame, 0])
        im_err.set_array(error)
        title.set_text(f't+{frame+1}')
        return [im_pred, im_gt, im_err, title]
    
    anim = animation.FuncAnimation(fig, update, frames=rollout_steps, interval=100, blit=False)
    plt.close()
    return HTML(anim.to_jshtml())

animate_rollout_comparison(model, val_data, device, traj_idx=0, rollout_steps=50)

## 7. Analysis: Does the Model Learn Physics?

### Latent Space Visualization

In [None]:
@torch.no_grad()
def analyze_latent_space(model, trajectories, device, num_traj=5):
    """Analyze what the latent space encodes."""
    model.eval()
    
    # Encode several trajectories
    all_latents = []
    for i in range(num_traj):
        traj = trajectories[i].unsqueeze(0).to(device)  # (1, T, 1, H, W)
        latents = model.encode_frames(traj).cpu().numpy()  # (1, T, latent_dim)
        all_latents.append(latents[0])
    
    # Plot latent trajectories (first 3 dimensions)
    fig = plt.figure(figsize=(12, 4))
    
    ax1 = fig.add_subplot(131)
    for i, latents in enumerate(all_latents):
        ax1.plot(latents[:, 0], label=f'Traj {i}')
    ax1.set_xlabel('Frame')
    ax1.set_ylabel('Latent dim 0')
    ax1.set_title('Latent Dimension 0 Over Time')
    ax1.legend()
    ax1.grid(alpha=0.3)
    
    ax2 = fig.add_subplot(132)
    for i, latents in enumerate(all_latents):
        ax2.plot(latents[:, 1], label=f'Traj {i}')
    ax2.set_xlabel('Frame')
    ax2.set_ylabel('Latent dim 1')
    ax2.set_title('Latent Dimension 1 Over Time')
    ax2.grid(alpha=0.3)
    
    # 2D projection of latent trajectory
    ax3 = fig.add_subplot(133)
    for i, latents in enumerate(all_latents):
        ax3.plot(latents[:, 0], latents[:, 1], 'o-', markersize=2, alpha=0.7, label=f'Traj {i}')
        ax3.plot(latents[0, 0], latents[0, 1], 'go', markersize=8)  # Start
        ax3.plot(latents[-1, 0], latents[-1, 1], 'ro', markersize=8)  # End
    ax3.set_xlabel('Latent dim 0')
    ax3.set_ylabel('Latent dim 1')
    ax3.set_title('Latent Space Trajectory')
    ax3.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()

analyze_latent_space(model, val_data, device)

## 8. Save Model

In [None]:
# Save model checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'history': history,
    'config': {
        'latent_dim': LATENT_DIM,
        'n_heads': N_HEADS,
        'n_layers': N_LAYERS,
        'context_len': CONTEXT_LEN,
    }
}
torch.save(checkpoint, 'results/model_checkpoint.pt')
print("Model saved to results/model_checkpoint.pt")

## Summary

Key metrics to track:
- **Single-step MSE**: How well does model predict 1 frame ahead?
- **Rollout degradation**: How fast does error grow with longer rollouts?
- **Visual quality**: Do predictions look like valid physics?
- **Latent structure**: Does latent space encode position/velocity?