In [1]:
"""
Combined file with all neural network models adapted to use ModuleDict
and work with NetworkMonitor for tracking activations and gradients.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
import torch.optim as optim



###########################################
# Utility Functions and Custom Layers
###########################################

def get_activation(activation_name):
    """
    Returns the activation function based on name
    
    Parameters:
        activation_name (str): Name of the activation function
        
    Returns:
        nn.Module: PyTorch activation module
    """
    activations = {
        'relu': nn.ReLU(inplace=False),  # Use inplace=False for compatibility with hooks
        'leaky_relu': nn.LeakyReLU(0.1, inplace=False),
        'tanh': nn.Tanh(),
        'sigmoid': nn.Sigmoid(),
        'gelu': nn.GELU(),
        'elu': nn.ELU(inplace=False),
        'selu': nn.SELU(inplace=False),
        'none': nn.Identity()
    }
    
    if activation_name.lower() not in activations:
        raise ValueError(f"Activation {activation_name} not supported. "
                         f"Choose from: {list(activations.keys())}")
    
    return activations[activation_name.lower()]

def get_normalization(norm_name, num_features):
    """
    Returns the normalization layer based on name
    
    Parameters:
        norm_name (str): Name of the normalization ('batch', 'layer', etc.)
        num_features (int): Number of features for the normalization layer
        
    Returns:
        nn.Module: PyTorch normalization module or None
    """
    if norm_name is None:
        return None
        
    normalizations = {
        'batch': nn.BatchNorm1d(num_features),
        'batch2d': nn.BatchNorm2d(num_features),
        'layer': nn.LayerNorm(num_features),
        'instance': nn.InstanceNorm1d(num_features),
        'instance2d': nn.InstanceNorm2d(num_features, affine=True),
        'group': nn.GroupNorm(min(32, num_features), num_features),
        'none': nn.Identity()
    }
    
    norm_key = str(norm_name).lower()
    if norm_key not in normalizations:
        raise ValueError(f"Normalization {norm_name} not supported. "
                         f"Choose from: {list(normalizations.keys())}")
    
    return normalizations[norm_key]


class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization from the paper:
    "Root Mean Square Layer Normalization"
    https://arxiv.org/abs/1910.07467
    """
    def __init__(self, dim, eps=1e-8):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(dim))
        self.eps = eps
        
    def forward(self, x):
        # Calculate RMS
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x_normalized = x / rms
        # Scale
        return self.scale * x_normalized


class PatchEmbedding(nn.Module):
    """
    Image to Patch Embedding for Vision Transformer.
    Adapted to work with ModuleDict and NetworkMonitor.
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Use ModuleDict for components
        self.layers = nn.ModuleDict({
            'proj': nn.Conv2d(
                in_channels,
                embed_dim,
                kernel_size=patch_size,
                stride=patch_size
            )
        })

    def forward(self, x):
        x = self.layers['proj'](x)  # (B, E, H', W')
        # Rearrange to sequence of patches: [B, C, H, W] -> [B, N, C]
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, N, C)
        return x


###########################################
# Network Monitor
###########################################

class NetworkMonitor:
    def __init__(self, model):
        self.model = model
        self.activations = defaultdict(list)
        self.gradients = defaultdict(list)
        self.fwd_hooks = []
        self.bwd_hooks = []
        
    def register_hooks(self):
        for name, module in self.model.named_modules():
            if name != '':  # Skip the root module
                # Forward hook
                def make_fwd_hook(name=name):
                    def hook(module, input, output):
                        # print(f"Forward hook called for {name}")
                        self.activations[name].append(output.clone().detach().cpu())
                    return hook
                
                # Backward hook - FIXED to avoid nested lists
                def make_bwd_hook(name=name):
                    def hook(module, grad_input, grad_output):
                        # print(f"Backward hook called for {name}")
                        # Just take the first gradient tensor directly
                        if len(grad_output) > 0 and grad_output[0] is not None:
                            self.gradients[f"{name}_grad_output"].append(grad_output[0].clone().detach().cpu())
                        return grad_input
                    return hook
                
                # Register both hooks
                h1 = module.register_forward_hook(make_fwd_hook())
                h2 = module.register_full_backward_hook(make_bwd_hook())
                self.fwd_hooks.append(h1)
                self.bwd_hooks.append(h2)
    
    def remove_hooks(self):
        for h in self.fwd_hooks + self.bwd_hooks:
            h.remove()
        self.fwd_hooks = []
        self.bwd_hooks = []
        
    def clear_data(self):
        self.activations = defaultdict(list)
        self.gradients = defaultdict(list)

    def get_activations(self):
        """Get the recorded activations"""
        return {k: v[0] if len(v)==1 else v for k, v in self.activations.items()}
    
    def get_gradients(self):
        """Get the recorded gradients"""
        return {k: v[0] if len(v)==1 else v for k, v in self.gradients.items()}
    
    def show_data_structure(self):
        print("\nACTIVATIONS DATA STRUCTURE:")
        for name, acts in self.activations.items():
            print(f"{name}: List with {len(acts)} tensors")
            for i, act in enumerate(acts):
                print(f"  - Item {i}: shape={act.shape}, mean={act.mean().item():.4f}")
        
        print("\nGRADIENTS DATA STRUCTURE:")
        for name, grads in self.gradients.items():
            print(f"{name}: List with {len(grads)} tensors")
            for i, g in enumerate(grads):
                print(f"  - Item {i}: shape={g.shape}, mean={g.mean().item():.4f}")

    def print_activation_stats(self, layers=None):
        """Print statistics about activations"""
        activations = self.get_activations()
        print("\nActivation Statistics:")
        
        # Filter by layers if specified
        if layers:
            filtered_activations = {k: v for k, v in activations.items() if k in layers}
        else:
            filtered_activations = activations
            
        for name, act in filtered_activations.items():
            if isinstance(act, torch.Tensor):
                print(f"{name}: shape={act.shape}, mean={act.mean().item():.6f}, "
                      f"std={act.std().item():.6f}, min={act.min().item():.6f}, "
                      f"max={act.max().item():.6f}")
            else:
                print(f"{name}: {type(act)}")
    
    def print_gradient_stats(self, layers=None, grad_type="output"):
        """Print statistics about gradients"""
        gradients = self.get_gradients()
        print(f"\nGradient Statistics ({grad_type}):")
        
        # Filter by layers and gradient type
        keys = [k for k in gradients.keys() if f"_grad_{grad_type}" in k]
        if layers:
            keys = [k for k in keys if any(layer in k for layer in layers)]
            
        for key in keys:
            grads = gradients[key]
            if not isinstance(grads, list):
                grads = [grads]
                
            for i, g in enumerate(grads):
                if g is not None:
                    print(f"{key}[{i}]: shape={g.shape}, mean={g.mean().item():.6f}, "
                          f"std={g.std().item():.6f}, max={g.abs().max().item():.6f}")
                else:
                    print(f"{key}[{i}]: None")

    def visualize_activation_flow(self, layer_names=None):
        """Visualize activation magnitudes across layers"""
        activations = self.get_activations()
        
        # Filter by layer names if specified
        if layer_names:
            filtered_acts = {k: v for k, v in activations.items() if k in layer_names}
        else:
            # Filter out non-tensor activations and exclude input
            filtered_acts = {k: v for k, v in activations.items() 
                            if isinstance(v, torch.Tensor) and k != 'input'}
        
        # Calculate mean activation magnitudes
        magnitudes = {name: float(act.abs().mean().item()) 
                     for name, act in filtered_acts.items()}
        
        # Sort by layer name (assuming sequential naming like conv1, conv2, etc.)
        sorted_items = sorted(magnitudes.items())
        
        # Print a simple text-based visualization
        print("\nActivation Magnitude Flow:")
        max_name_len = max(len(name) for name in magnitudes.keys())
        max_mag = max(magnitudes.values())
        
        for name, mag in sorted_items:
            bar_len = int((mag / max_mag) * 40)
            print(f"{name.ljust(max_name_len)} | {'█' * bar_len} {mag:.6f}")
    
    def visualize_gradient_flow(self, filter_type="output"):
        """Visualize gradient magnitudes across layers"""
        gradients = self.get_gradients()
        
        # Filter gradient keys by type (input or output)
        grad_keys = [k for k in gradients.keys() if f"_grad_{filter_type}" in k]
        
        # Extract mean gradient magnitudes
        magnitudes = {}
        for key in grad_keys:
            grads = gradients[key]
            if not isinstance(grads, list):
                grads = [grads]
            
            # Get the first non-None gradient
            for g in grads:
                if g is not None:
                    # Extract the module name from the gradient key
                    name = key.split('_grad_')[0]
                    magnitudes[name] = float(g.abs().mean().item())
                    break
        
        # Sort and visualize
        sorted_items = sorted(magnitudes.items())
        
        print(f"\nGradient Magnitude Flow ({filter_type}):")
        max_name_len = max(len(name) for name in magnitudes.keys())
        max_mag = max(magnitudes.values()) if magnitudes else 1.0
        
        for name, mag in sorted_items:
            bar_len = int((mag / max_mag) * 40)
            print(f"{name.ljust(max_name_len)} | {'█' * bar_len} {mag:.6f}")

###########################################
# Usage Example
###########################################

def test_model_with_monitor(model_name='mlp'):
    """
    Create and test a model with NetworkMonitor.
    
    Parameters:
        model_name (str): One of 'mlp', 'cnn', 'resnet', 'vit'
    """
    # Create model based on model_name
    if model_name == 'mlp':
        model = MLP(
            input_size=784,  # MNIST flattened size
            hidden_sizes=[512, 256],
            output_size=10,
            activation='relu',
            dropout_p=0.2,
            normalization='batch'
        )
        input_shape = (16, 1, 28, 28)  # batch_size, channels, height, width
        
    elif model_name == 'cnn':
        model = CNN(
            in_channels=3,
            conv_channels=[64, 128, 256],
            kernel_sizes=[3, 3, 3],
            strides=[1, 1, 1],
            paddings=[1, 1, 1],
            fc_hidden_units=[512],
            num_classes=10,
            activation='relu',
            dropout_p=0.2
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    elif model_name == 'resnet':
        model = ResNet(
            layers=[2, 2, 2, 2],  # ResNet18
            num_classes=10,
            in_channels=3,
            activation='relu',
            dropout_p=0.2
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    elif model_name == 'vit':
        model = VisionTransformer(
            img_size=32,
            patch_size=4,
            in_channels=3,
            num_classes=10,
            embed_dim=192,
            depth=6,
            n_heads=8
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    # Print model structure
    print(f"\n=== {model_name.upper()} Model Structure ===")
    for name, module in model.named_modules():
        if len(name) > 0:
            print(f"{name}: {module.__class__.__name__}")
    
    # Create NetworkMonitor and register hooks
    monitor = NetworkMonitor(model)
    monitor.register_hooks()
    
    # Generate dummy data
    x = torch.randn(*input_shape)
    target = torch.randint(0, 10, (input_shape[0],))
    
    # Define loss function
    criterion = nn.CrossEntropyLoss()
    
    # Forward pass
    output = model(x)
    loss = criterion(output, target)
    
    # Backward pass
    loss.backward()
    
    # Print statistics
    print(f"\n=== {model_name.upper()} Activation Statistics ===")
    important_layers = []
    # Get a few important layers based on the model type
    if model_name == 'mlp':
        for i in range(len(model.hidden_sizes)):
            important_layers.append(f'linear_{i}')
    elif model_name == 'cnn':
        for i in range(3):
            important_layers.append(f'conv_{i}')
    elif model_name == 'resnet':
        important_layers = ['layer1_block0', 'layer2_block0', 'layer3_block0', 'layer4_block0']
    elif model_name == 'vit':
        important_layers = ['patch_embed', 'block_0', 'block_2', 'block_5']
    
    # Print statistics for selected layers
    monitor.print_activation_stats(layers=important_layers)
    monitor.print_gradient_stats(layers=important_layers)
    
    # Visualize activation and gradient flow
    monitor.visualize_activation_flow()
    monitor.visualize_gradient_flow()
    
    # Clean up
    monitor.remove_hooks()




###########################################
#  MLP
###########################################

class MLP(nn.Module):
    def __init__(self, 
                 input_size=784, 
                 hidden_sizes=[512, 256, 128], 
                 output_size=10, 
                 activation='relu',
                 dropout_p=0.0,
                 normalization=None,
                 norm_after_activation=False,
                 bias=True):
        """
        Fully  MLP that supports various activations and normalizations.
        
        Parameters:
            input_size (int): Dimensionality of input features
            hidden_sizes (list): List of hidden layer dimensions
            output_size (int): Number of output classes
            activation (str): Activation function to use ('relu', 'tanh', 'sigmoid', etc.)
            dropout_p (float): Dropout probability (0 to disable)
            normalization (str): Normalization to use ('batch', 'layer', None)
            norm_after_activation (bool): If True, apply normalization after activation
            bias (bool): Whether to include bias terms in linear layers
        """
        super(MLP, self).__init__()
        
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.norm_after_activation = norm_after_activation
        
        # Build network using ModuleDict
        self.layers = nn.ModuleDict()
        in_features = input_size
        
        for i, hidden_size in enumerate(hidden_sizes):
            # Linear layer
            self.layers[f'linear_{i}'] = nn.Linear(in_features, hidden_size, bias=bias)
            
            # Activation
            self.layers[f'activation_{i}'] = get_activation(activation)
            
            # Normalization
            if normalization:
                self.layers[f'norm_{i}'] = get_normalization(normalization, hidden_size)
            
            # Dropout
            if dropout_p > 0:
                self.layers[f'dropout_{i}'] = nn.Dropout(dropout_p)
            
            in_features = hidden_size
        
        # Output layer
        self.layers['output'] = nn.Linear(in_features, output_size, bias=bias)
        
    def forward(self, x):
        """
        Forward pass without activation storage.
        
        Parameters:
            x (torch.Tensor): Input data with shape [batch_size, input_size]
        
        Returns:
            torch.Tensor: Output logits
        """
        # Flatten input if needed
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        
        # Apply hidden layers
        for i in range(len(self.hidden_sizes)):
            # Linear
            x = self.layers[f'linear_{i}'](x)
            
            # Apply norm before activation if configured that way
            if not self.norm_after_activation and f'norm_{i}' in self.layers:
                x = self.layers[f'norm_{i}'](x)
                
            # Activation
            x = self.layers[f'activation_{i}'](x)
            
            # Apply norm after activation if configured that way
            if self.norm_after_activation and f'norm_{i}' in self.layers:
                x = self.layers[f'norm_{i}'](x)
                
            # Dropout (if present)
            if f'dropout_{i}' in self.layers:
                x = self.layers[f'dropout_{i}'](x)
        
        # Output layer
        x = self.layers['output'](x)
        
        return x


###########################################
#  CNN
###########################################

class CNN(nn.Module):
    def __init__(self, 
                 in_channels=3,
                 conv_channels=[64, 128, 256], 
                 kernel_sizes=[3, 3, 3],
                 strides=[1, 1, 1],
                 paddings=[1, 1, 1],
                 fc_hidden_units=[512],
                 num_classes=10, 
                 input_size=32,
                 activation='relu',
                 dropout_p=0.0,
                 pool_type='max',
                 pool_size=2,
                 use_batchnorm=True,
                 norm_after_activation=False):
        """
         CNN with  layers, activations, and normalizations.
        
        Parameters:
            in_channels (int): Number of input channels (3 for RGB images)
            conv_channels (list): List of convolutional layer output channels
            kernel_sizes (list): List of kernel sizes for each conv layer
            strides (list): List of stride values for each conv layer
            paddings (list): List of padding values for each conv layer
            fc_hidden_units (list): List of hidden units for fully connected layers
            num_classes (int): Number of output classes
            input_size (int): Height/width of the input images (assumed square)
            activation (str): Activation function to use ('relu', 'tanh', etc.)
            dropout_p (float): Dropout probability (0 to disable)
            pool_type (str): Type of pooling ('max', 'avg', or None)
            pool_size (int): Size of the pooling window
            use_batchnorm (bool): Whether to use batch normalization
            norm_after_activation (bool): Apply normalization after activation
        """
        super(CNN, self).__init__()
        
        # Check if input lists are of the same length
        assert len(conv_channels) == len(kernel_sizes) == len(strides) == len(paddings), \
            "Convolutional parameters (channels, kernels, strides, paddings) must have the same length"
        
        self.norm_after_activation = norm_after_activation
        
        # Create a ModuleDict to store all layers
        self.layers = nn.ModuleDict()
        
        # Convolutional layers
        channels = in_channels
        for i, (out_channels, kernel_size, stride, padding) in enumerate(
                zip(conv_channels, kernel_sizes, strides, paddings)):
            # Conv layer
            self.layers[f'conv_{i}'] = nn.Conv2d(channels, out_channels, kernel_size, stride, padding)
            
            # Normalization
            if use_batchnorm:
                self.layers[f'norm_{i}'] = nn.BatchNorm2d(out_channels)
            
            # Activation
            self.layers[f'act_{i}'] = get_activation(activation)
            
            # Pooling
            if pool_type == 'max':
                self.layers[f'pool_{i}'] = nn.MaxPool2d(pool_size, pool_size)
            elif pool_type == 'avg':
                self.layers[f'pool_{i}'] = nn.AvgPool2d(pool_size, pool_size)
            
            channels = out_channels
        
        # Calculate the size after all pooling operations
        num_pools = len(conv_channels) if pool_type in ['max', 'avg'] else 0
        final_size = input_size // (pool_size ** num_pools)
        self.flattened_size = conv_channels[-1] * final_size * final_size
        
        # Flatten layer
        self.layers['flatten'] = nn.Flatten()
        
        # Fully connected layers
        fc_input_size = self.flattened_size
        for i, hidden_units in enumerate(fc_hidden_units):
            self.layers[f'fc_{i}'] = nn.Linear(fc_input_size, hidden_units)
            self.layers[f'fc_act_{i}'] = get_activation(activation)
            
            if dropout_p > 0:
                self.layers[f'fc_dropout_{i}'] = nn.Dropout(dropout_p)
                
            fc_input_size = hidden_units
        
        # Output layer
        self.layers['output'] = nn.Linear(fc_input_size, num_classes)
        
        # Store configuration
        self.num_conv_layers = len(conv_channels)
        self.num_fc_layers = len(fc_hidden_units)
        self.use_batchnorm = use_batchnorm
        self.has_pool = pool_type in ['max', 'avg']
        self.dropout_p = dropout_p
    
    def forward(self, x):
        """
        Forward pass without activation storage.
        
        Parameters:
            x (torch.Tensor): Input data [batch_size, in_channels, height, width]
        
        Returns:
            torch.Tensor: Output logits
        """
        # Process conv layers
        for i in range(self.num_conv_layers):
            # Conv
            x = self.layers[f'conv_{i}'](x)
            
            # Normalization (before activation)
            if self.use_batchnorm and not self.norm_after_activation:
                if f'norm_{i}' in self.layers:
                    x = self.layers[f'norm_{i}'](x)
            
            # Activation
            x = self.layers[f'act_{i}'](x)
            
            # Normalization (after activation)
            if self.use_batchnorm and self.norm_after_activation:
                if f'norm_{i}' in self.layers:
                    x = self.layers[f'norm_{i}'](x)
            
            # Pooling
            if self.has_pool and f'pool_{i}' in self.layers:
                x = self.layers[f'pool_{i}'](x)
        
        # Flatten
        x = self.layers['flatten'](x)
        
        # FC layers
        for i in range(self.num_fc_layers):
            x = self.layers[f'fc_{i}'](x)
            x = self.layers[f'fc_act_{i}'](x)
            
            if self.dropout_p > 0 and f'fc_dropout_{i}' in self.layers:
                x = self.layers[f'fc_dropout_{i}'](x)
        
        # Output layer
        x = self.layers['output'](x)
        
        return x


###########################################
# ResNet
###########################################

class BasicBlock(nn.Module):
    """Basic ResNet block with  activation and normalization."""
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1, activation='relu', 
                 use_batchnorm=True, norm_after_activation=False, downsample=None):
        super(BasicBlock, self).__init__()
        
        self.norm_after_activation = norm_after_activation
        
        # Use ModuleDict for all components
        self.layers = nn.ModuleDict()
        
        # First convolution
        self.layers['conv1'] = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, 
                                        padding=1, bias=not use_batchnorm)
        
        # Normalization for first conv
        if use_batchnorm:
            self.layers['bn1'] = nn.BatchNorm2d(planes)
        
        # Activation
        self.layers['activation'] = get_activation(activation)
        
        # Second convolution
        self.layers['conv2'] = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 
                                        padding=1, bias=not use_batchnorm)
        
        # Normalization for second conv
        if use_batchnorm:
            self.layers['bn2'] = nn.BatchNorm2d(planes)
        
        # Downsample if needed (for shortcut connection)
        if downsample is not None:
            self.layers['downsample'] = downsample
        
    def forward(self, x):
        identity = x
        
        # Apply conv1
        out = self.layers['conv1'](x)
        
        # Apply norm1 if exists (before activation)
        if 'bn1' in self.layers and not self.norm_after_activation:
            out = self.layers['bn1'](out)
        
        # Apply activation
        out = self.layers['activation'](out)
        
        # Apply norm1 if exists (after activation)
        if 'bn1' in self.layers and self.norm_after_activation:
            out = self.layers['bn1'](out)
            
        # Apply conv2
        out = self.layers['conv2'](out)
        
        # Apply norm2 if exists (before final addition)
        if 'bn2' in self.layers and not self.norm_after_activation:
            out = self.layers['bn2'](out)
            
        # Handle shortcut connection
        if 'downsample' in self.layers:
            identity = self.layers['downsample'](x)
            
        # Add identity - FIXED: don't use in-place operation (+=)
        # Instead use out = out + identity
        out = out + identity  # This line was changed from out += identity
        
        # Final activation
        out = self.layers['activation'](out)
        
        # Apply norm2 if exists (after final activation)
        if 'bn2' in self.layers and self.norm_after_activation:
            out = self.layers['bn2'](out)
            
        return out


class ResNet(nn.Module):
    """
     ResNet architecture for continual learning experiments.
    """
    def __init__(self, 
                 block=BasicBlock,
                 layers=[2, 2, 2, 2],  # ResNet18 by default
                 num_classes=10,
                 in_channels=3,
                 base_channels=64,
                 activation='relu',
                 dropout_p=0.0,
                 use_batchnorm=True,
                 norm_after_activation=False):
        """
        Initialize the ResNet.
        
        Parameters:
            block (nn.Module): The block type to use (BasicBlock)
            layers (list): Number of blocks in each layer
            num_classes (int): Number of output classes
            in_channels (int): Number of input channels (3 for RGB images)
            base_channels (int): Base number of channels (first layer)
            activation (str): Activation function to use
            dropout_p (float): Dropout probability before final layer
            use_batchnorm (bool): Whether to use batch normalization
            norm_after_activation (bool): Apply normalization after activation
        """
        super(ResNet, self).__init__()
        
        self.use_batchnorm = use_batchnorm
        self.norm_after_activation = norm_after_activation
        self.in_planes = base_channels
        
        # Use ModuleDict for all layers
        self.layers = nn.ModuleDict()
        
        # Initial convolutional layer
        self.layers['conv1'] = nn.Conv2d(in_channels, base_channels, kernel_size=3, 
                                        stride=1, padding=1, bias=not use_batchnorm)
        
        # Batch norm after first conv
        if use_batchnorm:
            self.layers['bn1'] = nn.BatchNorm2d(base_channels)
        
        # Activation
        self.layers['activation'] = get_activation(activation)
        
        # ResNet layers
        self._make_layer(block, base_channels, layers[0], stride=1, 
                        activation=activation, use_batchnorm=use_batchnorm, 
                        norm_after_activation=norm_after_activation, 
                        layer_name='layer1')
        self._make_layer(block, base_channels*2, layers[1], stride=2, 
                        activation=activation, use_batchnorm=use_batchnorm, 
                        norm_after_activation=norm_after_activation, 
                        layer_name='layer2')
        self._make_layer(block, base_channels*4, layers[2], stride=2, 
                        activation=activation, use_batchnorm=use_batchnorm,
                        norm_after_activation=norm_after_activation, 
                        layer_name='layer3')
        self._make_layer(block, base_channels*8, layers[3], stride=2, 
                        activation=activation, use_batchnorm=use_batchnorm,
                        norm_after_activation=norm_after_activation, 
                        layer_name='layer4')
        
        # Global average pooling
        self.layers['avgpool'] = nn.AdaptiveAvgPool2d((1, 1))
        
        # Flatten operation
        self.layers['flatten'] = nn.Flatten()
        
        # Dropout if needed
        if dropout_p > 0:
            self.layers['dropout'] = nn.Dropout(dropout_p)
        
        # Final classifier
        self.layers['fc'] = nn.Linear(base_channels*8*block.expansion, num_classes)
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
        # Store layer information
        self.num_layers = len(layers)
        self.blocks_per_layer = layers
                
    def _make_layer(self, block, planes, num_blocks, stride=1, activation='relu', 
                    use_batchnorm=True, norm_after_activation=False, layer_name='layer'):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample_layers = nn.Sequential(
                nn.Conv2d(self.in_planes, planes * block.expansion, 
                         kernel_size=1, stride=stride, bias=not use_batchnorm)
            )
            
            if use_batchnorm:
                downsample_layers.add_module('1', nn.BatchNorm2d(planes * block.expansion))
                
            downsample = downsample_layers
        
        # Add first block with stride and downsample
        self.layers[f'{layer_name}_block0'] = block(
            self.in_planes, planes, stride, activation, 
            use_batchnorm, norm_after_activation, downsample
        )
        
        self.in_planes = planes * block.expansion
        
        # Add remaining blocks
        for i in range(1, num_blocks):
            self.layers[f'{layer_name}_block{i}'] = block(
                self.in_planes, planes, 1, activation, 
                use_batchnorm, norm_after_activation
            )
        
    def forward(self, x):
        """
        Forward pass without activation storage.
        
        Parameters:
            x (torch.Tensor): Input data
            
        Returns:
            torch.Tensor: Output logits
        """
        # Initial conv
        x = self.layers['conv1'](x)
        
        # Apply normalization (before activation)
        if self.use_batchnorm and not self.norm_after_activation:
            if 'bn1' in self.layers:
                x = self.layers['bn1'](x)
                
        # Activation
        x = self.layers['activation'](x)
        
        # Apply normalization (after activation)
        if self.use_batchnorm and self.norm_after_activation:
            if 'bn1' in self.layers:
                x = self.layers['bn1'](x)
        
        # ResNet blocks
        for layer_idx in range(1, self.num_layers + 1):
            for block_idx in range(self.blocks_per_layer[layer_idx - 1]):
                block_name = f'layer{layer_idx}_block{block_idx}'
                x = self.layers[block_name](x)
        
        # Global average pooling
        x = self.layers['avgpool'](x)
        
        # Flatten
        x = self.layers['flatten'](x)
        
        # Dropout if specified
        if 'dropout' in self.layers:
            x = self.layers['dropout'](x)
            
        # Final classifier
        x = self.layers['fc'](x)
            
        return x


###########################################
# ViT Components
###########################################

class Attention(nn.Module):
    """Multi-head attention module."""
    def __init__(self, dim, n_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % n_heads == 0
        self.n_heads = n_heads
        head_dim = dim // n_heads
        self.scale = head_dim ** -0.5

        # Use ModuleDict
        self.layers = nn.ModuleDict({
            'qkv': nn.Linear(dim, dim * 3, bias=qkv_bias),
            'attn_drop': nn.Dropout(attn_drop),
            'proj': nn.Linear(dim, dim),
            'proj_drop': nn.Dropout(proj_drop)
        })

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.layers['qkv'](x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.layers['attn_drop'](attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.layers['proj'](x)
        x = self.layers['proj_drop'](x)
        return x


class TransformerMLP(nn.Module):
    """MLP module with  activation."""
    def __init__(self, in_features, hidden_features, out_features, 
                 activation='gelu', drop=0.):
        super().__init__()
        
        # Use ModuleDict
        self.layers = nn.ModuleDict({
            'fc1': nn.Linear(in_features, hidden_features),
            'act': get_activation(activation),
            'drop1': nn.Dropout(drop) if drop > 0 else nn.Identity(),
            'fc2': nn.Linear(hidden_features, out_features),
            'drop2': nn.Dropout(drop) if drop > 0 else nn.Identity()
        })

    def forward(self, x):
        x = self.layers['fc1'](x)
        x = self.layers['act'](x)
        x = self.layers['drop1'](x)
        x = self.layers['fc2'](x)
        x = self.layers['drop2'](x)
        return x


class TransformerBlock(nn.Module):
    """Transformer block with  components."""
    def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0., 
                 attn_drop=0., activation='gelu', normalization='layer'):
        super().__init__()
        
        # Use ModuleDict
        self.layers = nn.ModuleDict({
            'norm1': get_normalization(normalization, dim),
            'attn': Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, 
                             attn_drop=attn_drop, proj_drop=drop),
            'norm2': get_normalization(normalization, dim),
            'mlp': TransformerMLP(dim, int(dim * mlp_ratio), dim, 
                       activation=activation, drop=drop)
        })

    def forward(self, x):
        # Self-attention branch
        norm_x = self.layers['norm1'](x)
        attn_out = self.layers['attn'](norm_x)
        x = x + attn_out
        
        # MLP branch
        norm_x = self.layers['norm2'](x)
        mlp_out = self.layers['mlp'](norm_x)
        x = x + mlp_out
            
        return x


class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) model with  architecture.
    """
    def __init__(self, 
                 img_size=32, 
                 patch_size=4, 
                 in_channels=3, 
                 num_classes=10, 
                 embed_dim=192,
                 depth=12, 
                 n_heads=8, 
                 mlp_ratio=4., 
                 qkv_bias=True, 
                 drop_rate=0.1,
                 attn_drop_rate=0.0,
                 activation='gelu',
                 normalization='layer'):
        """
        Initialize Vision Transformer.
        
        Parameters:
            img_size (int): Input image size
            patch_size (int): Patch size for splitting image
            in_channels (int): Number of image channels
            num_classes (int): Number of output classes
            embed_dim (int): Embedding dimension
            depth (int): Number of transformer blocks
            n_heads (int): Number of attention heads
            mlp_ratio (float): Ratio for MLP hidden dimension
            qkv_bias (bool): Whether to use bias in QKV projection
            drop_rate (float): Dropout rate
            attn_drop_rate (float): Attention dropout rate
            activation (str): Activation function to use
            normalization (str): Normalization method to use
        """
        super().__init__()
        
        # Use ModuleDict for all components
        self.layers = nn.ModuleDict()
        
        # Patch embedding
        self.layers['patch_embed'] = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.layers['patch_embed'].n_patches

        # Class token and position embeddings (these aren't in ModuleDict as they're parameters)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        
        # Position dropout
        self.layers['pos_drop'] = nn.Dropout(drop_rate)

        # Transformer blocks
        for i in range(depth):
            self.layers[f'block_{i}'] = TransformerBlock(
                embed_dim, n_heads, mlp_ratio, qkv_bias, 
                drop_rate, attn_drop_rate, activation, normalization
            )

        # Final normalization
        self.layers['norm'] = get_normalization(normalization, embed_dim)
        
        # Classifier head
        self.layers['head'] = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        self._init_weights()
        
        # Save configuration for forward pass
        self.depth = depth

    def _init_weights(self):
        # Initialize position embedding and class token
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        # Initialize other weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        """
        Forward pass without activation storage.
        
        Parameters:
            x (torch.Tensor): Input images [batch_size, channels, height, width]
            
        Returns:
            torch.Tensor: Output logits
        """
        # Patch embedding
        x = self.layers['patch_embed'](x)
        
        # Add class token
        B = x.shape[0]
        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        # Add position embeddings
        x = x + self.pos_embed
        x = self.layers['pos_drop'](x)

        # Apply transformer blocks
        for i in range(self.depth):
            x = self.layers[f'block_{i}'](x)
        
        # Final normalization
        x = self.layers['norm'](x)
        
        # Extract class token and classify
        x = x[:, 0]  # Use only the cls token for classification
        x = self.layers['head'](x)
            
        return x



def example_test():
    # Setup model and monitor
    model = SimpleModel()
    monitor = NetworkMonitor(model)
    monitor.register_hooks()
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    print("=== BATCH 1 ===")
    # Create batch 1
    x1 = torch.randn(2, 3, 8, 8)
    y1 = torch.tensor([0, 3])
    
    # Forward and backward pass for batch 1
    output1 = model(x1)
    loss1 = criterion(output1, y1)
    loss1.backward()
    
    print("\nAfter batch 1, let's examine the data structure:")
    monitor.show_data_structure()
    
    print("\n=== BATCH 2 ===")
    # Create batch 2 (with different data)
    x2 = torch.randn(3, 3, 8, 8)
    y2 = torch.tensor([2, 4, 5])
    
    # Forward and backward pass for batch 2
    output2 = model(x2)
    loss2 = criterion(output2, y2)
    loss2.backward()
    
    print("\nAfter batch 2, let's examine the data structure:")
    monitor.show_data_structure()
    
    # Clean up
    monitor.remove_hooks()

"""
Example of how to use the adapted models with NetworkMonitor.
"""

def test_model_with_monitor(model_name='mlp'):
    """
    Create and test a model with NetworkMonitor.
    
    Parameters:
        model_name (str): One of 'mlp', 'cnn', 'resnet', 'vit'
    """
    # Create model based on model_name
    if model_name == 'mlp':
        model = MLP(
            input_size=784,  # MNIST flattened size
            hidden_sizes=[512, 256],
            output_size=10,
            activation='relu',
            dropout_p=0.2,
            normalization='batch'
        )
        input_shape = (16, 1, 28, 28)  # batch_size, channels, height, width
        
    elif model_name == 'cnn':
        model = CNN(
            in_channels=3,
            conv_channels=[64, 128, 256],
            kernel_sizes=[3, 3, 3],
            strides=[1, 1, 1],
            paddings=[1, 1, 1],
            fc_hidden_units=[512],
            num_classes=10,
            activation='relu',
            dropout_p=0.2
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    elif model_name == 'resnet':
        model = ResNet(
            layers=[2, 2, 2, 2],  # ResNet18
            num_classes=10,
            in_channels=3,
            activation='relu',
            dropout_p=0.2
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    elif model_name == 'vit':
        model = VisionTransformer(
            img_size=32,
            patch_size=4,
            in_channels=3,
            num_classes=10,
            embed_dim=192,
            depth=6,
            n_heads=8
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    # Print model structure
    print(f"\n=== {model_name.upper()} Model Structure ===")
    for name, module in model.named_modules():
        if len(name) > 0:
            print(f"{name}: {module.__class__.__name__}")
    
    # Create NetworkMonitor and register hooks
    monitor = NetworkMonitor(model)
    monitor.register_hooks()
    
    # Generate dummy data
    x = torch.randn(*input_shape)
    target = torch.randint(0, 10, (input_shape[0],))
    
    # Define loss function
    criterion = nn.CrossEntropyLoss()
    
    # Forward pass
    output = model(x)
    loss = criterion(output, target)
    
    # Backward pass
    loss.backward()
    
    # Print statistics
    print(f"\n=== {model_name.upper()} Activation Statistics ===")
    important_layers = []
    # Get a few important layers based on the model type
    if model_name == 'mlp':
        for i in range(len(model.layers) // 3):
            important_layers.append(f'linear_{i}')
    elif model_name == 'cnn':
        for i in range(3):
            important_layers.append(f'conv_{i}')
    elif model_name == 'resnet':
        important_layers = ['layer1_block0', 'layer2_block0', 'layer3_block0', 'layer4_block0']
    elif model_name == 'vit':
        important_layers = ['patch_embed', 'block_0', 'block_2', 'block_5']
    
    # Print statistics for selected layers
    monitor.show_data_structure()
    monitor.print_activation_stats(layers=important_layers)
    monitor.print_gradient_stats(layers=important_layers)
    
    # Visualize activation and gradient flow
    monitor.visualize_activation_flow()
    monitor.visualize_gradient_flow()
    
    # Clean up
    monitor.remove_hooks()
    

if __name__ == "__main__":
    # Test each model
    for model_name in ['mlp', 'cnn', 'resnet', 'vit']:
        test_model_with_monitor(model_name)


=== MLP Model Structure ===
layers: ModuleDict
layers.linear_0: Linear
layers.activation_0: ReLU
layers.norm_0: BatchNorm1d
layers.dropout_0: Dropout
layers.linear_1: Linear
layers.activation_1: ReLU
layers.norm_1: BatchNorm1d
layers.dropout_1: Dropout
layers.output: Linear

=== MLP Activation Statistics ===

ACTIVATIONS DATA STRUCTURE:
layers.linear_0: List with 1 tensors
  - Item 0: shape=torch.Size([16, 512]), mean=-0.0056
layers.norm_0: List with 1 tensors
  - Item 0: shape=torch.Size([16, 512]), mean=0.0000
layers.activation_0: List with 1 tensors
  - Item 0: shape=torch.Size([16, 512]), mean=0.4078
layers.dropout_0: List with 1 tensors
  - Item 0: shape=torch.Size([16, 512]), mean=0.4028
layers.linear_1: List with 1 tensors
  - Item 0: shape=torch.Size([16, 256]), mean=0.0289
layers.norm_1: List with 1 tensors
  - Item 0: shape=torch.Size([16, 256]), mean=-0.0000
layers.activation_1: List with 1 tensors
  - Item 0: shape=torch.Size([16, 256]), mean=0.4041
layers.dropout_1: List 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import time
import os
import random

# Import NetworkMonitor and metric functions from your code
from collections import defaultdict

###########################################
# Metric Functions 
###########################################

def measure_dead_neurons(layer_act, dead_threshold=0.95):
    """
    For a given activation tensor (B, C, ...) flatten spatial dims and compute
    the fraction of neurons that output zero (or nearly zero) for most inputs.
    
    Uses consistent flattening approach across all metrics:
    (B, C, H, W) -> (B*H*W, C) where each row is one spatial location across all batches,
    and each column represents one channel/neuron.
    """
    shape = layer_act.shape
    
    # Reshape to (B*H*W, C) format, similar to measure_effective_rank
    if len(shape) > 2:
        # e.g., for (B, C, H, W): rearrange to (B*H*W, C)
        layer_act = layer_act.permute(0, 2, 3, 1).contiguous().view(-1, shape[1])
    else:
        layer_act = layer_act.view(-1, shape[1])
    
    # Check which values are nearly zero
    is_zero = (layer_act.abs() < 1e-7)
    
    # Average over all spatial locations for each channel
    frac_zero_per_neuron = is_zero.float().mean(dim=0)
    
    # A neuron is considered "dead" if it's zero for more than threshold% of inputs
    dead_mask = (frac_zero_per_neuron > dead_threshold)
    dead_fraction = dead_mask.float().mean().item()
    
    return dead_fraction

def measure_duplicate_neurons(layer_act, corr_threshold=0.99):
    """
    Measure the fraction of neurons that are nearly duplicates (i.e. their outputs
    are highly correlated) across the batch.
    
    Uses consistent flattening approach across all metrics:
    (B, C, H, W) -> (B*H*W, C) where each row is one spatial location across all batches,
    and each column represents one channel/neuron.
    """
    shape = layer_act.shape
    
    # Reshape to (B*H*W, C) format, similar to measure_effective_rank
    if len(shape) > 2:
        # e.g., for (B, C, H, W): rearrange to (B*H*W, C)
        layer_act = layer_act.permute(0, 2, 3, 1).contiguous().view(-1, shape[1])
    else:
        layer_act = layer_act.view(-1, shape[1])
    
    # Now we have data in (Locations, Channels) format
    # For duplicate analysis, we want to compare channels, so we transpose to (C, Locations)
    layer_act = layer_act.t()  # Shape becomes (C, B*H*W)
    C = layer_act.shape[0]  # Number of channels/features
    
    # Normalize each neuron's outputs for cosine similarity calculation
    layer_act = torch.nn.functional.normalize(layer_act, p=2, dim=1)
    
    # Compute similarity matrix - entry (i,j) is the cosine similarity between neurons i and j
    similarity_matrix = torch.matmul(layer_act, layer_act.t())
    
    # Mark duplicates: high similarity but not self (diagonal)
    dup_mask = (similarity_matrix > corr_threshold) & (~torch.eye(C, dtype=torch.bool, device=similarity_matrix.device))
    
    # A neuron is a duplicate if it's highly similar to any other neuron
    neuron_is_dup = dup_mask.any(dim=1)
    
    # Calculate the fraction of neurons that are duplicates
    fraction_dup = neuron_is_dup.float().mean().item()
    
    return fraction_dup

def measure_effective_rank(layer_act, svd_sample_size=1024):
    """
    Compute the effective rank of the activation matrix.
    Effective rank = exp(-sum_i p_i * log(p_i)) where p_i = sigma_i/sum_j sigma_j.
    """
    shape = layer_act.shape
    if len(shape) > 2:
        # e.g., for (B, C, H, W): rearrange to (B*H*W, C)
        layer_act = layer_act.permute(0, 2, 3, 1).contiguous().view(-1, shape[1])
    else:
        layer_act = layer_act.view(-1, shape[1])
    N = layer_act.shape[0]
    if N > svd_sample_size:
        idx = torch.randperm(N)[:svd_sample_size]
        layer_act = layer_act[idx]
    U, S, Vt = torch.linalg.svd(layer_act, full_matrices=False)
    S_sum = S.sum()
    if S_sum < 1e-9:
        return 0.0
    p = S / S_sum
    p_log_p = p * torch.log(p + 1e-12)
    eff_rank = torch.exp(-p_log_p.sum()).item()
    return eff_rank

def get_activations_for_batch(model, monitor, x_batch, y_batch=None, criterion=None):
    """
    Clears monitor data, runs a forward pass (and backward pass if criterion provided),
    then returns a dict of final activations from each hooked layer.
    """
    monitor.clear_data()
    model.eval()
    with torch.set_grad_enabled(criterion is not None):
        output = model(x_batch)
        if criterion is not None and y_batch is not None:
            loss = criterion(output, y_batch)
            loss.backward()
    final_acts = {name: acts_list[-1] for name, acts_list in monitor.activations.items()}
    return final_acts

def evaluate_layer_metrics(model, monitor, x_batch, y_batch=None, criterion=None,
                           dead_threshold=0.95, corr_threshold=0.99):
    """
    Run a forward (and optionally backward) pass on x_batch (and y_batch),
    then compute for each layer:
      - Fraction of dead neurons
      - Fraction of duplicate neurons
      - Effective rank of the activations
    Returns a dictionary of metrics per layer.
    """
    final_acts = get_activations_for_batch(model, monitor, x_batch, y_batch, criterion)
    results = {}
    for layer_name, act in final_acts.items():
        if not isinstance(act, torch.Tensor):
            continue
        # Skip certain non-meaningful layers
        if layer_name.startswith('dropout') or 'flatten' in layer_name or 'shortcut' in layer_name:
            continue
            
        dead_frac = measure_dead_neurons(act, dead_threshold)
        dup_frac = measure_duplicate_neurons(act, corr_threshold)
        eff_rank = measure_effective_rank(act)
        results[layer_name] = {
            'dead_fraction': dead_frac,
            'dup_fraction': dup_frac,
            'eff_rank': eff_rank
        }
    return results

###########################################
# Training and Evaluation Functions
###########################################

def set_seed(seed):
    """Set random seed for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_cifar10_data(batch_size=128):
    """Load CIFAR10 dataset with standard transformations"""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Create fixed batches for consistent metric measurement
    # 1. Training fixed batch
    train_indices = list(range(500))  # Use first 500 samples for fixed training batch
    fixed_train_set = Subset(trainset, train_indices)
    fixed_trainloader = DataLoader(fixed_train_set, batch_size=100, shuffle=False)
    
    # 2. Validation fixed batch
    val_indices = list(range(500))  # Use first 500 samples for fixed validation batch
    fixed_val_set = Subset(testset, val_indices)
    fixed_valloader = DataLoader(fixed_val_set, batch_size=100, shuffle=False)
    
    return trainloader, testloader, fixed_trainloader, fixed_valloader

def train_and_evaluate(model, trainloader, testloader, fixed_trainloader, fixed_valloader, 
                       monitor, num_epochs=20, metrics_frequency=100, device='cpu'):
    """
    Train the model and periodically measure metrics on fixed batches
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Initialize metric tracking
    train_losses = []
    test_accs = []
    
    # Metrics storage
    training_metrics_history = defaultdict(lambda: defaultdict(list))
    validation_metrics_history = defaultdict(lambda: defaultdict(list))
    step_history = []
    
    # Fixed batches for consistent metric evaluation
    fixed_train_batch, fixed_train_targets = next(iter(fixed_trainloader))
    fixed_val_batch, fixed_val_targets = next(iter(fixed_valloader))
    
    # Move fixed batches to device
    fixed_train_batch, fixed_train_targets = fixed_train_batch.to(device), fixed_train_targets.to(device)
    fixed_val_batch, fixed_val_targets = fixed_val_batch.to(device), fixed_val_targets.to(device)
    
    # Baseline metrics (before training)
    print("Measuring baseline metrics before training...")
    train_metrics = evaluate_layer_metrics(model, monitor, fixed_train_batch)
    val_metrics = evaluate_layer_metrics(model, monitor, fixed_val_batch)
    
    # Print baseline metrics
    print("\n=== Training Batch Metrics (before training) ===")
    for layer_name in (train_metrics.keys()):
        metrics = train_metrics[layer_name]
        print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
              f"Dup: {metrics['dup_fraction']:8.3f}, " +
              f"EffRank: {metrics['eff_rank']:8.3f}")
    
    print("\n=== Validation Batch Metrics (before training) ===")
    for layer_name in (val_metrics.keys()):
        metrics = val_metrics[layer_name]
        print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
              f"Dup: {metrics['dup_fraction']:8.3f}, " +
              f"EffRank: {metrics['eff_rank']:8.3f}")
    
    # Add baseline metrics to history
    step_history.append(0)
    for layer_name, metrics in train_metrics.items():
        for metric_name, value in metrics.items():
            training_metrics_history[layer_name][metric_name].append(value)
    
    for layer_name, metrics in val_metrics.items():
        for metric_name, value in metrics.items():
            validation_metrics_history[layer_name][metric_name].append(value)
    
    total_steps = 0
    
    # Training loop
    start_time = time.time()
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Increment step counter
            total_steps += 1
            
            # Measure metrics periodically
            if total_steps % metrics_frequency == 0:
                # Measure metrics on fixed batches
                print(f"\nStep {total_steps}: Measuring metrics...")
                
                # Training metrics
                train_metrics = evaluate_layer_metrics(model, monitor, fixed_train_batch)
                
                # Validation metrics
                val_metrics = evaluate_layer_metrics(model, monitor, fixed_val_batch)
                
                # Store metrics
                step_history.append(total_steps)
                for layer_name, metrics in train_metrics.items():
                    for metric_name, value in metrics.items():
                        training_metrics_history[layer_name][metric_name].append(value)
                
                for layer_name, metrics in val_metrics.items():
                    for metric_name, value in metrics.items():
                        validation_metrics_history[layer_name][metric_name].append(value)
                
                # Print metrics for all layers
                print(f"\n=== Training Batch Metrics (step {total_steps}) ===")
                for layer_name in (train_metrics.keys()):
                    metrics = train_metrics[layer_name]
                    print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
                          f"Dup: {metrics['dup_fraction']:8.3f}, " +
                          f"EffRank: {metrics['eff_rank']:8.3f}")
                
                print(f"\n=== Validation Batch Metrics (step {total_steps}) ===")
                for layer_name in (val_metrics.keys()):
                    metrics = val_metrics[layer_name]
                    print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
                          f"Dup: {metrics['dup_fraction']:8.3f}, " +
                          f"EffRank: {metrics['eff_rank']:8.3f}")
        
        # Evaluate on test set at the end of each epoch
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        # Calculate epoch statistics
        train_loss = running_loss / len(trainloader)
        test_acc = 100. * correct / total
        
        # Store epoch statistics
        train_losses.append(train_loss)
        test_accs.append(test_acc)
        
        # Print epoch summary
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f} | Test Acc: {test_acc:.2f}%')
        print(f'Time: {time.time() - start_time:.2f}s')
    
    # Final evaluation
    print("\nFinal metrics:")
    
    # Training metrics
    train_metrics = evaluate_layer_metrics(model, monitor, fixed_train_batch)
    
    # Validation metrics
    val_metrics = evaluate_layer_metrics(model, monitor, fixed_val_batch)
    
    # Print final metrics
    print("\n=== Final Training Batch Metrics ===")
    for layer_name in (train_metrics.keys()):
        metrics = train_metrics[layer_name]
        print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
              f"Dup: {metrics['dup_fraction']:8.3f}, " +
              f"EffRank: {metrics['eff_rank']:8.3f}")
    
    print("\n=== Final Validation Batch Metrics ===")
    for layer_name in (val_metrics.keys()):
        metrics = val_metrics[layer_name]
        print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
              f"Dup: {metrics['dup_fraction']:8.3f}, " +
              f"EffRank: {metrics['eff_rank']:8.3f}")
    
    # Store final metrics
    step_history.append(total_steps)
    for layer_name, metrics in train_metrics.items():
        for metric_name, value in metrics.items():
            training_metrics_history[layer_name][metric_name].append(value)
    
    for layer_name, metrics in val_metrics.items():
        for metric_name, value in metrics.items():
            validation_metrics_history[layer_name][metric_name].append(value)
    
    return {
        'train_losses': train_losses,
        'test_accs': test_accs,
        'training_metrics_history': dict(training_metrics_history),
        'validation_metrics_history': dict(validation_metrics_history),
        'step_history': step_history
    }

###########################################
# Visualization Functions
###########################################

def plot_training_curves(history):
    """Plot training loss and test accuracy curves"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot training loss
    ax1.plot(history['train_losses'])
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    
    # Plot test accuracy
    ax2.plot(history['test_accs'])
    ax2.set_title('Test Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    
    plt.tight_layout()
    plt.show()

def plot_metric_evolution(history, metric_name, layer_names=None, is_train=True):
    """Plot the evolution of a specific metric for selected layers"""
    prefix = "Training" if is_train else "Validation"
    metrics_history = history['training_metrics_history'] if is_train else history['validation_metrics_history']
    steps = history['step_history']
    
    # If no layers specified, use all available
    if layer_names is None:
        layer_names = list(metrics_history.keys())
    
    # Filter to layers that exist
    layer_names = [layer for layer in layer_names if layer in metrics_history]
    
    plt.figure(figsize=(12, 6))
    for layer in layer_names:
        if layer in metrics_history and metric_name in metrics_history[layer]:
            plt.plot(steps, metrics_history[layer][metric_name], label=layer)
    
    plt.title(f'{prefix} {metric_name.replace("_", " ").title()} Evolution')
    plt.xlabel('Training Steps')
    plt.ylabel(metric_name.replace("_", " ").title())
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()

def plot_all_metrics(history, layer_names=None):
    """Plot the evolution of all metrics for selected layers"""
    # Get all metrics
    metrics = ['dead_fraction', 'dup_fraction', 'eff_rank']
    
    # Plot training metrics
    for metric in metrics:
        plot_metric_evolution(history, metric, layer_names, is_train=True)
    
    # Plot validation metrics
    for metric in metrics:
        plot_metric_evolution(history, metric, layer_names, is_train=False)

###########################################
# Main Function
###########################################

def main():
    # Set random seed for reproducibility
    set_seed(42)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data loaders
    print("Loading CIFAR10 dataset...")
    trainloader, testloader, fixed_trainloader, fixed_valloader = get_cifar10_data(batch_size=128)
    
    # Create CNN model
    print("Creating model...")
    model = CNN(
        in_channels=3,
        conv_channels=[64, 128, 256],
        kernel_sizes=[3, 3, 3],
        strides=[1, 1, 1],
        paddings=[1, 1, 1],
        fc_hidden_units=[512],
        num_classes=10,
        activation='relu',
        dropout_p=0.0,
        use_batchnorm=True
    )
    model = model.to(device)
    
    # Create monitor
    monitor = NetworkMonitor(model)
    monitor.register_hooks()
    
    # Print model architecture
    print("\nModel Architecture:")
    for name, module in model.named_modules():
        if len(name) > 0:  # Skip the root module
            print(f"{name}: {module.__class__.__name__}")
    
    # Training settings
    num_epochs = 5
    metrics_frequency = 100  # Measure metrics every 100 steps
    
    # Train and evaluate
    print("\nStarting training...")
    history = train_and_evaluate(
        model, trainloader, testloader, fixed_trainloader, fixed_valloader,
        monitor, num_epochs, metrics_frequency, device
    )
    
    # Clean up monitor
    monitor.remove_hooks()
    
    # Plot results
    print("\nPlotting results...")
    plot_training_curves(history)
    
    # Select important layers to visualize
    important_layers = ['conv_0', 'conv_1', 'conv_2', 'fc_0', 'output']
    
    # Plot metrics evolution for important layers
    plot_all_metrics(history, important_layers)
    
    print("\nDone!")

if __name__ == "__main__":
    main()

Using device: cuda
Loading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified
Creating model...

Model Architecture:
layers: ModuleDict
layers.conv_0: Conv2d
layers.norm_0: BatchNorm2d
layers.act_0: ReLU
layers.pool_0: MaxPool2d
layers.conv_1: Conv2d
layers.norm_1: BatchNorm2d
layers.act_1: ReLU
layers.pool_1: MaxPool2d
layers.conv_2: Conv2d
layers.norm_2: BatchNorm2d
layers.act_2: ReLU
layers.pool_2: MaxPool2d
layers.flatten: Flatten
layers.fc_0: Linear
layers.fc_act_0: ReLU
layers.output: Linear

Starting training...
Measuring baseline metrics before training...

=== Training Batch Metrics (before training) ===
layers.conv_0  : Dead:    0.000, Dup:    0.000, EffRank:   11.315
layers.norm_0  : Dead:    0.000, Dup:    0.000, EffRank:   11.312
layers.act_0   : Dead:    0.000, Dup:    0.000, EffRank:   27.140
layers.pool_0  : Dead:    0.000, Dup:    0.000, EffRank:   25.894
layers.conv_1  : Dead:    0.000, Dup:    0.000, EffRank:   61.011
layer