# Baseline Models + HiMAC-JEPA Training

This notebook provides an interactive interface for training all baseline models AND HiMAC-JEPA.

**Models:**
1. Camera-Only (ResNet18 + LSTM)
2. LiDAR-Only (PointNet++)
3. Radar-Only (3D CNN)
4. I-JEPA (ViT + JEPA, camera-only)
5. V-JEPA (Multi-modal JEPA)
6. **HiMAC-JEPA** (Multi-modal + Hierarchical Actions + JEPA)

**Usage:**
- Run cells sequentially
- Modify hyperparameters as needed
- Monitor training progress with progress bars
- Results saved to checkpoints/

## Setup

In [None]:
import sys
import os
from pathlib import Path
import yaml
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Add project root to path
project_root = Path.cwd().parent if 'notebooks' in str(Path.cwd()) else Path.cwd()
sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Import baseline models
from src.models.baselines import (
    CameraOnlyBaseline,
    LiDAROnlyBaseline,
    RadarOnlyBaseline,
    IJEPABaseline,
    VJEPABaseline
)

# Import HiMAC-JEPA
from src.models.himac_jepa import HiMACJEPA

print("✓ All models imported successfully")

## Configuration

In [None]:
# Select which models to train
MODELS_TO_TRAIN = [
    'camera_only',
    'lidar_only',
    'radar_only',
    'ijepa',
    'vjepa',
    'himac_jepa'  # Add HiMAC-JEPA
]

# Training settings
NUM_EPOCHS = 50  # Set lower for notebook (use 100 for full training)
BATCH_SIZE = 16  # Adjust based on GPU memory
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
USE_WANDB = False  # Set to True to enable W&B logging

print(f"Training configuration:")
print(f"  Models: {MODELS_TO_TRAIN}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Device: {DEVICE}")

## Helper Functions

In [None]:
def load_config(model_name: str) -> dict:
    """Load configuration for a baseline model or HiMAC-JEPA."""
    if model_name == 'himac_jepa':
        config_path = project_root / "configs/config.yaml"
    else:
        config_path = project_root / f"configs/baseline/{model_name}.yaml"

    with open(config_path) as f:
        config = yaml.safe_load(f)

    # Override with notebook settings
    config['training']['num_epochs'] = NUM_EPOCHS
    config['training']['batch_size'] = BATCH_SIZE
    config['training']['learning_rate'] = LEARNING_RATE
    config['logging']['wandb']['enabled'] = USE_WANDB

    return config


def create_model(model_name: str, config: dict):
    """Create baseline model from config."""
    if model_name == 'camera_only':
        model = CameraOnlyBaseline(config['model'])
    elif model_name == 'lidar_only':
        model = LiDAROnlyBaseline(config['model'])
    elif model_name == 'radar_only':
        model = RadarOnlyBaseline(config['model'])
    elif model_name == 'ijepa':
        model = IJEPABaseline(config['model'])
    elif model_name == 'vjepa':
        model = VJEPABaseline(config['model'])
    elif model_name == 'himac_jepa':
        model = HiMACJEPA(config['model'])
    else:
        raise ValueError(f"Unknown model: {model_name}")

    return model


def create_dummy_dataloader(model_name: str, batch_size: int, num_batches: int = 100):
    """Create dummy dataloader for testing (replace with actual dataset)."""

    class DummyDataset:
        def __init__(self, model_name, num_samples):
            self.model_name = model_name
            self.num_samples = num_samples

        def __len__(self):
            return self.num_samples

        def __getitem__(self, idx):
            batch = {}

            if self.model_name in ['camera_only', 'ijepa']:
                batch['camera'] = torch.randn(3, 224, 224)
            elif self.model_name == 'lidar_only':
                batch['lidar'] = torch.randn(2048, 3)
            elif self.model_name == 'radar_only':
                batch['radar'] = torch.randn(1, 128, 128)
            elif self.model_name == 'vjepa':
                # Temporal data for V-JEPA
                batch['camera'] = torch.randn(5, 3, 224, 224)  # (T, C, H, W)
                batch['lidar'] = torch.randn(5, 2048, 3)
                batch['radar'] = torch.randn(5, 1, 128, 128)
            elif self.model_name == 'himac_jepa':
                # Multi-modal + hierarchical actions for HiMAC-JEPA
                batch['camera'] = torch.randn(5, 3, 224, 224)  # (T, C, H, W)
                batch['lidar'] = torch.randn(5, 2048, 3)
                batch['radar'] = torch.randn(5, 1, 128, 128)
                batch['strategic_actions'] = torch.randn(5, 4)  # (T, strategic_dim)
                batch['tactical_actions'] = torch.randn(5, 8)   # (T, tactical_dim)

            return batch

    dataset = DummyDataset(model_name, num_batches * batch_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    return dataloader


print("✓ Helper functions defined")

## Training Function

In [None]:
def train_baseline(model_name: str, config: dict, device: str):
    """Train a single baseline model."""

    print(f"\n{'='*60}")
    print(f"Training: {model_name.upper()}")
    print(f"{'='*60}")

    # Create model
    model = create_model(model_name, config)
    model = model.to(device)

    print(f"Model parameters: {model.get_num_parameters():,}")
    print(f"Model size: {model.get_model_size_mb():.2f} MB")

    # Create optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['training']['learning_rate'],
        weight_decay=config['training']['weight_decay']
    )

    # Create scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config['training']['num_epochs'] - config['training']['warmup_epochs'],
        eta_min=1e-6
    )

    # Create dataloaders (using dummy data - replace with actual dataset)
    train_loader = create_dummy_dataloader(model_name, config['training']['batch_size'], num_batches=50)
    val_loader = create_dummy_dataloader(model_name, config['training']['batch_size'], num_batches=10)

    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'lr': []
    }

    # Training loop
    best_val_loss = float('inf')

    for epoch in range(1, config['training']['num_epochs'] + 1):
        # Train
        model.train()
        train_losses = []

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{config['training']['num_epochs']}")
        for batch in pbar:
            # Move to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

            # Training step
            metrics = model.train_step(batch, optimizer)
            train_losses.append(metrics['loss'])

            pbar.set_postfix(loss=metrics['loss'])

        avg_train_loss = np.mean(train_losses)

        # Validate
        model.eval()
        val_losses = []

        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                metrics = model.val_step(batch)
                val_losses.append(metrics['loss'])

        avg_val_loss = np.mean(val_losses)

        # Update scheduler
        if epoch > config['training']['warmup_epochs']:
            scheduler.step()

        # Save history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        # Print progress
        print(f"Epoch {epoch}: train_loss={avg_train_loss:.4f}, val_loss={avg_val_loss:.4f}, lr={optimizer.param_groups[0]['lr']:.6f}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            # HiMAC-JEPA uses different checkpoint directory
            if model_name == 'himac_jepa':
                save_dir = project_root / 'checkpoints/himac_jepa'
            else:
                save_dir = project_root / config['checkpoint']['save_dir']
            save_dir.mkdir(parents=True, exist_ok=True)
            model.save_checkpoint(
                str(save_dir / 'best_model.pth'),
                epoch=epoch,
                optimizer=optimizer
            )
            print(f"  ✓ Saved best model (val_loss={best_val_loss:.4f})")

    print(f"\n✓ Training complete for {model_name}")
    print(f"  Best val loss: {best_val_loss:.4f}")

    return history


print("✓ Training function defined")

## Train All Baselines

In [None]:
# Store training histories
all_histories = {}

for model_name in MODELS_TO_TRAIN:
    try:
        # Load config
        config = load_config(model_name)

        # Train
        history = train_baseline(model_name, config, DEVICE)
        all_histories[model_name] = history

    except Exception as e:
        print(f"\n❌ Error training {model_name}: {e}")
        import traceback
        traceback.print_exc()
        continue

print(f"\n{'='*60}")
print("All model training complete!")
print(f"{'='*60}")

## Visualize Training Progress

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Training loss
for model_name, history in all_histories.items():
    axes[0].plot(history['train_loss'], label=model_name, linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss Curves')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Validation loss
for model_name, history in all_histories.items():
    axes[1].plot(history['val_loss'], label=model_name, linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Validation Loss')
axes[1].set_title('Validation Loss Curves')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning rate
for model_name, history in all_histories.items():
    axes[2].plot(history['lr'], label=model_name, linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
axes[2].set_yscale('log')

plt.tight_layout()
plt.savefig(project_root / 'results/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Training curves saved to results/training_curves.png")

## Training Summary

In [None]:
import pandas as pd

# Create summary table
summary_data = []
for model_name, history in all_histories.items():
    summary_data.append({
        'Model': model_name,
        'Final Train Loss': history['train_loss'][-1],
        'Final Val Loss': history['val_loss'][-1],
        'Best Val Loss': min(history['val_loss']),
        'Best Epoch': np.argmin(history['val_loss']) + 1
    })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('Best Val Loss')

print("\nTraining Summary:")
print(summary_df.to_string(index=False))

# Save summary
summary_df.to_csv(project_root / 'results/training_summary.csv', index=False)
print("\n✓ Summary saved to results/training_summary.csv")

## Next Steps

Now that all models are trained:
1. Run `02_evaluate_baselines.ipynb` to evaluate models on test set
2. Run `03_results_analysis.ipynb` to analyze and compare results
3. Run `04_visualize_predictions.ipynb` for qualitative visualization

**Checkpoints saved to:**
- `checkpoints/baselines/camera_only/best_model.pth`
- `checkpoints/baselines/lidar_only/best_model.pth`
- `checkpoints/baselines/radar_only/best_model.pth`
- `checkpoints/baselines/ijepa/best_model.pth`
- `checkpoints/baselines/vjepa/best_model.pth`
- `checkpoints/himac_jepa/best_model.pth`