# Loading Checkpoint Files - Complete Guide

This notebook shows all the different ways to load saved checkpoint files in the Advanced Manipulation Transformer project.

## Method 1: Load Checkpoint for Evaluation Only

Use this when you just want to evaluate a trained model without continuing training.

In [None]:
import torch
import sys
import os

# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath('./'))))

from Advanced_Manipulation_Transformer.models.unified_model import UnifiedManipulationTransformer

# Method 1: Load checkpoint for evaluation
def load_model_for_evaluation(checkpoint_path, device='cuda'):
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Extract config from checkpoint
    config = checkpoint.get('config', {})
    
    # Create model
    model = UnifiedManipulationTransformer(config.get('model', {}))
    
    # Load model weights
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    elif 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['state_dict'])
    else:
        # Assume checkpoint is just the state dict
        model.load_state_dict(checkpoint)
    
    # Load EMA weights if available (usually better for evaluation)
    if 'ema_state_dict' in checkpoint:
        print("Loading EMA weights for evaluation")
        model.load_state_dict(checkpoint['ema_state_dict'])
    
    model = model.to(device)
    model.eval()
    
    # Print checkpoint info
    if 'epoch' in checkpoint:
        print(f"Loaded checkpoint from epoch: {checkpoint['epoch']}")
    if 'metrics' in checkpoint:
        print(f"Checkpoint metrics: {checkpoint['metrics']}")
    
    return model, checkpoint

# Example usage
checkpoint_path = 'checkpoints/best.pth'  # Update with your path
# model, checkpoint_info = load_model_for_evaluation(checkpoint_path)

## Method 2: Resume Training from Checkpoint

Use this when you want to continue training from where you left off.

In [None]:
from Advanced_Manipulation_Transformer.training.trainer import ManipulationTrainer

def resume_training_from_checkpoint(checkpoint_path, config):
    """
    Resume training from a checkpoint, including optimizer and scheduler states
    """
    # Create model
    model = UnifiedManipulationTransformer(config['model'])
    
    # Create trainer
    trainer = ManipulationTrainer(
        model=model,
        config=config['training'],
        device='cuda'
    )
    
    # Load checkpoint (this loads model, optimizer, scheduler, etc.)
    trainer.load_checkpoint(checkpoint_path)
    
    return trainer

# Example usage
# config = {...}  # Your training config
# trainer = resume_training_from_checkpoint('checkpoints/latest.pth', config)

## Method 3: Command Line - Resume Training

The easiest way to resume training is using the command line with Hydra config overrides:

In [None]:
# Command line examples for resuming training

print("""# Resume from latest checkpoint:
python train_advanced.py checkpoint.resume_from=outputs/experiment_name/checkpoints/latest.pth

# Resume from best checkpoint:
python train_advanced.py checkpoint.resume_from=outputs/experiment_name/checkpoints/best.pth

# Resume from specific epoch:
python train_advanced.py checkpoint.resume_from=outputs/experiment_name/checkpoints/epoch_20.pth

# Resume with different settings:
python train_advanced.py \\
    checkpoint.resume_from=checkpoints/best.pth \\
    training.learning_rate=5e-4 \\
    training.batch_size=64
""")

## Method 4: Load Specific Components

Sometimes you only want to load certain parts of the checkpoint.

In [None]:
def load_partial_checkpoint(checkpoint_path, model=None, load_components=['model']):
    """
    Load only specific components from checkpoint
    
    Args:
        checkpoint_path: Path to checkpoint file
        model: Model instance (create new if None)
        load_components: List of components to load:
            - 'model': Model weights
            - 'ema': EMA model weights
            - 'optimizer': Optimizer state
            - 'scheduler': Scheduler state
    """
    checkpoint = torch.load(checkpoint_path, map_location='cuda')
    
    loaded = {}
    
    # Load model weights
    if 'model' in load_components:
        if model is None:
            config = checkpoint.get('config', {})
            model = UnifiedManipulationTransformer(config.get('model', {}))
        
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            loaded['model'] = model
    
    # Load EMA weights
    if 'ema' in load_components and 'ema_state_dict' in checkpoint:
        loaded['ema_weights'] = checkpoint['ema_state_dict']
    
    # Load optimizer state
    if 'optimizer' in load_components and 'optimizer_state_dict' in checkpoint:
        loaded['optimizer_state'] = checkpoint['optimizer_state_dict']
    
    # Load scheduler state
    if 'scheduler' in load_components and 'scheduler_state_dict' in checkpoint:
        loaded['scheduler_state'] = checkpoint['scheduler_state_dict']
    
    # Additional info
    loaded['epoch'] = checkpoint.get('epoch', 0)
    loaded['global_step'] = checkpoint.get('global_step', 0)
    loaded['metrics'] = checkpoint.get('metrics', {})
    
    return loaded

# Example: Load only model weights
# components = load_partial_checkpoint('checkpoints/best.pth', load_components=['model'])

## Method 5: Load for Fine-tuning

When you want to fine-tune a pre-trained model on new data.

In [None]:
def load_for_finetuning(checkpoint_path, freeze_backbone=True, freeze_layers=12):
    """
    Load model for fine-tuning with optional freezing of backbone layers
    """
    # Load model
    model, checkpoint = load_model_for_evaluation(checkpoint_path)
    
    # Set to training mode
    model.train()
    
    if freeze_backbone:
        # Freeze DINOv2 backbone layers
        frozen_params = 0
        for name, param in model.named_parameters():
            if 'image_encoder.dinov2' in name:
                # Extract layer number if possible
                layer_num = None
                if 'blocks.' in name:
                    try:
                        layer_num = int(name.split('blocks.')[1].split('.')[0])
                    except:
                        pass
                
                # Freeze if it's in the first N layers
                if layer_num is None or layer_num < freeze_layers:
                    param.requires_grad = False
                    frozen_params += param.numel()
        
        print(f"Froze {frozen_params:,} parameters in backbone")
    
    # Count trainable parameters
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
    
    return model

# Example: Load and freeze first 12 layers
# model = load_for_finetuning('checkpoints/best.pth', freeze_backbone=True, freeze_layers=12)

## Method 6: Load and Convert Checkpoints

Sometimes you need to convert between different checkpoint formats.

In [None]:
def convert_checkpoint_format(input_path, output_path, target_format='standard'):
    """
    Convert between different checkpoint formats
    
    Args:
        input_path: Path to input checkpoint
        output_path: Path to save converted checkpoint
        target_format: 'standard', 'minimal', or 'inference'
    """
    # Load original checkpoint
    checkpoint = torch.load(input_path, map_location='cpu')
    
    if target_format == 'minimal':
        # Save only model weights
        minimal_checkpoint = {
            'model_state_dict': checkpoint.get('model_state_dict', checkpoint),
            'config': checkpoint.get('config', {})
        }
        torch.save(minimal_checkpoint, output_path)
        
    elif target_format == 'inference':
        # Save model with EMA weights for inference
        inference_checkpoint = {}
        
        # Use EMA weights if available, otherwise regular weights
        if 'ema_state_dict' in checkpoint:
            inference_checkpoint['model_state_dict'] = checkpoint['ema_state_dict']
        else:
            inference_checkpoint['model_state_dict'] = checkpoint.get('model_state_dict', checkpoint)
        
        inference_checkpoint['config'] = checkpoint.get('config', {})
        torch.save(inference_checkpoint, output_path)
        
    elif target_format == 'standard':
        # Save full checkpoint
        torch.save(checkpoint, output_path)
    
    print(f"Converted checkpoint saved to: {output_path}")
    print(f"Original size: {os.path.getsize(input_path) / 1e6:.1f} MB")
    print(f"New size: {os.path.getsize(output_path) / 1e6:.1f} MB")

# Example: Convert to minimal checkpoint for sharing
# convert_checkpoint_format('checkpoints/best.pth', 'checkpoints/best_minimal.pth', 'minimal')

## Method 7: Inspect Checkpoint Contents

Before loading, you might want to inspect what's in the checkpoint.

In [None]:
def inspect_checkpoint(checkpoint_path):
    """
    Inspect contents of a checkpoint file without loading the full model
    """
    # Load checkpoint structure only
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    print(f"Checkpoint: {checkpoint_path}")
    print("="*50)
    
    # Basic info
    print(f"\nCheckpoint keys: {list(checkpoint.keys())}")
    
    # Training info
    if 'epoch' in checkpoint:
        print(f"\nEpoch: {checkpoint['epoch']}")
    if 'global_step' in checkpoint:
        print(f"Global step: {checkpoint['global_step']}")
    
    # Metrics
    if 'metrics' in checkpoint:
        print("\nMetrics:")
        for k, v in checkpoint['metrics'].items():
            print(f"  {k}: {v}")
    
    # Model info
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
        print(f"\nModel parameters: {len(state_dict)} tensors")
        total_params = sum(p.numel() for p in state_dict.values())
        print(f"Total parameters: {total_params:,}")
        
        # Show first few layers
        print("\nFirst 5 layers:")
        for i, (name, tensor) in enumerate(state_dict.items()):
            if i >= 5:
                break
            print(f"  {name}: {tensor.shape}")
    
    # Config info
    if 'config' in checkpoint:
        print("\nConfig summary:")
        config = checkpoint['config']
        if isinstance(config, dict):
            for key in ['model', 'training', 'data']:
                if key in config:
                    print(f"  {key}: {list(config[key].keys())[:5]}...")

# Example usage
# inspect_checkpoint('checkpoints/best.pth')

## Common Checkpoint Locations

The Advanced Manipulation Transformer saves checkpoints in these locations by default:

In [None]:
# Default checkpoint locations
print("""Common checkpoint locations:

1. Latest checkpoint (auto-saved):
   outputs/<experiment_name>/checkpoints/latest.pth

2. Best checkpoint (lowest validation loss):
   outputs/<experiment_name>/checkpoints/best.pth

3. Periodic checkpoints:
   outputs/<experiment_name>/checkpoints/epoch_10.pth
   outputs/<experiment_name>/checkpoints/epoch_20.pth
   ...

4. Custom checkpoints directory:
   Set with: checkpoint.checkpoint_dir=path/to/dir

5. Pre-trained models (if provided):
   pretrained/dinov2_hand_pose.pth
   pretrained/amt_dexycb.pth
""")

# List available checkpoints
import glob

def list_available_checkpoints(checkpoint_dir='outputs'):
    """List all available checkpoint files"""
    checkpoints = glob.glob(f"{checkpoint_dir}/**/checkpoints/*.pth", recursive=True)
    
    if checkpoints:
        print(f"Found {len(checkpoints)} checkpoints:")
        for ckpt in sorted(checkpoints):
            size_mb = os.path.getsize(ckpt) / 1e6
            print(f"  {ckpt} ({size_mb:.1f} MB)")
    else:
        print(f"No checkpoints found in {checkpoint_dir}")

# Example: List all checkpoints
# list_available_checkpoints()

## Complete Example: Load and Evaluate

Here's a complete example that loads a checkpoint and runs evaluation:

In [None]:
def load_and_evaluate(checkpoint_path, test_loader=None):
    """
    Complete example: Load checkpoint and run evaluation
    """
    # Load model
    print(f"Loading checkpoint from: {checkpoint_path}")
    model, checkpoint_info = load_model_for_evaluation(checkpoint_path)
    
    # Create dummy test data if no loader provided
    if test_loader is None:
        print("\nCreating dummy test data...")
        batch_size = 4
        dummy_batch = {
            'color': torch.randn(batch_size, 3, 224, 224).cuda(),
            'hand_joints_3d': torch.randn(batch_size, 21, 3).cuda(),
            'object_pose': torch.randn(batch_size, 3, 4).cuda(),
            'camera_intrinsics': torch.eye(3).unsqueeze(0).repeat(batch_size, 1, 1).cuda()
        }
        test_loader = [dummy_batch]  # Single batch for testing
    
    # Run evaluation
    print("\nRunning evaluation...")
    model.eval()
    
    all_predictions = []
    with torch.no_grad():
        for batch in test_loader:
            # Forward pass
            outputs = model(batch)
            
            # Extract predictions
            predictions = {
                'hand_joints': outputs['hand_joints'].cpu().numpy(),
                'object_positions': outputs['object_positions'].cpu().numpy(),
                'object_rotations': outputs['object_rotations'].cpu().numpy(),
            }
            
            if 'contact_points' in outputs:
                predictions['contact_points'] = outputs['contact_points'].cpu().numpy()
            
            all_predictions.append(predictions)
            break  # Just one batch for demo
    
    print("\nPrediction shapes:")
    for key, value in predictions.items():
        print(f"  {key}: {value.shape}")
    
    return model, all_predictions

# Example usage
# model, predictions = load_and_evaluate('checkpoints/best.pth')

## Troubleshooting Common Issues

In [None]:
def safe_load_checkpoint(checkpoint_path, device='cuda'):
    """
    Safely load checkpoint with error handling
    """
    try:
        # Try loading normally
        checkpoint = torch.load(checkpoint_path, map_location=device)
        print("✓ Checkpoint loaded successfully")
        return checkpoint
    
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            print("✗ CUDA OOM - trying CPU load")
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            print("✓ Loaded on CPU")
            return checkpoint
        
        elif "Missing key(s)" in str(e):
            print("✗ Model architecture mismatch")
            print("  Try loading with strict=False:")
            print("  model.load_state_dict(checkpoint['model_state_dict'], strict=False)")
            raise e
        
        else:
            print(f"✗ Error loading checkpoint: {e}")
            raise e
    
    except FileNotFoundError:
        print(f"✗ Checkpoint not found: {checkpoint_path}")
        print("  Available checkpoints:")
        list_available_checkpoints()
        raise
    
    except Exception as e:
        print(f"✗ Unexpected error: {type(e).__name__}: {e}")
        raise

# Example with error handling
# checkpoint = safe_load_checkpoint('checkpoints/best.pth')