# Advanced Manipulation Transformer - Barebones Debug Version

This notebook provides a minimal implementation without advanced features for easier debugging:
- No FlashAttention, FP8, or memory optimizations
- No mode collapse prevention modules
- Simple data loading without prefetching
- Basic loss functions only
- Small dataset subset
- CPU-friendly options

## 1. Basic Setup

In [None]:
# Minimal imports
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from tqdm.notebook import tqdm

# Set DEX_YCB_DIR environment variable
os.environ['DEX_YCB_DIR'] = '/home/n231/231nProjectV2/dex-ycb-toolkit/data'

# Add project root to path
project_root = Path('.').absolute().parent
sys.path.insert(0, str(project_root))

# Simple matplotlib setup
%matplotlib inline

# Check environment
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Minimal Configuration

In [None]:
# Simple configuration dictionary
config = {
    # Data
    'data_root': '../../dex-ycb-toolkit',
    'batch_size': 4,  # Small batch size
    'num_samples': 1000,  # Use only 1000 samples for debugging
    'num_workers': 0,  # No multiprocessing for debugging
    
    # Model - smaller sizes for debugging
    'hidden_dim': 256,  # Much smaller than production
    'num_heads': 8,
    'num_layers': 4,  # Fewer layers
    'dropout': 0.1,
    'num_refinement_steps': 1,  # Minimal refinement
    
    # Training
    'num_epochs': 5,  # Just a few epochs
    'learning_rate': 1e-3,
    'weight_decay': 0.01,
    'print_freq': 10,
    
    # No optimizations
    'use_amp': False,
    'use_ema': False,
    'gradient_checkpointing': False,
    
    # Output
    'output_dir': 'outputs/debug_run'
}

# Create output directory
os.makedirs(config['output_dir'], exist_ok=True)
print(f"Output directory: {config['output_dir']}")

## 3. Simple Dataset (No Augmentation)

In [None]:
# Import minimal dataset
from torch.utils.data import Dataset, DataLoader
from data.enhanced_dexycb import EnhancedDexYCBDataset

# Create simple dataset wrapper that limits samples
class DebugDataset(Dataset):
    def __init__(self, base_dataset, max_samples=1000):
        self.base_dataset = base_dataset
        self.max_samples = min(max_samples, len(base_dataset))
        print(f"Using {self.max_samples} samples from {len(base_dataset)} total")
    
    def __len__(self):
        return self.max_samples
    
    def __getitem__(self, idx):
        return self.base_dataset[idx]

# Create datasets without augmentation
print("Loading datasets...")
train_dataset_full = EnhancedDexYCBDataset(
    dexycb_root=config['data_root'],
    split='train',  # Changed from 's0_train' to 'train'
    sequence_length=1,
    augment=False  # No augmentation for debugging
)

val_dataset_full = EnhancedDexYCBDataset(
    dexycb_root=config['data_root'],
    split='val',  # Changed from 's0_val' to 'val'
    sequence_length=1,
    augment=False
)

# Wrap with debug dataset
train_dataset = DebugDataset(train_dataset_full, config['num_samples'])
val_dataset = DebugDataset(val_dataset_full, config['num_samples'] // 10)

# Simple dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers'],
    pin_memory=False  # Disable for debugging
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers'],
    pin_memory=False
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
# Check data format
sample_batch = next(iter(train_loader))
print("Sample batch contents:")
for key, value in sample_batch.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: {value.shape} ({value.dtype}) - device: {value.device}")
        # Only compute statistics for floating point tensors
        if value.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]:
            print(f"    min: {value.min():.3f}, max: {value.max():.3f}, mean: {value.mean():.3f}")
        else:
            # For integer tensors, just show min/max
            print(f"    min: {value.min()}, max: {value.max()}")
    else:
        print(f"  {key}: {type(value)}")

## 4. Simplified Model (No Advanced Features)

In [None]:
# Create a simplified version of the model for debugging
class SimpleManipulationModel(nn.Module):
    """Simplified model without advanced features for debugging"""
    
    def __init__(self, config):
        super().__init__()
        
        self.hidden_dim = config['hidden_dim']
        self.num_heads = config['num_heads']
        self.num_layers = config['num_layers']
        
        # Simple patch embedding (no DINOv2)
        self.patch_embed = nn.Conv2d(3, self.hidden_dim, kernel_size=16, stride=16)
        self.pos_embed = nn.Parameter(torch.randn(1, 196, self.hidden_dim) * 0.02)
        
        # Simple transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.hidden_dim,
            nhead=self.num_heads,
            dim_feedforward=self.hidden_dim * 4,
            dropout=config['dropout'],
            activation='gelu',
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.num_layers)
        
        # Simple output heads
        self.hand_head = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.GELU(),
            nn.Dropout(config['dropout']),
            nn.Linear(self.hidden_dim, 21 * 3)  # 21 joints x 3D
        )
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)
    
    def forward(self, batch):
        # Get image
        x = batch['image']  # [B, 3, H, W]
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # [B, hidden_dim, H/16, W/16]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, hidden_dim]
        
        # Add positional embedding
        x = x + self.pos_embed
        
        # Transformer encoder
        x = self.encoder(x)  # [B, num_patches, hidden_dim]
        
        # Global pooling
        x = x.mean(dim=1)  # [B, hidden_dim]
        
        # Predict hand joints
        hand_joints = self.hand_head(x)  # [B, 21*3]
        hand_joints = hand_joints.view(B, 21, 3)  # [B, 21, 3]
        
        # Simple output dictionary
        outputs = {
            'hand_joints': hand_joints,
            'features': x  # For debugging
        }
        
        return outputs

# Create model
print("Creating simplified model...")
model = SimpleManipulationModel(config)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print(f"Model size: {total_params * 4 / 1024**2:.2f} MB")

In [None]:
# Test forward pass
print("Testing forward pass...")
model.eval()
with torch.no_grad():
    sample_batch_gpu = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                       for k, v in sample_batch.items()}
    outputs = model(sample_batch_gpu)
    
print("Output shapes:")
for key, value in outputs.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: {value.shape}")
        print(f"    min: {value.min():.3f}, max: {value.max():.3f}, mean: {value.mean():.3f}")

## 5. Simple Loss Function

In [None]:
# Simple MPJPE loss
class SimpleLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, outputs, targets):
        losses = {}
        
        # Initialize pred_joints to None
        pred_joints = None
        
        # Hand joint loss (MPJPE)
        if 'hand_joints' in outputs and 'hand_joints' in targets:
            pred_joints = outputs['hand_joints']
            gt_joints = targets['hand_joints']
            
            # Simple L2 loss
            joint_loss = F.mse_loss(pred_joints, gt_joints)
            losses['joint_mse'] = joint_loss
            
            # MPJPE for monitoring
            with torch.no_grad():
                mpjpe = torch.norm(pred_joints - gt_joints, dim=-1).mean()
                losses['mpjpe'] = mpjpe
        
        # Simple diversity loss to prevent collapse
        if pred_joints is not None and pred_joints.shape[0] > 1:
            # Variance of predictions
            pred_std = pred_joints.std(dim=0).mean()
            diversity_loss = torch.relu(0.01 - pred_std)  # Penalize if std < 0.01
            losses['diversity'] = diversity_loss * 0.1  # Small weight
        
        # Total loss
        total_loss = sum(losses.values())
        losses['total'] = total_loss
        
        return losses

# Create loss function
criterion = SimpleLoss()
print("Loss function created")

## 6. Simple Training Loop

In [None]:
# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), 
                            lr=config['learning_rate'],
                            weight_decay=config['weight_decay'])

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config['num_epochs']
)

print(f"Optimizer: Adam with lr={config['learning_rate']}")
print(f"Scheduler: CosineAnnealingLR")

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_mpjpe': [],
    'val_loss': [],
    'val_mpjpe': [],
    'lr': []
}

# Training function
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    total_mpjpe = 0
    num_batches = 0
    
    for batch_idx, batch in enumerate(tqdm(loader, desc="Training")):
        # Move to device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                for k, v in batch.items()}
        
        # Forward pass
        outputs = model(batch)
        losses = criterion(outputs, batch)
        
        # Backward pass
        optimizer.zero_grad()
        losses['total'].backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        
        # Update metrics
        total_loss += losses['total'].item()
        if 'mpjpe' in losses:
            total_mpjpe += losses['mpjpe'].item()
        num_batches += 1
        
        # Print progress
        if batch_idx % config['print_freq'] == 0:
            print(f"  Batch {batch_idx}/{len(loader)}: "
                  f"Loss={losses['total'].item():.4f}, "
                  f"MPJPE={losses.get('mpjpe', 0).item():.2f}mm")
    
    return total_loss / num_batches, total_mpjpe / num_batches

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    total_mpjpe = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                    for k, v in batch.items()}
            
            outputs = model(batch)
            losses = criterion(outputs, batch)
            
            total_loss += losses['total'].item()
            if 'mpjpe' in losses:
                total_mpjpe += losses['mpjpe'].item()
            num_batches += 1
    
    return total_loss / num_batches, total_mpjpe / num_batches

In [None]:
# Main training loop
print(f"\nStarting training for {config['num_epochs']} epochs...\n")

for epoch in range(config['num_epochs']):
    print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
    print("=" * 50)
    
    # Train
    train_loss, train_mpjpe = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_mpjpe = validate(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_mpjpe'].append(train_mpjpe)
    history['val_loss'].append(val_loss)
    history['val_mpjpe'].append(val_mpjpe)
    history['lr'].append(optimizer.param_groups[0]['lr'])
    
    # Print summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}, MPJPE: {train_mpjpe:.2f}mm")
    print(f"  Val Loss: {val_loss:.4f}, MPJPE: {val_mpjpe:.2f}mm")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'config': config
    }
    torch.save(checkpoint, f"{config['output_dir']}/checkpoint_epoch_{epoch+1}.pth")

print("\nTraining completed!")

## 7. Plot Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss plot
ax = axes[0]
ax.plot(history['train_loss'], label='Train')
ax.plot(history['val_loss'], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True)

# MPJPE plot
ax = axes[1]
ax.plot(history['train_mpjpe'], label='Train')
ax.plot(history['val_mpjpe'], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('MPJPE (mm)')
ax.set_title('Mean Per Joint Position Error')
ax.legend()
ax.grid(True)

# Learning rate plot
ax = axes[2]
ax.plot(history['lr'])
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedule')
ax.set_yscale('log')
ax.grid(True)

plt.tight_layout()
plt.savefig(f"{config['output_dir']}/training_curves.png")
plt.show()

# Print final results
print(f"\nFinal Results:")
print(f"  Best Train MPJPE: {min(history['train_mpjpe']):.2f}mm")
print(f"  Best Val MPJPE: {min(history['val_mpjpe']):.2f}mm")

## 8. Debug Model Predictions

In [None]:
# Check prediction diversity
print("Checking prediction diversity...")
model.eval()

all_predictions = []
num_batches = 5

with torch.no_grad():
    for i, batch in enumerate(val_loader):
        if i >= num_batches:
            break
        
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                for k, v in batch.items()}
        
        outputs = model(batch)
        all_predictions.append(outputs['hand_joints'].cpu())

# Concatenate predictions
all_predictions = torch.cat(all_predictions, dim=0)
print(f"Total predictions: {all_predictions.shape}")

# Compute statistics
pred_mean = all_predictions.mean(dim=0)
pred_std = all_predictions.std(dim=0)

print(f"\nPrediction statistics:")
print(f"  Mean std across joints: {pred_std.mean():.6f}")
print(f"  Min std: {pred_std.min():.6f}")
print(f"  Max std: {pred_std.max():.6f}")

# Visualize prediction diversity
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Std per joint
ax = axes[0]
joint_std = pred_std.mean(dim=1)  # Average over xyz
ax.bar(range(21), joint_std)
ax.set_xlabel('Joint Index')
ax.set_ylabel('Std Dev')
ax.set_title('Prediction Diversity per Joint')
ax.grid(True, axis='y')

# Distribution of predictions
ax = axes[1]
ax.hist(all_predictions.flatten().numpy(), bins=50, alpha=0.7, density=True)
ax.set_xlabel('Predicted Value')
ax.set_ylabel('Density')
ax.set_title('Distribution of All Predictions')
ax.grid(True, axis='y')

plt.tight_layout()
plt.show()

# Check for mode collapse
if pred_std.mean() < 0.001:
    print("\n⚠️ WARNING: Possible mode collapse detected! Predictions have very low diversity.")
else:
    print("\n✓ Good prediction diversity")

## 9. Visualize Sample Predictions

In [None]:
# Visualize some predictions
print("Visualizing predictions...")
model.eval()

# Get one batch
vis_batch = next(iter(val_loader))
vis_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
            for k, v in vis_batch.items()}

with torch.no_grad():
    outputs = model(vis_batch)

# Visualize first 4 samples
num_vis = min(4, vis_batch['image'].shape[0])
fig, axes = plt.subplots(num_vis, 2, figsize=(8, 4*num_vis))
if num_vis == 1:
    axes = axes.reshape(1, -1)

for i in range(num_vis):
    # Input image
    ax = axes[i, 0]
    img = vis_batch['image'][i].cpu().numpy().transpose(1, 2, 0)
    img = (img - img.min()) / (img.max() - img.min())
    ax.imshow(img)
    ax.set_title(f"Sample {i+1} - Input")
    ax.axis('off')
    
    # Predictions vs GT
    ax = axes[i, 1]
    ax.imshow(img)
    
    # Plot predicted joints (red)
    pred_joints = outputs['hand_joints'][i].cpu().numpy()
    # Simple projection assuming normalized coordinates
    pred_2d = pred_joints[:, :2] * 112 + 112
    ax.scatter(pred_2d[:, 0], pred_2d[:, 1], c='red', s=30, alpha=0.8, label='Pred')
    
    # Plot GT joints (green)
    if 'hand_joints' in vis_batch:
        gt_joints = vis_batch['hand_joints'][i].cpu().numpy()
        gt_2d = gt_joints[:, :2] * 112 + 112
        ax.scatter(gt_2d[:, 0], gt_2d[:, 1], c='green', s=30, alpha=0.8, label='GT')
        
        # Compute error
        mpjpe = np.mean(np.linalg.norm(pred_joints - gt_joints, axis=1))
        ax.set_title(f"Sample {i+1} - MPJPE: {mpjpe:.1f}mm")
    else:
        ax.set_title(f"Sample {i+1} - Predictions")
    
    ax.axis('off')
    ax.legend()

plt.tight_layout()
plt.savefig(f"{config['output_dir']}/predictions.png")
plt.show()

## 10. Debug Analysis

In [None]:
# Check gradient flow
print("Analyzing gradient flow...")

# Do one forward-backward pass
model.train()
batch = next(iter(train_loader))
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
        for k, v in batch.items()}

outputs = model(batch)
losses = criterion(outputs, batch)
losses['total'].backward()

# Check gradients
grad_norms = {}
for name, param in model.named_parameters():
    if param.grad is not None:
        grad_norm = param.grad.norm().item()
        grad_norms[name] = grad_norm

# Print gradient statistics
print(f"\nGradient statistics:")
grad_values = list(grad_norms.values())
print(f"  Mean gradient norm: {np.mean(grad_values):.6f}")
print(f"  Max gradient norm: {np.max(grad_values):.6f}")
print(f"  Min gradient norm: {np.min(grad_values):.6f}")

# Check for vanishing/exploding gradients
vanishing = sum(1 for v in grad_values if v < 1e-6)
exploding = sum(1 for v in grad_values if v > 100)
print(f"\n  Parameters with vanishing gradients (<1e-6): {vanishing}/{len(grad_values)}")
print(f"  Parameters with exploding gradients (>100): {exploding}/{len(grad_values)}")

# Plot gradient norms
plt.figure(figsize=(12, 4))
names = list(grad_norms.keys())
values = list(grad_norms.values())
plt.bar(range(len(names)), values)
plt.yscale('log')
plt.xlabel('Layer')
plt.ylabel('Gradient Norm (log scale)')
plt.title('Gradient Norms by Layer')
plt.xticks(range(len(names)), names, rotation=90)
plt.tight_layout()
plt.show()

# Clear gradients
optimizer.zero_grad()

In [None]:
# Check for dead neurons
print("\nChecking for dead neurons...")

# Get activations
model.eval()
activations = {}

def hook_fn(name):
    def hook(module, input, output):
        activations[name] = output.detach()
    return hook

# Register hooks on key layers
hooks = []
for name, module in model.named_modules():
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        hook = module.register_forward_hook(hook_fn(name))
        hooks.append(hook)

# Run forward pass
with torch.no_grad():
    _ = model(batch)

# Analyze activations
for name, act in activations.items():
    if act.dim() > 1:
        # Check for dead neurons (always zero)
        dead_neurons = (act.abs().max(dim=0)[0] < 1e-6).float().mean().item()
        print(f"{name}: {dead_neurons*100:.1f}% dead neurons")

# Remove hooks
for hook in hooks:
    hook.remove()

print("\nDebug analysis complete!")

## 11. Save Final Report

In [None]:
# Generate debug report
report = f"""
# Debug Training Report

## Configuration
- Model: Simplified transformer
- Hidden dim: {config['hidden_dim']}
- Layers: {config['num_layers']}
- Batch size: {config['batch_size']}
- Learning rate: {config['learning_rate']}
- Epochs: {config['num_epochs']}
- Training samples: {len(train_dataset)}
- Validation samples: {len(val_dataset)}

## Results
- Final train MPJPE: {history['train_mpjpe'][-1]:.2f}mm
- Final val MPJPE: {history['val_mpjpe'][-1]:.2f}mm
- Best val MPJPE: {min(history['val_mpjpe']):.2f}mm
- Prediction diversity (std): {pred_std.mean():.6f}

## Gradient Analysis
- Mean gradient norm: {np.mean(grad_values):.6f}
- Parameters with vanishing gradients: {vanishing}
- Parameters with exploding gradients: {exploding}

## Issues Found
"""

# Add any issues
if pred_std.mean() < 0.001:
    report += "- ⚠️ Mode collapse detected\n"
if vanishing > len(grad_values) * 0.5:
    report += "- ⚠️ Many vanishing gradients\n"
if exploding > 0:
    report += "- ⚠️ Some exploding gradients\n"

print(report)

# Save report
with open(f"{config['output_dir']}/debug_report.txt", 'w') as f:
    f.write(report)

print(f"\nAll outputs saved to: {config['output_dir']}/")