# SurgicalTheater Demo

This notebook demonstrates the core functionality of SurgicalTheater for memory-efficient model validation.

## Features Demonstrated
- Memory-efficient model modifications
- Perfect state restoration
- Gradient flow preservation
- Custom modification functions
- Exception safety

In [None]:
# Setup
import torch
import torch.nn as nn
import torchvision
from surgical_theater import SurgicalTheater

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## Demo 1: Memory-Efficient Model Validation

Using ResNet-18 to demonstrate memory efficiency compared to traditional approaches.

In [None]:
# Create ResNet-18 model
model = torchvision.models.resnet18(weights=None).to(device)
total_params = sum(p.numel() for p in model.parameters())
param_memory_mb = (total_params * 4) / (1024 * 1024)  # 4 bytes per float32

print(f"Model parameters: {total_params:,}")
print(f"Model memory: {param_memory_mb:.1f} MB")
print(f"Traditional deepcopy would need: {param_memory_mb * 2:.1f} MB")
print()

# Demonstrate memory-efficient validation
with SurgicalTheater(model, track_memory=True) as theater:
    x = torch.randn(2, 3, 224, 224).to(device)
    output = model(x)
    
    delta_memory = theater.total_delta_memory_mb
    efficiency_ratio = param_memory_mb / delta_memory if delta_memory > 0 else float('inf')
    
    print(f"✅ Forward pass successful: {output.shape}")
    print(f"✅ SurgicalTheater delta memory: {delta_memory:.2f} MB")
    print(f"✅ Memory efficiency: {efficiency_ratio:.1f}x better than deepcopy")

## Demo 2: Perfect State Restoration

Demonstrates that weights and training state are perfectly restored after modifications.

In [None]:
# Create a simple model for testing
model = nn.Sequential(
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
).to(device)

model.train()

# Store original state
original_weights = [p.clone() for p in model.parameters()]
original_requires_grad = [p.requires_grad for p in model.parameters()]
original_training_mode = model.training

print(f"Original training mode: {original_training_mode}")
print(f"Original requires_grad: {all(original_requires_grad)}")
print()

# Test with SurgicalTheater
with SurgicalTheater(model, modification_type="scale", factor=0.8):
    # Switch to eval mode inside context
    model.eval()
    
    # Run validation
    x = torch.randn(32, 64).to(device)
    output = model(x)
    
    # Compute gradients to test gradient flow
    loss = output.sum()
    loss.backward()
    
    print(f"✅ Validation output shape: {output.shape}")
    print(f"✅ Gradients computed: {all(p.grad is not None for p in model.parameters())}")
    print(f"✅ Model in eval mode: {not model.training}")

# Verify restoration
weights_restored = all(torch.allclose(orig, curr, atol=1e-6) 
                      for orig, curr in zip(original_weights, model.parameters()))
grad_flags_restored = [p.requires_grad for p in model.parameters()] == original_requires_grad
training_mode_restored = model.training == original_training_mode

print()
print("After context exit:")
print(f"✅ Weights restored: {weights_restored}")
print(f"✅ requires_grad restored: {grad_flags_restored}")
print(f"✅ Training mode restored: {training_mode_restored}")

## Demo 3: Custom Modification Functions

Shows how to create custom modification functions for specialized use cases.

In [None]:
# Custom modification function
def attention_scaling(param, temperature=1.0):
    """Apply attention-style scaling to parameters."""
    if param.dim() == 2:  # Only scale 2D tensors (weight matrices)
        scaled = torch.softmax(param / temperature, dim=-1)
        scaled_weights = scaled * param.shape[-1]  # Rescale
        return scaled_weights - param  # Return delta
    else:
        return torch.zeros_like(param)  # No change for biases

# Create model
model = nn.Sequential(
    nn.Linear(32, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 1)
).to(device)

original_weights = [p.clone() for p in model.parameters()]

# Test different temperature values
temperatures = [0.5, 1.0, 2.0]
results = []

for temp in temperatures:
    with SurgicalTheater(model, modification_type="custom", 
                        modification_fn=attention_scaling, temperature=temp):
        x = torch.randn(16, 32).to(device)
        output = model(x)
        results.append(output.std().item())
        
    print(f"Temperature {temp}: Output std = {results[-1]:.4f}")

# Verify weights are restored
weights_restored = all(torch.allclose(orig, curr, atol=1e-6) 
                      for orig, curr in zip(original_weights, model.parameters()))

print(f"\n✅ Custom modifications completed")
print(f"✅ Weights restored after all modifications: {weights_restored}")

## Demo 4: Exception Safety

Demonstrates that SurgicalTheater safely restores model state even when exceptions occur.

In [None]:
# Create model
model = nn.Linear(10, 5).to(device)
original_weight = model.weight.clone()
original_bias = model.bias.clone()

print("Testing exception safety...")

# Test that weights are restored even when exceptions occur
try:
    with SurgicalTheater(model, modification_type="scale", factor=2.0):
        # Verify weights are modified
        modified = not torch.allclose(model.weight, original_weight)
        print(f"✅ Weights modified inside context: {modified}")
        
        # Intentionally cause an exception
        raise ValueError("Intentional test exception")
        
except ValueError as e:
    print(f"✅ Exception caught: {e}")

# Check if weights are still restored despite the exception
weight_restored = torch.allclose(model.weight, original_weight, atol=1e-6)
bias_restored = torch.allclose(model.bias, original_bias, atol=1e-6)

print(f"\n✅ Weight restored after exception: {weight_restored}")
print(f"✅ Bias restored after exception: {bias_restored}")
print(f"✅ Exception safety verified!")

## Demo 5: Practical Training Integration

Shows how SurgicalTheater integrates into real training loops for frequent validation.

In [None]:
# Create a training setup
model = nn.Sequential(
    nn.Linear(20, 50),
    nn.ReLU(),
    nn.Linear(50, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Generate dummy data
train_x = torch.randn(100, 20).to(device)
train_y = torch.randn(100, 1).to(device)
val_x = torch.randn(20, 20).to(device)
val_y = torch.randn(20, 1).to(device)

print("Training with frequent validation using SurgicalTheater:")
print()

# Training loop with frequent validation
for epoch in range(5):
    model.train()
    
    # Training step
    optimizer.zero_grad()
    train_output = model(train_x)
    train_loss = criterion(train_output, train_y)
    train_loss.backward()
    optimizer.step()
    
    # Validation step with SurgicalTheater (no memory overhead!)
    with SurgicalTheater(model, track_memory=True) as theater:
        model.eval()
        with torch.no_grad():
            val_output = model(val_x)
            val_loss = criterion(val_output, val_y)
        
        delta_memory = theater.total_delta_memory_mb
    
    # Model automatically restored to training mode
    assert model.training, "Model should be back in training mode"
    
    print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Delta Memory: {delta_memory:.2f} MB")

print("\n✅ Training completed with frequent validation!")
print("✅ No memory overhead from validation")
print("✅ Model state perfectly preserved throughout training")

## Summary

SurgicalTheater provides:

- **Memory Efficiency**: Dramatically reduces memory usage compared to traditional model copying
- **Perfect Restoration**: Guarantees exact weight and state restoration after modifications
- **Gradient Safety**: Preserves gradient flow and training state
- **Custom Modifications**: Flexible API for specialized use cases
- **Exception Safety**: Robust error handling ensures reliability
- **Training Integration**: Seamlessly integrates into existing training loops

**Use Cases:**
- LoRA/PEFT training with frequent validation
- Model experimentation and hyperparameter testing
- Reinforcement learning with reward hacking prevention
- Budget hardware training with larger models
- Research experiments requiring temporary model modifications

SurgicalTheater is production-ready and provides the foundation for memory-efficient ML workflows.