# RT-TDDFT Model Debugging

This notebook provides tools for debugging the RT-TDDFT ML model:

1. **Model Architecture** - Inspect layers, parameters, shapes
2. **Forward Pass** - Step-by-step execution with hooks
3. **Gradient Analysis** - Gradient flow, vanishing/exploding gradients
4. **Activation Statistics** - Layer activations, dead neurons
5. **Weight Analysis** - Distributions, initialization quality
6. **Shape Debugging** - Input/output dimensions at each layer
7. **Memory Profiling** - GPU memory usage
8. **Common Issues** - NaN detection, numerical stability

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict, defaultdict
from pathlib import Path
import warnings

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

# Device setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Import project modules
from src.models import (
    RTTDDFTModel,
    GeometryEncoder,
    DensityEncoder,
    FieldEncoder,
    GeometryConditionedMamba,
    DensityDecoder,
)
from src.utils import build_molecular_graph, close_all

## Configuration

In [None]:
# Model configuration
MODEL_CONFIG = {
    'latent_dim': 256,
    'n_mamba_layers': 6,
    'd_state': 16,
    'geometry_irreps': '32x0e + 16x1o + 8x2e',
    'max_ell': 2,
}

# Test data configuration
TEST_CONFIG = {
    'batch_size': 2,
    'n_atoms': 3,
    'n_basis': 10,
    'seq_length': 16,
    'n_electrons': 4,
}

# Checkpoint path (optional)
CHECKPOINT_PATH = None  # Set to load trained model

---
## 1. Model Architecture Analysis

In [None]:
def count_parameters(model):
    """Count trainable and total parameters."""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def format_params(n):
    """Format parameter count."""
    if n >= 1e6:
        return f"{n/1e6:.2f}M"
    elif n >= 1e3:
        return f"{n/1e3:.1f}K"
    return str(n)

def print_model_summary(model, indent=0):
    """Print hierarchical model summary."""
    prefix = "  " * indent
    total, trainable = count_parameters(model)
    
    print(f"{prefix}{model.__class__.__name__}: {format_params(trainable)} params")
    
    for name, child in model.named_children():
        child_total, child_trainable = count_parameters(child)
        print(f"{prefix}  └─ {name}: {format_params(child_trainable)}")

In [None]:
# Create or load model
def create_model(config, checkpoint_path=None):
    """Create model from config or load from checkpoint."""
    if checkpoint_path and Path(checkpoint_path).exists():
        print(f"Loading model from {checkpoint_path}")
        ckpt = torch.load(checkpoint_path, map_location='cpu')
        if 'model' in ckpt:
            return ckpt['model']
        else:
            raise ValueError("Checkpoint doesn't contain full model")
    
    print("Creating new model from config")
    model = RTTDDFTModel(
        latent_dim=config['latent_dim'],
        n_mamba_layers=config['n_mamba_layers'],
        d_state=config['d_state'],
        geometry_irreps=config['geometry_irreps'],
        max_ell=config['max_ell'],
    )
    return model

try:
    model = create_model(MODEL_CONFIG, CHECKPOINT_PATH)
    model = model.to(device)
    model.eval()
    print("Model created successfully")
except Exception as e:
    print(f"Error creating model: {e}")
    print("\nCreating minimal mock model for demonstration...")
    
    class MockModel(nn.Module):
        def __init__(self, latent_dim=256, n_basis=10):
            super().__init__()
            self.encoder = nn.Sequential(
                nn.Linear(n_basis * n_basis * 2, latent_dim),
                nn.ReLU(),
                nn.Linear(latent_dim, latent_dim),
            )
            self.dynamics = nn.LSTM(latent_dim + 3, latent_dim, num_layers=2, batch_first=True)
            self.decoder = nn.Sequential(
                nn.Linear(latent_dim, latent_dim),
                nn.ReLU(),
                nn.Linear(latent_dim, n_basis * n_basis * 2),
            )
            self.n_basis = n_basis
            self.latent_dim = latent_dim
        
        def forward(self, batch):
            rho = batch['density']  # (B, n, n) complex
            field = batch['field']  # (B, 3)
            
            B = rho.shape[0]
            n = self.n_basis
            
            # Flatten and encode
            rho_flat = torch.cat([rho.real.reshape(B, -1), rho.imag.reshape(B, -1)], dim=-1)
            z = self.encoder(rho_flat)
            
            # Combine with field
            z_field = torch.cat([z, field], dim=-1).unsqueeze(1)
            h, _ = self.dynamics(z_field)
            h = h.squeeze(1)
            
            # Decode
            out = self.decoder(h)
            rho_real = out[:, :n*n].reshape(B, n, n)
            rho_imag = out[:, n*n:].reshape(B, n, n)
            rho_out = rho_real + 1j * rho_imag
            
            # Make Hermitian
            rho_out = 0.5 * (rho_out + rho_out.conj().transpose(-2, -1))
            return rho_out
    
    model = MockModel(latent_dim=256, n_basis=TEST_CONFIG['n_basis']).to(device)
    print("Created mock model")

In [None]:
# Print model summary
print("=" * 60)
print("MODEL ARCHITECTURE")
print("=" * 60)

total, trainable = count_parameters(model)
print(f"\nTotal parameters: {format_params(total)} ({total:,})")
print(f"Trainable parameters: {format_params(trainable)} ({trainable:,})")

print("\nModule hierarchy:")
print_model_summary(model)

In [None]:
# Detailed layer breakdown
print("\nDetailed Layer Breakdown:")
print("-" * 70)
print(f"{'Layer':<40} {'Shape':<20} {'Params':>10}")
print("-" * 70)

for name, param in model.named_parameters():
    print(f"{name:<40} {str(list(param.shape)):<20} {param.numel():>10,}")

In [None]:
# Parameter distribution by module
module_params = defaultdict(int)
for name, param in model.named_parameters():
    module = name.split('.')[0]
    module_params[module] += param.numel()

plt.figure(figsize=(10, 6))
modules = list(module_params.keys())
params = list(module_params.values())
plt.barh(modules, params)
plt.xlabel('Number of Parameters')
plt.title('Parameters by Module')
for i, v in enumerate(params):
    plt.text(v + max(params)*0.01, i, format_params(v), va='center')
plt.tight_layout()
plt.show()

---
## 2. Create Test Inputs

In [None]:
def create_test_batch(config, device='cpu'):
    """Create a test batch for model debugging."""
    B = config['batch_size']
    n = config['n_basis']
    n_atoms = config['n_atoms']
    n_elec = config['n_electrons']
    
    # Create Hermitian density matrix with correct trace
    rho = torch.randn(B, n, n, dtype=torch.complex64, device=device)
    rho = 0.5 * (rho + rho.conj().transpose(-2, -1))  # Hermitianize
    
    # Normalize trace
    overlap = torch.eye(n, dtype=torch.complex64, device=device)
    for i in range(B):
        trace = torch.einsum('ij,ji->', rho[i], overlap).real
        rho[i] = rho[i] * (n_elec / trace)
    
    # External field
    field = 0.01 * torch.randn(B, 3, device=device)
    
    # Geometry (simple linear molecule)
    positions = torch.zeros(B, n_atoms, 3, device=device)
    for i in range(n_atoms):
        positions[:, i, 0] = i * 2.0  # 2 Bohr spacing
    
    atomic_numbers = torch.ones(B, n_atoms, dtype=torch.long, device=device)
    
    batch = {
        'density': rho,
        'field': field,
        'positions': positions,
        'atomic_numbers': atomic_numbers,
        'overlap': overlap.unsqueeze(0).expand(B, -1, -1),
        'n_electrons': torch.tensor([n_elec] * B, device=device),
    }
    
    return batch

test_batch = create_test_batch(TEST_CONFIG, device=device)

print("Test batch created:")
for key, val in test_batch.items():
    if isinstance(val, torch.Tensor):
        print(f"  {key}: {val.shape} {val.dtype}")
    else:
        print(f"  {key}: {val}")

---
## 3. Forward Pass Debugging

In [None]:
class ForwardHook:
    """Hook to capture layer activations."""
    def __init__(self):
        self.activations = OrderedDict()
        self.handles = []
    
    def hook_fn(self, name):
        def hook(module, input, output):
            if isinstance(output, tuple):
                output = output[0]  # Handle LSTM/GRU outputs
            self.activations[name] = output.detach()
        return hook
    
    def register(self, model):
        for name, module in model.named_modules():
            if len(list(module.children())) == 0:  # Leaf modules
                handle = module.register_forward_hook(self.hook_fn(name))
                self.handles.append(handle)
    
    def remove(self):
        for handle in self.handles:
            handle.remove()
        self.handles = []
    
    def clear(self):
        self.activations = OrderedDict()

In [None]:
# Run forward pass with hooks
forward_hook = ForwardHook()
forward_hook.register(model)

try:
    with torch.no_grad():
        output = model(test_batch)
    print("Forward pass successful!")
    print(f"\nOutput shape: {output.shape}")
    print(f"Output dtype: {output.dtype}")
except Exception as e:
    print(f"Forward pass failed: {e}")
    import traceback
    traceback.print_exc()

forward_hook.remove()

In [None]:
# Analyze layer activations
print("\nLayer Activations:")
print("-" * 80)
print(f"{'Layer':<40} {'Shape':<20} {'Mean':>10} {'Std':>10}")
print("-" * 80)

for name, activation in forward_hook.activations.items():
    if activation.numel() > 0:
        if activation.is_complex():
            act_real = activation.real.float()
        else:
            act_real = activation.float()
        mean = act_real.mean().item()
        std = act_real.std().item()
        shape = str(list(activation.shape))
        print(f"{name[:40]:<40} {shape:<20} {mean:>10.4f} {std:>10.4f}")

In [None]:
# Plot activation distributions for select layers
interesting_layers = [name for name in forward_hook.activations.keys() 
                      if any(x in name for x in ['linear', 'Linear', 'conv', 'dense'])]

if interesting_layers:
    n_plots = min(len(interesting_layers), 6)
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    axes = axes.flatten()
    
    for idx, name in enumerate(interesting_layers[:n_plots]):
        activation = forward_hook.activations[name]
        if activation.is_complex():
            data = activation.real.cpu().numpy().flatten()
        else:
            data = activation.cpu().numpy().flatten()
        
        axes[idx].hist(data, bins=50, edgecolor='black', alpha=0.7)
        axes[idx].set_title(name.split('.')[-1][:20])
        axes[idx].set_xlabel('Activation')
        axes[idx].set_ylabel('Count')
    
    for idx in range(n_plots, len(axes)):
        axes[idx].set_visible(False)
    
    plt.suptitle('Activation Distributions')
    plt.tight_layout()
    plt.show()
else:
    print("No linear/conv layers found for visualization")

---
## 4. Gradient Analysis

In [None]:
class GradientHook:
    """Hook to capture gradients during backward pass."""
    def __init__(self):
        self.gradients = OrderedDict()
        self.handles = []
    
    def hook_fn(self, name):
        def hook(module, grad_input, grad_output):
            if grad_output[0] is not None:
                self.gradients[name] = grad_output[0].detach()
        return hook
    
    def register(self, model):
        for name, module in model.named_modules():
            if len(list(module.children())) == 0:
                handle = module.register_full_backward_hook(self.hook_fn(name))
                self.handles.append(handle)
    
    def remove(self):
        for handle in self.handles:
            handle.remove()
        self.handles = []

In [None]:
# Compute gradients
model.train()
gradient_hook = GradientHook()
gradient_hook.register(model)

try:
    # Forward
    output = model(test_batch)
    
    # Simple loss (MSE to target)
    target = test_batch['density']
    loss = (output - target).abs().pow(2).mean()
    
    # Backward
    loss.backward()
    
    print(f"Loss: {loss.item():.6f}")
    print("Backward pass successful!")
except Exception as e:
    print(f"Backward pass failed: {e}")

gradient_hook.remove()
model.eval()

In [None]:
# Analyze parameter gradients
print("\nParameter Gradients:")
print("-" * 80)
print(f"{'Parameter':<40} {'Grad Mean':>12} {'Grad Std':>12} {'Grad Max':>12}")
print("-" * 80)

grad_stats = []
for name, param in model.named_parameters():
    if param.grad is not None:
        grad = param.grad.float()
        stats = {
            'name': name,
            'mean': grad.abs().mean().item(),
            'std': grad.std().item(),
            'max': grad.abs().max().item(),
        }
        grad_stats.append(stats)
        print(f"{name[:40]:<40} {stats['mean']:>12.2e} {stats['std']:>12.2e} {stats['max']:>12.2e}")

# Zero gradients for next experiment
model.zero_grad()

In [None]:
# Check for vanishing/exploding gradients
if grad_stats:
    means = [s['mean'] for s in grad_stats]
    
    print("\nGradient Health Check:")
    print(f"  Overall gradient magnitude: {np.mean(means):.2e}")
    
    vanishing = sum(1 for m in means if m < 1e-7)
    exploding = sum(1 for m in means if m > 1e3)
    
    print(f"  Vanishing gradients (<1e-7): {vanishing}/{len(means)} layers")
    print(f"  Exploding gradients (>1e3): {exploding}/{len(means)} layers")
    
    if vanishing > len(means) // 2:
        print("  WARNING: Many vanishing gradients detected!")
    if exploding > 0:
        print("  WARNING: Exploding gradients detected!")
    if vanishing == 0 and exploding == 0:
        print("  OK: Gradient magnitudes look healthy")

In [None]:
# Plot gradient flow
if grad_stats:
    plt.figure(figsize=(14, 5))
    
    names = [s['name'].split('.')[-1][:15] for s in grad_stats]
    means = [s['mean'] for s in grad_stats]
    
    plt.bar(range(len(means)), means)
    plt.xticks(range(len(names)), names, rotation=45, ha='right')
    plt.yscale('log')
    plt.xlabel('Layer')
    plt.ylabel('Gradient Magnitude (log)')
    plt.title('Gradient Flow Through Network')
    plt.axhline(y=1e-7, color='r', linestyle='--', label='Vanishing threshold')
    plt.legend()
    plt.tight_layout()
    plt.show()

---
## 5. Weight Analysis

In [None]:
# Analyze weight distributions
print("Weight Statistics:")
print("-" * 80)
print(f"{'Parameter':<40} {'Mean':>12} {'Std':>12} {'Min':>12} {'Max':>12}")
print("-" * 80)

weight_stats = []
for name, param in model.named_parameters():
    if 'weight' in name:
        w = param.data.float()
        stats = {
            'name': name,
            'mean': w.mean().item(),
            'std': w.std().item(),
            'min': w.min().item(),
            'max': w.max().item(),
        }
        weight_stats.append(stats)
        print(f"{name[:40]:<40} {stats['mean']:>12.4f} {stats['std']:>12.4f} "
              f"{stats['min']:>12.4f} {stats['max']:>12.4f}")

In [None]:
# Plot weight distributions
weight_params = [(name, param) for name, param in model.named_parameters() if 'weight' in name]

if weight_params:
    n_plots = min(len(weight_params), 6)
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    axes = axes.flatten()
    
    for idx, (name, param) in enumerate(weight_params[:n_plots]):
        data = param.data.cpu().float().numpy().flatten()
        axes[idx].hist(data, bins=50, edgecolor='black', alpha=0.7)
        axes[idx].set_title(name.split('.')[-2] + '.' + name.split('.')[-1])
        axes[idx].axvline(x=0, color='r', linestyle='--', alpha=0.5)
    
    for idx in range(n_plots, len(axes)):
        axes[idx].set_visible(False)
    
    plt.suptitle('Weight Distributions')
    plt.tight_layout()
    plt.show()

In [None]:
# Check initialization quality
print("\nInitialization Quality Check:")
for stats in weight_stats:
    name = stats['name']
    
    # Check for common issues
    issues = []
    
    if abs(stats['mean']) > 0.1:
        issues.append(f"mean={stats['mean']:.3f} (should be ~0)")
    
    if stats['std'] < 0.01:
        issues.append(f"std={stats['std']:.4f} (too small)")
    elif stats['std'] > 2.0:
        issues.append(f"std={stats['std']:.3f} (too large)")
    
    if issues:
        print(f"  {name[:50]}: {', '.join(issues)}")

---
## 6. Shape Debugging

In [None]:
def trace_shapes(model, batch):
    """Trace tensor shapes through the network."""
    shapes = OrderedDict()
    handles = []
    
    def hook(name):
        def fn(module, input, output):
            input_shapes = [tuple(x.shape) if isinstance(x, torch.Tensor) else type(x) for x in input]
            if isinstance(output, tuple):
                output_shapes = [tuple(x.shape) if isinstance(x, torch.Tensor) else type(x) for x in output]
            elif isinstance(output, torch.Tensor):
                output_shapes = tuple(output.shape)
            else:
                output_shapes = type(output)
            shapes[name] = {'input': input_shapes, 'output': output_shapes}
        return fn
    
    for name, module in model.named_modules():
        if name:  # Skip root module
            handles.append(module.register_forward_hook(hook(name)))
    
    with torch.no_grad():
        model(batch)
    
    for h in handles:
        h.remove()
    
    return shapes

In [None]:
# Trace shapes
shape_trace = trace_shapes(model, test_batch)

print("Shape Trace:")
print("=" * 100)
print(f"{'Module':<40} {'Input Shape':<30} {'Output Shape':<30}")
print("=" * 100)

for name, shapes in shape_trace.items():
    input_str = str(shapes['input'])[:28]
    output_str = str(shapes['output'])[:28]
    print(f"{name[:40]:<40} {input_str:<30} {output_str:<30}")

---
## 7. Memory Profiling

In [None]:
def get_memory_stats():
    """Get GPU memory statistics."""
    if not torch.cuda.is_available():
        return None
    
    return {
        'allocated': torch.cuda.memory_allocated() / 1e9,
        'reserved': torch.cuda.memory_reserved() / 1e9,
        'max_allocated': torch.cuda.max_memory_allocated() / 1e9,
    }

def print_memory_stats(label=""):
    """Print memory statistics."""
    stats = get_memory_stats()
    if stats:
        print(f"{label} - Allocated: {stats['allocated']:.2f} GB, "
              f"Reserved: {stats['reserved']:.2f} GB, "
              f"Max: {stats['max_allocated']:.2f} GB")
    else:
        print(f"{label} - GPU not available")

In [None]:
if torch.cuda.is_available():
    # Reset memory stats
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    
    print("Memory Profiling:")
    print("-" * 60)
    
    print_memory_stats("Before forward")
    
    with torch.no_grad():
        output = model(test_batch)
    
    print_memory_stats("After forward")
    
    # With gradients
    model.train()
    output = model(test_batch)
    loss = output.abs().pow(2).mean()
    loss.backward()
    
    print_memory_stats("After backward")
    
    model.zero_grad()
    model.eval()
    
    # Estimate memory per batch element
    mem_per_sample = torch.cuda.max_memory_allocated() / 1e6 / TEST_CONFIG['batch_size']
    print(f"\nEstimated memory per sample: {mem_per_sample:.1f} MB")
    
    # Estimate max batch size
    total_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    usable_mem = total_mem * 0.8  # Leave 20% headroom
    max_batch = int(usable_mem * 1e3 / mem_per_sample)
    print(f"Estimated max batch size: {max_batch}")
else:
    print("GPU not available for memory profiling")

---
## 8. Numerical Stability Checks

In [None]:
def check_numerical_stability(model, batch, n_iters=10):
    """Check for NaN/Inf in forward passes."""
    model.eval()
    issues = []
    
    for i in range(n_iters):
        with torch.no_grad():
            output = model(batch)
        
        if torch.isnan(output).any():
            issues.append(f"Iteration {i}: NaN in output")
        if torch.isinf(output).any():
            issues.append(f"Iteration {i}: Inf in output")
        
        # Use output as next input (simulate autoregressive)
        batch = batch.copy()
        batch['density'] = output
    
    return issues

In [None]:
print("Numerical Stability Check:")
print("-" * 40)

# Check with normal inputs
issues = check_numerical_stability(model, test_batch)
if issues:
    print("ISSUES FOUND:")
    for issue in issues:
        print(f"  {issue}")
else:
    print("OK: No NaN/Inf detected in 10 iterations")

In [None]:
# Check with edge case inputs
print("\nEdge Case Testing:")
print("-" * 40)

edge_cases = [
    ('Zero density', torch.zeros_like(test_batch['density'])),
    ('Large density', test_batch['density'] * 100),
    ('Small density', test_batch['density'] * 0.001),
    ('Identity density', torch.eye(TEST_CONFIG['n_basis'], dtype=torch.complex64, device=device).unsqueeze(0).expand(TEST_CONFIG['batch_size'], -1, -1)),
]

for name, density in edge_cases:
    edge_batch = test_batch.copy()
    edge_batch['density'] = density
    
    try:
        with torch.no_grad():
            output = model(edge_batch)
        
        has_nan = torch.isnan(output).any().item()
        has_inf = torch.isinf(output).any().item()
        
        if has_nan or has_inf:
            print(f"  {name}: FAIL (NaN={has_nan}, Inf={has_inf})")
        else:
            print(f"  {name}: OK")
    except Exception as e:
        print(f"  {name}: ERROR - {e}")

In [None]:
# Check output physics
print("\nOutput Physics Check:")
print("-" * 40)

with torch.no_grad():
    output = model(test_batch)

# Check Hermiticity
herm_error = (output - output.conj().transpose(-2, -1)).abs().max().item()
print(f"  Hermiticity error: {herm_error:.2e}")

# Check trace
overlap = test_batch['overlap'][0]
traces = []
for i in range(output.shape[0]):
    trace = torch.einsum('ij,ji->', output[i], overlap).real.item()
    traces.append(trace)
print(f"  Trace values: {traces}")
print(f"  Expected: {TEST_CONFIG['n_electrons']}")

# Check magnitude
max_mag = output.abs().max().item()
print(f"  Max magnitude: {max_mag:.4f}")

---
## 9. Performance Benchmarking

In [None]:
import time

def benchmark_forward(model, batch, n_warmup=5, n_runs=20):
    """Benchmark forward pass."""
    model.eval()
    
    # Warmup
    for _ in range(n_warmup):
        with torch.no_grad():
            _ = model(batch)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # Benchmark
    times = []
    for _ in range(n_runs):
        start = time.perf_counter()
        with torch.no_grad():
            _ = model(batch)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        times.append(time.perf_counter() - start)
    
    return np.array(times) * 1000  # Convert to ms

In [None]:
print("Performance Benchmark:")
print("-" * 40)

times = benchmark_forward(model, test_batch)

print(f"Forward pass time:")
print(f"  Mean: {times.mean():.2f} ms")
print(f"  Std:  {times.std():.2f} ms")
print(f"  Min:  {times.min():.2f} ms")
print(f"  Max:  {times.max():.2f} ms")

throughput = TEST_CONFIG['batch_size'] / (times.mean() / 1000)
print(f"\nThroughput: {throughput:.1f} samples/sec")

In [None]:
# Benchmark different batch sizes
batch_sizes = [1, 2, 4, 8, 16]
throughputs = []

for bs in batch_sizes:
    try:
        config = TEST_CONFIG.copy()
        config['batch_size'] = bs
        batch = create_test_batch(config, device=device)
        
        times = benchmark_forward(model, batch, n_warmup=3, n_runs=10)
        throughput = bs / (times.mean() / 1000)
        throughputs.append(throughput)
        print(f"Batch size {bs}: {times.mean():.2f} ms, {throughput:.1f} samples/sec")
    except RuntimeError as e:
        if 'out of memory' in str(e).lower():
            print(f"Batch size {bs}: OOM")
            throughputs.append(0)
        else:
            raise

In [None]:
# Plot throughput vs batch size
if throughputs:
    plt.figure(figsize=(10, 5))
    valid_idx = [i for i, t in enumerate(throughputs) if t > 0]
    plt.plot([batch_sizes[i] for i in valid_idx], 
             [throughputs[i] for i in valid_idx], 
             'o-', markersize=8)
    plt.xlabel('Batch Size')
    plt.ylabel('Throughput (samples/sec)')
    plt.title('Throughput vs Batch Size')
    plt.grid(True)
    plt.show()

---
## Summary & Recommendations

In [None]:
print("=" * 60)
print("DEBUGGING SUMMARY")
print("=" * 60)

print(f"\nModel:")
print(f"  Parameters: {format_params(count_parameters(model)[1])}")
print(f"  Device: {device}")

print(f"\nForward Pass:")
print(f"  Status: {'OK' if 'output' in dir() else 'FAILED'}")

print(f"\nGradients:")
if grad_stats:
    vanishing = sum(1 for s in grad_stats if s['mean'] < 1e-7)
    print(f"  Vanishing: {vanishing}/{len(grad_stats)} layers")
else:
    print(f"  Not computed")

print(f"\nNumerical Stability:")
print(f"  Issues: {len(issues) if 'issues' in dir() else 'Not checked'}")

print(f"\nPerformance:")
if 'times' in dir():
    print(f"  Forward: {times.mean():.2f} ms")
    print(f"  Throughput: {throughput:.1f} samples/sec")

In [None]:
# Cleanup
close_all()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print("\nDebugging session complete!")