# Advanced Manipulation Transformer - W&B Hyperparameter Sweep

This notebook implements a Weights & Biases sweep for hyperparameter optimization, focusing on:
- Augmentation parameters
- Dropout
- Learning rate scheduler
- Loss weights
- Diversity margin
- Per-joint weighting
- Fingertip weight
- Learning rate

Configuration:
- 7 epochs per run
- 20,000 training samples
- 2,000 validation samples

## 1. Environment Setup

In [None]:
# Standard imports
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import wandb
from pathlib import Path
from omegaconf import OmegaConf
import warnings
warnings.filterwarnings('ignore')

# Set environment variables
os.environ['DEX_YCB_DIR'] = '/home/n231/231nProjectV2/dex-ycb-toolkit/data'

# Add project root to path
project_root = Path('.').absolute().parent
sys.path.insert(0, str(project_root))

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Define Sweep Configuration

In [None]:
# Define sweep configuration
sweep_config = {
    'method': 'bayes',  # Bayesian optimization
    'metric': {
        'name': 'val/hand_mpjpe',
        'goal': 'minimize'
    },
    'parameters': {
        # Batch size
        'batch_size': {
            'values': [32,64,128,256]
        },
        
        # Dropout
        'dropout': {
            'values': [0.1, 0.15, 0.2, 0.25, 0.3]
        },
        
        # Scheduler type
        'scheduler_type': {
            'values': ['cosine', 'cosine_warmup', 'step', 'exponential']
        },
        
        # Augmentation parameters
        'aug_rotation_range': {
            'distribution': 'uniform',
            'min': 5.0,
            'max': 15.0
        },
        'aug_scale_min': {
            'distribution': 'uniform',
            'min': 0.7,
            'max': 0.9
        },
        'aug_scale_max': {
            'distribution': 'uniform',
            'min': 1.1,
            'max': 1.3
        },
        'aug_translation_std': {
            'distribution': 'uniform',
            'min': 0.02,
            'max': 0.1
        },
        'aug_color_jitter': {
            'distribution': 'uniform',
            'min': 0.1,
            'max': 0.3
        },
        'aug_joint_noise_std': {
            'distribution': 'uniform',
            'min': 0.002,
            'max': 0.01
        },
        
        # Loss weights
        'loss_weight_hand_coarse': {
            'distribution': 'uniform',
            'min': 0.8,
            'max': 1.2
        },
        'loss_weight_hand_refined': {
            'distribution': 'uniform',
            'min': 1.0,
            'max': 1.5
        },
        'loss_weight_object_position': {
            'distribution': 'uniform',
            'min': 0.8,
            'max': 1.2
        },
        'loss_weight_object_rotation': {
            'distribution': 'uniform',
            'min': 0.3,
            'max': 0.7
        },
        'loss_weight_contact': {
            'distribution': 'uniform',
            'min': 0.2,
            'max': 0.5
        },
        'loss_weight_physics': {
            'distribution': 'uniform',
            'min': 0.05,
            'max': 0.2
        },
        'loss_weight_diversity': {
            'distribution': 'log_uniform_values',
            'min': 0.005,
            'max': 0.05
        },
        'loss_weight_reprojection': {
            'distribution': 'uniform',
            'min': 0.3,
            'max': 0.7
        },
        
        # Loss configuration
        'diversity_margin': {
            'distribution': 'uniform',
            'min': 0.005,
            'max': 0.02
        },
        'per_joint_weighting': {
            'values': [True, False]
        },
        'fingertip_weight': {
            'distribution': 'uniform',
            'min': 1.2,
            'max': 2.0
        }
    }
}

print("Sweep configuration defined:")
print(f"Method: {sweep_config['method']}")
print(f"Metric: {sweep_config['metric']['name']} ({sweep_config['metric']['goal']})")
print(f"Number of parameters: {len(sweep_config['parameters'])}")

## 3. Define Training Function

In [None]:
def train():
    # Initialize wandb
    run = wandb.init()
    
    # Load base configuration
    config = OmegaConf.load('../configs/default_config.yaml')
    
    # Override with sweep parameters
    # Learning rate, batch size, and dropout
    config.training.batch_size = wandb.config.batch_size  # Now from sweep
    config.model.dropout = wandb.config.dropout
    
    # Augmentation parameters
    config.data.augmentation.rotation_range = wandb.config.aug_rotation_range
    config.data.augmentation.scale_range = [
        wandb.config.aug_scale_min,
        wandb.config.aug_scale_max
    ]
    config.data.augmentation.translation_std = wandb.config.aug_translation_std
    config.data.augmentation.color_jitter = wandb.config.aug_color_jitter
    config.data.augmentation.joint_noise_std = wandb.config.aug_joint_noise_std
    
    # Loss weights
    config.loss.loss_weights.hand_coarse = wandb.config.loss_weight_hand_coarse
    config.loss.loss_weights.hand_refined = wandb.config.loss_weight_hand_refined
    config.loss.loss_weights.object_position = wandb.config.loss_weight_object_position
    config.loss.loss_weights.object_rotation = wandb.config.loss_weight_object_rotation
    config.loss.loss_weights.contact = wandb.config.loss_weight_contact
    config.loss.loss_weights.physics = wandb.config.loss_weight_physics
    config.loss.loss_weights.diversity = wandb.config.loss_weight_diversity
    config.loss.loss_weights.reprojection = wandb.config.loss_weight_reprojection
    
    # Loss configuration
    config.loss.diversity_margin = wandb.config.diversity_margin
    config.loss.per_joint_weighting = wandb.config.per_joint_weighting
    config.loss.fingertip_weight = wandb.config.fingertip_weight
    
    # Fixed parameters for sweep
    config.training.num_epochs = 7  # Fixed at 7 epochs
    config.training.use_wandb = True
    config.training.use_amp = True
    config.training.use_bf16 = True
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Import components
    from models.unified_model import UnifiedManipulationTransformer
    from training.losses import ComprehensiveLoss
    from data.gpu_cached_dataset import create_gpu_cached_dataloaders
    from solutions.mode_collapse import ModeCollapsePreventionModule
    from optimizations.pytorch_native_optimization import PyTorchNativeOptimizer
    
    try:
        # Create GPU-cached dataloaders with fixed sizes
        print(f"Creating dataloaders with 20k train, 2k val samples, batch_size={config.training.batch_size}...")
        gpu_config = {
            'gpu_max_samples': 20000,      # Fixed at 20k
            'gpu_max_samples_val': 2000,   # Fixed at 2k
            'gpu_cache_path': './gpu_cache_sweep',
            'batch_size': config.training.batch_size,  # Now using sweep parameter
            'use_bfloat16': config.training.use_bf16,
            'preload_dinov2': False
        }
        
        train_loader, val_loader = create_gpu_cached_dataloaders(gpu_config)
        print(f"Dataloaders created: {len(train_loader)} train batches, {len(val_loader)} val batches")
        
        # Create model
        model = UnifiedManipulationTransformer(config.model)
        
        # Apply mode collapse prevention
        mode_collapse_config = {
            'noise_std': 0.01,
            'drop_path_rate': 0.1,
            'mixup_alpha': 0.2
        }
        model = ModeCollapsePreventionModule.wrap_model(model, mode_collapse_config)
        
        # Initialize weights
        def init_weights(module):
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                if hasattr(module, 'weight') and module.weight is not None:
                    nn.init.xavier_uniform_(module.weight)
                if hasattr(module, 'bias') and module.bias is not None:
                    nn.init.constant_(module.bias, 0.01)
        
        for name, module in model.named_modules():
            if 'dinov2' not in name or 'encoder.layer.' not in name:
                init_weights(module)
        
        # Apply optimizations (DISABLE torch.compile to avoid errors)
        native_optimizer = PyTorchNativeOptimizer()
        model = native_optimizer.optimize_model(model, {'use_compile': False})
        model = model.to(device)
        
        # Create optimizer with parameter groups
        dinov2_params = []
        encoder_params = []
        decoder_params = []
        other_params = []
        
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
            if 'dinov2' in name:
                dinov2_params.append(param)
            elif 'decoder' in name:
                decoder_params.append(param)
            elif 'encoder' in name:
                encoder_params.append(param)
            else:
                other_params.append(param)
        
        param_groups = []
        if dinov2_params:
            param_groups.append({
                'params': dinov2_params,
                'lr': config.training.learning_rate * 0.01,
                'name': 'dinov2'
            })
        if encoder_params:
            param_groups.append({
                'params': encoder_params,
                'lr': config.training.learning_rate * 0.5,
                'name': 'encoders'
            })
        if decoder_params:
            param_groups.append({
                'params': decoder_params,
                'lr': config.training.learning_rate,
                'name': 'decoders'
            })
        if other_params:
            param_groups.append({
                'params': other_params,
                'lr': config.training.learning_rate,
                'name': 'other'
            })
        
        optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01, fused=True)
        
        # Create scheduler based on sweep parameter
        scheduler_type = wandb.config.scheduler_type
        if scheduler_type == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=config.training.num_epochs, eta_min=1e-6
            )
        elif scheduler_type == 'cosine_warmup':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=10, T_mult=2, eta_min=1e-6
            )
        elif scheduler_type == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=3, gamma=0.5
            )
        elif scheduler_type == 'exponential':
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=0.9
            )
        
        # Loss function
        criterion = ComprehensiveLoss(config.loss)
        
        # Training loop
        best_val_mpjpe = float('inf')
        
        for epoch in range(config.training.num_epochs):
            # Update loss epoch
            criterion.set_epoch(epoch)
            
            # Train
            model.train()
            train_loss = 0
            train_mpjpe = 0
            train_samples = 0
            
            for batch_idx, batch in enumerate(train_loader):
                # Convert BFloat16 images to Float32 for DINOv2
                if batch['image'].dtype == torch.bfloat16:
                    batch['image'] = batch['image'].float()
                
                optimizer.zero_grad()
                
                # Forward pass
                with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                    outputs = model(batch)
                    losses = criterion(outputs, batch)
                    loss = losses['total'] if isinstance(losses, dict) else losses
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                
                batch_size = batch['image'].shape[0]
                train_loss += loss.item() * batch_size
                train_samples += batch_size
                
                # Calculate MPJPE
                if 'hand_joints' in outputs and 'hand_joints' in batch:
                    with torch.no_grad():
                        mpjpe = torch.norm(outputs['hand_joints'] - batch['hand_joints'], dim=-1).mean()
                        train_mpjpe += mpjpe.item() * 1000 * batch_size
                
                # Log batch metrics
                if batch_idx % 20 == 0:
                    wandb.log({
                        'train/batch_loss': loss.item(),
                        'train/batch_mpjpe': mpjpe.item() * 1000 if 'hand_joints' in outputs else 0,
                        'train/lr': optimizer.param_groups[0]['lr'],
                        'system/gpu_memory_gb': torch.cuda.memory_allocated() / 1e9
                    })
            
            # Average training metrics
            train_loss /= train_samples
            train_mpjpe /= train_samples
            
            # Validation
            model.eval()
            val_loss = 0
            val_mpjpe = 0
            val_samples = 0
            
            with torch.no_grad():
                for batch in val_loader:
                    if batch['image'].dtype == torch.bfloat16:
                        batch['image'] = batch['image'].float()
                    
                    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                        outputs = model(batch)
                        losses = criterion(outputs, batch)
                        loss = losses['total'] if isinstance(losses, dict) else losses
                    
                    batch_size = batch['image'].shape[0]
                    val_samples += batch_size
                    val_loss += loss.item() * batch_size
                    
                    if 'hand_joints' in outputs and 'hand_joints' in batch:
                        mpjpe = torch.norm(outputs['hand_joints'] - batch['hand_joints'], dim=-1).mean()
                        val_mpjpe += mpjpe.item() * 1000 * batch_size
            
            # Average validation metrics
            val_loss /= val_samples
            val_mpjpe /= val_samples
            
            # Update best
            if val_mpjpe < best_val_mpjpe:
                best_val_mpjpe = val_mpjpe
            
            # Update scheduler
            scheduler.step()
            
            # Log epoch metrics
            wandb.log({
                'epoch': epoch,
                'train/loss': train_loss,
                'train/hand_mpjpe': train_mpjpe,
                'val/loss': val_loss,
                'val/hand_mpjpe': val_mpjpe,
                'val/best_mpjpe': best_val_mpjpe
            })
            
            print(f"Epoch {epoch+1}/{config.training.num_epochs}: "
                  f"train_loss={train_loss:.4f}, train_mpjpe={train_mpjpe:.2f}mm, "
                  f"val_loss={val_loss:.4f}, val_mpjpe={val_mpjpe:.2f}mm")
        
        # Log final metrics
        wandb.log({
            'final/best_val_mpjpe': best_val_mpjpe,
            'final/epochs_trained': config.training.num_epochs
        })
        
        # Clean up
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"Training failed: {e}")
        wandb.log({'error': str(e), 'val/hand_mpjpe': 1000.0})
        raise

print("Training function defined!")

## 4. Initialize and Run Sweep

In [None]:
# Initialize sweep
project_name = 'amt-hyperparameter-sweep'
sweep_id = wandb.sweep(sweep_config, project=project_name)

print(f"Sweep initialized!")
print(f"Sweep ID: {sweep_id}")
print(f"View at: https://wandb.ai/{wandb.api.default_entity}/{project_name}/sweeps/{sweep_id}")

In [None]:
# Run sweep agent
# You can adjust count to run more or fewer experiments
sweep_count = 50  # Number of experiments to run

print(f"Starting sweep agent to run {sweep_count} experiments...")
print("Each experiment will run for 7 epochs with 20k training samples")
print("This will take approximately 10-15 minutes per experiment")

wandb.agent(sweep_id, function=train, count=sweep_count)

## 5. Analyze Results

In [None]:
# After sweep completes, analyze results
api = wandb.Api()
sweep = api.sweep(f"{wandb.api.default_entity}/{project_name}/sweeps/{sweep_id}")

# Get best run
best_run = sweep.best_run()
print(f"Best run: {best_run.name}")
print(f"Best validation MPJPE: {best_run.summary['val/hand_mpjpe']:.2f} mm")
print("\nBest hyperparameters:")
for param, value in best_run.config.items():
    if param.startswith('_'):
        continue
    print(f"  {param}: {value}")

In [None]:
# Visualize parameter importance
import matplotlib.pyplot as plt
import pandas as pd

# Get all runs
runs = sweep.runs
data = []

for run in runs:
    if 'val/hand_mpjpe' in run.summary:
        config = dict(run.config)
        config['val_mpjpe'] = run.summary['val/hand_mpjpe']
        data.append(config)

df = pd.DataFrame(data)

# Plot correlation of key parameters with performance
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
axes = axes.flatten()

key_params = ['learning_rate', 'batch_size', 'dropout', 'loss_weight_diversity', 
              'aug_rotation_range', 'diversity_margin', 'fingertip_weight',
              'scheduler_type', 'per_joint_weighting']

for i, param in enumerate(key_params):
    if param in df.columns:
        ax = axes[i]
        if param in ['scheduler_type', 'per_joint_weighting']:
            # Categorical parameters
            param_values = df[param].unique()
            mpjpe_by_param = []
            for val in param_values:
                mpjpe_by_param.append(df[df[param] == val]['val_mpjpe'].values)
            ax.boxplot(mpjpe_by_param, labels=param_values)
            ax.set_xlabel(param)
            ax.set_ylabel('Validation MPJPE (mm)')
            ax.set_title(f'{param} vs Performance')
        else:
            # Continuous parameters
            ax.scatter(df[param], df['val_mpjpe'], alpha=0.6)
            ax.set_xlabel(param)
            ax.set_ylabel('Validation MPJPE (mm)')
            ax.set_title(f'{param} vs Performance')
            ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle('Hyperparameter Impact on Performance', y=1.02)
plt.show()

# Print summary statistics
print("\nSummary Statistics:")
print(f"Total runs: {len(df)}")
print(f"Best MPJPE: {df['val_mpjpe'].min():.2f} mm")
print(f"Worst MPJPE: {df['val_mpjpe'].max():.2f} mm")
print(f"Mean MPJPE: {df['val_mpjpe'].mean():.2f} mm")
print(f"Std MPJPE: {df['val_mpjpe'].std():.2f} mm")

# Print best batch size analysis
if 'batch_size' in df.columns:
    print("\nBatch Size Analysis:")
    batch_size_stats = df.groupby('batch_size')['val_mpjpe'].agg(['mean', 'std', 'min', 'count'])
    print(batch_size_stats)

## 6. Export Best Configuration

In [None]:
# Export best configuration for future use
best_config = dict(best_run.config)

# Create optimized configuration file
optimized_config = OmegaConf.load('../configs/default_config.yaml')

# Update with best parameters
optimized_config.training.learning_rate = best_config['learning_rate']
optimized_config.model.dropout = best_config['dropout']

# Update augmentation
optimized_config.data.augmentation.rotation_range = best_config['aug_rotation_range']
optimized_config.data.augmentation.scale_range = [
    best_config['aug_scale_min'],
    best_config['aug_scale_max']
]
optimized_config.data.augmentation.translation_std = best_config['aug_translation_std']
optimized_config.data.augmentation.color_jitter = best_config['aug_color_jitter']
optimized_config.data.augmentation.joint_noise_std = best_config['aug_joint_noise_std']

# Update loss weights
for key in best_config:
    if key.startswith('loss_weight_'):
        loss_key = key.replace('loss_weight_', '')
        optimized_config.loss.loss_weights[loss_key] = best_config[key]

# Update other loss configs
optimized_config.loss.diversity_margin = best_config['diversity_margin']
optimized_config.loss.per_joint_weighting = best_config['per_joint_weighting']
optimized_config.loss.fingertip_weight = best_config['fingertip_weight']

# Save optimized configuration
OmegaConf.save(optimized_config, '../configs/optimized_config.yaml')

print("Optimized configuration saved to configs/optimized_config.yaml")
print("\nYou can now use this configuration for training:")
print("python train_advanced.py --config configs/optimized_config.yaml")

## Summary

This notebook implemented a comprehensive W&B sweep for hyperparameter optimization of the Advanced Manipulation Transformer, focusing on:

1. **Augmentation parameters**: rotation, scale, translation, color jitter, joint noise
2. **Model parameters**: dropout, learning rate
3. **Scheduler options**: cosine, cosine with warmup, step, exponential
4. **Loss weights**: all major loss components
5. **Loss configuration**: diversity margin, per-joint weighting, fingertip weight

The sweep was configured to run each experiment for:
- 7 epochs
- 20,000 training samples
- 2,000 validation samples

Results are automatically tracked in W&B, and the best configuration is exported for future use.