In [None]:
"""
Combined file with all neural network models adapted to use ModuleDict
and work with NetworkMonitor for tracking activations and gradients.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
import torch.optim as optim



###########################################
# Utility Functions and Custom Layers
###########################################

def get_activation(activation_name):
    """
    Returns the activation function based on name
    
    Parameters:
        activation_name (str): Name of the activation function
        
    Returns:
        nn.Module: PyTorch activation module
    """
    activations = {
        'relu': nn.ReLU(inplace=False),  # Use inplace=False for compatibility with hooks
        'leaky_relu': nn.LeakyReLU(0.1, inplace=False),
        'tanh': nn.Tanh(),
        'sigmoid': nn.Sigmoid(),
        'gelu': nn.GELU(),
        'elu': nn.ELU(inplace=False),
        'selu': nn.SELU(inplace=False),
        'none': nn.Identity()
    }
    
    if activation_name.lower() not in activations:
        raise ValueError(f"Activation {activation_name} not supported. "
                         f"Choose from: {list(activations.keys())}")
    
    return activations[activation_name.lower()]

def get_normalization(norm_name, num_features):
    """
    Returns the normalization layer based on name
    
    Parameters:
        norm_name (str): Name of the normalization ('batch', 'layer', etc.)
        num_features (int): Number of features for the normalization layer
        
    Returns:
        nn.Module: PyTorch normalization module or None
    """
    if norm_name is None:
        return None
        
    normalizations = {
        'batch': nn.BatchNorm1d(num_features),
        'batch2d': nn.BatchNorm2d(num_features),
        'layer': nn.LayerNorm(num_features),
        'instance': nn.InstanceNorm1d(num_features),
        'instance2d': nn.InstanceNorm2d(num_features, affine=True),
        'group': nn.GroupNorm(min(32, num_features), num_features),
        'none': nn.Identity()
    }
    
    norm_key = str(norm_name).lower()
    if norm_key not in normalizations:
        raise ValueError(f"Normalization {norm_name} not supported. "
                         f"Choose from: {list(normalizations.keys())}")
    
    return normalizations[norm_key]


class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization from the paper:
    "Root Mean Square Layer Normalization"
    https://arxiv.org/abs/1910.07467
    """
    def __init__(self, dim, eps=1e-8):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(dim))
        self.eps = eps
        
    def forward(self, x):
        # Calculate RMS
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x_normalized = x / rms
        # Scale
        return self.scale * x_normalized


class PatchEmbedding(nn.Module):
    """
    Image to Patch Embedding for Vision Transformer.
    Adapted to work with ModuleDict and NetworkMonitor.
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Use ModuleDict for components
        self.layers = nn.ModuleDict({
            'proj': nn.Conv2d(
                in_channels,
                embed_dim,
                kernel_size=patch_size,
                stride=patch_size
            )
        })

    def forward(self, x):
        x = self.layers['proj'](x)  # (B, E, H', W')
        # Rearrange to sequence of patches: [B, C, H, W] -> [B, N, C]
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, N, C)
        return x


###########################################
# Network Monitor
###########################################

class NetworkMonitor:
    def __init__(self, model):
        self.model = model
        self.activations = defaultdict(list)
        self.gradients = defaultdict(list)
        self.fwd_hooks = []
        self.bwd_hooks = []
        
    def register_hooks(self):
        for name, module in self.model.named_modules():
            if name != '':  # Skip the root module
                # Forward hook
                def make_fwd_hook(name=name):
                    def hook(module, input, output):
                        # print(f"Forward hook called for {name}")
                        self.activations[name].append(output.clone().detach().cpu())
                    return hook
                
                # Backward hook - FIXED to avoid nested lists
                def make_bwd_hook(name=name):
                    def hook(module, grad_input, grad_output):
                        # print(f"Backward hook called for {name}")
                        # Just take the first gradient tensor directly
                        if len(grad_output) > 0 and grad_output[0] is not None:
                            self.gradients[f"{name}_grad_output"].append(grad_output[0].clone().detach().cpu())
                        return grad_input
                    return hook
                
                # Register both hooks
                h1 = module.register_forward_hook(make_fwd_hook())
                h2 = module.register_full_backward_hook(make_bwd_hook())
                self.fwd_hooks.append(h1)
                self.bwd_hooks.append(h2)
    
    def remove_hooks(self):
        for h in self.fwd_hooks + self.bwd_hooks:
            h.remove()
        self.fwd_hooks = []
        self.bwd_hooks = []
        
    def clear_data(self):
        self.activations = defaultdict(list)
        self.gradients = defaultdict(list)

    def get_activations(self):
        """Get the recorded activations"""
        return {k: v[0] if len(v)==1 else v for k, v in self.activations.items()}
    
    def get_gradients(self):
        """Get the recorded gradients"""
        return {k: v[0] if len(v)==1 else v for k, v in self.gradients.items()}
    
    def show_data_structure(self):
        print("\nACTIVATIONS DATA STRUCTURE:")
        for name, acts in self.activations.items():
            print(f"{name}: List with {len(acts)} tensors")
            for i, act in enumerate(acts):
                print(f"  - Item {i}: shape={act.shape}, mean={act.mean().item():.4f}")
        
        print("\nGRADIENTS DATA STRUCTURE:")
        for name, grads in self.gradients.items():
            print(f"{name}: List with {len(grads)} tensors")
            for i, g in enumerate(grads):
                print(f"  - Item {i}: shape={g.shape}, mean={g.mean().item():.4f}")

    def print_activation_stats(self, layers=None):
        """Print statistics about activations"""
        activations = self.get_activations()
        print("\nActivation Statistics:")
        
        # Filter by layers if specified
        if layers:
            filtered_activations = {k: v for k, v in activations.items() if k in layers}
        else:
            filtered_activations = activations
            
        for name, act in filtered_activations.items():
            if isinstance(act, torch.Tensor):
                print(f"{name}: shape={act.shape}, mean={act.mean().item():.6f}, "
                      f"std={act.std().item():.6f}, min={act.min().item():.6f}, "
                      f"max={act.max().item():.6f}")
            else:
                print(f"{name}: {type(act)}")
    
    def print_gradient_stats(self, layers=None, grad_type="output"):
        """Print statistics about gradients"""
        gradients = self.get_gradients()
        print(f"\nGradient Statistics ({grad_type}):")
        
        # Filter by layers and gradient type
        keys = [k for k in gradients.keys() if f"_grad_{grad_type}" in k]
        if layers:
            keys = [k for k in keys if any(layer in k for layer in layers)]
            
        for key in keys:
            grads = gradients[key]
            if not isinstance(grads, list):
                grads = [grads]
                
            for i, g in enumerate(grads):
                if g is not None:
                    print(f"{key}[{i}]: shape={g.shape}, mean={g.mean().item():.6f}, "
                          f"std={g.std().item():.6f}, max={g.abs().max().item():.6f}")
                else:
                    print(f"{key}[{i}]: None")

    def visualize_activation_flow(self, layer_names=None):
        """Visualize activation magnitudes across layers"""
        activations = self.get_activations()
        
        # Filter by layer names if specified
        if layer_names:
            filtered_acts = {k: v for k, v in activations.items() if k in layer_names}
        else:
            # Filter out non-tensor activations and exclude input
            filtered_acts = {k: v for k, v in activations.items() 
                            if isinstance(v, torch.Tensor) and k != 'input'}
        
        # Calculate mean activation magnitudes
        magnitudes = {name: float(act.abs().mean().item()) 
                     for name, act in filtered_acts.items()}
        
        # Sort by layer name (assuming sequential naming like conv1, conv2, etc.)
        sorted_items = sorted(magnitudes.items())
        
        # Print a simple text-based visualization
        print("\nActivation Magnitude Flow:")
        max_name_len = max(len(name) for name in magnitudes.keys())
        max_mag = max(magnitudes.values())
        
        for name, mag in sorted_items:
            bar_len = int((mag / max_mag) * 40)
            print(f"{name.ljust(max_name_len)} | {'█' * bar_len} {mag:.6f}")
    
    def visualize_gradient_flow(self, filter_type="output"):
        """Visualize gradient magnitudes across layers"""
        gradients = self.get_gradients()
        
        # Filter gradient keys by type (input or output)
        grad_keys = [k for k in gradients.keys() if f"_grad_{filter_type}" in k]
        
        # Extract mean gradient magnitudes
        magnitudes = {}
        for key in grad_keys:
            grads = gradients[key]
            if not isinstance(grads, list):
                grads = [grads]
            
            # Get the first non-None gradient
            for g in grads:
                if g is not None:
                    # Extract the module name from the gradient key
                    name = key.split('_grad_')[0]
                    magnitudes[name] = float(g.abs().mean().item())
                    break
        
        # Sort and visualize
        sorted_items = sorted(magnitudes.items())
        
        print(f"\nGradient Magnitude Flow ({filter_type}):")
        max_name_len = max(len(name) for name in magnitudes.keys())
        max_mag = max(magnitudes.values()) if magnitudes else 1.0
        
        for name, mag in sorted_items:
            bar_len = int((mag / max_mag) * 40)
            print(f"{name.ljust(max_name_len)} | {'█' * bar_len} {mag:.6f}")

###########################################
# Usage Example
###########################################

def test_model_with_monitor(model_name='mlp'):
    """
    Create and test a model with NetworkMonitor.
    
    Parameters:
        model_name (str): One of 'mlp', 'cnn', 'resnet', 'vit'
    """
    # Create model based on model_name
    if model_name == 'mlp':
        model = MLP(
            input_size=784,  # MNIST flattened size
            hidden_sizes=[512, 256],
            output_size=10,
            activation='relu',
            dropout_p=0.2,
            normalization='batch'
        )
        input_shape = (16, 1, 28, 28)  # batch_size, channels, height, width
        
    elif model_name == 'cnn':
        model = CNN(
            in_channels=3,
            conv_channels=[64, 128, 256],
            kernel_sizes=[3, 3, 3],
            strides=[1, 1, 1],
            paddings=[1, 1, 1],
            fc_hidden_units=[512],
            num_classes=10,
            activation='relu',
            dropout_p=0.2
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    elif model_name == 'resnet':
        model = ResNet(
            layers=[2, 2, 2, 2],  # ResNet18
            num_classes=10,
            in_channels=3,
            activation='relu',
            dropout_p=0.2
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    elif model_name == 'vit':
        model = VisionTransformer(
            img_size=32,
            patch_size=4,
            in_channels=3,
            num_classes=10,
            embed_dim=192,
            depth=6,
            n_heads=8
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    # Print model structure
    print(f"\n=== {model_name.upper()} Model Structure ===")
    for name, module in model.named_modules():
        if len(name) > 0:
            print(f"{name}: {module.__class__.__name__}")
    
    # Create NetworkMonitor and register hooks
    monitor = NetworkMonitor(model)
    monitor.register_hooks()
    
    # Generate dummy data
    x = torch.randn(*input_shape)
    target = torch.randint(0, 10, (input_shape[0],))
    
    # Define loss function
    criterion = nn.CrossEntropyLoss()
    
    # Forward pass
    output = model(x)
    loss = criterion(output, target)
    
    # Backward pass
    loss.backward()
    
    # Print statistics
    print(f"\n=== {model_name.upper()} Activation Statistics ===")
    important_layers = []
    # Get a few important layers based on the model type
    if model_name == 'mlp':
        for i in range(len(model.hidden_sizes)):
            important_layers.append(f'linear_{i}')
    elif model_name == 'cnn':
        for i in range(3):
            important_layers.append(f'conv_{i}')
    elif model_name == 'resnet':
        important_layers = ['layer1_block0', 'layer2_block0', 'layer3_block0', 'layer4_block0']
    elif model_name == 'vit':
        important_layers = ['patch_embed', 'block_0', 'block_2', 'block_5']
    
    # Print statistics for selected layers
    monitor.print_activation_stats(layers=important_layers)
    monitor.print_gradient_stats(layers=important_layers)
    
    # Visualize activation and gradient flow
    monitor.visualize_activation_flow()
    monitor.visualize_gradient_flow()
    
    # Clean up
    monitor.remove_hooks()


###########################################
#  MLP
###########################################

class MLP(nn.Module):
    def __init__(self, 
                 input_size=784, 
                 hidden_sizes=[512, 256, 128], 
                 output_size=10, 
                 activation='relu',
                 dropout_p=0.0,
                 normalization=None,
                 norm_after_activation=False,
                 bias=True):
        """
        Fully  MLP that supports various activations and normalizations.
        
        Parameters:
            input_size (int): Dimensionality of input features
            hidden_sizes (list): List of hidden layer dimensions
            output_size (int): Number of output classes
            activation (str): Activation function to use ('relu', 'tanh', 'sigmoid', etc.)
            dropout_p (float): Dropout probability (0 to disable)
            normalization (str): Normalization to use ('batch', 'layer', None)
            norm_after_activation (bool): If True, apply normalization after activation
            bias (bool): Whether to include bias terms in linear layers
        """
        super(MLP, self).__init__()
        
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.norm_after_activation = norm_after_activation
        
        # Build network using ModuleDict
        self.layers = nn.ModuleDict()
        in_features = input_size
        
        for i, hidden_size in enumerate(hidden_sizes):
            # Linear layer
            self.layers[f'linear_{i}'] = nn.Linear(in_features, hidden_size, bias=bias)
            
            
            if norm_after_activation:
                # Activation
                self.layers[f'act_{i}'] = get_activation(activation)
                # Normalization
                if normalization:
                    self.layers[f'norm_{i}'] = get_normalization(normalization, hidden_size)
            else:
                # Normalization
                if normalization:
                    self.layers[f'norm_{i}'] = get_normalization(normalization, hidden_size)
                # Activation
                self.layers[f'act_{i}'] = get_activation(activation)
            
            # Dropout
            if dropout_p > 0:
                self.layers[f'drop_{i}'] = nn.Dropout(dropout_p)
            
            in_features = hidden_size
        
        # Output layer
        self.layers['out'] = nn.Linear(in_features, output_size, bias=bias)
        
    def forward(self, x):
        """
        Forward pass without activation storage.
        
        Parameters:
            x (torch.Tensor): Input data with shape [batch_size, input_size]
        
        Returns:
            torch.Tensor: Output logits
        """
        # Flatten input if needed
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        
        # Apply hidden layers
        for k,l in self.layers.items():
            # Linear
            x = l(x)
        
        return x


###########################################
#  CNN
###########################################

class CNN(nn.Module):
    def __init__(self, 
                 in_channels=3,
                 conv_channels=[64, 128, 256], 
                 kernel_sizes=[3, 3, 3],
                 strides=[1, 1, 1],
                 paddings=[1, 1, 1],
                 fc_hidden_units=[512],
                 num_classes=10, 
                 input_size=32,
                 activation='relu',
                 dropout_p=0.0,
                 pool_type='max',
                 pool_size=2,
                 use_batchnorm=True,
                 norm_after_activation=False):
        """
         CNN with  layers, activations, and normalizations.
        
        Parameters:
            in_channels (int): Number of input channels (3 for RGB images)
            conv_channels (list): List of convolutional layer output channels
            kernel_sizes (list): List of kernel sizes for each conv layer
            strides (list): List of stride values for each conv layer
            paddings (list): List of padding values for each conv layer
            fc_hidden_units (list): List of hidden units for fully connected layers
            num_classes (int): Number of output classes
            input_size (int): Height/width of the input images (assumed square)
            activation (str): Activation function to use ('relu', 'tanh', etc.)
            dropout_p (float): Dropout probability (0 to disable)
            pool_type (str): Type of pooling ('max', 'avg', or None)
            pool_size (int): Size of the pooling window
            use_batchnorm (bool): Whether to use batch normalization
            norm_after_activation (bool): Apply normalization after activation
        """
        super(CNN, self).__init__()
        
        # Check if input lists are of the same length
        assert len(conv_channels) == len(kernel_sizes) == len(strides) == len(paddings), \
            "Convolutional parameters (channels, kernels, strides, paddings) must have the same length"
        
        self.norm_after_activation = norm_after_activation
        
        # Create a ModuleDict to store all layers
        self.layers = nn.ModuleDict()
        
        # Convolutional layers
        channels = in_channels
        for i, (out_channels, kernel_size, stride, padding) in enumerate(
                zip(conv_channels, kernel_sizes, strides, paddings)):
            # Conv layer
            self.layers[f'conv_{i}'] = nn.Conv2d(channels, out_channels, kernel_size, stride, padding)
            
            # Normalization
            if use_batchnorm:
                self.layers[f'norm_{i}'] = nn.BatchNorm2d(out_channels)
            
            # Activation
            self.layers[f'act_{i}'] = get_activation(activation)
            
            # Pooling
            if pool_type == 'max':
                self.layers[f'pool_{i}'] = nn.MaxPool2d(pool_size, pool_size)
            elif pool_type == 'avg':
                self.layers[f'pool_{i}'] = nn.AvgPool2d(pool_size, pool_size)
            
            channels = out_channels
        
        # Calculate the size after all pooling operations
        num_pools = len(conv_channels) if pool_type in ['max', 'avg'] else 0
        final_size = input_size // (pool_size ** num_pools)
        self.flattened_size = conv_channels[-1] * final_size * final_size
        
        # Flatten layer
        self.layers['flatten'] = nn.Flatten()
        
        # Fully connected layers
        fc_input_size = self.flattened_size
        for i, hidden_units in enumerate(fc_hidden_units):
            self.layers[f'fc_{i}'] = nn.Linear(fc_input_size, hidden_units)
            self.layers[f'fc_act_{i}'] = get_activation(activation)
            
            if dropout_p > 0:
                self.layers[f'fc_dropout_{i}'] = nn.Dropout(dropout_p)
                
            fc_input_size = hidden_units
        
        # Output layer
        self.layers['output'] = nn.Linear(fc_input_size, num_classes)
        
        # Store configuration
        self.num_conv_layers = len(conv_channels)
        self.num_fc_layers = len(fc_hidden_units)
        self.use_batchnorm = use_batchnorm
        self.has_pool = pool_type in ['max', 'avg']
        self.dropout_p = dropout_p
    
    def forward(self, x):
        """
        Forward pass without activation storage.
        
        Parameters:
            x (torch.Tensor): Input data [batch_size, in_channels, height, width]
        
        Returns:
            torch.Tensor: Output logits
        """
        # Process conv layers
        for i in range(self.num_conv_layers):
            # Conv
            x = self.layers[f'conv_{i}'](x)
            
            # Normalization (before activation)
            if self.use_batchnorm and not self.norm_after_activation:
                if f'norm_{i}' in self.layers:
                    x = self.layers[f'norm_{i}'](x)
            
            # Activation
            x = self.layers[f'act_{i}'](x)
            
            # Normalization (after activation)
            if self.use_batchnorm and self.norm_after_activation:
                if f'norm_{i}' in self.layers:
                    x = self.layers[f'norm_{i}'](x)
            
            # Pooling
            if self.has_pool and f'pool_{i}' in self.layers:
                x = self.layers[f'pool_{i}'](x)
        
        # Flatten
        x = self.layers['flatten'](x)
        
        # FC layers
        for i in range(self.num_fc_layers):
            x = self.layers[f'fc_{i}'](x)
            x = self.layers[f'fc_act_{i}'](x)
            
            if self.dropout_p > 0 and f'fc_dropout_{i}' in self.layers:
                x = self.layers[f'fc_dropout_{i}'](x)
        
        # Output layer
        x = self.layers['output'](x)
        
        return x


###########################################
# ResNet
###########################################

class BasicBlock(nn.Module):
    """Basic ResNet block with  activation and normalization."""
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1, activation='relu', 
                 use_batchnorm=True, norm_after_activation=False, downsample=None):
        super(BasicBlock, self).__init__()
        
        self.norm_after_activation = norm_after_activation
        
        # Use ModuleDict for all components
        self.layers = nn.ModuleDict()
        
        # First convolution
        self.layers['conv1'] = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, 
                                        padding=1, bias=not use_batchnorm)
        
        # Normalization for first conv
        if use_batchnorm:
            self.layers['bn1'] = nn.BatchNorm2d(planes)
        
        # Activation
        self.layers['activation'] = get_activation(activation)
        
        # Second convolution
        self.layers['conv2'] = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 
                                        padding=1, bias=not use_batchnorm)
        
        # Normalization for second conv
        if use_batchnorm:
            self.layers['bn2'] = nn.BatchNorm2d(planes)
        
        # Downsample if needed (for shortcut connection)
        if downsample is not None:
            self.layers['downsample'] = downsample
        
    def forward(self, x):
        identity = x
        
        # Apply conv1
        out = self.layers['conv1'](x)
        
        # Apply norm1 if exists (before activation)
        if 'bn1' in self.layers and not self.norm_after_activation:
            out = self.layers['bn1'](out)
        
        # Apply activation
        out = self.layers['activation'](out)
        
        # Apply norm1 if exists (after activation)
        if 'bn1' in self.layers and self.norm_after_activation:
            out = self.layers['bn1'](out)
            
        # Apply conv2
        out = self.layers['conv2'](out)
        
        # Apply norm2 if exists (before final addition)
        if 'bn2' in self.layers and not self.norm_after_activation:
            out = self.layers['bn2'](out)
            
        # Handle shortcut connection
        if 'downsample' in self.layers:
            identity = self.layers['downsample'](x)
            
        # Add identity - FIXED: don't use in-place operation (+=)
        # Instead use out = out + identity
        out = out + identity  # This line was changed from out += identity
        
        # Final activation
        out = self.layers['activation'](out)
        
        # Apply norm2 if exists (after final activation)
        if 'bn2' in self.layers and self.norm_after_activation:
            out = self.layers['bn2'](out)
            
        return out


class ResNet(nn.Module):
    """
     ResNet architecture for continual learning experiments.
    """
    def __init__(self, 
                 block=BasicBlock,
                 layers=[2, 2, 2, 2],  # ResNet18 by default
                 num_classes=10,
                 in_channels=3,
                 base_channels=64,
                 activation='relu',
                 dropout_p=0.0,
                 use_batchnorm=True,
                 norm_after_activation=False):
        """
        Initialize the ResNet.
        
        Parameters:
            block (nn.Module): The block type to use (BasicBlock)
            layers (list): Number of blocks in each layer
            num_classes (int): Number of output classes
            in_channels (int): Number of input channels (3 for RGB images)
            base_channels (int): Base number of channels (first layer)
            activation (str): Activation function to use
            dropout_p (float): Dropout probability before final layer
            use_batchnorm (bool): Whether to use batch normalization
            norm_after_activation (bool): Apply normalization after activation
        """
        super(ResNet, self).__init__()
        
        self.use_batchnorm = use_batchnorm
        self.norm_after_activation = norm_after_activation
        self.in_planes = base_channels
        
        # Use ModuleDict for all layers
        self.layers = nn.ModuleDict()
        
        # Initial convolutional layer
        self.layers['conv1'] = nn.Conv2d(in_channels, base_channels, kernel_size=3, 
                                        stride=1, padding=1, bias=not use_batchnorm)
        
        # Batch norm after first conv
        if use_batchnorm:
            self.layers['bn1'] = nn.BatchNorm2d(base_channels)
        
        # Activation
        self.layers['activation'] = get_activation(activation)
        
        # ResNet layers
        self._make_layer(block, base_channels, layers[0], stride=1, 
                        activation=activation, use_batchnorm=use_batchnorm, 
                        norm_after_activation=norm_after_activation, 
                        layer_name='layer1')
        self._make_layer(block, base_channels*2, layers[1], stride=2, 
                        activation=activation, use_batchnorm=use_batchnorm, 
                        norm_after_activation=norm_after_activation, 
                        layer_name='layer2')
        self._make_layer(block, base_channels*4, layers[2], stride=2, 
                        activation=activation, use_batchnorm=use_batchnorm,
                        norm_after_activation=norm_after_activation, 
                        layer_name='layer3')
        self._make_layer(block, base_channels*8, layers[3], stride=2, 
                        activation=activation, use_batchnorm=use_batchnorm,
                        norm_after_activation=norm_after_activation, 
                        layer_name='layer4')
        
        # Global average pooling
        self.layers['avgpool'] = nn.AdaptiveAvgPool2d((1, 1))
        
        # Flatten operation
        self.layers['flatten'] = nn.Flatten()
        
        # Dropout if needed
        if dropout_p > 0:
            self.layers['dropout'] = nn.Dropout(dropout_p)
        
        # Final classifier
        self.layers['fc'] = nn.Linear(base_channels*8*block.expansion, num_classes)
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
        # Store layer information
        self.num_layers = len(layers)
        self.blocks_per_layer = layers
                
    def _make_layer(self, block, planes, num_blocks, stride=1, activation='relu', 
                    use_batchnorm=True, norm_after_activation=False, layer_name='layer'):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample_layers = nn.Sequential(
                nn.Conv2d(self.in_planes, planes * block.expansion, 
                         kernel_size=1, stride=stride, bias=not use_batchnorm)
            )
            
            if use_batchnorm:
                downsample_layers.add_module('1', nn.BatchNorm2d(planes * block.expansion))
                
            downsample = downsample_layers
        
        # Add first block with stride and downsample
        self.layers[f'{layer_name}_block0'] = block(
            self.in_planes, planes, stride, activation, 
            use_batchnorm, norm_after_activation, downsample
        )
        
        self.in_planes = planes * block.expansion
        
        # Add remaining blocks
        for i in range(1, num_blocks):
            self.layers[f'{layer_name}_block{i}'] = block(
                self.in_planes, planes, 1, activation, 
                use_batchnorm, norm_after_activation
            )
        
    def forward(self, x):
        """
        Forward pass without activation storage.
        
        Parameters:
            x (torch.Tensor): Input data
            
        Returns:
            torch.Tensor: Output logits
        """
        # Initial conv
        x = self.layers['conv1'](x)
        
        # Apply normalization (before activation)
        if self.use_batchnorm and not self.norm_after_activation:
            if 'bn1' in self.layers:
                x = self.layers['bn1'](x)
                
        # Activation
        x = self.layers['activation'](x)
        
        # Apply normalization (after activation)
        if self.use_batchnorm and self.norm_after_activation:
            if 'bn1' in self.layers:
                x = self.layers['bn1'](x)
        
        # ResNet blocks
        for layer_idx in range(1, self.num_layers + 1):
            for block_idx in range(self.blocks_per_layer[layer_idx - 1]):
                block_name = f'layer{layer_idx}_block{block_idx}'
                x = self.layers[block_name](x)
        
        # Global average pooling
        x = self.layers['avgpool'](x)
        
        # Flatten
        x = self.layers['flatten'](x)
        
        # Dropout if specified
        if 'dropout' in self.layers:
            x = self.layers['dropout'](x)
            
        # Final classifier
        x = self.layers['fc'](x)
            
        return x


###########################################
# ViT Components
###########################################

class Attention(nn.Module):
    """Multi-head attention module."""
    def __init__(self, dim, n_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % n_heads == 0
        self.n_heads = n_heads
        head_dim = dim // n_heads
        self.scale = head_dim ** -0.5

        # Use ModuleDict
        self.layers = nn.ModuleDict({
            'qkv': nn.Linear(dim, dim * 3, bias=qkv_bias),
            'attn_drop': nn.Dropout(attn_drop),
            'proj': nn.Linear(dim, dim),
            'proj_drop': nn.Dropout(proj_drop)
        })

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.layers['qkv'](x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.layers['attn_drop'](attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.layers['proj'](x)
        x = self.layers['proj_drop'](x)
        return x


class TransformerMLP(nn.Module):
    """MLP module with  activation."""
    def __init__(self, in_features, hidden_features, out_features, 
                 activation='gelu', drop=0.):
        super().__init__()
        
        # Use ModuleDict
        self.layers = nn.ModuleDict({
            'fc1': nn.Linear(in_features, hidden_features),
            'act': get_activation(activation),
            'drop1': nn.Dropout(drop) if drop > 0 else nn.Identity(),
            'fc2': nn.Linear(hidden_features, out_features),
            'drop2': nn.Dropout(drop) if drop > 0 else nn.Identity()
        })

    def forward(self, x):
        x = self.layers['fc1'](x)
        x = self.layers['act'](x)
        x = self.layers['drop1'](x)
        x = self.layers['fc2'](x)
        x = self.layers['drop2'](x)
        return x


class TransformerBlock(nn.Module):
    """Transformer block with  components."""
    def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0., 
                 attn_drop=0., activation='gelu', normalization='layer'):
        super().__init__()
        
        # Use ModuleDict
        self.layers = nn.ModuleDict({
            'norm1': get_normalization(normalization, dim),
            'attn': Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, 
                             attn_drop=attn_drop, proj_drop=drop),
            'norm2': get_normalization(normalization, dim),
            'mlp': TransformerMLP(dim, int(dim * mlp_ratio), dim, 
                       activation=activation, drop=drop)
        })

    def forward(self, x):
        # Self-attention branch
        norm_x = self.layers['norm1'](x)
        attn_out = self.layers['attn'](norm_x)
        x = x + attn_out
        
        # MLP branch
        norm_x = self.layers['norm2'](x)
        mlp_out = self.layers['mlp'](norm_x)
        x = x + mlp_out
            
        return x


class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) model with  architecture.
    """
    def __init__(self, 
                 img_size=32, 
                 patch_size=4, 
                 in_channels=3, 
                 num_classes=10, 
                 embed_dim=192,
                 depth=12, 
                 n_heads=8, 
                 mlp_ratio=4., 
                 qkv_bias=True, 
                 drop_rate=0.1,
                 attn_drop_rate=0.0,
                 activation='gelu',
                 normalization='layer'):
        """
        Initialize Vision Transformer.
        
        Parameters:
            img_size (int): Input image size
            patch_size (int): Patch size for splitting image
            in_channels (int): Number of image channels
            num_classes (int): Number of output classes
            embed_dim (int): Embedding dimension
            depth (int): Number of transformer blocks
            n_heads (int): Number of attention heads
            mlp_ratio (float): Ratio for MLP hidden dimension
            qkv_bias (bool): Whether to use bias in QKV projection
            drop_rate (float): Dropout rate
            attn_drop_rate (float): Attention dropout rate
            activation (str): Activation function to use
            normalization (str): Normalization method to use
        """
        super().__init__()
        
        # Use ModuleDict for all components
        self.layers = nn.ModuleDict()
        
        # Patch embedding
        self.layers['patch_embed'] = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.layers['patch_embed'].n_patches

        # Class token and position embeddings (these aren't in ModuleDict as they're parameters)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        
        # Position dropout
        self.layers['pos_drop'] = nn.Dropout(drop_rate)

        # Transformer blocks
        for i in range(depth):
            self.layers[f'block_{i}'] = TransformerBlock(
                embed_dim, n_heads, mlp_ratio, qkv_bias, 
                drop_rate, attn_drop_rate, activation, normalization
            )

        # Final normalization
        self.layers['norm'] = get_normalization(normalization, embed_dim)
        
        # Classifier head
        self.layers['head'] = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        self._init_weights()
        
        # Save configuration for forward pass
        self.depth = depth

    def _init_weights(self):
        # Initialize position embedding and class token
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        # Initialize other weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        """
        Forward pass without activation storage.
        
        Parameters:
            x (torch.Tensor): Input images [batch_size, channels, height, width]
            
        Returns:
            torch.Tensor: Output logits
        """
        # Patch embedding
        x = self.layers['patch_embed'](x)
        
        # Add class token
        B = x.shape[0]
        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        # Add position embeddings
        x = x + self.pos_embed
        x = self.layers['pos_drop'](x)

        # Apply transformer blocks
        for i in range(self.depth):
            x = self.layers[f'block_{i}'](x)
        
        # Final normalization
        x = self.layers['norm'](x)
        
        # Extract class token and classify
        x = x[:, 0]  # Use only the cls token for classification
        x = self.layers['head'](x)
            
        return x



def example_test():
    # Setup model and monitor
    model = SimpleModel()
    monitor = NetworkMonitor(model)
    monitor.register_hooks()
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    print("=== BATCH 1 ===")
    # Create batch 1
    x1 = torch.randn(2, 3, 8, 8)
    y1 = torch.tensor([0, 3])
    
    # Forward and backward pass for batch 1
    output1 = model(x1)
    loss1 = criterion(output1, y1)
    loss1.backward()
    
    print("\nAfter batch 1, let's examine the data structure:")
    monitor.show_data_structure()
    
    print("\n=== BATCH 2 ===")
    # Create batch 2 (with different data)
    x2 = torch.randn(3, 3, 8, 8)
    y2 = torch.tensor([2, 4, 5])
    
    # Forward and backward pass for batch 2
    output2 = model(x2)
    loss2 = criterion(output2, y2)
    loss2.backward()
    
    print("\nAfter batch 2, let's examine the data structure:")
    monitor.show_data_structure()
    
    # Clean up
    monitor.remove_hooks()

"""
Example of how to use the adapted models with NetworkMonitor.
"""

def test_model_with_monitor(model_name='mlp'):
    """
    Create and test a model with NetworkMonitor.
    
    Parameters:
        model_name (str): One of 'mlp', 'cnn', 'resnet', 'vit'
    """
    # Create model based on model_name
    if model_name == 'mlp':
        model = MLP(
            input_size=784,  # MNIST flattened size
            hidden_sizes=[512, 256],
            output_size=10,
            activation='relu',
            dropout_p=0.2,
            normalization='batch'
        )
        input_shape = (16, 1, 28, 28)  # batch_size, channels, height, width
        
    elif model_name == 'cnn':
        model = CNN(
            in_channels=3,
            conv_channels=[64, 128, 256],
            kernel_sizes=[3, 3, 3],
            strides=[1, 1, 1],
            paddings=[1, 1, 1],
            fc_hidden_units=[512],
            num_classes=10,
            activation='relu',
            dropout_p=0.2
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    elif model_name == 'resnet':
        model = ResNet(
            layers=[2, 2, 2, 2],  # ResNet18
            num_classes=10,
            in_channels=3,
            activation='relu',
            dropout_p=0.2
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    elif model_name == 'vit':
        model = VisionTransformer(
            img_size=32,
            patch_size=4,
            in_channels=3,
            num_classes=10,
            embed_dim=192,
            depth=6,
            n_heads=8
        )
        input_shape = (16, 3, 32, 32)  # batch_size, channels, height, width
        
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    # Print model structure
    print(f"\n=== {model_name.upper()} Model Structure ===")
    for name, module in model.named_modules():
        if len(name) > 0:
            print(f"{name}: {module.__class__.__name__}")
    
    # Create NetworkMonitor and register hooks
    monitor = NetworkMonitor(model)
    monitor.register_hooks()
    
    # Generate dummy data
    x = torch.randn(*input_shape)
    target = torch.randint(0, 10, (input_shape[0],))
    
    # Define loss function
    criterion = nn.CrossEntropyLoss()
    
    # Forward pass
    output = model(x)
    loss = criterion(output, target)
    
    # Backward pass
    loss.backward()
    
    # Print statistics
    print(f"\n=== {model_name.upper()} Activation Statistics ===")
    important_layers = []
    # Get a few important layers based on the model type
    if model_name == 'mlp':
        for i in range(len(model.layers) // 3):
            important_layers.append(f'linear_{i}')
    elif model_name == 'cnn':
        for i in range(3):
            important_layers.append(f'conv_{i}')
    elif model_name == 'resnet':
        important_layers = ['layer1_block0', 'layer2_block0', 'layer3_block0', 'layer4_block0']
    elif model_name == 'vit':
        important_layers = ['patch_embed', 'block_0', 'block_2', 'block_5']
    
    # Print statistics for selected layers
    monitor.show_data_structure()
    monitor.print_activation_stats(layers=important_layers)
    monitor.print_gradient_stats(layers=important_layers)
    
    # Visualize activation and gradient flow
    monitor.visualize_activation_flow()
    monitor.visualize_gradient_flow()
    
    # Clean up
    monitor.remove_hooks()


    
# if __name__ == "__main__":
#     # Test each model
#     for model_name in ['mlp', 'cnn', 'resnet', 'vit']:
#         test_model_with_monitor(model_name)


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import time
import os
import random

###########################################
# NetworkMonitor Class with Enhanced Control
###########################################

class NetworkMonitor:
    def __init__(self, model):
        self.model = model
        self.activations = defaultdict(list)
        self.gradients = defaultdict(list)
        self.fwd_hooks = []
        self.bwd_hooks = []
        self.hooks_active = False
        
    def register_hooks(self):
        """Register hooks if not already registered"""
        if not self.hooks_active:
            for name, module in self.model.named_modules():
                if name != '':  # Skip the root module
                    # Forward hook
                    def make_fwd_hook(name=name):
                        def hook(module, input, output):
                            # self.activations[f"{name}.in"].append(input[0].clone().detach().cpu())
                            self.activations[f"{name}"].append(output.clone().detach().cpu())
                        return hook
                    
                    # Backward hook
                    def make_bwd_hook(name=name):
                        def hook(module, grad_input, grad_output):
                            # if len(grad_input) > 0 and grad_input[0] is not None:
                                # self.gradients[f"{name}.in"].append(grad_input[0].clone().detach().cpu())
                            if len(grad_output) > 0 and grad_output[0] is not None:
                                self.gradients[f"{name}"].append(grad_output[0].clone().detach().cpu())
                            return grad_input
                        return hook
                    
                    # Register both hooks
                    h1 = module.register_forward_hook(make_fwd_hook())
                    h2 = module.register_full_backward_hook(make_bwd_hook())
                    self.fwd_hooks.append(h1)
                    self.bwd_hooks.append(h2)
            
            self.hooks_active = True
    
    def remove_hooks(self):
        """Remove hooks if they are registered"""
        if self.hooks_active:
            for h in self.fwd_hooks + self.bwd_hooks:
                h.remove()
            self.fwd_hooks = []
            self.bwd_hooks = []
            self.hooks_active = False
        
    def clear_data(self):
        """Clear all recorded activations and gradients"""
        self.activations = defaultdict(list)
        self.gradients = defaultdict(list)

    def get_latest_activations(self):
        """Get the most recent activations for each layer"""
        latest_acts = {}
        for name, acts_list in self.activations.items():
            if acts_list:  # Only include if the list is not empty
                latest_acts[name] = acts_list[-1]
        return latest_acts
    
    def get_latest_gradients(self):
        """Get the most recent gradients for each layer"""
        latest_grads = {}
        for name, grads_list in self.gradients.items():
            if grads_list:  # Only include if the list is not empty
                latest_grads[name] = grads_list[-1]
        return latest_grads

###########################################
# Utility Functions
###########################################

def flatten_activations(layer_act):
    """
    Flatten activation tensor to a consistent format for metric calculation.
    For a tensor of shape (B, C, ...), flatten spatial dims to get (B*H*W, C)
    where each row represents one spatial location across all batches,
    and each column represents one channel/neuron.
    """
    shape = layer_act.shape
    
    if len(shape) == 4:
        # e.g., for (B, C, H, W): rearrange to (B*H*W, C)
        return layer_act.permute(0, 2, 3, 1).contiguous().view(-1, shape[1])
    
    elif len(shape) == 3:  # Transformer format: (B, N, C)
        # For transformer, each token is already a "spatial location" with C features
        # Just reshape to (B*N, C)
        return layer_act.contiguous().view(-1, shape[2])
    else:
        # For 2D tensors like (B, C), just reshape to ensure consistent format
        return layer_act.view(-1, shape[1])

###########################################
# Metric Functions 
###########################################

def measure_dead_neurons(layer_act, dead_threshold=0.95):
    """
    For a given activation tensor, compute the fraction of neurons
    that output zero (or nearly zero) for most inputs.
    """
    # Flatten activations
    flattened_act = flatten_activations(layer_act)
    
    # Check which values are nearly zero
    is_zero = (flattened_act.abs() < 1e-7)
    
    # Average over all spatial locations for each channel
    frac_zero_per_neuron = is_zero.float().mean(dim=0)
    
    # A neuron is considered "dead" if it's zero for more than threshold% of inputs
    dead_mask = (frac_zero_per_neuron > dead_threshold)
    dead_fraction = dead_mask.float().mean().item()
    
    return dead_fraction

def measure_duplicate_neurons(layer_act, corr_threshold=0.99):
    """
    Measure the fraction of neurons that are nearly duplicates
    (i.e. their outputs are highly correlated) across the batch.
    Only counts one neuron from each duplicate pair by examining only 
    the upper triangular part of the similarity matrix.
    """
    # Flatten activations
    flattened_act = flatten_activations(layer_act)
    
    # For duplicate analysis, we want to compare channels, so we transpose
    # Shape becomes (C, B*H*W)
    flattened_act = flattened_act.t()  
    C = flattened_act.shape[0]  # Number of channels/features
    
    # Normalize each neuron's outputs for cosine similarity calculation
    flattened_act = torch.nn.functional.normalize(flattened_act, p=2, dim=1)
    
    # Compute similarity matrix - entry (i,j) is the cosine similarity between neurons i and j
    similarity_matrix = torch.matmul(flattened_act, flattened_act.t())
    
    # Create a mask for the upper triangular part (excluding diagonal)
    # This way, for each pair (i,j) where i≠j, we only count it once
    upper_tri_mask = torch.triu(torch.ones_like(similarity_matrix), diagonal=1).bool()
    
    # Find duplicate pairs in upper triangular part only
    dup_pairs = (similarity_matrix > corr_threshold) & upper_tri_mask

    # A neuron is a duplicate if it's highly similar to any other neuron
    neuron_is_dup = dup_pairs.any(dim=1)
    
    # Calculate the fraction of neurons that are duplicates
    fraction_dup = neuron_is_dup.float().mean().item()
    
    return fraction_dup

def measure_effective_rank(layer_act, svd_sample_size=1024):
    """
    Compute the effective rank of the activation matrix.
    Effective rank = exp(-sum_i p_i * log(p_i)) where p_i = sigma_i/sum_j sigma_j.
    """
    # Flatten activations
    flattened_act = flatten_activations(layer_act)
    
    # Sample rows if there are too many
    N = flattened_act.shape[0]
    if N > svd_sample_size:
        idx = torch.randperm(N)[:svd_sample_size]
        flattened_act = flattened_act[idx]
    
    # Compute SVD and calculate effective rank
    U, S, Vt = torch.linalg.svd(flattened_act, full_matrices=False)
    S_sum = S.sum()
    if S_sum < 1e-9:
        return 0.0
    p = S / S_sum
    p_log_p = p * torch.log(p + 1e-12)
    eff_rank = torch.exp(-p_log_p.sum()).item()
    return eff_rank

def measure_stable_rank(layer_act, sample_size=1024, use_gram=True):
    """
    Compute the stable rank of the activation matrix.
    Stable rank = ||A||_F^2 / ||A||_2^2 = Tr(C)^2 / Tr(C^2),
    where C is either the gram matrix (A^T A) or the covariance matrix (AA^T).
    
    The stable rank is numerically more stable than effective rank and 
    provides a measure of how many dimensions in the data have significant variance.
    
    Parameters:
        layer_act: Layer activations tensor
        sample_size: Maximum number of samples to use
        use_gram: If True, compute using the gram matrix (better when feature dimension < sample dimension)
                 If False, compute using the covariance matrix
    
    Returns:
        Stable rank value (float)
    """
    # Flatten activations
    flattened_act = flatten_activations(layer_act)
    
    # Sample rows if there are too many
    N, D = flattened_act.shape  # N: num samples, D: feature dimension
    if N > sample_size:
        idx = torch.randperm(N)[:sample_size]
        flattened_act = flattened_act[idx]
        N = sample_size
    
    # Center the data (optional but typically done for covariance)
    flattened_act = flattened_act - flattened_act.mean(dim=0, keepdim=True)
    
    # Determine whether to use gram or covariance based on computational efficiency
    if use_gram or D < N:
        # Use gram matrix (A^T A) - more efficient when feature dimension < sample dimension
        # Compute Frobenius norm squared (equivalent to sum of squared singular values)
        frob_norm_sq = torch.sum(flattened_act**2).item()
        
        # Compute gram matrix and its trace
        gram = torch.matmul(flattened_act.t(), flattened_act)
        trace_gram_squared = torch.sum(gram**2).item()
        
        # Avoid division by zero
        if trace_gram_squared < 1e-9:
            return 0.0
        
        # Compute stable rank: ||A||_F^2 / ||A||_2^2 = Tr(A^T A) / ||A^T A||_2
        stable_rank = (frob_norm_sq**2) / trace_gram_squared
    else:
        # Use covariance matrix (AA^T) - more efficient when sample dimension < feature dimension
        # Compute covariance matrix and its trace
        cov = torch.matmul(flattened_act, flattened_act.t())
        trace_cov = torch.trace(cov).item()
        trace_cov_squared = torch.sum(cov**2).item()
        
        # Avoid division by zero
        if trace_cov_squared < 1e-9:
            return 0.0
        
        # Compute stable rank using trace formula: Tr(C)^2 / Tr(C^2)
        stable_rank = (trace_cov**2) / trace_cov_squared
    
    return stable_rank

###########################################
# Analysis with Single Monitor
###########################################

def analyze_fixed_batch_with_monitor(model, monitor, fixed_batch, fixed_targets=None, 
                                    criterion=None, dead_threshold=0.95, 
                                    corr_threshold=0.99, device='cpu'):
    """
    Use the provided monitor to analyze metrics for a fixed batch.
    Temporarily registers hooks, runs the analysis, then returns to previous state.
    
    Parameters:
        model: The neural network model
        monitor: The NetworkMonitor instance to use
        fixed_batch: Input data batch
        fixed_targets: Corresponding targets (optional, for backward pass)
        criterion: Loss function (optional, for backward pass)
        dead_threshold: Threshold for determining dead neurons
        corr_threshold: Threshold for determining duplicate neurons
        device: Device to run on (cpu or cuda)
        
    Returns:
        Dictionary of metrics for each layer
    """
    # Move data to the specified device if not already there
    if fixed_batch.device != device:
        fixed_batch = fixed_batch.to(device)
    if fixed_targets is not None and fixed_targets.device != device:
        fixed_targets = fixed_targets.to(device)
    
    # Store current hook state
    hooks_were_active = monitor.hooks_active
    
    # Make sure hooks are registered for the analysis
    monitor.register_hooks()
    
    # Run the model in eval mode for consistent results
    model.eval()
    
    # Forward pass
    with torch.set_grad_enabled(criterion is not None):
        outputs = model(fixed_batch)
        
        # Backward pass if criterion is provided
        if criterion is not None and fixed_targets is not None:
            loss = criterion(outputs, fixed_targets)
            loss.backward()
    
    # Get metrics using the latest activations
    metrics = {}
    latest_acts = monitor.get_latest_activations()
    
    for layer_name, act in latest_acts.items():
        # Skip non-tensor activations or certain layers
        if not isinstance(act, torch.Tensor) or \
           layer_name.startswith('dropout') or 'flatten' in layer_name or 'shortcut' in layer_name:
            continue
        
        # Compute metrics
        dead_frac = measure_dead_neurons(act, dead_threshold)
        dup_frac = measure_duplicate_neurons(act, corr_threshold)
        eff_rank = measure_effective_rank(act)
        stable_rank = measure_stable_rank(act)
        
        metrics[layer_name] = {
            'dead_fraction': dead_frac,
            'dup_fraction': dup_frac,
            'eff_rank': eff_rank,
            'stable_rank': stable_rank
        }
    
    # Restore previous hook state if hooks weren't active before
    if not hooks_were_active:
        monitor.remove_hooks()
    
    return metrics

###########################################
# Training and Evaluation Functions
###########################################

def set_seed(seed):
    """Set random seed for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_cifar10_data(batch_size=128):
    """Load CIFAR10 dataset with standard transformations"""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    trainloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    testloader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Create fixed batches for consistent metric measurement
    # 1. Training fixed batch
    train_indices = list(range(500))  # Use first 500 samples for fixed training batch
    fixed_train_set = Subset(trainset, train_indices)
    fixed_trainloader = DataLoader(fixed_train_set, batch_size=100, shuffle=False)
    
    # 2. Validation fixed batch
    val_indices = list(range(500))  # Use first 500 samples for fixed validation batch
    fixed_val_set = Subset(testset, val_indices)
    fixed_valloader = DataLoader(fixed_val_set, batch_size=100, shuffle=False)
    
    return trainloader, testloader, fixed_trainloader, fixed_valloader


def get_cifar10_data_with_class_selection(
        batch_size=128, 
        sample_classes=None):
    
    # Define transformations
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # Load the full datasets
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    
    # Filter by classes and create subsets with remapped labels
    if sample_classes:
        # Create a mapping from original class labels to new consecutive indices
        class_mapping = {original_class: i for i, original_class in enumerate(sample_classes)}
        
        # Create new datasets with remapped labels
        train_data = []
        train_targets = []
        for image, label in trainset:
            if label in sample_classes:
                # Remap the label to a new consecutive index
                new_label = class_mapping[label]
                train_data.append(image)
                train_targets.append(new_label)
        
        test_data = []
        test_targets = []
        for image, label in testset:
            if label in sample_classes:
                # Remap the label to a new consecutive index
                new_label = class_mapping[label]
                test_data.append(image)
                test_targets.append(new_label)
        
        # Create new datasets from the filtered and remapped data
        from torch.utils.data import TensorDataset
        trainset = TensorDataset(torch.stack(train_data), torch.tensor(train_targets))
        testset = TensorDataset(torch.stack(test_data), torch.tensor(test_targets))
    
    # Create data loaders
    trainloader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Create fixed batches for consistent metric measurement
    from torch.utils.data import Subset
    
    # 1. Training fixed batch - use a subset of training data
    fixed_train_set = Subset(trainset, range(min(500, len(trainset))))
    fixed_trainloader = DataLoader(fixed_train_set, batch_size=batch_size, shuffle=False)
    
    # 2. Validation fixed batch - use a subset of test data
    fixed_val_set = Subset(testset, range(min(500, len(testset))))
    fixed_valloader = DataLoader(fixed_val_set, batch_size=batch_size, shuffle=False)
    
    return trainloader, testloader, fixed_trainloader, fixed_valloader

def train_with_separate_monitors(model, trainloader, testloader, fixed_trainloader, fixed_valloader,
                                 train_monitor, val_monitor,learning_rate=0.001,
                                 num_epochs=20, metrics_frequency=100, device='cpu'):
    """
    Train the model and periodically measure metrics on fixed batches,
    using separate NetworkMonitor instances for training and validation.
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    
    # Initialize metric tracking
    train_losses = []
    test_accs = []
    
    # Metrics storage
    training_metrics_history = defaultdict(lambda: defaultdict(list))
    validation_metrics_history = defaultdict(lambda: defaultdict(list))
    step_history = []
    
    # Fixed batches for consistent metric evaluation
    fixed_train_batch, fixed_train_targets = next(iter(fixed_trainloader))
    fixed_val_batch, fixed_val_targets = next(iter(fixed_valloader))
    
    # Move fixed batches to device
    fixed_train_batch, fixed_train_targets = fixed_train_batch.to(device), fixed_train_targets.to(device)
    fixed_val_batch, fixed_val_targets = fixed_val_batch.to(device), fixed_val_targets.to(device)
    
    # Baseline metrics (before training)
    print("Measuring baseline metrics before training...")
    
    # Training metrics with train_monitor
    train_metrics = analyze_fixed_batch_with_monitor(
        model, train_monitor, fixed_train_batch, fixed_train_targets, criterion, device=device
    )
    
    # Validation metrics with val_monitor
    val_metrics = analyze_fixed_batch_with_monitor(
        model, val_monitor, fixed_val_batch, device=device
    )
    
    # Print baseline metrics
    print("\n=== Training Batch Metrics (before training) ===")
    for layer_name in (train_metrics.keys()):
        metrics = train_metrics[layer_name]
        print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
              f"Dup: {metrics['dup_fraction']:8.3f}, " +
              f"EffRank: {metrics['eff_rank']:8.3f}, " +
              f"StableRank: {metrics['stable_rank']:8.3f}")
    
    print("\n=== Validation Batch Metrics (before training) ===")
    for layer_name in (val_metrics.keys()):
        metrics = val_metrics[layer_name]
        print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
              f"Dup: {metrics['dup_fraction']:8.3f}, " +
              f"EffRank: {metrics['eff_rank']:8.3f}, " +
              f"StableRank: {metrics['stable_rank']:8.3f}")
    
    # Add baseline metrics to history
    step_history.append(0)
    for layer_name, metrics in train_metrics.items():
        for metric_name, value in metrics.items():
            training_metrics_history[layer_name][metric_name].append(value)
    
    for layer_name, metrics in val_metrics.items():
        for metric_name, value in metrics.items():
            validation_metrics_history[layer_name][metric_name].append(value)
    
    total_steps = 0
    
    # Training loop
    start_time = time.time()
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Ensure hooks are NOT active during regular training
        train_monitor.remove_hooks()
        val_monitor.remove_hooks()
        
        for i, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Increment step counter
            total_steps += 1
            
            # Measure metrics periodically
            if total_steps % metrics_frequency == 0:
                # Measure metrics on fixed batches
                print(f"\nStep {total_steps}: Measuring metrics...")
                
                # # Clear previous data in monitors
                # train_monitor.clear_data()
                # val_monitor.clear_data()
                
                # Training metrics with train_monitor and backward pass
                train_metrics = analyze_fixed_batch_with_monitor(
                    model, train_monitor, fixed_train_batch, fixed_train_targets, criterion, device=device
                )
                
                # Validation metrics with val_monitor without backward pass
                val_metrics = analyze_fixed_batch_with_monitor(
                    model, val_monitor, fixed_val_batch, device=device
                )
                
                # Store metrics
                step_history.append(total_steps)
                for layer_name, metrics in train_metrics.items():
                    for metric_name, value in metrics.items():
                        training_metrics_history[layer_name][metric_name].append(value)
                
                for layer_name, metrics in val_metrics.items():
                    for metric_name, value in metrics.items():
                        validation_metrics_history[layer_name][metric_name].append(value)
                
                # Print metrics for all layers
                print(f"\n=== Training Batch Metrics (step {total_steps}) ===")
                for layer_name in (train_metrics.keys()):
                    metrics = train_metrics[layer_name]
                    print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
                          f"Dup: {metrics['dup_fraction']:8.3f}, " +
                          f"EffRank: {metrics['eff_rank']:8.3f}, " +
                          f"StableRank: {metrics['stable_rank']:8.3f}")
                
                print(f"\n=== Validation Batch Metrics (step {total_steps}) ===")
                for layer_name in (val_metrics.keys()):
                    metrics = val_metrics[layer_name]
                    print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
                          f"Dup: {metrics['dup_fraction']:8.3f}, " +
                          f"EffRank: {metrics['eff_rank']:8.3f}, " +
                          f"StableRank: {metrics['stable_rank']:8.3f}")
        
        # Evaluate on test set at the end of each epoch
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        # Calculate epoch statistics
        train_loss = running_loss / len(trainloader)
        test_acc = 100. * correct / total
        
        # Store epoch statistics
        train_losses.append(train_loss)
        test_accs.append(test_acc)
        
        # Print epoch summary
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f} | Test Acc: {test_acc:.2f}%')
        print(f'Time: {time.time() - start_time:.2f}s')
    
    # Final evaluation
    print("\nFinal metrics:")
    
    # # Clear previous data in monitors
    # train_monitor.clear_data()
    # val_monitor.clear_data()
    
    # Training metrics with train_monitor
    train_metrics = analyze_fixed_batch_with_monitor(
        model, train_monitor, fixed_train_batch, fixed_train_targets, criterion, device=device
    )
    
    # Validation metrics with val_monitor
    val_metrics = analyze_fixed_batch_with_monitor(
        model, val_monitor, fixed_val_batch, device=device
    )
    
    # Print final metrics
    print("\n=== Final Training Batch Metrics ===")
    for layer_name in (train_metrics.keys()):
        metrics = train_metrics[layer_name]
        print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
              f"Dup: {metrics['dup_fraction']:8.3f}, " +
              f"EffRank: {metrics['eff_rank']:8.3f}, " +
              f"StableRank: {metrics['stable_rank']:8.3f}")
    
    print("\n=== Final Validation Batch Metrics ===")
    for layer_name in (val_metrics.keys()):
        metrics = val_metrics[layer_name]
        print(f"{layer_name:15}: Dead: {metrics['dead_fraction']:8.3f}, " +
              f"Dup: {metrics['dup_fraction']:8.3f}, " +
              f"EffRank: {metrics['eff_rank']:8.3f}, " +
              f"StableRank: {metrics['stable_rank']:8.3f}")
    
    # Store final metrics
    step_history.append(total_steps)
    for layer_name, metrics in train_metrics.items():
        for metric_name, value in metrics.items():
            training_metrics_history[layer_name][metric_name].append(value)
    
    for layer_name, metrics in val_metrics.items():
        for metric_name, value in metrics.items():
            validation_metrics_history[layer_name][metric_name].append(value)
    
    # Ensure hooks are removed at the end
    train_monitor.remove_hooks()
    val_monitor.remove_hooks()
    
    return {
        'train_losses': train_losses,
        'test_accs': test_accs,
        'training_metrics_history': dict(training_metrics_history),
        'validation_metrics_history': dict(validation_metrics_history),
        'step_history': step_history,
        'train_monitor': train_monitor,
        'val_monitor': val_monitor
    }

###########################################
# Visualization Functions
###########################################

def plot_training_curves(history, save_path=None):
    """Plot training loss and test accuracy curves"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot training loss
    ax1.plot(history['train_losses'])
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    
    # Plot test accuracy
    ax2.plot(history['test_accs'])
    ax2.set_title('Test Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(f"{save_path}/training_curves.png", dpi=300, bbox_inches='tight')
    plt.show()

def plot_metric_evolution(history, metric_name, layer_names=None, is_train=True, save_path=None):
    """Plot the evolution of a specific metric for selected layers"""
    prefix = "Training" if is_train else "Validation"
    metrics_history = history['training_metrics_history'] if is_train else history['validation_metrics_history']
    steps = history['step_history']
    
    # If no layers specified, use all available
    if layer_names is None:
        layer_names = list(metrics_history.keys())
    
    # Filter to layers that exist
    layer_names = [layer for layer in layer_names if layer in metrics_history]
    
    plt.figure(figsize=(12, 6))
    for layer in layer_names:
        if layer in metrics_history and metric_name in metrics_history[layer]:
            plt.plot(steps, metrics_history[layer][metric_name], label=layer)
    
    plt.title(f'{prefix} {metric_name.replace("_", " ").title()} Evolution')
    plt.xlabel('Training Steps')
    plt.ylabel(metric_name.replace("_", " ").title())
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(f"{save_path}/{prefix.lower()}_{metric_name}.png", dpi=300, bbox_inches='tight')
    plt.show()

def plot_all_metrics(history, layer_names=None, save_path=None):
    """Plot the evolution of all metrics for selected layers"""
    # Create save directory if specified
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        
    # Get all metrics
    metrics = ['dead_fraction', 'dup_fraction', 'eff_rank', 'stable_rank']
    
    # Plot training metrics
    for metric in metrics:
        plot_metric_evolution(history, metric, layer_names, is_train=True, save_path=save_path)
    
    # Plot validation metrics
    for metric in metrics:
        plot_metric_evolution(history, metric, layer_names, is_train=False, save_path=save_path)
        
def plot_comparison_metrics(history, metric_names=None, layer_names=None, save_path=None):
    """Plot a comparison of metrics between training and validation sets"""
    # If no metrics specified, use all available
    if metric_names is None:
        metric_names = ['dead_fraction', 'dup_fraction', 'eff_rank', 'stable_rank']
    
    # Get important layers if not specified
    if layer_names is None:
        layer_names = [layer_name for layer_name in history['training_metrics_history'].keys()
                      if not (layer_name.startswith('dropout') or 'flatten' in layer_name)]
                      
    # Take only the final values
    final_index = -1
    
    # Create figures for bar plots
    for metric in metric_names:
        plt.figure(figsize=(12, 6))
        
        # Get layers that have this metric in both training and validation
        valid_layers = [layer for layer in layer_names 
                        if layer in history['training_metrics_history'] 
                        and layer in history['validation_metrics_history']
                        and metric in history['training_metrics_history'][layer]
                        and metric in history['validation_metrics_history'][layer]]
        
        # Prepare data for plotting
        train_values = [history['training_metrics_history'][layer][metric][final_index] for layer in valid_layers]
        val_values = [history['validation_metrics_history'][layer][metric][final_index] for layer in valid_layers]
        
        # Set up bar positions
        x = np.arange(len(valid_layers))
        width = 0.35
        
        # Create grouped bars
        plt.bar(x - width/2, train_values, width, label='Training')
        plt.bar(x + width/2, val_values, width, label='Validation')
        
        # Add labels and legend
        plt.xlabel('Layer')
        plt.ylabel(metric.replace('_', ' ').title())
        plt.title(f'Comparison of {metric.replace("_", " ").title()} Between Training and Validation')
        plt.xticks(x, [layer.split('.')[-1] for layer in valid_layers], rotation=45, ha='right')
        plt.legend()
        plt.tight_layout()
        
        if save_path:
            plt.savefig(f"{save_path}/comparison_{metric}.png", dpi=300, bbox_inches='tight')
        plt.show()

###########################################
# Main Function
###########################################

if __name__ == "__main__":
    
    
    # Training settings 
    seed = 41
    sample_classes = [0,1,]
    batch_size = 256
    learning_rate = 0.001
    num_epochs = 20
    metrics_frequency = 100  # Measure metrics every 100 steps

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Set random seed for reproducibility
    set_seed(seed)
    
    num_classes = len(sample_classes)
    print("Loading CIFAR10 dataset...")
    trainloader, testloader, fixed_trainloader, fixed_valloader = get_cifar10_data_with_class_selection(
        batch_size=128,
        sample_classes=
        sample_classes)
    
    # Create CNN model
    print("Creating model...")
    # model = MLP(
    #     input_size=3*32*32, 
    #     hidden_sizes=[512]*10, 
    #     activation='selu',
    #     # normalization='batch',
    #     norm_after_activation=False,
    #     output_size = num_classes,
    # )
    # model = CNN(
    #     in_channels=3,
    #     conv_channels=[64, 128, 256],
    #     kernel_sizes=[3, 3, 3],
    #     strides=[1, 1, 1],
    #     paddings=[1, 1, 1],
    #     fc_hidden_units=[512],
    #     num_classes=num_classes,
    #     activation='selu',
    #     dropout_p=0.0,
    #     use_batchnorm=True
    # )
    model = VisionTransformer(num_classes=num_classes)

    model = model.to(device)

    # Create separate monitors for training and validation
    train_monitor = NetworkMonitor(model)
    val_monitor = NetworkMonitor(model)
    
    # Print model architecture
    print("\nModel Architecture:")
    for name, module in model.named_modules():
        if len(name) > 0:  # Skip the root module
            print(f"{name}: {module.__class__.__name__}")
    
    # Training settings
    
    # Train and evaluate using separate monitors
    print("\nStarting training with separate monitors...")
    history = train_with_separate_monitors(
        model, trainloader, testloader, fixed_trainloader, fixed_valloader,
        train_monitor, val_monitor,learning_rate,
        num_epochs, metrics_frequency, device
    )
    
    # Create a directory for saving plots
    results_dir = './results'
    os.makedirs(results_dir, exist_ok=True)
    
    # Plot results
    print("\nPlotting results...")
    plot_training_curves(history, save_path=results_dir)
    
    # Plot metrics evolution for all layers
    plot_all_metrics(history, save_path=results_dir)
    
    # Plot comparison of metrics between training and validation
    plot_comparison_metrics(history, save_path=results_dir)
    
    # Additional analysis comparing training vs validation metrics
    print("\nComparing final training vs validation metrics:")
    # Get important layers to analyze (can be customized based on model)
    important_layers = [layer_name for layer_name in history['training_metrics_history'].keys()
                       if not (layer_name.startswith('dropout') or 'flatten' in layer_name)]
    
    for layer in important_layers:
        if (layer in history['training_metrics_history'] and 
            layer in history['validation_metrics_history']):
            print(f"\nLayer: {layer}")
            
            train_dead = history['training_metrics_history'][layer]['dead_fraction'][-1]
            val_dead = history['validation_metrics_history'][layer]['dead_fraction'][-1]
            print(f"  Dead neurons: Train {train_dead:.3f} vs Val {val_dead:.3f}")
            90
            train_dup = history['training_metrics_history'][layer]['dup_fraction'][-1]
            val_dup = history['validation_metrics_history'][layer]['dup_fraction'][-1]
            print(f"  Duplicate neurons: Train {train_dup:.3f} vs Val {val_dup:.3f}")
            
            train_rank = history['training_metrics_history'][layer]['eff_rank'][-1]
            val_rank = history['validation_metrics_history'][layer]['eff_rank'][-1]
            print(f"  Effective rank: Train {train_rank:.3f} vs Val {val_rank:.3f}")
            
            train_stable = history['training_metrics_history'][layer]['stable_rank'][-1]
            val_stable = history['validation_metrics_history'][layer]['stable_rank'][-1]
            print(f"  Stable rank: Train {train_stable:.3f} vs Val {val_stable:.3f}")
    
    print("\nDone!")

Using device: cuda
Loading CIFAR10 dataset...
Files already downloaded and verified
Files already downloaded and verified
Creating model...

Model Architecture:
layers: ModuleDict
layers.patch_embed: PatchEmbedding
layers.patch_embed.layers: ModuleDict
layers.patch_embed.layers.proj: Conv2d
layers.pos_drop: Dropout
layers.block_0: TransformerBlock
layers.block_0.layers: ModuleDict
layers.block_0.layers.norm1: LayerNorm
layers.block_0.layers.attn: Attention
layers.block_0.layers.attn.layers: ModuleDict
layers.block_0.layers.attn.layers.qkv: Linear
layers.block_0.layers.attn.layers.attn_drop: Dropout
layers.block_0.layers.attn.layers.proj: Linear
layers.block_0.layers.attn.layers.proj_drop: Dropout
layers.block_0.layers.norm2: LayerNorm
layers.block_0.layers.mlp: TransformerMLP
layers.block_0.layers.mlp.layers: ModuleDict
layers.block_0.layers.mlp.layers.fc1: Linear
layers.block_0.layers.mlp.layers.act: GELU
layers.block_0.layers.mlp.layers.drop1: Dropout
layers.block_0.layers.mlp.layers