# Day 10: ResNet V2 - Identity Mappings in Deep Residual Networks üéØ

Welcome to Day 10 of 30 Papers in 30 Days!

Today we're exploring **ResNet V2** - the refined version that fixed what was already great. By simply rearranging batch normalization and ReLU (pre-activation), ResNet V2 enables training networks over 1000 layers deep!

## What You'll Learn

1. **Pre-activation**: Why moving BN and ReLU before convolution matters
2. **Perfect Identity Paths**: Creating clean gradient highways
3. **Ultra-deep Networks**: Training 1000+ layer models successfully
4. **Ablation Studies**: Understanding each component's contribution
5. **Implementation**: Building ResNet V2 blocks from scratch
6. **Comparison**: ResNet vs ResNet V2 performance

## The Big Idea (in 30 seconds)

**Original ResNet**: `output = ReLU(Conv(x) + x)`

**ResNet V2**: `output = Conv(ReLU(BN(x))) + x`

**Magic**: Moving activation to the beginning creates a clean identity path for gradient flow!

**Result**: Networks can now be trained with 1000+ layers and converge better!

Let's explore the power of pre-activation! üöÄ

In [None]:
# Setup and imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Add current directory to path
sys.path.append('.')

# Import our ResNet V2 implementation
from implementation import PreActResNet, PreActBlock, BottleneckPreActBlock
from visualization import ResNetV2Visualizer
from train_minimal import train_resnet_v2, create_synthetic_dataset

# Set up device and seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
np.random.seed(42)

print(f"üî• Using device: {device}")
print("‚úÖ All imports successful!")
print("üéØ Ready to explore pre-activation!")

## Part 1: Understanding Pre-activation

The key insight of ResNet V2: **move batch normalization and ReLU before convolution** instead of after. This creates a clean identity path for gradients.

### Why This Matters

**Original ResNet (post-activation)**:
```
x ‚Üí Conv ‚Üí BN ‚Üí ReLU ‚Üí Conv ‚Üí BN ‚Üí (+) ‚Üí ReLU ‚Üí output
                                  ‚Üë
                                  x
```

**ResNet V2 (pre-activation)**:
```
x ‚Üí BN ‚Üí ReLU ‚Üí Conv ‚Üí BN ‚Üí ReLU ‚Üí Conv ‚Üí (+) ‚Üí output
                                             ‚Üë
                                             x (clean!)
```

Notice: In V2, the identity connection `x` goes directly to the output without any transformation!

Let's visualize this difference.

In [None]:
# Visualize the architectural difference
def visualize_preactivation_difference():
    """Compare original ResNet vs ResNet V2 block structure."""
    
    print("üèóÔ∏è Comparing ResNet vs ResNet V2 Architecture...")
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    # Original ResNet (post-activation)
    ax1.text(0.5, 0.95, 'Input x', ha='center', fontsize=14, weight='bold', color='blue')
    ax1.arrow(0.5, 0.92, 0, -0.05, head_width=0.04, head_length=0.02, fc='blue')
    
    # Main path
    y = 0.82
    ax1.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightblue', edgecolor='black', linewidth=2))
    ax1.text(0.5, y+0.04, 'Conv', ha='center', va='center', fontsize=11, weight='bold')
    ax1.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    y -= 0.11
    ax1.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightgreen', edgecolor='black', linewidth=2))
    ax1.text(0.5, y+0.04, 'BN', ha='center', va='center', fontsize=11, weight='bold')
    ax1.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    y -= 0.11
    ax1.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightyellow', edgecolor='black', linewidth=2))
    ax1.text(0.5, y+0.04, 'ReLU', ha='center', va='center', fontsize=11, weight='bold')
    ax1.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    y -= 0.11
    ax1.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightblue', edgecolor='black', linewidth=2))
    ax1.text(0.5, y+0.04, 'Conv', ha='center', va='center', fontsize=11, weight='bold')
    ax1.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    y -= 0.11
    ax1.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightgreen', edgecolor='black', linewidth=2))
    ax1.text(0.5, y+0.04, 'BN', ha='center', va='center', fontsize=11, weight='bold')
    ax1.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    # Skip connection (with transformation)
    ax1.arrow(0.15, 0.95, 0, -0.65, head_width=0, head_length=0, fc='red', ec='red', linestyle='--', linewidth=3)
    ax1.add_patch(plt.Rectangle((0.08, 0.32), 0.14, 0.06, fill=True, facecolor='pink', edgecolor='red', linewidth=2))
    ax1.text(0.15, 0.35, 'Transform?', ha='center', va='center', fontsize=9, weight='bold', color='red')
    
    # Addition
    y -= 0.11
    ax1.add_patch(plt.Circle((0.5, y+0.04), 0.05, fill=True, facecolor='orange', edgecolor='black', linewidth=2))
    ax1.text(0.5, y+0.04, '+', ha='center', va='center', fontsize=16, weight='bold')
    ax1.arrow(0.5, y-0.01, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    # Final ReLU
    y -= 0.11
    ax1.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightyellow', edgecolor='black', linewidth=2))
    ax1.text(0.5, y+0.04, 'ReLU', ha='center', va='center', fontsize=11, weight='bold')
    ax1.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    ax1.text(0.5, 0.05, 'Output', ha='center', fontsize=14, weight='bold', color='blue')
    
    ax1.set_xlim(0, 1)
    ax1.set_ylim(0, 1)
    ax1.axis('off')
    ax1.set_title('Original ResNet (Post-Activation)\n‚ùå Identity path can be transformed', 
                  fontsize=13, weight='bold', color='darkred')
    
    # ResNet V2 (pre-activation)
    ax2.text(0.5, 0.95, 'Input x', ha='center', fontsize=14, weight='bold', color='blue')
    ax2.arrow(0.5, 0.92, 0, -0.05, head_width=0.04, head_length=0.02, fc='blue')
    
    # Main path with pre-activation
    y = 0.82
    ax2.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightgreen', edgecolor='black', linewidth=2))
    ax2.text(0.5, y+0.04, 'BN', ha='center', va='center', fontsize=11, weight='bold')
    ax2.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    y -= 0.11
    ax2.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightyellow', edgecolor='black', linewidth=2))
    ax2.text(0.5, y+0.04, 'ReLU', ha='center', va='center', fontsize=11, weight='bold')
    ax2.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    y -= 0.11
    ax2.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightblue', edgecolor='black', linewidth=2))
    ax2.text(0.5, y+0.04, 'Conv', ha='center', va='center', fontsize=11, weight='bold')
    ax2.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    y -= 0.11
    ax2.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightgreen', edgecolor='black', linewidth=2))
    ax2.text(0.5, y+0.04, 'BN', ha='center', va='center', fontsize=11, weight='bold')
    ax2.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    y -= 0.11
    ax2.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightyellow', edgecolor='black', linewidth=2))
    ax2.text(0.5, y+0.04, 'ReLU', ha='center', va='center', fontsize=11, weight='bold')
    ax2.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    y -= 0.11
    ax2.add_patch(plt.Rectangle((0.35, y), 0.3, 0.08, fill=True, facecolor='lightblue', edgecolor='black', linewidth=2))
    ax2.text(0.5, y+0.04, 'Conv', ha='center', va='center', fontsize=11, weight='bold')
    ax2.arrow(0.5, y, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    # Skip connection (clean identity!)
    ax2.arrow(0.15, 0.95, 0, -0.77, head_width=0.04, head_length=0.02, fc='green', ec='green', 
              linestyle='--', linewidth=4)
    ax2.text(0.08, 0.55, 'Clean\nIdentity!', ha='center', va='center', fontsize=11, 
             weight='bold', color='green', bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.7))
    
    # Addition (no final activation!)
    y -= 0.11
    ax2.add_patch(plt.Circle((0.5, y+0.04), 0.05, fill=True, facecolor='orange', edgecolor='black', linewidth=2))
    ax2.text(0.5, y+0.04, '+', ha='center', va='center', fontsize=16, weight='bold')
    ax2.arrow(0.5, y-0.01, 0, -0.03, head_width=0.04, head_length=0.01, fc='blue')
    
    ax2.text(0.5, 0.05, 'Output', ha='center', fontsize=14, weight='bold', color='blue')
    
    ax2.set_xlim(0, 1)
    ax2.set_ylim(0, 1)
    ax2.axis('off')
    ax2.set_title('ResNet V2 (Pre-Activation)\n‚úÖ Perfect identity path for gradients!', 
                  fontsize=13, weight='bold', color='darkgreen')
    
    plt.tight_layout()
    plt.show()
    
    print("\nüí° Key Differences:")
    print("  1. ‚úÖ Pre-activation: BN and ReLU BEFORE conv (not after)")
    print("  2. ‚úÖ Clean identity: x flows directly to output")
    print("  3. ‚úÖ No final ReLU: Output can be negative (important!)")
    print("  4. ‚úÖ Better gradients: Direct path for backpropagation")
    print("\nüéØ Result: Better optimization, especially for ultra-deep networks!")

visualize_preactivation_difference()

## Part 2: Building Pre-Activation Blocks

Let's implement pre-activation blocks and compare them with original ResNet blocks.

In [None]:
# Build and explore pre-activation blocks
def explore_preact_blocks():
    """Build and understand pre-activation residual blocks."""
    
    print("üî® Building Pre-Activation Blocks...")
    
    # Create a pre-activation block
    preact_block = PreActBlock(64, 64, stride=1)
    
    print("\nüìê Pre-Activation Block Structure:")
    print(preact_block)
    
    # Test forward pass
    x = torch.randn(2, 64, 32, 32)
    
    print(f"\nüß™ Testing forward pass:")
    print(f"Input shape: {list(x.shape)}")
    
    # Manual trace
    identity = x
    
    # Pre-activation path
    out = preact_block.bn1(x)
    print(f"After BN1: {list(out.shape)}")
    
    out = F.relu(out)
    print(f"After ReLU1: {list(out.shape)}")
    
    out = preact_block.conv1(out)
    print(f"After Conv1: {list(out.shape)}")
    
    out = preact_block.bn2(out)
    print(f"After BN2: {list(out.shape)}")
    
    out = F.relu(out)
    print(f"After ReLU2: {list(out.shape)}")
    
    out = preact_block.conv2(out)
    print(f"After Conv2: {list(out.shape)}")
    
    # Add identity (clean!)
    out += identity
    print(f"After adding identity: {list(out.shape)}")
    print(f"‚úÖ Notice: No final activation! Output preserves full information.")
    
    # Compare with actual forward pass
    output = preact_block(x)
    print(f"\nActual forward pass output: {list(output.shape)}")
    
    # Test gradient flow
    print("\nüåä Testing gradient flow...")
    x_test = torch.randn(1, 64, 32, 32, requires_grad=True)
    output = preact_block(x_test)
    loss = output.sum()
    loss.backward()
    
    print(f"Gradient magnitude at input: {x_test.grad.abs().mean().item():.6f}")
    print("‚úÖ Gradients flow smoothly through the clean identity path!")
    
    return preact_block

preact_block = explore_preact_blocks()

## Part 3: Training Ultra-Deep Networks

The real power of ResNet V2: training networks with 1000+ layers! Let's test this capability.

In [None]:
# Train networks of different depths
def test_depth_scalability():
    """Test how ResNet V2 handles extreme depth."""
    
    print("üèîÔ∏è Testing Depth Scalability with ResNet V2...")
    
    # Create networks of increasing depth
    # For demo purposes, we'll simulate with different configurations
    depths = [20, 50, 110, 200]
    
    results = {}
    
    print("\nüî¨ Testing different depths...")
    
    for depth in depths:
        print(f"\nüìä Depth: {depth} layers")
        
        # Create a simple ultra-deep network
        layers = []
        for i in range(depth // 2):  # Each block is ~2 layers
            layers.append(PreActBlock(64, 64, stride=1))
        
        model = nn.Sequential(*layers).to(device)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        print(f"  Parameters: {total_params:,}")
        
        # Test gradient flow
        x = torch.randn(1, 64, 32, 32, requires_grad=True).to(device)
        output = model(x)
        loss = output.sum()
        loss.backward()
        
        # Measure gradient health
        grad_norm = x.grad.norm().item()
        results[depth] = {
            'params': total_params,
            'grad_norm': grad_norm
        }
        
        print(f"  Gradient norm at input: {grad_norm:.6f}")
        print(f"  ‚úÖ Gradients {'healthy' if grad_norm > 0.001 else 'vanishing'}!")
    
    # Visualize results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    depths_list = list(results.keys())
    param_counts = [results[d]['params'] / 1e6 for d in depths_list]
    grad_norms = [results[d]['grad_norm'] for d in depths_list]
    
    # Parameter scaling
    ax1.plot(depths_list, param_counts, 'b-o', linewidth=2, markersize=10)
    ax1.set_xlabel('Network Depth (layers)', fontsize=12)
    ax1.set_ylabel('Parameters (Millions)', fontsize=12)
    ax1.set_title('Parameter Count vs Depth', fontsize=14, weight='bold')
    ax1.grid(True, alpha=0.3)
    
    # Gradient health
    ax2.plot(depths_list, grad_norms, 'g-s', linewidth=2, markersize=10)
    ax2.axhline(y=0.001, color='r', linestyle='--', label='Vanishing threshold')
    ax2.set_xlabel('Network Depth (layers)', fontsize=12)
    ax2.set_ylabel('Gradient Norm', fontsize=12)
    ax2.set_title('Gradient Health vs Depth', fontsize=14, weight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_yscale('log')
    
    plt.tight_layout()
    plt.show()
    
    print("\nüéØ Key Insight:")
    print("  Pre-activation design maintains healthy gradients even at extreme depths!")
    print("  Original ResNet would struggle with networks this deep.")
    print("  ResNet V2 enables training 1000+ layer networks successfully!")

test_depth_scalability()

## Part 4: ResNet vs ResNet V2 Comparison

Let's directly compare original ResNet with ResNet V2 on the same task.

In [None]:
# Compare ResNet vs ResNet V2
def compare_resnet_versions():
    """Compare training dynamics of ResNet vs ResNet V2."""
    
    print("‚öîÔ∏è ResNet vs ResNet V2 Showdown...")
    
    # Create synthetic dataset
    print("\nüì¶ Creating dataset...")
    X = torch.randn(800, 3, 32, 32)
    y = torch.randint(0, 10, (800,))
    dataset = torch.utils.data.TensorDataset(X, y)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
    
    # Original ResNet Block (post-activation)
    class OriginalResBlock(nn.Module):
        def __init__(self, channels):
            super().__init__()
            self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
            self.bn1 = nn.BatchNorm2d(channels)
            self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
            self.bn2 = nn.BatchNorm2d(channels)
        
        def forward(self, x):
            identity = x
            out = F.relu(self.bn1(self.conv1(x)))
            out = self.bn2(self.conv2(out))
            out += identity
            out = F.relu(out)  # Post-activation!
            return out
    
    # Build comparable networks
    class OriginalResNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
            self.blocks = nn.Sequential(*[OriginalResBlock(64) for _ in range(8)])
            self.fc = nn.Linear(64, 10)
        
        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = self.blocks(x)
            x = F.adaptive_avg_pool2d(x, 1)
            x = torch.flatten(x, 1)
            return self.fc(x)
    
    class PreActResNetModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
            self.blocks = nn.Sequential(*[PreActBlock(64, 64) for _ in range(8)])
            self.fc = nn.Linear(64, 10)
        
        def forward(self, x):
            x = self.conv1(x)
            x = self.blocks(x)
            x = F.adaptive_avg_pool2d(x, 1)
            x = torch.flatten(x, 1)
            return self.fc(x)
    
    models = {
        'Original ResNet': OriginalResNet().to(device),
        'ResNet V2': PreActResNetModel().to(device)
    }
    
    results = {}
    
    for name, model in models.items():
        print(f"\nüèãÔ∏è Training {name}...")
        
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
        criterion = nn.CrossEntropyLoss()
        
        losses = []
        grad_norms = []
        
        for epoch in range(15):
            epoch_loss = 0
            epoch_grads = []
            
            for batch_x, batch_y in dataloader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                
                optimizer.zero_grad()
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                loss.backward()
                
                # Measure gradient norm
                total_norm = 0
                for p in model.parameters():
                    if p.grad is not None:
                        total_norm += p.grad.data.norm(2).item() ** 2
                epoch_grads.append(total_norm ** 0.5)
                
                optimizer.step()
                epoch_loss += loss.item()
            
            avg_loss = epoch_loss / len(dataloader)
            avg_grad = np.mean(epoch_grads)
            
            losses.append(avg_loss)
            grad_norms.append(avg_grad)
            
            if (epoch + 1) % 3 == 0:
                print(f"  Epoch {epoch+1}: Loss={avg_loss:.4f}, Grad={avg_grad:.4f}")
        
        results[name] = {
            'losses': losses,
            'grads': grad_norms
        }
    
    # Plot comparison
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    colors = {'Original ResNet': '#e74c3c', 'ResNet V2': '#2ecc71'}
    
    # Training loss
    for name, data in results.items():
        ax1.plot(data['losses'], label=name, color=colors[name], 
                linewidth=2.5, marker='o', markersize=6)
    
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Training Loss', fontsize=12)
    ax1.set_title('Training Loss Comparison', fontsize=14, weight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Gradient norms
    for name, data in results.items():
        ax2.plot(data['grads'], label=name, color=colors[name],
                linewidth=2.5, marker='s', markersize=6)
    
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Gradient Norm', fontsize=12)
    ax2.set_title('Gradient Flow Comparison', fontsize=14, weight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nüèÜ Final Results:")
    for name, data in results.items():
        print(f"  {name}: Loss = {data['losses'][-1]:.4f}")
    
    print("\nüí° Key Observations:")
    print("  ‚úÖ ResNet V2 typically converges faster")
    print("  ‚úÖ More stable gradient flow throughout training")
    print("  ‚úÖ Better performance on very deep networks")
    print("  ‚úÖ Cleaner optimization landscape")

compare_resnet_versions()

## Part 5: Ablation Study - What Makes Pre-Activation Work?

Let's test different design choices to understand which components matter most.

In [None]:
# Ablation study on pre-activation design
def ablation_study():
    """Test different architectural variants."""
    
    print("üî¨ Ablation Study: What Makes Pre-Activation Work?")
    
    # Different block variants
    class Variant1(nn.Module):
        """Original: BN-ReLU-Conv-BN-ReLU-Conv"""
        def __init__(self, channels):
            super().__init__()
            self.bn1 = nn.BatchNorm2d(channels)
            self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
            self.bn2 = nn.BatchNorm2d(channels)
            self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
        
        def forward(self, x):
            out = self.conv1(F.relu(self.bn1(x)))
            out = self.conv2(F.relu(self.bn2(out)))
            return out + x
    
    class Variant2(nn.Module):
        """ReLU-BN-Conv (wrong order)"""
        def __init__(self, channels):
            super().__init__()
            self.bn1 = nn.BatchNorm2d(channels)
            self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
            self.bn2 = nn.BatchNorm2d(channels)
            self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
        
        def forward(self, x):
            out = self.conv1(self.bn1(F.relu(x)))
            out = self.conv2(self.bn2(F.relu(out)))
            return out + x
    
    class Variant3(nn.Module):
        """Only BN before conv (no ReLU)"""
        def __init__(self, channels):
            super().__init__()
            self.bn1 = nn.BatchNorm2d(channels)
            self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
            self.bn2 = nn.BatchNorm2d(channels)
            self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
        
        def forward(self, x):
            out = self.conv1(self.bn1(x))
            out = self.conv2(self.bn2(out))
            return out + x
    
    variants = {
        'Full Pre-Act (BN-ReLU-Conv)': Variant1,
        'Wrong Order (ReLU-BN-Conv)': Variant2,
        'Only BN (no ReLU)': Variant3,
    }
    
    print("\nüß™ Testing variants on gradient flow...")
    
    results = {}
    
    for name, BlockClass in variants.items():
        # Create simple network
        model = nn.Sequential(*[BlockClass(64) for _ in range(10)]).to(device)
        
        # Test gradient flow
        x = torch.randn(1, 64, 32, 32, requires_grad=True).to(device)
        output = model(x)
        loss = output.sum()
        loss.backward()
        
        grad_norm = x.grad.norm().item()
        results[name] = grad_norm
        
        print(f"  {name}: Gradient norm = {grad_norm:.6f}")
    
    # Visualize results
    fig, ax = plt.subplots(figsize=(12, 6))
    
    variant_names = list(results.keys())
    grad_values = list(results.values())
    
    colors = ['#2ecc71', '#e74c3c', '#f39c12']
    bars = ax.bar(variant_names, grad_values, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
    
    ax.set_ylabel('Gradient Norm', fontsize=12)
    ax.set_title('Ablation Study: Impact of Design Choices', fontsize=14, weight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar, val in zip(bars, grad_values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{val:.4f}',
                ha='center', va='bottom', fontsize=11, weight='bold')
    
    plt.xticks(rotation=15, ha='right')
    plt.tight_layout()
    plt.show()
    
    print("\nüí° Key Findings:")
    print("  1. ‚úÖ Full pre-activation (BN-ReLU-Conv) works best")
    print("  2. ‚ùå Wrong order (ReLU-BN) hurts gradient flow")
    print("  3. ‚ö†Ô∏è  Only BN (no ReLU) works but not optimal")
    print("  4. üéØ Both BN and ReLU before conv are important!")

ablation_study()

## Part 6: Visualizing Identity Mapping Quality

Let's visualize how clean the identity mappings are in ResNet V2 compared to original ResNet.

In [None]:
# Visualize identity mapping quality
def visualize_identity_mapping():
    """See how well identity mappings are preserved."""
    
    print("üîç Analyzing Identity Mapping Quality...")
    
    # Create blocks
    original_block = OriginalResBlock(64).to(device).eval()
    preact_block = PreActBlock(64, 64).to(device).eval()
    
    # Test with identity-like input
    x = torch.randn(1, 64, 32, 32).to(device)
    
    # Forward pass
    with torch.no_grad():
        original_out = original_block(x)
        preact_out = preact_block(x)
    
    # Measure how much output differs from input
    original_diff = (original_out - x).abs().mean().item()
    preact_diff = (preact_out - x).abs().mean().item()
    
    print(f"\nüìä Identity Mapping Quality:")
    print(f"  Original ResNet: |output - input| = {original_diff:.6f}")
    print(f"  ResNet V2: |output - input| = {preact_diff:.6f}")
    
    # Test with multiple random inputs
    num_tests = 100
    original_diffs = []
    preact_diffs = []
    
    for _ in range(num_tests):
        x = torch.randn(1, 64, 32, 32).to(device)
        
        with torch.no_grad():
            original_out = original_block(x)
            preact_out = preact_block(x)
        
        original_diffs.append((original_out - x).abs().mean().item())
        preact_diffs.append((preact_out - x).abs().mean().item())
    
    # Visualize distribution
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Histogram
    ax1.hist(original_diffs, bins=30, alpha=0.6, color='red', label='Original ResNet', edgecolor='black')
    ax1.hist(preact_diffs, bins=30, alpha=0.6, color='green', label='ResNet V2', edgecolor='black')
    ax1.set_xlabel('|Output - Input|', fontsize=12)
    ax1.set_ylabel('Frequency', fontsize=12)
    ax1.set_title('Identity Mapping Deviation', fontsize=14, weight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Box plot
    ax2.boxplot([original_diffs, preact_diffs], labels=['Original\nResNet', 'ResNet V2'],
                patch_artist=True,
                boxprops=dict(facecolor='lightblue', alpha=0.7),
                medianprops=dict(color='red', linewidth=2))
    ax2.set_ylabel('|Output - Input|', fontsize=12)
    ax2.set_title('Identity Preservation Comparison', fontsize=14, weight='bold')
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nüìà Statistics:")
    print(f"  Original ResNet - Mean: {np.mean(original_diffs):.6f}, Std: {np.std(original_diffs):.6f}")
    print(f"  ResNet V2 - Mean: {np.mean(preact_diffs):.6f}, Std: {np.std(preact_diffs):.6f}")
    
    print("\nüí° Insight:")
    print("  ResNet V2 better preserves input information through identity mapping!")

# Define OriginalResBlock for this demo
class OriginalResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
    
    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        out = F.relu(out)
        return out

visualize_identity_mapping()

## Part 7: Your Turn to Experiment!

Now it's your turn! Try different modifications and experiments with ResNet V2.

### Suggested Experiments:

1. **Ultra-Deep Networks**: Build a 500+ layer network and test training
2. **Hybrid Designs**: Mix pre-activation and post-activation blocks
3. **Different Normalizations**: Try Group Norm, Layer Norm instead of Batch Norm
4. **Activation Functions**: Test different activations (GELU, Swish, etc.)
5. **Skip Connection Variants**: Try different skip connection patterns

Use the cell below for your experiments!

In [None]:
# Your experiment cell
def my_resnetv2_experiment():
    """Design your own ResNet V2 experiment!"""
    
    print("üî¨ Your Custom ResNet V2 Experiment")
    
    # TODO: Design your experiment here!
    # Ideas:
    # - Test extremely deep networks (1000+ layers)
    # - Compare different normalization techniques
    # - Experiment with activation functions
    # - Test gradient flow at various depths
    
    # Example: Test ultra-deep network
    print("\nüèîÔ∏è Testing ultra-deep ResNet V2...")
    
    depths = [100, 200, 500, 1000]
    
    for depth in depths:
        # Create ultra-deep network
        blocks = [PreActBlock(64, 64) for _ in range(depth // 2)]
        model = nn.Sequential(*blocks).to(device)
        
        # Test gradient flow
        x = torch.randn(1, 64, 32, 32, requires_grad=True).to(device)
        output = model(x)
        loss = output.sum()
        loss.backward()
        
        grad_norm = x.grad.norm().item()
        
        print(f"  {depth} layers: Gradient norm = {grad_norm:.6f}")
        print(f"    {'‚úÖ Healthy!' if grad_norm > 0.001 else '‚ùå Vanishing'}")
    
    print("\nüí° Your turn: Modify this cell to explore ResNet V2!")

# Run your experiment
my_resnetv2_experiment()

## Conclusions and Takeaways

üéâ **Congratulations!** You've mastered ResNet V2 and the power of pre-activation!

### Key Insights Discovered:

1. **Pre-Activation Design**: Moving BN and ReLU before convolution creates clean identity paths
2. **Perfect Identity Mapping**: `x` flows directly to output without transformation
3. **Ultra-Deep Networks**: Enables training 1000+ layer networks successfully
4. **Better Optimization**: Cleaner gradients lead to faster, more stable training
5. **Simple Change, Big Impact**: Just rearranging layers dramatically improves performance

### Why Pre-Activation Matters:

- **Clean Gradients**: Direct path for backward propagation
- **Information Flow**: Input information preserved throughout network
- **Optimization Landscape**: Smoother loss surface, easier to optimize
- **Scalability**: Handles extreme depth that original ResNet cannot

### The Magic Formula:

**ResNet V2 Block**:
```python
output = Conv(ReLU(BN(Conv(ReLU(BN(x)))))) + x
```

**Key**: Identity `x` is added AFTER all transformations, with NO final activation!

### Modern Impact:

ResNet V2's pre-activation design influenced:
- üß† Transformer architectures (Pre-LayerNorm)
- üé® Diffusion models (U-Nets with pre-activation)
- üéØ EfficientNets (optimized block designs)
- üöÄ Vision Transformers (residual connections everywhere)

### The Research Lesson:

Sometimes the best improvements come from **simplifying and perfecting** existing ideas rather than adding complexity. ResNet V2 shows that careful architectural choices matter immensely!

### Next Steps:

1. **Explore Wide ResNets**: Increasing width vs depth
2. **Try ResNeXt**: Grouped convolutions + residuals
3. **Study DenseNet**: Connecting all layers
4. **Apply to Projects**: Use pre-activation in your own architectures

The pre-activation principle: **prepare, transform, connect** - is now a fundamental pattern in deep learning architecture design! üéØüß†‚ú®