# GPU-Cached Training Example

This notebook demonstrates how to:
1. Disable torch.compile to avoid graph break warnings
2. Use GPU-cached datasets to achieve 100GB+ memory usage
3. Maximize training performance on H200 GPU

## 1. Setup and Disable torch.compile

In [None]:
# Essential imports
import os
import sys
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Add parent directory to path
sys.path.append('..')

# IMPORTANT: Disable torch.compile for debugging
from disable_compile_for_debug import disable_torch_compile
disable_torch_compile()
print("✓ Debugging mode enabled - no more graph break warnings!")

# Set environment
os.environ['DEX_YCB_DIR'] = '/home/n231/231nProjectV2/dex-ycb-toolkit/data'

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# GPU info
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Total memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Configuration

In [None]:
# Configuration for GPU-cached training
config = {
    # GPU caching settings
    'gpu_max_samples': 100000,      # Load 100k samples (~100GB memory)
    'gpu_max_samples_val': 10000,   # 10k validation samples (~10GB)
    'gpu_cache_path': './gpu_cache', # Cache directory
    'use_bfloat16': True,           # Use bfloat16 to fit more data
    'preload_dinov2': False,        # Pre-extract DINOv2 features
    
    # Model settings
    'hidden_dim': 512,              # Model hidden dimension
    'num_heads': 16,
    'num_layers': 8,
    'dropout': 0.1,
    'num_refinement_steps': 2,
    
    # Training settings
    'batch_size': 256,              # Start with this, will optimize
    'num_epochs': 20,
    'learning_rate': 1e-4,
    'weight_decay': 0.01,
    
    # Output
    'output_dir': 'outputs/gpu_cached_run'
}

os.makedirs(config['output_dir'], exist_ok=True)

## 3. Create GPU-Cached Datasets

In [None]:
# Import GPU-cached dataset
from data.gpu_cached_dataset import GPUCachedDataset, GPUDataLoader

print("Creating GPU-cached datasets...")
print(f"Target memory usage: {config['gpu_max_samples'] / 1000:.0f} GB")
print("Note: First run will be slow (preprocessing), subsequent runs will be instant\n")

# Create train dataset
train_dataset = GPUCachedDataset(
    split='train',
    max_samples=config['gpu_max_samples'],
    image_size=(224, 224),
    device='cuda',
    dtype=torch.bfloat16 if config['use_bfloat16'] else torch.float32,
    cache_path=config['gpu_cache_path'],
    normalize=True,
    load_dinov2_features=config['preload_dinov2']
)

# Create validation dataset
val_dataset = GPUCachedDataset(
    split='val',
    max_samples=config['gpu_max_samples_val'],
    image_size=(224, 224),
    device='cuda',
    dtype=torch.bfloat16 if config['use_bfloat16'] else torch.float32,
    cache_path=config['gpu_cache_path'],
    normalize=True,
    load_dinov2_features=config['preload_dinov2']
)

print(f"\n✓ Datasets loaded!")
print(f"  Train: {len(train_dataset)} samples ({train_dataset._get_memory_usage():.1f} GB)")
print(f"  Val: {len(val_dataset)} samples ({val_dataset._get_memory_usage():.1f} GB)")
print(f"  Total GPU memory used: {train_dataset._get_memory_usage() + val_dataset._get_memory_usage():.1f} GB")

## 4. Find Optimal Batch Size

In [None]:
# Simple test model for batch size optimization
class SimpleTestModel(nn.Module):
    def __init__(self, hidden_dim=512):
        super().__init__()
        self.conv = nn.Conv2d(3, hidden_dim, 16, 16)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(hidden_dim, 21 * 3)
    
    def forward(self, batch):
        x = batch['image']
        x = self.conv(x)
        x = self.pool(x).squeeze(-1).squeeze(-1)
        joints = self.fc(x).view(-1, 21, 3)
        return {'hand_joints': joints}

# Create test model
test_model = SimpleTestModel(config['hidden_dim']).to(device)
if config['use_bfloat16']:
    test_model = test_model.to(dtype=torch.bfloat16)

# Find optimal batch size
def find_optimal_batch_size(model, dataset, initial=64, maximum=2048):
    print("Finding optimal batch size...")
    batch_size = initial
    best_batch_size = batch_size
    
    while batch_size <= maximum:
        try:
            print(f"Testing batch size: {batch_size}", end="")
            loader = GPUDataLoader(dataset, batch_size=batch_size)
            batch = next(iter(loader))
            
            # Forward and backward
            outputs = model(batch)
            loss = outputs['hand_joints'].mean()
            loss.backward()
            
            model.zero_grad()
            torch.cuda.empty_cache()
            
            best_batch_size = batch_size
            print(f" ✓ ({torch.cuda.memory_allocated() / 1e9:.1f} GB used)")
            batch_size *= 2
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(" ✗ OOM")
                break
            raise e
    
    return best_batch_size

optimal_batch_size = find_optimal_batch_size(test_model, train_dataset)
print(f"\nOptimal batch size: {optimal_batch_size}")
config['batch_size'] = optimal_batch_size

# Clean up test model
del test_model
torch.cuda.empty_cache()

## 5. Create Dataloaders

In [None]:
# Create GPU dataloaders with optimal batch size
train_loader = GPUDataLoader(
    train_dataset, 
    batch_size=config['batch_size'], 
    shuffle=True, 
    drop_last=True
)

val_loader = GPUDataLoader(
    val_dataset, 
    batch_size=config['batch_size'], 
    shuffle=False, 
    drop_last=False
)

print(f"Dataloaders created:")
print(f"  Batch size: {config['batch_size']}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

## 6. Benchmark Performance

In [None]:
import time

def benchmark_dataloader(loader, num_batches=50):
    """Benchmark dataloader performance"""
    print(f"Benchmarking {num_batches} batches...")
    
    # Warmup
    for i, batch in enumerate(loader):
        if i >= 5:
            break
        _ = batch['image'].mean()
    
    torch.cuda.synchronize()
    start_time = time.time()
    
    for i, batch in enumerate(loader):
        if i >= num_batches:
            break
        # Simulate computation
        _ = batch['image'].mean()
        torch.cuda.synchronize()
    
    elapsed = time.time() - start_time
    
    samples_per_sec = (num_batches * loader.batch_size) / elapsed
    batches_per_sec = num_batches / elapsed
    
    print(f"  Time: {elapsed:.2f}s")
    print(f"  Throughput: {samples_per_sec:,.0f} samples/sec")
    print(f"  Batches/sec: {batches_per_sec:.1f}")
    
    return samples_per_sec

# Benchmark
throughput = benchmark_dataloader(train_loader)
print(f"\n{'='*50}")
print(f"GPU Memory Usage: {torch.cuda.memory_allocated() / 1e9:.1f} GB")
print(f"Expected training speed: {throughput:,.0f} samples/sec")

## 7. Create Model

In [None]:
# Import model - choose based on your needs
USE_SIMPLE_MODEL = True  # Set False to use full UnifiedManipulationModel

if USE_SIMPLE_MODEL:
    # Simple model for testing
    class SimpleManipulationModel(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.hidden_dim = config['hidden_dim']
            
            # Patch embedding
            self.patch_embed = nn.Conv2d(3, self.hidden_dim, 16, 16)
            self.pos_embed = nn.Parameter(torch.randn(1, 196, self.hidden_dim) * 0.02)
            
            # Transformer
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=self.hidden_dim,
                nhead=config['num_heads'],
                dim_feedforward=self.hidden_dim * 4,
                dropout=config['dropout'],
                batch_first=True
            )
            self.encoder = nn.TransformerEncoder(encoder_layer, config['num_layers'])
            
            # Output head
            self.hand_head = nn.Linear(self.hidden_dim, 21 * 3)
        
        def forward(self, batch):
            x = batch['image']
            B = x.shape[0]
            
            # Patch embedding
            x = self.patch_embed(x)
            x = x.flatten(2).transpose(1, 2)
            x = x + self.pos_embed
            
            # Transformer
            x = self.encoder(x)
            x = x.mean(dim=1)
            
            # Predict joints
            joints = self.hand_head(x).view(B, 21, 3)
            
            return {
                'hand_joints': joints,
                'hand_joints_refined': joints  # Same for simple model
            }
    
    model = SimpleManipulationModel(config)
    print("Using simple model for testing")
    
else:
    # Full model
    from models.unified_model import UnifiedManipulationModel
    model = UnifiedManipulationModel(
        hidden_dim=config['hidden_dim'],
        num_heads=config['num_heads'],
        num_layers=config['num_layers'],
        dropout=config['dropout'],
        num_refinement_steps=config['num_refinement_steps'],
        use_sigma_reparam=False  # Disable for debugging
    )
    print("Using full UnifiedManipulationModel")

# Move to GPU with appropriate dtype
model = model.to(device)
if config['use_bfloat16']:
    model = model.to(dtype=torch.bfloat16)

# Enable gradient checkpointing if available
if hasattr(model, 'gradient_checkpointing_enable'):
    model.gradient_checkpointing_enable()
    print("✓ Gradient checkpointing enabled")

# Model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")
print(f"  Model size: {total_params * 4 / 1024**3:.2f} GB")

## 8. Training Setup

In [None]:
# Loss function
class SimpleLoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, outputs, targets):
        losses = {}
        
        # Hand joint loss
        if 'hand_joints' in outputs and 'hand_joints' in targets:
            pred = outputs['hand_joints']
            gt = targets['hand_joints']
            
            # Filter valid joints
            if 'hand_joints_valid' in targets:
                valid = targets['hand_joints_valid'].unsqueeze(-1)
                pred = pred * valid
                gt = gt * valid
            
            # MSE loss
            losses['joint_loss'] = F.mse_loss(pred, gt)
            
            # MPJPE
            with torch.no_grad():
                losses['mpjpe'] = torch.norm(pred - gt, dim=-1).mean()
        
        # Total loss
        losses['total'] = sum(v for k, v in losses.items() if k != 'mpjpe')
        
        return losses

# Create optimizer and scheduler
criterion = SimpleLoss()
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config['num_epochs']
)

print("Training setup complete")

## 9. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_mpjpe': [],
    'val_loss': [],
    'val_mpjpe': [],
    'gpu_memory': [],
    'throughput': []
}

# Training function
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    total_mpjpe = 0
    num_batches = 0
    
    pbar = tqdm(loader, desc="Training")
    for batch in pbar:
        # Forward pass
        outputs = model(batch)
        losses = criterion(outputs, batch)
        
        # Backward pass
        optimizer.zero_grad()
        losses['total'].backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        
        # Update metrics
        total_loss += losses['total'].item()
        if 'mpjpe' in losses:
            total_mpjpe += losses['mpjpe'].item()
        num_batches += 1
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{losses['total'].item():.4f}",
            'mpjpe': f"{losses.get('mpjpe', 0).item():.2f}mm"
        })
    
    return total_loss / num_batches, total_mpjpe / num_batches

# Validation function
@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    total_mpjpe = 0
    num_batches = 0
    
    for batch in tqdm(loader, desc="Validation"):
        outputs = model(batch)
        losses = criterion(outputs, batch)
        
        total_loss += losses['total'].item()
        if 'mpjpe' in losses:
            total_mpjpe += losses['mpjpe'].item()
        num_batches += 1
    
    return total_loss / num_batches, total_mpjpe / num_batches

# Main training loop
print(f"\nStarting training for {config['num_epochs']} epochs...\n")

best_val_mpjpe = float('inf')

for epoch in range(config['num_epochs']):
    print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
    print("=" * 50)
    
    # Record GPU memory
    gpu_mem = torch.cuda.memory_allocated() / 1e9
    history['gpu_memory'].append(gpu_mem)
    print(f"GPU Memory: {gpu_mem:.1f} GB")
    
    # Train
    import time
    start_time = time.time()
    train_loss, train_mpjpe = train_epoch(model, train_loader, criterion, optimizer)
    train_time = time.time() - start_time
    
    # Calculate throughput
    samples_per_sec = len(train_loader) * config['batch_size'] / train_time
    history['throughput'].append(samples_per_sec)
    
    # Validate
    val_loss, val_mpjpe = validate(model, val_loader, criterion)
    
    # Update scheduler
    scheduler.step()
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_mpjpe'].append(train_mpjpe)
    history['val_loss'].append(val_loss)
    history['val_mpjpe'].append(val_mpjpe)
    
    # Print summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}, MPJPE: {train_mpjpe:.2f}mm")
    print(f"  Val Loss: {val_loss:.4f}, MPJPE: {val_mpjpe:.2f}mm")
    print(f"  Throughput: {samples_per_sec:,.0f} samples/sec")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model
    if val_mpjpe < best_val_mpjpe:
        best_val_mpjpe = val_mpjpe
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_mpjpe': val_mpjpe,
            'config': config
        }, f"{config['output_dir']}/best_model.pth")
        print("  ✓ Saved best model")

print("\nTraining completed!")
print(f"Best validation MPJPE: {best_val_mpjpe:.2f}mm")

## 10. Plot Results

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Loss plot
ax = axes[0, 0]
ax.plot(history['train_loss'], label='Train')
ax.plot(history['val_loss'], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True)

# MPJPE plot
ax = axes[0, 1]
ax.plot(history['train_mpjpe'], label='Train')
ax.plot(history['val_mpjpe'], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('MPJPE (mm)')
ax.set_title('Mean Per Joint Position Error')
ax.legend()
ax.grid(True)

# GPU Memory plot
ax = axes[1, 0]
ax.plot(history['gpu_memory'])
ax.set_xlabel('Epoch')
ax.set_ylabel('GPU Memory (GB)')
ax.set_title('GPU Memory Usage')
ax.grid(True)
ax.axhline(y=100, color='r', linestyle='--', label='Target: 100GB')
ax.legend()

# Throughput plot
ax = axes[1, 1]
ax.plot(history['throughput'])
ax.set_xlabel('Epoch')
ax.set_ylabel('Samples/sec')
ax.set_title('Training Throughput')
ax.grid(True)

plt.tight_layout()
plt.savefig(f"{config['output_dir']}/training_curves.png", dpi=150)
plt.show()

# Print summary
print("\nTraining Summary:")
print(f"  Best Train MPJPE: {min(history['train_mpjpe']):.2f}mm")
print(f"  Best Val MPJPE: {min(history['val_mpjpe']):.2f}mm")
print(f"  Average GPU Memory: {np.mean(history['gpu_memory']):.1f} GB")
print(f"  Average Throughput: {np.mean(history['throughput']):,.0f} samples/sec")

## 11. Save Training Report

In [None]:
# Generate training report
report = f"""
# GPU-Cached Training Report

## Configuration
- GPU: {torch.cuda.get_device_name()}
- Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB
- Dataset samples: {config['gpu_max_samples']} train, {config['gpu_max_samples_val']} val
- Batch size: {config['batch_size']}
- Learning rate: {config['learning_rate']}
- Epochs: {config['num_epochs']}
- Using BFloat16: {config['use_bfloat16']}

## Results
- Best Train MPJPE: {min(history['train_mpjpe']):.2f}mm
- Best Val MPJPE: {min(history['val_mpjpe']):.2f}mm
- Final Train MPJPE: {history['train_mpjpe'][-1]:.2f}mm
- Final Val MPJPE: {history['val_mpjpe'][-1]:.2f}mm

## Performance
- GPU Memory Usage: {np.mean(history['gpu_memory']):.1f} GB average
- Peak GPU Memory: {max(history['gpu_memory']):.1f} GB
- Average Throughput: {np.mean(history['throughput']):,.0f} samples/sec
- Peak Throughput: {max(history['throughput']):,.0f} samples/sec

## Dataset Loading
- Train dataset memory: {train_dataset._get_memory_usage():.1f} GB
- Val dataset memory: {val_dataset._get_memory_usage():.1f} GB
- Total dataset memory: {train_dataset._get_memory_usage() + val_dataset._get_memory_usage():.1f} GB

## Model
- Total parameters: {sum(p.numel() for p in model.parameters()):,}
- Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}
- Model memory: {sum(p.numel() for p in model.parameters()) * 4 / 1e9:.2f} GB
"""

print(report)

# Save report
with open(f"{config['output_dir']}/training_report.txt", 'w') as f:
    f.write(report)

# Save configuration
import json
with open(f"{config['output_dir']}/config.json", 'w') as f:
    json.dump(config, f, indent=2)

print(f"\nAll outputs saved to: {config['output_dir']}/")