# Video-to-Manipulation Transformer: GPU-Only Stage 1 Training

This notebook implements the fastest GPU-only training approach for maximum performance on H200.

**Key Features:**
- Entire dataset cached in GPU memory (zero CPU-GPU transfers)
- Large batch sizes (1024+)
- BFloat16 mixed precision
- Compiled models for better performance
- No DataLoader overhead

**Requirements:**
- H200 GPU with 140GiB memory
- PyTorch 2.0+ for torch.compile
- CUDA 12.0+

In [2]:
# Setup and imports
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.cuda.amp import autocast
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time
from datetime import datetime
import subprocess
from typing import Dict, Optional

# Set environment
os.environ['DEX_YCB_DIR'] = '/home/n231/231nProjectV2/dex-ycb-toolkit/data'

# CUDA optimizations for H200
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# Add project root to path
project_root = os.path.abspath('.')
sys.path.insert(0, project_root)

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

Using device: cuda
GPU: NVIDIA H200
Memory: 150.0 GB


In [3]:
# Import our modules
from models.encoders.hand_encoder import HandPoseEncoder
from models.encoders.object_encoder import ObjectPoseEncoder
from models.encoders.contact_encoder import ContactDetectionEncoder
from data.gpu_preprocessing import GPUVideoPreprocessor

print("✓ All modules imported successfully")

✓ All modules imported successfully


In [4]:
# Configuration
config = {
    # GPU-only dataset settings
    'max_samples_train': 300000,  # Adjust based on GPU memory
    'max_samples_val': 20000,
    'batch_size': 1024,  # Large batch for H200
    'image_size': (224, 224),
    'patch_size': 16,
    'cache_path': 'gpu_cache',
    'dtype': torch.bfloat16,  # Use bfloat16 to fit more samples
    
    # Model settings - scale up for H200
    'hand_hidden_dim': 2048,
    'object_hidden_dim': 2048,
    'contact_hidden_dim': 1024,
    'hand_layers': 12,
    'object_layers': 12,
    'contact_layers': 8,
    
    # Training settings
    'learning_rate': 2e-3,
    'num_epochs': 5,
    'grad_clip': 1.0,
    'log_interval': 10,
    'val_interval': 50
}

print("Configuration loaded")
print(f"Will cache {config['max_samples_train']} training samples in GPU memory")

Configuration loaded
Will cache 300000 training samples in GPU memory


In [5]:
# Fixed GPU-Only Dataset Implementation
import cv2
from pathlib import Path

class GPUOnlyDataset:
    """Dataset that lives entirely on GPU memory"""
    
    def __init__(self, split='s0_train', max_samples=50000, image_size=(224, 224),
                 device='cuda', dtype=torch.float32, cache_path=None):
        self.split = split
        self.max_samples = max_samples
        self.image_size = image_size
        self.device = device
        self.dtype = dtype
        self.cache_path = cache_path
        
        # Check cache
        cache_file = f"{cache_path}/{split}_gpu_cache.pt" if cache_path else None
        if cache_path and os.path.exists(cache_file):
            print(f"Loading cached GPU dataset from {cache_file}...")
            self.data = torch.load(cache_file, map_location=device)
            self.num_samples = len(self.data['color'])
        else:
            print(f"Building GPU dataset for {split}...")
            self._build_dataset()
            if cache_path:
                os.makedirs(cache_path, exist_ok=True)
                torch.save(self.data, cache_file)
                print(f"Saved cache to {cache_file}")
        
        print(f"✓ GPU dataset ready with {self.num_samples} samples")
        print(f"  Memory usage: {torch.cuda.memory_allocated()/1e9:.1f} GB")
    
    def _build_dataset(self):
        """Build dataset directly on GPU"""
        from dex_ycb_toolkit.factory import get_dataset
        dex_dataset = get_dataset(self.split)
        
        num_samples = min(len(dex_dataset), self.max_samples)
        self.num_samples = num_samples
        
        # Pre-allocate GPU tensors
        print(f"Allocating GPU memory for {num_samples} samples...")
        self.data = {
            'color': torch.zeros((num_samples, 3, *self.image_size), 
                               device=self.device, dtype=self.dtype),
            'hand_joints_3d': torch.full((num_samples, 21, 3), -1.0,
                                       device=self.device, dtype=self.dtype),
            'hand_joints_2d': torch.full((num_samples, 21, 2), -1.0,
                                       device=self.device, dtype=self.dtype),
            'hand_pose': torch.zeros((num_samples, 51),  # Fixed: MANO pose is 51D
                                   device=self.device, dtype=self.dtype),
            'object_poses': torch.zeros((num_samples, 10, 3, 4),
                                      device=self.device, dtype=self.dtype),
            'ycb_ids': torch.zeros((num_samples, 10),
                                 device=self.device, dtype=torch.long),
            'num_objects': torch.zeros((num_samples,),
                                     device=self.device, dtype=torch.long),
            'has_hand': torch.zeros((num_samples,), device=self.device, dtype=torch.bool),
        }
        
        # Load data
        print("Loading and preprocessing data...")
        for i in tqdm(range(num_samples), desc="Loading samples"):
            try:
                sample = dex_dataset[i]
                
                # Load image
                img = cv2.imread(sample['color_file'])
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, self.image_size)
                img_tensor = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
                self.data['color'][i] = img_tensor.to(self.device, dtype=self.dtype)
                
                # Load labels
                labels = np.load(sample['label_file'])
                
                # Hand data
                if 'joint_3d' in labels and labels['joint_3d'].shape[0] > 0:
                    joints_3d = torch.from_numpy(labels['joint_3d'][0])
                    self.data['hand_joints_3d'][i] = joints_3d.to(self.device, dtype=self.dtype)
                    self.data['has_hand'][i] = True
                
                if 'joint_2d' in labels and labels['joint_2d'].shape[0] > 0:
                    joints_2d = torch.from_numpy(labels['joint_2d'][0])
                    self.data['hand_joints_2d'][i] = joints_2d.to(self.device, dtype=self.dtype)
                
                if 'pose_m' in labels and labels['pose_m'].shape[0] > 0:
                    pose = torch.from_numpy(labels['pose_m'][0])
                    # Handle different pose dimensions
                    if pose.shape[0] == 48:
                        # Pad to 51 if needed
                        pose = F.pad(pose, (0, 3), value=0)
                    self.data['hand_pose'][i, :pose.shape[0]] = pose.to(self.device, dtype=self.dtype)
                
                # Object data
                if 'pose_y' in labels and len(labels['pose_y']) > 0:
                    obj_poses = labels['pose_y']
                    num_objs = min(len(obj_poses), 10)
                    if num_objs > 0:
                        obj_tensor = torch.from_numpy(obj_poses[:num_objs])
                        self.data['object_poses'][i, :num_objs] = obj_tensor.to(self.device, dtype=self.dtype)
                    self.data['num_objects'][i] = num_objs
                
                # YCB IDs
                ycb_ids = sample.get('ycb_ids', [])
                if ycb_ids:
                    num_ids = min(len(ycb_ids), 10)
                    self.data['ycb_ids'][i, :num_ids] = torch.tensor(ycb_ids[:num_ids], 
                                                                    device=self.device, dtype=torch.long)
            except Exception as e:
                print(f"Error loading sample {i}: {e}")
                continue
                
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        """Return a sample - already on GPU!"""
        return {
            'color': self.data['color'][idx],
            'hand_joints_3d': self.data['hand_joints_3d'][idx],
            'hand_joints_2d': self.data['hand_joints_2d'][idx],
            'hand_pose': self.data['hand_pose'][idx],
            'object_poses': self.data['object_poses'][idx],
            'ycb_ids': self.data['ycb_ids'][idx],
            'num_objects': self.data['num_objects'][idx],
            'has_hand': self.data['has_hand'][idx],
        }


class GPUBatchGenerator:
    """Generate batches directly from GPU memory - zero copy"""
    
    def __init__(self, dataset, batch_size=256, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_samples = len(dataset)
        
    def __len__(self):
        return (self.num_samples + self.batch_size - 1) // self.batch_size
    
    def __iter__(self):
        # Create indices
        indices = torch.arange(self.num_samples, device='cuda')
        if self.shuffle:
            indices = indices[torch.randperm(self.num_samples, device='cuda')]
        
        # Generate batches
        for start_idx in range(0, self.num_samples, self.batch_size):
            end_idx = min(start_idx + self.batch_size, self.num_samples)
            batch_indices = indices[start_idx:end_idx]
            
            # Create batch - everything stays on GPU
            batch = {}
            for key in self.dataset.data:
                if isinstance(self.dataset.data[key], torch.Tensor):
                    batch[key] = self.dataset.data[key][batch_indices]
            
            yield batch

In [6]:
# Create GPU-only datasets
print("Creating GPU-only datasets...")
print("First run will be slow (loading), subsequent runs will use cache")

# Clear GPU memory
torch.cuda.empty_cache()
torch.cuda.synchronize()

train_dataset = GPUOnlyDataset(
    split='s0_train',
    max_samples=config['max_samples_train'],
    image_size=config['image_size'],
    device='cuda',
    dtype=config['dtype'],
    cache_path=config['cache_path']
)

val_dataset = GPUOnlyDataset(
    split='s0_val',
    max_samples=config['max_samples_val'],
    image_size=config['image_size'],
    device='cuda',
    dtype=config['dtype'],
    cache_path=config['cache_path']
)

# Create batch generators
train_loader = GPUBatchGenerator(train_dataset, batch_size=config['batch_size'], shuffle=True)
val_loader = GPUBatchGenerator(val_dataset, batch_size=config['batch_size']//2, shuffle=False)

print(f"\n✓ Datasets ready:")
print(f"  Train: {len(train_dataset)} samples, {len(train_loader)} batches")
print(f"  Val: {len(val_dataset)} samples, {len(val_loader)} batches")
print(f"  GPU Memory: {torch.cuda.memory_allocated()/1e9:.1f} GB")

Creating GPU-only datasets...
First run will be slow (loading), subsequent runs will use cache
Building GPU dataset for s0_train...
Allocating GPU memory for 300000 samples...
Loading and preprocessing data...


Loading samples:   0%|          | 0/300000 [00:00<?, ?it/s]

Saved cache to gpu_cache/s0_train_gpu_cache.pt
✓ GPU dataset ready with 300000 samples
  Memory usage: 90.5 GB
Building GPU dataset for s0_val...
Allocating GPU memory for 20000 samples...
Loading and preprocessing data...


Loading samples:   0%|          | 0/20000 [00:00<?, ?it/s]

Saved cache to gpu_cache/s0_val_gpu_cache.pt
✓ GPU dataset ready with 20000 samples
  Memory usage: 96.5 GB

✓ Datasets ready:
  Train: 300000 samples, 293 batches
  Val: 20000 samples, 40 batches
  GPU Memory: 96.5 GB


In [10]:
# Create models
print("Creating scaled-up models for H200...")

patch_dim = 3 * config['patch_size'] * config['patch_size']

# Large models for H200
hand_encoder = HandPoseEncoder(
    input_dim=patch_dim,
    hidden_dim=config['hand_hidden_dim'],
    num_layers=config['hand_layers'],
    num_heads=32,
    mlp_dim=8192,
    dropout=0.1
).to(device)

object_encoder = ObjectPoseEncoder(
    input_dim=patch_dim,
    hidden_dim=config['object_hidden_dim'],
    num_layers=config['object_layers'],
    num_heads=32,
    mlp_dim=8192,
    dropout=0.1,
    max_objects=10
).to(device)

contact_encoder = ContactDetectionEncoder(
    input_dim=patch_dim,
    hidden_dim=config['contact_hidden_dim'],
    num_layers=config['contact_layers'],
    num_heads=32,
    mlp_dim=4096,
    dropout=0.1
).to(device)

# Compile models for better performance
if hasattr(torch, 'compile'):
    print("Compiling models with torch.compile...")
    hand_encoder = torch.compile(hand_encoder, mode='max-autotune')
    object_encoder = torch.compile(object_encoder, mode='max-autotune')
    contact_encoder = torch.compile(contact_encoder, mode='max-autotune')

# Count parameters
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_params(hand_encoder) + count_params(object_encoder) + count_params(contact_encoder)
print(f"\nModel parameters:")
print(f"  Hand: {count_params(hand_encoder)/1e6:.1f}M")
print(f"  Object: {count_params(object_encoder)/1e6:.1f}M")
print(f"  Contact: {count_params(contact_encoder)/1e6:.1f}M")
print(f"  Total: {total_params/1e6:.1f}M")

Creating scaled-up models for H200...
Compiling models with torch.compile...

Model parameters:
  Hand: 612.4M
  Object: 610.5M
  Contact: 108.4M
  Total: 1331.3M


In [11]:
# Create GPU preprocessor and optimizers
gpu_preprocessor = GPUVideoPreprocessor(
    image_size=config['image_size'],
    patch_size=config['patch_size'],
    normalize=True,
    device='cuda'
).to(device)

# Optimizers
optimizer_hand = optim.AdamW(hand_encoder.parameters(), lr=config['learning_rate'])
optimizer_object = optim.AdamW(object_encoder.parameters(), lr=config['learning_rate'])
optimizer_contact = optim.AdamW(contact_encoder.parameters(), lr=config['learning_rate'])

print("✓ Preprocessor and optimizers ready")

✓ Preprocessor and optimizers ready


In [12]:
# Training functions
def train_epoch(epoch):
    """GPU-only training - everything stays on GPU"""
    hand_encoder.train()
    object_encoder.train()
    contact_encoder.train()
    
    total_loss = 0
    num_batches = 0
    epoch_start = time.time()
    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["num_epochs"]}')
    
    for batch_idx, batch in enumerate(progress_bar):
        # Create patches on GPU
        with torch.no_grad():
            patches = gpu_preprocessor(batch['color'])
        
        # Zero gradients
        optimizer_hand.zero_grad(set_to_none=True)
        optimizer_object.zero_grad(set_to_none=True)
        optimizer_contact.zero_grad(set_to_none=True)
        
        # Forward passes with autocast
        with autocast(device_type='cuda', dtype=torch.bfloat16):
            # Hand encoder
            hand_output = hand_encoder(patches)
            hand_gt = batch['hand_joints_3d']
            valid_hands = batch['has_hand']
            
            if valid_hands.any():
                hand_loss = F.mse_loss(hand_output['joints_3d'][valid_hands], hand_gt[valid_hands])
            else:
                hand_loss = torch.tensor(0.0, device='cuda')
            
            # Object encoder
            object_output = object_encoder(patches, object_ids=batch['ycb_ids'])
            object_loss = torch.tensor(0.0, device='cuda')
            
            valid_objects = batch['num_objects'] > 0
            if valid_objects.any():
                # Get object positions
                object_positions_gt = batch['object_poses'][:, :, :3, 3]
                num_pred = min(object_output['positions'].shape[1], 10)
                
                # Compute loss for each sample with objects
                for i in torch.where(valid_objects)[0]:
                    n_obj = batch['num_objects'][i].item()
                    if n_obj > 0 and n_obj <= num_pred:
                        pred = object_output['positions'][i, :n_obj]
                        gt = object_positions_gt[i, :n_obj]
                        object_loss = object_loss + F.mse_loss(pred, gt)
                
                if valid_objects.sum() > 0:
                    object_loss = object_loss / valid_objects.sum()
            
            # Contact encoder (no supervision)
            contact_output = contact_encoder(
                hand_output['features'].detach(),
                object_output['features'].detach()
            )
            
            # Total loss
            total_batch_loss = hand_loss + object_loss
        
        # Backward
        total_batch_loss.backward()
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(hand_encoder.parameters(), config['grad_clip'])
        torch.nn.utils.clip_grad_norm_(object_encoder.parameters(), config['grad_clip'])
        torch.nn.utils.clip_grad_norm_(contact_encoder.parameters(), config['grad_clip'])
        
        # Update
        optimizer_hand.step()
        optimizer_object.step()
        optimizer_contact.step()
        
        # Metrics
        total_loss += total_batch_loss.item()
        num_batches += 1
        
        # Update progress bar
        if batch_idx % 5 == 0:
            gpu_mem = torch.cuda.memory_allocated() / 1e9
            elapsed = time.time() - epoch_start
            samples_per_sec = (batch_idx + 1) * config['batch_size'] / elapsed
            
            progress_bar.set_postfix({
                'loss': f'{total_batch_loss.item():.4f}',
                'gpu': f'{gpu_mem:.1f}GB',
                'speed': f'{samples_per_sec:.0f}/s'
            })
    
    return total_loss / max(num_batches, 1)


def validate():
    """Fast validation"""
    hand_encoder.eval()
    object_encoder.eval()
    
    total_loss = 0
    total_mpjpe = 0
    num_valid = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            if batch_idx >= 10:  # Quick validation
                break
            
            # Preprocess
            patches = gpu_preprocessor(batch['color'])
            
            # Forward
            with autocast(device_type='cuda', dtype=torch.bfloat16):
                hand_output = hand_encoder(patches)
            
            # Compute metrics
            valid_hands = batch['has_hand']
            if valid_hands.any():
                hand_gt = batch['hand_joints_3d'][valid_hands]
                hand_pred = hand_output['joints_3d'][valid_hands]
                
                loss = F.mse_loss(hand_pred, hand_gt)
                mpjpe = (hand_pred - hand_gt).norm(dim=-1).mean()
                
                total_loss += loss.item() * valid_hands.sum().item()
                total_mpjpe += mpjpe.item() * valid_hands.sum().item()
                num_valid += valid_hands.sum().item()
    
    return {
        'loss': total_loss / max(num_valid, 1),
        'mpjpe': total_mpjpe / max(num_valid, 1)
    }

In [13]:
# GPU monitoring utilities
def print_gpu_stats():
    """Print current GPU statistics"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9
        reserved = torch.cuda.memory_reserved() / 1e9
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        
        print("\n" + "="*60)
        print("GPU Statistics:")
        print(f"  Memory: {allocated:.1f} / {total:.1f} GB ({allocated/total*100:.1f}%)")
        print(f"  Reserved: {reserved:.1f} GB")
        
        # Get utilization and power
        try:
            result = subprocess.run([
                'nvidia-smi', '--query-gpu=utilization.gpu,power.draw,power.limit', 
                '--format=csv,noheader,nounits'
            ], capture_output=True, text=True)
            
            if result.returncode == 0:
                gpu_util, power_draw, power_limit = result.stdout.strip().split(', ')
                print(f"  Utilization: {gpu_util}%")
                print(f"  Power: {power_draw}W / {power_limit}W")
                
                # Performance indicator
                if float(gpu_util) > 80:
                    print("  Status: ✓ Excellent GPU utilization")
                elif float(gpu_util) > 50:
                    print("  Status: ⚠️ Moderate GPU utilization")
                else:
                    print("  Status: ✗ Low GPU utilization")
        except:
            pass
        print("="*60 + "\n")

In [14]:
# Training loop
print("Starting GPU-Only Training...")
print(f"Configuration:")
print(f"  Batch size: {config['batch_size']}")
print(f"  Epochs: {config['num_epochs']}")
print(f"  Samples cached in GPU: {config['max_samples_train']}")
print("-" * 60)

# History
history = {
    'train_loss': [],
    'val_loss': [],
    'val_mpjpe': [],
    'throughput': [],
    'gpu_util': []
}

# Initial GPU stats
print_gpu_stats()

# Training
best_val_loss = float('inf')
for epoch in range(config['num_epochs']):
    epoch_start = time.time()
    
    # Train
    train_loss = train_epoch(epoch)
    history['train_loss'].append(train_loss)
    
    # Validate
    val_metrics = validate()
    history['val_loss'].append(val_metrics['loss'])
    history['val_mpjpe'].append(val_metrics['mpjpe'])
    
    # Calculate throughput
    epoch_time = time.time() - epoch_start
    samples_processed = len(train_loader) * config['batch_size']
    throughput = samples_processed / epoch_time
    history['throughput'].append(throughput)
    
    # Get GPU utilization
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=utilization.gpu', 
                               '--format=csv,noheader,nounits'], 
                              capture_output=True, text=True)
        if result.returncode == 0:
            gpu_util = float(result.stdout.strip())
            history['gpu_util'].append(gpu_util)
    except:
        history['gpu_util'].append(0)
    
    # Print summary
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{config['num_epochs']} Summary:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_metrics['loss']:.4f}")
    print(f"  Val MPJPE: {val_metrics['mpjpe']*1000:.2f} mm")
    print(f"  Throughput: {throughput:.0f} samples/s")
    print(f"  Time: {epoch_time:.1f}s")
    
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        print("  ✓ New best validation loss!")
    
    print_gpu_stats()

print("\n✓ Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Average throughput: {np.mean(history['throughput']):.0f} samples/s")
print(f"Average GPU utilization: {np.mean(history['gpu_util']):.1f}%")

Starting GPU-Only Training...
Configuration:
  Batch size: 1024
  Epochs: 5
  Samples cached in GPU: 300000
------------------------------------------------------------

GPU Statistics:
  Memory: 101.9 / 150.0 GB (67.9%)
  Reserved: 107.2 GB
  Utilization: 0%
  Power: 114.21W / 700.00W
  Status: ✗ Low GPU utilization



Epoch 1/5:   0%|          | 0/293 [00:00<?, ?it/s]

  with autocast(device_type='cuda', dtype=torch.bfloat16):


TypeError: autocast.__init__() got an unexpected keyword argument 'device_type'

In [None]:
# Plot results
plt.figure(figsize=(20, 5))

# Loss curves
plt.subplot(1, 4, 1)
plt.plot(history['train_loss'], label='Train Loss', marker='o')
plt.plot(history['val_loss'], label='Val Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True)

# MPJPE
plt.subplot(1, 4, 2)
plt.plot([x*1000 for x in history['val_mpjpe']], label='Val MPJPE', marker='o', color='green')
plt.xlabel('Epoch')
plt.ylabel('MPJPE (mm)')
plt.title('Hand Pose Error')
plt.legend()
plt.grid(True)

# Throughput
plt.subplot(1, 4, 3)
plt.plot(history['throughput'], label='Throughput', marker='o', color='orange')
plt.axhline(y=np.mean(history['throughput']), color='red', linestyle='--', label='Average')
plt.xlabel('Epoch')
plt.ylabel('Samples/s')
plt.title('Training Speed')
plt.legend()
plt.grid(True)

# GPU Utilization
plt.subplot(1, 4, 4)
plt.plot(history['gpu_util'], label='GPU Util %', marker='o', color='purple')
plt.axhline(y=80, color='green', linestyle='--', label='Target (80%)')
plt.xlabel('Epoch')
plt.ylabel('GPU Utilization %')
plt.title('GPU Usage')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('gpu_only_training_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nPerformance Summary:")
print(f"  Peak throughput: {max(history['throughput']):.0f} samples/s")
print(f"  Peak GPU utilization: {max(history['gpu_util']):.1f}%")
print(f"  Final MPJPE: {history['val_mpjpe'][-1]*1000:.2f} mm")

In [None]:
# Save models
checkpoint_dir = 'checkpoints/gpu_only'
os.makedirs(checkpoint_dir, exist_ok=True)

# Save with all training info
torch.save({
    'model_state_dict': hand_encoder.state_dict(),
    'optimizer_state_dict': optimizer_hand.state_dict(),
    'config': config,
    'history': history,
    'best_val_loss': best_val_loss
}, os.path.join(checkpoint_dir, 'hand_encoder.pth'))

torch.save({
    'model_state_dict': object_encoder.state_dict(),
    'optimizer_state_dict': optimizer_object.state_dict(),
}, os.path.join(checkpoint_dir, 'object_encoder.pth'))

torch.save({
    'model_state_dict': contact_encoder.state_dict(),
    'optimizer_state_dict': optimizer_contact.state_dict(),
}, os.path.join(checkpoint_dir, 'contact_encoder.pth'))

print(f"✓ Models saved to {checkpoint_dir}")
print(f"  Best validation loss: {best_val_loss:.4f}")

## Tips for Maximum Performance

### 1. Increase Cache Size
```python
# For H200 with 140GB:
'max_samples_train': 100000,  # ~100GB with bfloat16
'max_samples_train': 150000,  # Use most of GPU memory
```

### 2. Larger Batch Sizes
```python
'batch_size': 2048,  # Or even 4096
```

### 3. Monitor in Real-Time
```bash
# In another terminal:
watch -n 0.5 nvidia-smi
```

### 4. Expected Performance
- **GPU Utilization**: 85-95%
- **Memory Usage**: 50-120GB
- **Throughput**: 10,000+ samples/s
- **Power Draw**: 500-700W

### 5. Troubleshooting
- If GPU util is low: Increase batch size
- If OOM: Reduce cache size or use bfloat16
- If slow: Check torch.compile is working