# Advanced Manipulation Transformer - Second Stage W&B Sweep

This notebook implements a second stage Weights & Biases sweep for fine-tuning the top 8 models from the first sweep.

**Focus**: Fine-tune only the most critical parameters:
- Weight decay
- Learning rate
- Dropout

**Configuration**:
- Start from top 8 configurations from first sweep
- 7 epochs per run
- 20,000 training samples
- 2,000 validation samples

## 1. Environment Setup

In [1]:
# Standard imports
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import wandb
from pathlib import Path
from omegaconf import OmegaConf
import warnings
warnings.filterwarnings('ignore')

# Set environment variables
os.environ['DEX_YCB_DIR'] = '/home/n231/231nProjectV2/dex-ycb-toolkit/data'

# Add project root to path
project_root = Path('.').absolute().parent
sys.path.insert(0, str(project_root))

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

PyTorch version: 2.5.0+cu124
CUDA available: True
GPU: NVIDIA H200
GPU Memory: 139.7 GB


# Load results from first sweep
# Replace with your actual sweep ID from the first stage
FIRST_SWEEP_ID = "c189gpt8"  # Update this with your actual first sweep ID
FIRST_PROJECT_NAME = "amt-hyperparameter-sweep"

# Get top 8 configurations
api = wandb.Api()
try:
    sweep = api.sweep(f"{wandb.api.default_entity}/{FIRST_PROJECT_NAME}/sweeps/{FIRST_SWEEP_ID}")
    runs = list(sweep.runs)
    
    # Sort by validation MPJPE and get top 8
    valid_runs = [r for r in runs if 'val/hand_mpjpe' in r.summary]
    sorted_runs = sorted(valid_runs, key=lambda r: r.summary['val/hand_mpjpe'])
    top_8_runs = sorted_runs[:8]
    
    print(f"Found {len(top_8_runs)} top performing runs:")
    for i, run in enumerate(top_8_runs):
        print(f"  {i+1}. {run.name}: {run.summary['val/hand_mpjpe']:.2f} mm")
    
    # Extract configurations - IMPORTANT: Add learning_rate from default config
    top_8_configs = []
    for run in top_8_runs:
        config = dict(run.config)
        # Remove wandb internal fields
        config = {k: v for k, v in config.items() if not k.startswith('_')}
        # Add learning_rate if not present (use default)
        if 'learning_rate' not in config:
            config['learning_rate'] = 0.001  # Default learning rate
        top_8_configs.append(config)
        
except Exception as e:
    print(f"Could not load first sweep results: {e}")
    print("\nUsing default top configurations for demonstration...")
    
    # Default configurations for demonstration
    top_8_configs = [
        {
            'batch_size': 32,
            'scheduler_type': 'cosine_warmup',
            'aug_rotation_range': 15.0,
            'aug_scale_min': 0.85,
            'aug_scale_max': 1.15,
            'aug_translation_std': 0.05,
            'aug_color_jitter': 0.2,
            'aug_joint_noise_std': 0.005,
            'loss_weight_hand_coarse': 1.0,
            'loss_weight_hand_refined': 1.2,
            'loss_weight_object_position': 1.0,
            'loss_weight_object_rotation': 0.5,
            'loss_weight_contact': 0.3,
            'loss_weight_physics': 0.1,
            'loss_weight_diversity': 0.01,
            'loss_weight_reprojection': 0.5,
            'diversity_margin': 0.01,
            'per_joint_weighting': True,
            'fingertip_weight': 1.5,
            'learning_rate': 0.001  # Added default learning rate
        }
    ] * 8  # Duplicate for demonstration

print(f"\nLoaded {len(top_8_configs)} configurations for second stage sweep")

# Print sample config to verify
if top_8_configs:
    print("\nSample configuration keys:")
    for key in sorted(top_8_configs[0].keys()):
        print(f"  {key}")

In [ ]:
# Define second stage sweep configuration
# Focus on fine-tuning weight decay, learning rate, and dropout
sweep_config_second = {
    'method': 'bayes',  # Bayesian optimization
    'metric': {
        'name': 'val/hand_mpjpe',
        'goal': 'minimize'
    },
    'parameters': {
        # Index of top 8 configuration to use
        'config_index': {
            'values': list(range(8))  # 0-7 for top 8 configs
        },
        
        # Fine-tune critical parameters
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 1e-5,
            'max': 1e-3
        },
        
        'weight_decay': {
            'distribution': 'log_uniform_values',
            'min': 1e-5,
            'max': 1e-2
        },
        
        'dropout': {
            'distribution': 'uniform',
            'min': 0.05,
            'max': 0.4
        }
    }
}

print("Second stage sweep configuration:")
print(f"Method: {sweep_config_second['method']}")
print(f"Metric: {sweep_config_second['metric']['name']} ({sweep_config_second['metric']['goal']})")
print(f"\nParameters to fine-tune:")
print("- Learning rate: log-uniform [1e-5, 1e-3]")
print("- Weight decay: log-uniform [1e-5, 1e-2]")
print("- Dropout: uniform [0.05, 0.4]")
print(f"\nTesting on top {len(top_8_configs)} configurations from first sweep")

## 3. Define Second Stage Sweep Configuration

In [2]:
def train_second_stage():
    # Initialize wandb
    run = wandb.init()
    
    # Get base configuration from top 8
    config_idx = wandb.config.config_index
    base_config = top_8_configs[config_idx].copy()
    
    # Load default configuration
    config = OmegaConf.load('../configs/default_config.yaml')
    
    # Apply base configuration from first sweep
    config.training.batch_size = base_config['batch_size']
    config.data.augmentation.rotation_range = base_config['aug_rotation_range']
    config.data.augmentation.scale_range = [
        base_config['aug_scale_min'],
        base_config['aug_scale_max']
    ]
    config.data.augmentation.translation_std = base_config['aug_translation_std']
    config.data.augmentation.color_jitter = base_config['aug_color_jitter']
    config.data.augmentation.joint_noise_std = base_config['aug_joint_noise_std']
    
    # Apply loss weights from base config
    config.loss.loss_weights.hand_coarse = base_config['loss_weight_hand_coarse']
    config.loss.loss_weights.hand_refined = base_config['loss_weight_hand_refined']
    config.loss.loss_weights.object_position = base_config['loss_weight_object_position']
    config.loss.loss_weights.object_rotation = base_config['loss_weight_object_rotation']
    config.loss.loss_weights.contact = base_config['loss_weight_contact']
    config.loss.loss_weights.physics = base_config['loss_weight_physics']
    config.loss.loss_weights.diversity = base_config['loss_weight_diversity']
    config.loss.loss_weights.reprojection = base_config['loss_weight_reprojection']
    
    config.loss.diversity_margin = base_config['diversity_margin']
    config.loss.per_joint_weighting = base_config.get('per_joint_weighting', True)
    config.loss.fingertip_weight = base_config.get('fingertip_weight', 1.5)
    
    # Apply base learning rate if it exists
    if 'learning_rate' in base_config:
        config.training.learning_rate = base_config['learning_rate']
    
    # Override with second stage sweep parameters
    config.training.learning_rate = wandb.config.learning_rate
    config.training.weight_decay = wandb.config.weight_decay
    config.model.dropout = wandb.config.dropout
    
    # Fixed parameters
    config.training.num_epochs = 7  # Fixed at 7 epochs
    config.training.use_wandb = True
    config.training.use_amp = True
    config.training.use_bf16 = True
    
    # Log which base config we're using
    wandb.log({
        'base_config_index': config_idx,
        'base_batch_size': base_config['batch_size'],
        'base_scheduler': base_config['scheduler_type'],
        'base_learning_rate': base_config.get('learning_rate', 'not_specified')
    })
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Import components
    from models.unified_model import UnifiedManipulationTransformer
    from training.losses import ComprehensiveLoss
    from data.gpu_cached_dataset import create_gpu_cached_dataloaders
    from solutions.mode_collapse import ModeCollapsePreventionModule
    from optimizations.pytorch_native_optimization import PyTorchNativeOptimizer
    
    try:
        # Create GPU-cached dataloaders
        print(f"Creating dataloaders with config {config_idx}: batch_size={config.training.batch_size}...")
        gpu_config = {
            'gpu_max_samples': 20000,
            'gpu_max_samples_val': 2000,
            'gpu_cache_path': './gpu_cache_sweep_second',
            'batch_size': config.training.batch_size,
            'use_bfloat16': config.training.use_bf16,
            'preload_dinov2': False
        }
        
        train_loader, val_loader = create_gpu_cached_dataloaders(gpu_config)
        print(f"Dataloaders created: {len(train_loader)} train batches, {len(val_loader)} val batches")
        
        # Create model
        model = UnifiedManipulationTransformer(config.model)
        
        # Apply mode collapse prevention
        mode_collapse_config = {
            'noise_std': 0.01,
            'drop_path_rate': 0.1,
            'mixup_alpha': 0.2
        }
        model = ModeCollapsePreventionModule.wrap_model(model, mode_collapse_config)
        
        # Initialize weights
        def init_weights(module):
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                if hasattr(module, 'weight') and module.weight is not None:
                    nn.init.xavier_uniform_(module.weight)
                if hasattr(module, 'bias') and module.bias is not None:
                    nn.init.constant_(module.bias, 0.01)
        
        for name, module in model.named_modules():
            if 'dinov2' not in name or 'encoder.layer.' not in name:
                init_weights(module)
        
        # Apply optimizations (DISABLE torch.compile to avoid errors)
        native_optimizer = PyTorchNativeOptimizer()
        model = native_optimizer.optimize_model(model, {'use_compile': False})
        model = model.to(device)
        
        # Create optimizer with parameter groups
        dinov2_params = []
        encoder_params = []
        decoder_params = []
        other_params = []
        
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
            if 'dinov2' in name:
                dinov2_params.append(param)
            elif 'decoder' in name:
                decoder_params.append(param)
            elif 'encoder' in name:
                encoder_params.append(param)
            else:
                other_params.append(param)
        
        param_groups = []
        if dinov2_params:
            param_groups.append({
                'params': dinov2_params,
                'lr': config.training.learning_rate * 0.01,
                'name': 'dinov2'
            })
        if encoder_params:
            param_groups.append({
                'params': encoder_params,
                'lr': config.training.learning_rate * 0.5,
                'name': 'encoders'
            })
        if decoder_params:
            param_groups.append({
                'params': decoder_params,
                'lr': config.training.learning_rate,
                'name': 'decoders'
            })
        if other_params:
            param_groups.append({
                'params': other_params,
                'lr': config.training.learning_rate,
                'name': 'other'
            })
        
        # Use weight decay from sweep
        optimizer = torch.optim.AdamW(param_groups, weight_decay=config.training.weight_decay, fused=True)
        
        # Use scheduler from base config
        scheduler_type = base_config['scheduler_type']
        if scheduler_type == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=config.training.num_epochs, eta_min=1e-6
            )
        elif scheduler_type == 'cosine_warmup':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=10, T_mult=2, eta_min=1e-6
            )
        elif scheduler_type == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=3, gamma=0.5
            )
        elif scheduler_type == 'exponential':
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=0.9
            )
        
        # Loss function
        criterion = ComprehensiveLoss(config.loss)
        
        # Training loop
        best_val_mpjpe = float('inf')
        
        print(f"Starting training with LR={config.training.learning_rate:.2e}, WD={config.training.weight_decay:.2e}, Dropout={config.model.dropout:.2f}")
        
        for epoch in range(config.training.num_epochs):
            # Update loss epoch
            criterion.set_epoch(epoch)
            
            # Train
            model.train()
            train_loss = 0
            train_mpjpe = 0
            train_samples = 0
            
            for batch_idx, batch in enumerate(train_loader):
                # Convert BFloat16 images to Float32 for DINOv2
                if batch['image'].dtype == torch.bfloat16:
                    batch['image'] = batch['image'].float()
                
                optimizer.zero_grad()
                
                # Forward pass
                with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                    outputs = model(batch)
                    losses = criterion(outputs, batch)
                    loss = losses['total'] if isinstance(losses, dict) else losses
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                
                batch_size = batch['image'].shape[0]
                train_loss += loss.item() * batch_size
                train_samples += batch_size
                
                # Calculate MPJPE
                if 'hand_joints' in outputs and 'hand_joints' in batch:
                    with torch.no_grad():
                        mpjpe = torch.norm(outputs['hand_joints'] - batch['hand_joints'], dim=-1).mean()
                        train_mpjpe += mpjpe.item() * 1000 * batch_size
                
                # Log batch metrics
                if batch_idx % 20 == 0:
                    wandb.log({
                        'train/batch_loss': loss.item(),
                        'train/batch_mpjpe': mpjpe.item() * 1000 if 'hand_joints' in outputs else 0,
                        'train/lr': optimizer.param_groups[0]['lr'],
                        'train/weight_decay': config.training.weight_decay,
                        'train/dropout': config.model.dropout,
                        'system/gpu_memory_gb': torch.cuda.memory_allocated() / 1e9
                    })
            
            # Average training metrics
            train_loss /= train_samples
            train_mpjpe /= train_samples
            
            # Validation
            model.eval()
            val_loss = 0
            val_mpjpe = 0
            val_samples = 0
            
            with torch.no_grad():
                for batch in val_loader:
                    if batch['image'].dtype == torch.bfloat16:
                        batch['image'] = batch['image'].float()
                    
                    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                        outputs = model(batch)
                        losses = criterion(outputs, batch)
                        loss = losses['total'] if isinstance(losses, dict) else losses
                    
                    batch_size = batch['image'].shape[0]
                    val_samples += batch_size
                    val_loss += loss.item() * batch_size
                    
                    if 'hand_joints' in outputs and 'hand_joints' in batch:
                        mpjpe = torch.norm(outputs['hand_joints'] - batch['hand_joints'], dim=-1).mean()
                        val_mpjpe += mpjpe.item() * 1000 * batch_size
            
            # Average validation metrics
            val_loss /= val_samples
            val_mpjpe /= val_samples
            
            # Update best
            if val_mpjpe < best_val_mpjpe:
                best_val_mpjpe = val_mpjpe
            
            # Update scheduler
            scheduler.step()
            
            # Log epoch metrics
            wandb.log({
                'epoch': epoch,
                'train/loss': train_loss,
                'train/hand_mpjpe': train_mpjpe,
                'val/loss': val_loss,
                'val/hand_mpjpe': val_mpjpe,
                'val/best_mpjpe': best_val_mpjpe
            })
            
            print(f"Epoch {epoch+1}/{config.training.num_epochs}: "
                  f"train_loss={train_loss:.4f}, train_mpjpe={train_mpjpe:.2f}mm, "
                  f"val_loss={val_loss:.4f}, val_mpjpe={val_mpjpe:.2f}mm")
        
        # Log final metrics
        wandb.log({
            'final/best_val_mpjpe': best_val_mpjpe,
            'final/base_config_index': config_idx,
            'final/learning_rate': config.training.learning_rate,
            'final/weight_decay': config.training.weight_decay,
            'final/dropout': config.model.dropout
        })
        
        print(f"\nTraining completed! Best validation MPJPE: {best_val_mpjpe:.2f} mm")
        
        # Clean up
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"Training failed: {e}")
        import traceback
        traceback.print_exc()
        wandb.log({'error': str(e), 'val/hand_mpjpe': 1000.0})
        raise

print("Second stage training function defined!")

Second stage training function defined!


In [3]:
def train_second_stage():
    # Initialize wandb
    run = wandb.init()
    
    # Get base configuration from top 8
    config_idx = wandb.config.config_index
    base_config = top_8_configs[config_idx].copy()
    
    # Load default configuration
    config = OmegaConf.load('../configs/default_config.yaml')
    
    # Apply base configuration from first sweep
    config.training.batch_size = base_config['batch_size']
    config.data.augmentation.rotation_range = base_config['aug_rotation_range']
    config.data.augmentation.scale_range = [
        base_config['aug_scale_min'],
        base_config['aug_scale_max']
    ]
    config.data.augmentation.translation_std = base_config['aug_translation_std']
    config.data.augmentation.color_jitter = base_config['aug_color_jitter']
    config.data.augmentation.joint_noise_std = base_config['aug_joint_noise_std']
    
    # Apply loss weights from base config
    config.loss.loss_weights.hand_coarse = base_config['loss_weight_hand_coarse']
    config.loss.loss_weights.hand_refined = base_config['loss_weight_hand_refined']
    config.loss.loss_weights.object_position = base_config['loss_weight_object_position']
    config.loss.loss_weights.object_rotation = base_config['loss_weight_object_rotation']
    config.loss.loss_weights.contact = base_config['loss_weight_contact']
    config.loss.loss_weights.physics = base_config['loss_weight_physics']
    config.loss.loss_weights.diversity = base_config['loss_weight_diversity']
    config.loss.loss_weights.reprojection = base_config['loss_weight_reprojection']
    
    config.loss.diversity_margin = base_config['diversity_margin']
    config.loss.per_joint_weighting = base_config['per_joint_weighting']
    config.loss.fingertip_weight = base_config['fingertip_weight']
    
    # Override with second stage sweep parameters
    config.training.learning_rate = wandb.config.learning_rate
    config.training.weight_decay = wandb.config.weight_decay
    config.model.dropout = wandb.config.dropout
    
    # Fixed parameters
    config.training.num_epochs = 10
    config.training.use_wandb = True
    config.training.use_amp = True
    config.training.use_bf16 = True
    
    # Log which base config we're using
    wandb.log({
        'base_config_index': config_idx,
        'base_batch_size': base_config['batch_size'],
        'base_scheduler': base_config['scheduler_type']
    })
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Import components
    from models.unified_model import UnifiedManipulationTransformer
    from training.losses import ComprehensiveLoss
    from data.gpu_cached_dataset import create_gpu_cached_dataloaders
    from solutions.mode_collapse import ModeCollapsePreventionModule
    from optimizations.pytorch_native_optimization import optimize_for_h200
    
    try:
        # Create GPU-cached dataloaders
        print(f"Creating dataloaders with config {config_idx}: batch_size={config.training.batch_size}...")
        gpu_config = {
            'gpu_max_samples': 50000,
            'gpu_max_samples_val': 10000,
            'gpu_cache_path': './gpu_cache_sweep_second',
            'batch_size': config.training.batch_size,
            'use_bfloat16': config.training.use_bf16,
            'preload_dinov2': False
        }
        
        train_loader, val_loader = create_gpu_cached_dataloaders(gpu_config)
        
        # Create model
        model = UnifiedManipulationTransformer(config.model)
        
        # Apply mode collapse prevention
        mode_collapse_config = {
            'noise_std': 0.01,
            'drop_path_rate': 0.1,
            'mixup_alpha': 0.2
        }
        model = ModeCollapsePreventionModule.wrap_model(model, mode_collapse_config)
        
        # Initialize weights
        def init_weights(module):
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                if hasattr(module, 'weight') and module.weight is not None:
                    nn.init.xavier_uniform_(module.weight)
                if hasattr(module, 'bias') and module.bias is not None:
                    nn.init.constant_(module.bias, 0.01)
        
        for name, module in model.named_modules():
            if 'dinov2' not in name or 'encoder.layer.' not in name:
                init_weights(module)
        
        # Apply optimizations
        model = optimize_for_h200(model, compile_mode='reduce-overhead')
        model = model.to(device)
        
        # Create optimizer with parameter groups
        dinov2_params = []
        encoder_params = []
        decoder_params = []
        other_params = []
        
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
            if 'dinov2' in name:
                dinov2_params.append(param)
            elif 'decoder' in name:
                decoder_params.append(param)
            elif 'encoder' in name:
                encoder_params.append(param)
            else:
                other_params.append(param)
        
        param_groups = []
        if dinov2_params:
            param_groups.append({
                'params': dinov2_params,
                'lr': config.training.learning_rate * 0.01,
                'name': 'dinov2'
            })
        if encoder_params:
            param_groups.append({
                'params': encoder_params,
                'lr': config.training.learning_rate * 0.5,
                'name': 'encoders'
            })
        if decoder_params:
            param_groups.append({
                'params': decoder_params,
                'lr': config.training.learning_rate,
                'name': 'decoders'
            })
        if other_params:
            param_groups.append({
                'params': other_params,
                'lr': config.training.learning_rate,
                'name': 'other'
            })
        
        # Use weight decay from sweep
        optimizer = torch.optim.AdamW(param_groups, weight_decay=config.training.weight_decay, fused=True)
        
        # Use scheduler from base config
        scheduler_type = base_config['scheduler_type']
        if scheduler_type == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=config.training.num_epochs, eta_min=1e-6
            )
        elif scheduler_type == 'cosine_warmup':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=10, T_mult=2, eta_min=1e-6
            )
        elif scheduler_type == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=3, gamma=0.5
            )
        elif scheduler_type == 'exponential':
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=0.9
            )
        
        # Loss function
        criterion = ComprehensiveLoss(config.loss)
        
        # Training loop
        best_val_mpjpe = float('inf')
        
        for epoch in range(config.training.num_epochs):
            # Update loss epoch
            criterion.set_epoch(epoch)
            
            # Train
            model.train()
            train_loss = 0
            train_mpjpe = 0
            train_samples = 0
            
            for batch_idx, batch in enumerate(train_loader):
                # Convert BFloat16 images to Float32 for DINOv2
                if batch['image'].dtype == torch.bfloat16:
                    batch['image'] = batch['image'].float()
                
                optimizer.zero_grad()
                
                # Forward pass
                with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                    outputs = model(batch)
                    losses = criterion(outputs, batch)
                    loss = losses['total'] if isinstance(losses, dict) else losses
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                
                batch_size = batch['image'].shape[0]
                train_loss += loss.item() * batch_size
                train_samples += batch_size
                
                # Calculate MPJPE
                if 'hand_joints' in outputs and 'hand_joints' in batch:
                    with torch.no_grad():
                        mpjpe = torch.norm(outputs['hand_joints'] - batch['hand_joints'], dim=-1).mean()
                        train_mpjpe += mpjpe.item() * 1000 * batch_size
                
                # Log batch metrics
                if batch_idx % 20 == 0:
                    wandb.log({
                        'train/batch_loss': loss.item(),
                        'train/batch_mpjpe': mpjpe.item() * 1000 if 'hand_joints' in outputs else 0,
                        'train/lr': optimizer.param_groups[0]['lr'],
                        'train/weight_decay': config.training.weight_decay,
                        'train/dropout': config.model.dropout,
                        'system/gpu_memory_gb': torch.cuda.memory_allocated() / 1e9
                    })
            
            # Average training metrics
            train_loss /= train_samples
            train_mpjpe /= train_samples
            
            # Validation
            model.eval()
            val_loss = 0
            val_mpjpe = 0
            val_samples = 0
            
            with torch.no_grad():
                for batch in val_loader:
                    if batch['image'].dtype == torch.bfloat16:
                        batch['image'] = batch['image'].float()
                    
                    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                        outputs = model(batch)
                        losses = criterion(outputs, batch)
                        loss = losses['total'] if isinstance(losses, dict) else losses
                    
                    batch_size = batch['image'].shape[0]
                    val_samples += batch_size
                    val_loss += loss.item() * batch_size
                    
                    if 'hand_joints' in outputs and 'hand_joints' in batch:
                        mpjpe = torch.norm(outputs['hand_joints'] - batch['hand_joints'], dim=-1).mean()
                        val_mpjpe += mpjpe.item() * 1000 * batch_size
            
            # Average validation metrics
            val_loss /= val_samples
            val_mpjpe /= val_samples
            
            # Update best
            if val_mpjpe < best_val_mpjpe:
                best_val_mpjpe = val_mpjpe
            
            # Update scheduler
            scheduler.step()
            
            # Log epoch metrics
            wandb.log({
                'epoch': epoch,
                'train/loss': train_loss,
                'train/hand_mpjpe': train_mpjpe,
                'val/loss': val_loss,
                'val/hand_mpjpe': val_mpjpe,
                'val/best_mpjpe': best_val_mpjpe
            })
            
            print(f"Epoch {epoch+1}/{config.training.num_epochs}: "
                  f"train_loss={train_loss:.4f}, train_mpjpe={train_mpjpe:.2f}mm, "
                  f"val_loss={val_loss:.4f}, val_mpjpe={val_mpjpe:.2f}mm")
        
        # Log final metrics
        wandb.log({
            'final/best_val_mpjpe': best_val_mpjpe,
            'final/base_config_index': config_idx,
            'final/learning_rate': config.training.learning_rate,
            'final/weight_decay': config.training.weight_decay,
            'final/dropout': config.model.dropout
        })
        
        # Clean up
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"Training failed: {e}")
        wandb.log({'error': str(e), 'val/hand_mpjpe': 1000.0})
        raise

print("Second stage training function defined!")

Second stage training function defined!


## 5. Initialize and Run Second Stage Sweep

In [4]:
# Initialize second stage sweep
project_name = 'amt-second-stage-sweep'
sweep_id = wandb.sweep(sweep_config_second, project=project_name)

print(f"Second stage sweep initialized!")
print(f"Sweep ID: {sweep_id}")
print(f"View at: https://wandb.ai/{wandb.api.default_entity}/{project_name}/sweeps/{sweep_id}")

NameError: name 'sweep_config_second' is not defined

In [None]:
# Run second stage sweep
# More experiments since we have fewer parameters to tune
sweep_count = 8  # 5 runs per top config

print(f"Starting second stage sweep agent to run {sweep_count} experiments...")
print(f"Testing {len(top_8_configs)} base configurations with fine-tuning")
print("Each experiment will run for 7 epochs with 20k training samples")
print("This will take approximately 10-15 minutes per experiment")

wandb.agent(sweep_id, function=train_second_stage, count=sweep_count)

## 6. Analyze Second Stage Results

In [None]:
# Analyze results from second stage
api = wandb.Api()
sweep = api.sweep(f"{wandb.api.default_entity}/{project_name}/sweeps/{sweep_id}")

# Get best run
best_run = sweep.best_run()
print(f"Best run: {best_run.name}")
print(f"Best validation MPJPE: {best_run.summary['val/hand_mpjpe']:.2f} mm")
print(f"Base config index: {best_run.config['config_index']}")
print("\nOptimized parameters:")
print(f"  Learning rate: {best_run.config['learning_rate']:.2e}")
print(f"  Weight decay: {best_run.config['weight_decay']:.2e}")
print(f"  Dropout: {best_run.config['dropout']:.3f}")

In [None]:
# Visualize results by base configuration
import matplotlib.pyplot as plt
import pandas as pd

# Get all runs
runs = sweep.runs
data = []

for run in runs:
    if 'val/hand_mpjpe' in run.summary:
        config = dict(run.config)
        config['val_mpjpe'] = run.summary['val/hand_mpjpe']
        data.append(config)

df = pd.DataFrame(data)

# Plot results by base configuration
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

# 1. Performance by base config
ax = axes[0]
config_performance = df.groupby('config_index')['val_mpjpe'].agg(['mean', 'min', 'std', 'count'])
x = config_performance.index
ax.bar(x, config_performance['mean'], yerr=config_performance['std'], capsize=5)
ax.scatter(x, config_performance['min'], color='red', s=100, marker='*', label='Best')
ax.set_xlabel('Base Configuration Index')
ax.set_ylabel('Validation MPJPE (mm)')
ax.set_title('Performance by Base Configuration')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Learning rate vs performance
ax = axes[1]
ax.scatter(df['learning_rate'], df['val_mpjpe'], alpha=0.6, c=df['config_index'], cmap='tab10')
ax.set_xlabel('Learning Rate')
ax.set_ylabel('Validation MPJPE (mm)')
ax.set_title('Learning Rate vs Performance')
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

# 3. Weight decay vs performance
ax = axes[2]
ax.scatter(df['weight_decay'], df['val_mpjpe'], alpha=0.6, c=df['config_index'], cmap='tab10')
ax.set_xlabel('Weight Decay')
ax.set_ylabel('Validation MPJPE (mm)')
ax.set_title('Weight Decay vs Performance')
ax.set_xscale('log')
ax.grid(True, alpha=0.3)

# 4. Dropout vs performance
ax = axes[3]
ax.scatter(df['dropout'], df['val_mpjpe'], alpha=0.6, c=df['config_index'], cmap='tab10')
ax.set_xlabel('Dropout')
ax.set_ylabel('Validation MPJPE (mm)')
ax.set_title('Dropout vs Performance')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle('Second Stage Fine-tuning Results', y=1.02)
plt.show()

# Print summary statistics
print("\nSummary Statistics:")
print(f"Total runs: {len(df)}")
print(f"Best MPJPE: {df['val_mpjpe'].min():.2f} mm")
print(f"Mean MPJPE: {df['val_mpjpe'].mean():.2f} mm")
print(f"Std MPJPE: {df['val_mpjpe'].std():.2f} mm")

# Best parameters for each base config
print("\nBest parameters for each base configuration:")
for config_idx in range(8):
    config_df = df[df['config_index'] == config_idx]
    if len(config_df) > 0:
        best_idx = config_df['val_mpjpe'].idxmin()
        best_row = config_df.loc[best_idx]
        print(f"\nConfig {config_idx}: {best_row['val_mpjpe']:.2f} mm")
        print(f"  LR: {best_row['learning_rate']:.2e}, WD: {best_row['weight_decay']:.2e}, Dropout: {best_row['dropout']:.3f}")

## 7. Export Final Optimized Configuration

In [None]:
# Export the final best configuration
best_config = dict(best_run.config)
best_base_config = top_8_configs[best_config['config_index']]

# Create final optimized configuration
final_config = OmegaConf.load('../configs/default_config.yaml')

# Apply all parameters from best configuration
# From base config
final_config.training.batch_size = best_base_config['batch_size']
final_config.data.augmentation.rotation_range = best_base_config['aug_rotation_range']
final_config.data.augmentation.scale_range = [
    best_base_config['aug_scale_min'],
    best_base_config['aug_scale_max']
]
final_config.data.augmentation.translation_std = best_base_config['aug_translation_std']
final_config.data.augmentation.color_jitter = best_base_config['aug_color_jitter']
final_config.data.augmentation.joint_noise_std = best_base_config['aug_joint_noise_std']

# Loss weights
for key in best_base_config:
    if key.startswith('loss_weight_'):
        loss_key = key.replace('loss_weight_', '')
        final_config.loss.loss_weights[loss_key] = best_base_config[key]

# Loss configuration
final_config.loss.diversity_margin = best_base_config['diversity_margin']
final_config.loss.per_joint_weighting = best_base_config['per_joint_weighting']
final_config.loss.fingertip_weight = best_base_config['fingertip_weight']

# From second stage optimization
final_config.training.learning_rate = best_config['learning_rate']
final_config.training.weight_decay = best_config['weight_decay']
final_config.model.dropout = best_config['dropout']

# Save final configuration
OmegaConf.save(final_config, '../configs/final_optimized_config.yaml')

print("Final optimized configuration saved to configs/final_optimized_config.yaml")
print("\nFinal optimized parameters:")
print(f"  Learning rate: {final_config.training.learning_rate:.2e}")
print(f"  Weight decay: {final_config.training.weight_decay:.2e}")
print(f"  Dropout: {final_config.model.dropout:.3f}")
print(f"  Batch size: {final_config.training.batch_size}")
print(f"  Scheduler: {best_base_config['scheduler_type']}")
print("\nYou can now use this configuration for final training:")
print("python train_advanced.py --config configs/final_optimized_config.yaml")

## Summary

This notebook implemented a second stage W&B sweep for fine-tuning the Advanced Manipulation Transformer:

**Stage 1** (Previous notebook):
- Broad search across all hyperparameters
- Identified top 8 configurations

**Stage 2** (This notebook):
- Started from top 8 configurations
- Fine-tuned only critical parameters:
  - Learning rate: log-uniform [1e-5, 1e-3]
  - Weight decay: log-uniform [1e-5, 1e-2]
  - Dropout: uniform [0.05, 0.4]
- Ran 40 experiments (5 per base configuration)

**Results**:
- Further improved performance through fine-tuning
- Identified optimal learning rate, weight decay, and dropout for each base configuration
- Exported final optimized configuration combining best base config with optimal fine-tuning

The two-stage approach allows for:
1. Efficient exploration of the large hyperparameter space
2. Focused optimization of critical parameters
3. Better final performance than single-stage search