In [40]:
import torch
import torch.nn as nn
# Import the utility module and Setup the path
import notebook_utils
notebook_utils.setup_path()

from src.models.layers import TransformerBatchNorm
from src.utils.monitor import NetworkMonitor

from typing import Union, List, Dict, Set

# Replace TypeVar with Union for proper type handling
NormalizationLayer = Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm,  nn.GroupNorm, TransformerBatchNorm]

# Define activation function types properly
ActivationFunction = Union[nn.ReLU, nn.Sigmoid, nn.Tanh, nn.SELU, nn.GELU, nn.SiLU, nn.ELU, 
                         nn.LeakyReLU, nn.PReLU, nn.Threshold, nn.Softmax, nn.LogSoftmax, 
                         nn.Softplus, nn.Softmin, nn.Hardsigmoid, nn.Hardswish, nn.Softshrink, 
                         nn.Hardshrink, nn.Softsign, nn.GLU, nn.CELU, nn.Identity]


class CloneAwareFlatten(nn.Module):
    """
    A custom flatten module that ensures duplicated features remain adjacent when flattening
    convolutional feature maps.
    
    When cloning channels in convolutional layers, the standard nn.Flatten would arrange features
    as [a(0,0), a'(0,0), b(0,0), b'(0,0), ...] where features are grouped by spatial position.
    
    This module rearranges to keep all spatial positions of the same channel together:
    [a(0,0), a'(0,0), a(0,1), a'(0,1), ..., b(0,0), b'(0,0), ...] ensuring duplicated
    features remain adjacent.
    """
    def __init__(self, start_dim=1, end_dim=-1):
        super().__init__()
        self.start_dim = start_dim
        self.end_dim = end_dim
        
    def forward(self, x):
        # Standard flattening for non-starting dimensions or non-4D tensors
        if x.dim() != 4 or self.start_dim > 1:
            start_dim = self.start_dim if self.start_dim >= 0 else x.dim() + self.start_dim
            end_dim = self.end_dim if self.end_dim >= 0 else x.dim() + self.end_dim
            
            shape = x.shape
            new_shape = list(shape[:start_dim]) + [-1]
            return x.reshape(*new_shape)
        
        # Special handling for 4D tensors with channel duplication
        batch_size, channels, height, width = x.shape
        
        # If channels are not even, use standard flattening
        if channels % 2 != 0:
            return x.reshape(batch_size, -1)
        
        half_channels = channels // 2
        
        # Step 1: Reshape to separate duplicated channels
        # [batch, channels, h, w] -> [batch, half_channels, 2, h, w]
        x_reshaped = x.view(batch_size, half_channels, 2, height, width)
        
        # Step 2: Permute to get the desired order
        # [batch, half_channels, 2, h, w] -> [batch, half_channels, h, w, 2]
        x_permuted = x_reshaped.permute(0, 1, 3, 4, 2)
        
        # Step 3: Flatten
        # [batch, half_channels, h, w, 2] -> [batch, half_channels * h * w * 2]
        return x_permuted.reshape(batch_size, -1)

def clone_linear(base_module: nn.Linear, cloned_module: nn.Linear):
    # Get module dimensions
    base_in_features = base_module.in_features
    base_out_features = base_module.out_features
    cloned_in_features = cloned_module.in_features
    cloned_out_features = cloned_module.out_features
    
    # Verify expansion factors are valid
    if cloned_in_features % base_in_features != 0 or cloned_out_features % base_out_features != 0:
        raise ValueError(f"Linear module dimensions are not integer multiples: "
                         f"{base_in_features}→{cloned_in_features}, {base_out_features}→{cloned_out_features}")
        
    # Calculate expansion factors
    in_expansion = cloned_in_features // base_in_features
    out_expansion = cloned_out_features // base_out_features
    
    print(f"Cloning Linear module: {base_in_features}→{cloned_in_features}, {base_out_features}→{cloned_out_features}, in expansion: {in_expansion}, out expansion: {out_expansion}")
    
    # Clone the weights with proper scaling
    for i in range(in_expansion):
        for j in range(out_expansion):
            cloned_module.weight.data[j::out_expansion, i::in_expansion] = base_module.weight.data / in_expansion
    
    # Clone the bias if present (no scaling needed for bias)
    if base_module.bias is not None and cloned_module.bias is not None:
        for j in range(out_expansion):
            cloned_module.bias.data[j::out_expansion] = base_module.bias.data
    return cloned_module


def clone_conv1d(base_module: nn.Conv1d, cloned_module: nn.Conv1d):
    # Get module dimensions
    base_in_channels = base_module.in_channels
    base_out_channels = base_module.out_channels
    cloned_in_channels = cloned_module.in_channels
    cloned_out_channels = cloned_module.out_channels
    # Calculate expansion factors
    in_expansion = cloned_in_channels // base_in_channels
    out_expansion = cloned_out_channels // base_out_channels
    
    print(f"Cloning Conv1d module: {base_in_channels}→{cloned_in_channels}, {base_out_channels}→{cloned_out_channels}, in expansion: {in_expansion}, out expansion: {out_expansion}")
    
    # Verify expansion factors are valid
    if cloned_in_channels % base_in_channels != 0 or cloned_out_channels % base_out_channels != 0:
        raise ValueError(f"Conv1d module dimensions are not integer multiples: "
                         f"{base_in_channels}→{cloned_in_channels}, {base_out_channels}→{cloned_out_channels}")
    
    # Clone the weights with proper scaling
    for i in range(in_expansion):
        for j in range(out_expansion):
            cloned_module.weight.data[j::out_expansion, i::in_expansion, :] = base_module.weight.data / in_expansion
    
    # Clone the bias if present (no scaling needed for bias)
    if base_module.bias is not None and cloned_module.bias is not None:
        for j in range(out_expansion):
            cloned_module.bias.data[j::out_expansion] = base_module.bias.data
    return cloned_module

    
def clone_conv2d(base_module: nn.Conv2d, cloned_module: nn.Conv2d):
    # Get module dimensions
    base_in_channels = base_module.in_channels
    base_out_channels = base_module.out_channels
    cloned_in_channels = cloned_module.in_channels
    cloned_out_channels = cloned_module.out_channels
    # Calculate expansion factors
    in_expansion = cloned_in_channels // base_in_channels
    out_expansion = cloned_out_channels // base_out_channels
    
    print(f"Cloning Conv2d module: {base_in_channels}→{cloned_in_channels}, {base_out_channels}→{cloned_out_channels}, in expansion: {in_expansion}, out expansion: {out_expansion}")
    
    # Verify expansion factors are valid
    if cloned_in_channels % base_in_channels != 0 or cloned_out_channels % base_out_channels != 0:
        raise ValueError(f"Conv2d module dimensions are not integer multiples: "
                         f"{base_in_channels}→{cloned_in_channels}, {base_out_channels}→{cloned_out_channels}")
    
    # Clone the weights with proper scaling
    for i in range(in_expansion):
        for j in range(out_expansion):
            cloned_module.weight.data[j::out_expansion, i::in_expansion, :, :] = base_module.weight.data / in_expansion
    
    # Clone the bias if present (no scaling needed for bias)
    if base_module.bias is not None and cloned_module.bias is not None:
        for j in range(out_expansion):
            cloned_module.bias.data[j::out_expansion] = base_module.bias.data
    return cloned_module
    

def clone_normalization(
    base_module: NormalizationLayer, 
    cloned_module: NormalizationLayer,
) -> NormalizationLayer:
    """Clone normalization layer parameters with proper handling of different types."""
    assert isinstance(cloned_module, type(base_module)), "Cloned module must be of the same type as base module"
    
    # Check properties that exist for the specific normalization type
    if hasattr(base_module, 'affine') and hasattr(cloned_module, 'affine'):
        assert base_module.affine == cloned_module.affine, "Affine property must match"
    
    # Handle BatchNorm-specific properties
    if isinstance(base_module, (nn.BatchNorm1d, nn.BatchNorm2d)):
        if hasattr(base_module, 'track_running_stats') and hasattr(cloned_module, 'track_running_stats'):
            assert base_module.track_running_stats == cloned_module.track_running_stats, "Track running stats property must match"
    
    # Clone weights and biases
    if hasattr(base_module, 'weight') and base_module.weight is not None and cloned_module.weight is not None:
        expansion = cloned_module.weight.data.shape[0] // base_module.weight.data.shape[0] 
        for i in range(expansion):
            cloned_module.weight.data[i::expansion] = base_module.weight.data
            if hasattr(base_module, 'bias') and base_module.bias is not None and cloned_module.bias is not None:
                cloned_module.bias.data[i::expansion] = base_module.bias.data
    
    # Clone running stats for BatchNorm layers
    if hasattr(base_module, 'running_mean') and base_module.running_mean is not None:
        if hasattr(cloned_module, 'running_mean') and cloned_module.running_mean is not None:
            expansion = cloned_module.running_mean.data.shape[0] // base_module.running_mean.data.shape[0]
            for i in range(expansion):
                cloned_module.running_mean.data[i::expansion] = base_module.running_mean.data
                cloned_module.running_var.data[i::expansion] = base_module.running_var.data
    
    # Clone num_batches_tracked for BatchNorm layers
    if hasattr(base_module, 'num_batches_tracked') and base_module.num_batches_tracked is not None:
        if hasattr(cloned_module, 'num_batches_tracked') and cloned_module.num_batches_tracked is not None:
            cloned_module.num_batches_tracked.data.copy_(base_module.num_batches_tracked.data)
    
    return cloned_module
    
    
def clone_embedding(base_module: nn.Embedding, cloned_module: nn.Embedding):
    # Get module dimensions
    base_num_embeddings = base_module.num_embeddings
    base_embedding_dim = base_module.embedding_dim
    cloned_num_embeddings = cloned_module.num_embeddings
    cloned_embedding_dim = cloned_module.embedding_dim
    
    # Calculate expansion factors
    num_expansion = cloned_num_embeddings // base_num_embeddings
    dim_expansion = cloned_embedding_dim // base_embedding_dim
    
    print(f"Cloning Embedding module: {base_num_embeddings}→{cloned_num_embeddings}, {base_embedding_dim}→{cloned_embedding_dim}, num expansion: {num_expansion}, dim expansion: {dim_expansion}")
    
    # Verify expansion factors are valid
    if cloned_num_embeddings % base_num_embeddings != 0 or cloned_embedding_dim % base_embedding_dim != 0:
        raise ValueError(f"Embedding module dimensions are not integer multiples: "
                         f"{base_num_embeddings}→{cloned_num_embeddings}, {base_embedding_dim}→{cloned_embedding_dim}")
    
    # Clone the weights with proper scaling
    for i in range(num_expansion):
        for j in range(dim_expansion):
            cloned_module.weight.data[j::dim_expansion, i::num_expansion] = base_module.weight.data 
    
    return cloned_module


def clone_activation(base_module: ActivationFunction, cloned_module: ActivationFunction) -> ActivationFunction:
    """Clone activation function parameters, handling configuration parameters properly."""
    assert isinstance(cloned_module, type(base_module)), "Cloned module must be of the same type as base module"
    
    # Handle configuration parameters for different activation types
    if isinstance(base_module, nn.LeakyReLU):
        cloned_module.negative_slope = base_module.negative_slope
    
    elif isinstance(base_module, (nn.ELU, nn.CELU)):
        cloned_module.alpha = base_module.alpha
    
    elif isinstance(base_module, nn.Threshold):
        cloned_module.threshold = base_module.threshold
        cloned_module.value = base_module.value
    
    elif isinstance(base_module, (nn.Softmax, nn.LogSoftmax)):
        cloned_module.dim = base_module.dim
    
    elif isinstance(base_module, (nn.Hardshrink, nn.Softshrink)):
        cloned_module.lambd = base_module.lambd
        
    elif isinstance(base_module, nn.GLU):
        cloned_module.dim = base_module.dim
    
    # Handle PReLU specifically (has learnable parameters)
    elif isinstance(base_module, nn.PReLU):
        if base_module.num_parameters == 1 and cloned_module.num_parameters > 1:
            # If base is a single parameter, broadcast to all channels
            cloned_module.weight.data.fill_(base_module.weight.data[0])
        elif base_module.num_parameters > 1 and cloned_module.num_parameters > 1:
            # Channel-wise parameters need proper expansion
            expansion = cloned_module.num_parameters // base_module.num_parameters
            for i in range(expansion):
                cloned_module.weight.data[i::expansion] = base_module.weight.data
        else:
            # Direct copy if dimensions match
            cloned_module.weight.data.copy_(base_module.weight.data)
    
    # Handle other parameterized activation functions if they have weights
    # This is a general catch-all for any other activation function with parameters
    elif hasattr(base_module, 'weight') and hasattr(cloned_module, 'weight'):
        if base_module.weight is not None and cloned_module.weight is not None:
            if cloned_module.weight.data.shape == base_module.weight.data.shape:
                cloned_module.weight.data.copy_(base_module.weight.data)
    
    return cloned_module


def clone_dropout(base_module: nn.Dropout, cloned_module: nn.Dropout):
    """Clone dropout module parameters."""
    assert cloned_module.p == base_module.p, "Dropout probability must match"
    # Print warning if dropout p > 0
    if cloned_module.p > 0:
        print(f"Warning: Dropout probability is set to {cloned_module.p}, cloning is not perfect")
    return cloned_module

def clone_flatten(base_module: nn.Flatten) -> CloneAwareFlatten:
    """
    Clone parameters from a standard Flatten and return a new CloneAwareFlatten.
    
    Args:
        base_module: base nn.Flatten module
        
    Returns:
        A new CloneAwareFlatten module with the same parameters
    """
    return CloneAwareFlatten(
        start_dim=base_module.start_dim,
        end_dim=base_module.end_dim
    )


def is_parameter_free(module: nn.Module) -> bool:
    """Check if a module has no parameters."""
    return len(list(module.parameters())) == 0


def clone_parameter_free(base_module: nn.Module, cloned_module: nn.Module) -> nn.Module:
    """Clone a parameter-free module."""
    assert isinstance(cloned_module, type(base_module)), "Cloned module must be of the same type as base module"
    assert is_parameter_free(base_module), "base module must be parameter free"
    assert is_parameter_free(cloned_module), "Cloned module must be parameter free"
    
    # For parameter-free modules, there's no need to copy weights
    # Just make sure they're of the same type, which we've already checked
    return cloned_module


# Validation functions

def validate_activation_cloning(base_module: ActivationFunction, cloned_module: ActivationFunction):
    assert isinstance(cloned_module, type(base_module)), "Cloned module must be of the same type as base module"
    
    # Validate configuration parameters for different activation types
    if isinstance(base_module, nn.LeakyReLU):
        assert base_module.negative_slope == cloned_module.negative_slope, "LeakyReLU negative_slope does not match"
    
    elif isinstance(base_module, (nn.ELU, nn.CELU)):
        assert base_module.alpha == cloned_module.alpha, "Alpha parameter does not match"
    
    elif isinstance(base_module, nn.Threshold):
        assert base_module.threshold == cloned_module.threshold, "Threshold value does not match"
        assert base_module.value == cloned_module.value, "Replacement value does not match"
    
    elif isinstance(base_module, (nn.Softmax, nn.LogSoftmax)):
        assert base_module.dim == cloned_module.dim, "Dimension parameter does not match"
    
    elif isinstance(base_module, (nn.Hardshrink, nn.Softshrink)):
        assert base_module.lambd == cloned_module.lambd, "Lambda parameter does not match"
        
    elif isinstance(base_module, nn.GLU):
        assert base_module.dim == cloned_module.dim, "Dimension parameter does not match"
    
    # Validate PReLU parameters
    elif isinstance(base_module, nn.PReLU):
        if base_module.num_parameters == 1 and cloned_module.num_parameters > 1:
            # All elements should be equal to the single parameter
            assert torch.all(cloned_module.weight.data == base_module.weight.data[0])
        elif base_module.num_parameters > 1 and cloned_module.num_parameters > 1:
            expansion = cloned_module.num_parameters // base_module.num_parameters
            for i in range(expansion):
                assert torch.allclose(cloned_module.weight.data[i::expansion], base_module.weight.data)
    
    print("Passed all tests")
    return True


def validate_dropout_cloning(base_module: nn.Dropout, cloned_module: nn.Dropout):
    assert cloned_module.p == base_module.p, "Dropout probability must match"
    print("Passed all tests")
    return True


def validate_embedding_cloning(base_module: nn.Embedding, cloned_module: nn.Embedding):
    num_expansion = cloned_module.num_embeddings // base_module.num_embeddings
    dim_expansion = cloned_module.embedding_dim // base_module.embedding_dim
    for j in range(num_expansion):
        for i in range(dim_expansion):
            assert torch.allclose(cloned_module.weight.data[j::num_expansion, i::dim_expansion], base_module.weight.data)
    print("Passed all tests")
    return True


def validate_normalization_cloning(base_module: NormalizationLayer, cloned_module: NormalizationLayer):
    assert isinstance(cloned_module, type(base_module)), "Cloned module must be of the same type as base module"
    
    if hasattr(base_module, 'weight') and base_module.weight is not None and hasattr(cloned_module, 'weight'):
        expansion = cloned_module.weight.data.shape[0] // base_module.weight.data.shape[0] 
        for i in range(expansion):
            assert torch.allclose(cloned_module.weight.data[i::expansion], base_module.weight.data)
            
            if hasattr(base_module, 'bias') and base_module.bias is not None and hasattr(cloned_module, 'bias'):
                assert torch.allclose(cloned_module.bias.data[i::expansion], base_module.bias.data)
    
    # Check running stats for BatchNorm layers
    if hasattr(base_module, 'running_mean') and base_module.running_mean is not None:
        if hasattr(cloned_module, 'running_mean') and cloned_module.running_mean is not None:
            expansion = cloned_module.running_mean.data.shape[0] // base_module.running_mean.data.shape[0]
            for i in range(expansion):
                assert torch.allclose(cloned_module.running_mean.data[i::expansion], base_module.running_mean.data)
                assert torch.allclose(cloned_module.running_var.data[i::expansion], base_module.running_var.data)
    
    print("Passed all tests")
    

def validate_linear_cloning(base_module: nn.Linear, cloned_module: nn.Linear):
    in_expansion = cloned_module.in_features // base_module.in_features
    out_expansion = cloned_module.out_features // base_module.out_features
    for j in range(out_expansion):
        for i in range(in_expansion):
            assert torch.allclose(cloned_module.weight.data[j::out_expansion, i::in_expansion], base_module.weight.data/in_expansion)
            assert torch.allclose(cloned_module.bias.data[j::out_expansion], base_module.bias.data)
    print("Passed all tests")
    
    
def validate_conv1d_cloning(base_module: nn.Conv1d, cloned_module: nn.Conv1d):
    in_expansion = cloned_module.in_channels // base_module.in_channels
    out_expansion = cloned_module.out_channels // base_module.out_channels
    for j in range(out_expansion):
        for i in range(in_expansion):
            assert torch.allclose(cloned_module.weight.data[j::out_expansion, i::in_expansion, :], base_module.weight.data/in_expansion)
            assert torch.allclose(cloned_module.bias.data[j::out_expansion], base_module.bias.data)
    print("Passed all tests")
    

def validate_conv2d_cloning(base_module: nn.Conv2d, cloned_module: nn.Conv2d):
    in_expansion = cloned_module.in_channels // base_module.in_channels
    out_expansion = cloned_module.out_channels // base_module.out_channels
    for j in range(out_expansion):
        for i in range(in_expansion):
            assert torch.allclose(cloned_module.weight.data[j::out_expansion, i::in_expansion, :, :], base_module.weight.data/in_expansion)
            assert torch.allclose(cloned_module.bias.data[j::out_expansion], base_module.bias.data)
    print("Passed all tests")



def clone_module(
    base_module: nn.Module, 
    cloned_module: nn.Module,
) -> bool:
    """
    Clone parameters from a base module to a cloned module.
    
    Args:
        base_module: base module with smaller dimensions
        cloned_module: Target module with larger dimensions
        
    Returns:
        bool: True if cloning was successful, False otherwise
    """
    success = True
    
    # Define normalization and activation types inline for easier checking
    norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm)
    activation_types = (nn.ReLU, nn.Sigmoid, nn.Tanh, nn.SELU, nn.GELU, nn.SiLU, nn.ELU, nn.LeakyReLU, 
                       nn.PReLU, nn.Threshold, nn.Softmax, nn.LogSoftmax, nn.Softplus, nn.Softmin, 
                       nn.Hardsigmoid, nn.Hardswish, nn.Softshrink, nn.Hardshrink, nn.Softsign, 
                       nn.GLU, nn.CELU, nn.Identity)
    
    if isinstance(base_module, nn.Linear):
        clone_linear(base_module, cloned_module)
    elif isinstance(base_module, nn.Conv1d):
        clone_conv1d(base_module, cloned_module)
    elif isinstance(base_module, nn.Conv2d):
        clone_conv2d(base_module, cloned_module)
    elif isinstance(base_module, norm_types):
        clone_normalization(base_module, cloned_module)
    elif isinstance(base_module, nn.Embedding):
        clone_embedding(base_module, cloned_module)
    elif isinstance(base_module, activation_types):
        clone_activation(base_module, cloned_module)
    elif isinstance(base_module, nn.Dropout):
        clone_dropout(base_module, cloned_module)
    elif isinstance(base_module, nn.Flatten):
        pass # Flatten is handled separately
    elif is_parameter_free(base_module) and is_parameter_free(cloned_module):
        clone_parameter_free(base_module, cloned_module)
    else:
        success = False
        print(f"Unsupported module type: {type(base_module)}")
    
    return success



def model_clone(base_model: nn.Module, cloned_model: nn.Module) -> nn.Module:
    """
    Clone parameters from a base model to a cloned model.
    
    Args:
        base_model: base model with smaller dimensions
        cloned_model: Target model with larger dimensions
        
    Returns:
        cloned_model: The target model with cloned parameters
    """
    # First, replace all Flatten modules with CloneAwareFlatten
    for name, module in list(cloned_model.named_modules()):
        if isinstance(module, nn.Flatten):
            parent_name = '.'.join(name.split('.')[:-1])
            module_name = name.split('.')[-1]
            
            # Find parent module to modify
            if parent_name:
                parent = cloned_model.get_submodule(parent_name)
            else:
                parent = cloned_model
                
            # Create and replace with CloneAwareFlatten
            setattr(parent, module_name, CloneAwareFlatten(
                start_dim=module.start_dim,
                end_dim=module.end_dim
            ))
            print(f"Replaced Flatten with CloneAwareFlatten at {name}")
    
    # Second, handle direct parameters of the model (not within modules)
    for name, param in list(base_model.named_parameters(recurse=False)):
        if hasattr(cloned_model, name):
            base_param = getattr(base_model, name)
            cloned_param = getattr(cloned_model, name)
            
            # Check if dimensions differ and can be expanded
            if base_param.shape != cloned_param.shape:
                # For embedding dimensions (typically last dimension in transformers)
                if len(base_param.shape) >= 2 and base_param.shape[:-1] == cloned_param.shape[:-1]:
                    base_dim = base_param.shape[-1]
                    cloned_dim = cloned_param.shape[-1]
                    
                    if cloned_dim % base_dim == 0:
                        expansion = cloned_dim // base_dim
                        # Duplicate across embedding dimension
                        for i in range(expansion):
                            cloned_param.data[..., i::expansion] = base_param.data
                        print(f"Cloned parameter {name} with embedding expansion {expansion}")
                    else:
                        print(f"Warning: Parameter {name} dimensions don't match and can't be expanded automatically")
                else:
                    print(f"Warning: Parameter {name} shapes don't match and can't be expanded automatically")
            else:
                # Exact shape match, just copy
                cloned_param.data.copy_(base_param.data)
                print(f"Cloned parameter {name} with direct copy")
    
    # Finally, process module parameters
    for name, base_module in base_model.named_modules():
        try:
            cloned_module = cloned_model.get_submodule(name)
            print(f"Cloning module {name}")
            clone_module(base_module, cloned_module)
        except AttributeError:
            print(f"Warning: Could not find matching module for {name}")
    
    return cloned_model



def test_activation_cloning(base_model, cloned_model, input, target, tolerance=1e-3, check_equality=False):
    criterion = nn.CrossEntropyLoss()
    base_monitor = NetworkMonitor(base_model,)
    cloned_monitor = NetworkMonitor(cloned_model,)
    base_monitor.register_hooks()
    cloned_monitor.register_hooks()

    y1 = base_model(input)
    y2 = cloned_model(input)
    l1 = criterion(y1, target)
    l2 = criterion(y2, target)
    l1.backward()
    l2.backward()
    if check_equality:
        assert torch.allclose(y1, y2,atol=tolerance), "Outputs do not match after cloning"

    un_explained_vars = []
    for act_type in ['forward', 'backward']:
        if act_type == 'forward':
            base_acts = base_monitor.get_latest_activations()
            clone_acts = cloned_monitor.get_latest_activations()
        elif act_type == 'backward':
            base_acts = base_monitor.get_latest_gradients()
            clone_acts = cloned_monitor.get_latest_gradients()

        for key, a1 in base_acts.items():
            a2 = clone_acts[key]
            s1, s2 = torch.tensor(a1.shape), torch.tensor(a2.shape)
            print(f"key: {key}, a1: {a1.shape}, a2: {a2.shape}")
            i = (s1 != s2).nonzero()
            if len(i)==0:
                if check_equality:
                    assert torch.allclose(a1, a2, atol=tolerance), f"Activations for {key} do not match"
            elif len(i)==1:
                i = i[0][0]
                expansion = a2.shape[i] // a1.shape[i]
                # check expansion depending on the dimension 
                for j in range(expansion):
                    print(f"mismatch dim: {i}, checking slice: {j}, expansion: {expansion}")
                    slices = []
                    if i==0:
                        slice = a2[j::expansion]
                    elif i==1:
                        slice = a2[:, j::expansion]
                    elif i==2:
                        slice = a2[:, :, j::expansion]
                    elif i==3:
                        slice = a2[:, :, :, j::expansion]
                    slices.append(slice)
                    if check_equality:
                        assert torch.allclose(slice, a1, atol=tolerance), f"Activations for {key} do not match"
                slices = torch.stack(slices)
                if slices.shape[0]>1:
                    print(f"slices for {key} shape = {slices.shape}")
                    std, rms = slices.std(dim=0), ((slices**2).mean(dim=0)**0.5)
                    unexplained = (std/rms).mean().item()
                    print(f"unexplained variance for {key} is {unexplained}")
                    un_explained_vars[f'{key}_{act_type}'] = unexplained
                    assert unexplained<tolerance, f"unexplained variance is higher than the threshold {tolerance}"

            elif len(i)>1:
                assert False, f"Activations for {key} more than one dimension mismatch, this is unexpected behavior"
                    
            print(f"All {act_type} activations match after cloning up to tolerance {tolerance}")
            return un_explained_vars
    
def test_various_models_cloning(normalization='none', drpout_p=0.0, activation='relu', tolerance=1e-3,check_equality=False):
    from src.models import MLP, CNN, ResNet, VisionTransformer
    
    # generate random input and targets
    x_flat = torch.randn(32, 10) # for MLP 
    x = torch.randn(32, 3, 32, 32)
    y = torch.randint(0, 2, (32,))
    
    
    base_model = MLP(input_size=10, output_size=2, hidden_sizes=[64, 32,], activation=activation, dropout_p=drpout_p, normalization=normalization)
    cloned_model = MLP(input_size=10, output_size=2, hidden_sizes=[64*2, 32*2,], activation=activation, dropout_p=drpout_p, normalization=normalization)
    cloned_model = model_clone(base_model, cloned_model)
    test_activation_cloning(base_model, cloned_model, x_flat, y, tolerance=tolerance, check_equality=check_equality)
    
    base_model = CNN(in_channels=3, num_classes=2, conv_channels=[64, 128, 256], activation=activation, dropout_p=drpout_p, normalization=normalization)
    cloned_model = CNN(in_channels=3, num_classes=2, conv_channels=[64*2, 128*2, 256*2], activation=activation, dropout_p=drpout_p, normalization=normalization)
    cloned_model = model_clone(base_model, cloned_model)
    test_activation_cloning(base_model, cloned_model, x, y, tolerance=tolerance, check_equality=check_equality)

    base_model = ResNet(in_channels=3, num_classes=2, base_channels=64, activation=activation, dropout_p=drpout_p, normalization=normalization)
    cloned_model = ResNet(in_channels=3, num_classes=2, base_channels=64*2, activation=activation, dropout_p=drpout_p, normalization=normalization)
    cloned_model = model_clone(base_model, cloned_model)
    test_activation_cloning(base_model, cloned_model, x, y, tolerance=tolerance, check_equality=check_equality)

    base_model = VisionTransformer(
        in_channels=3, 
        num_classes=2, 
        embed_dim=64, 
        depth=2, 
        dropout_p=drpout_p,
        attn_drop_rate=drpout_p,
        activation=activation,)

    cloned_model = VisionTransformer(
        in_channels=3, 
        num_classes=2, 
        patch_size=4, 
        embed_dim=64*2, 
        depth=2, 
        dropout_p=drpout_p,
        attn_drop_rate=drpout_p,
        activation=activation,)

    cloned_model = model_clone(base_model, cloned_model)

    test_activation_cloning(base_model, cloned_model, x, y, tolerance=tolerance)
    
# if __name__ == "__main__":
#     # Test the cloning functionality with various models
#     for activation in ['relu', 'tanh', 'gelu']:
#         for normalization in ['none', 'layer', 'batch']:
#             test_various_models_cloning(activation=activation, normalization=normalization,drpout_p=0.0, tolerance=1e-8, check_equality=False)    
# if __name__ == "__main__":
#     test_various_models_cloning(activation=activation, normalization=normalization,drpout_p=0.0, tolerance=0.1)
    

/home/amir/Codes/NN-dynamic-scaling already in Python path


In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm import tqdm
import os

# Import the utility module and Setup the path
import notebook_utils
notebook_utils.setup_path()

# Import your cloning functions
from src.models import MLP, CNN, ResNet, VisionTransformer
from src.utils.monitor import NetworkMonitor
# Import the cloning function defined in your code

# For models that require flattened input (like MLP)
class ModelWrapper(nn.Module):
    """Wrapper to handle input reshaping for MLP models."""
    def __init__(self, model, flatten=False):
        super(ModelWrapper, self).__init__()
        self.model = model
        self.flatten = flatten
        
    def forward(self, x):
        if self.flatten:
            batch_size = x.size(0)
            x = x.view(batch_size, -1)
        return self.model(x)

def load_cifar10(batch_size=128):
    """Load and prepare CIFAR-10 dataset with data augmentation."""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    # Load datasets
    trainset = torchvision.datasets.CIFAR10(
        root='../data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(
        root='../data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2)
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck')
    
    return trainloader, testloader, classes

def train_epoch(model, trainloader, criterion, optimizer, device):
    """Train the model for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(trainloader, desc="Training")
    for inputs, targets in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Track statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        progress_bar.set_postfix({
            'loss': running_loss / (progress_bar.n + 1),
            'acc': 100. * correct / total
        })
    
    return running_loss / len(trainloader), 100. * correct / total

def evaluate(model, testloader, criterion, device):
    """Evaluate the model on the test set."""
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    return test_loss / len(testloader), 100. * correct / total

def create_model(model_type, expanded=False, activation='relu', normalization='batch', dropout_p=0.0):
    """Create a model based on the specified type and configuration."""
    if model_type == 'mlp':
        if expanded:
            return MLP(input_size=3*32*32, output_size=10, hidden_sizes=[512*2, 256*2], 
                      activation=activation, dropout_p=dropout_p, normalization=normalization)
        else:
            return MLP(input_size=3*32*32, output_size=10, hidden_sizes=[512, 256], 
                      activation=activation, dropout_p=dropout_p, normalization=normalization)
    
    elif model_type == 'cnn':
        if expanded:
            return CNN(in_channels=3, num_classes=10, conv_channels=[64*2, 128*2, 256*2], 
                      activation=activation, dropout_p=dropout_p, normalization=normalization)
        else:
            return CNN(in_channels=3, num_classes=10, conv_channels=[64, 128, 256], 
                      activation=activation, dropout_p=dropout_p, normalization=normalization)
    
    elif model_type == 'resnet':
        if expanded:
            return ResNet(in_channels=3, num_classes=10, base_channels=64*2, 
                         activation=activation, dropout_p=dropout_p, normalization=normalization)
        else:
            return ResNet(in_channels=3, num_classes=10, base_channels=64, 
                         activation=activation, dropout_p=dropout_p, normalization=normalization)
    
    elif model_type == 'vit':
        if expanded:
            return VisionTransformer(in_channels=3, num_classes=10, embed_dim=64*2, depth=2, 
                                    dropout_p=dropout_p, attn_drop_rate=dropout_p, activation=activation)
        else:
            return VisionTransformer(in_channels=3, num_classes=10, embed_dim=64, depth=2, 
                                    dropout_p=dropout_p, attn_drop_rate=dropout_p, activation=activation)
    
    else:
        raise ValueError(f"Unknown model type: {model_type}")

def run_cloning_experiment(model_type='cnn', num_epochs=5, activation='relu', normalization='batch', 
                     dropout_p=0.0, lr=0.01, weight_decay=5e-4, batch_size=128, 
                     validate_cloning_every=2, tolerance=1e-3):
    """
    Run the complete cloning experiment in a notebook environment.
    
    Args:
        model_type: Type of model ('mlp', 'cnn', 'resnet', 'vit')
        num_epochs: Number of epochs to train each phase
        activation: Activation function to use
        normalization: Normalization type
        dropout_p: Dropout probability
        lr: Learning rate
        weight_decay: Weight decay for optimizer
        batch_size: Batch size for training
        validate_cloning_every: Check cloning property every N epochs
        tolerance: Tolerance for activation similarity check
    
    Returns:
        dict: Results dictionary containing training metrics and cloning validation results
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load CIFAR-10
    trainloader, testloader, classes = load_cifar10(batch_size)
    
    # Set up criterion
    criterion = nn.CrossEntropyLoss()
    
    # Results tracking
    results = {
        'base_model': {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [], 'epoch_times': []},
        'cloned_model': {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [], 'epoch_times': []},
        'scratch_model': {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [], 'epoch_times': []}
    }
    
    # 1. Train base model
    print(f"\n{'='*20} Training base {model_type.upper()} model {'='*20}")
    base_model = create_model(model_type, expanded=False, activation=activation, 
                            normalization=normalization, dropout_p=dropout_p)
    
    # Wrap MLP model to handle input reshaping
    needs_flatten = model_type == 'mlp'
    base_model = ModelWrapper(base_model, flatten=needs_flatten).to(device)
    
    optimizer = optim.SGD(base_model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        start_time = time.time()
        
        train_loss, train_acc = train_epoch(base_model, trainloader, criterion, optimizer, device)
        test_loss, test_acc = evaluate(base_model, testloader, criterion, device)
        
        epoch_time = time.time() - start_time
        
        results['base_model']['train_loss'].append(train_loss)
        results['base_model']['train_acc'].append(train_acc)
        results['base_model']['test_loss'].append(test_loss)
        results['base_model']['test_acc'].append(test_acc)
        results['base_model']['epoch_times'].append(epoch_time)
        
        print(f"Base model - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%, Time: {epoch_time:.2f}s")
        
        scheduler.step()
    
    # 2. Clone and continue training
    print(f"\n{'='*20} Training cloned {model_type.upper()} model {'='*20}")
    expanded_model = create_model(model_type, expanded=True, activation=activation, 
                                normalization=normalization, dropout_p=dropout_p)
    
    # Extract the inner model for cloning if wrapped
    inner_base_model = base_model.model if needs_flatten else base_model
    
    # Clone the model
    cloned_model_inner = model_clone(inner_base_model, expanded_model)
    cloned_model = ModelWrapper(cloned_model_inner, flatten=needs_flatten).to(device)
    
    # Create an identical model for reference to verify cloning properties
    reference_model = create_model(model_type, expanded=False, activation=activation, 
                                 normalization=normalization, dropout_p=dropout_p)
    reference_model.load_state_dict(inner_base_model.state_dict())
    reference_model = reference_model.to(device)
    
    # Validate initial cloning
    print("Validating initial cloning properties...")
    # Get a small batch for validation
    val_inputs, val_targets = next(iter(testloader))
    val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
    
    # Cloning validation results
    cloning_validation_results = {'epochs': [], 'unexplained_variance': []}
    
    # Initial validation
    if needs_flatten:
        inner_model = cloned_model.model
        val_inputs_flattened = val_inputs.view(val_inputs.size(0), -1)
        try:
            unexplained_var = test_activation_cloning(
                reference_model, inner_model, val_inputs_flattened, val_targets, 
                tolerance=tolerance, check_equality=False, 
            )
            cloning_validation_results['epochs'].append(0)
            cloning_validation_results['unexplained_variance'].append(unexplained_var)
            print(f"Initial cloning validation passed! Unexplained variance: {unexplained_var:.6f}")
        except Exception as e:
            print(f"Initial cloning validation failed: {e}")
    else:
        try:
            unexplained_var = test_activation_cloning(
                reference_model, cloned_model, val_inputs, val_targets, 
                tolerance=tolerance, check_equality=False, 
            )
            cloning_validation_results['epochs'].append(0)
            cloning_validation_results['unexplained_variance'].append(unexplained_var)
            print(f"Initial cloning validation passed! Unexplained variance: {unexplained_var:.6f}")
        except Exception as e:
            print(f"Initial cloning validation failed: {e}")
    
    optimizer = optim.SGD(cloned_model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        start_time = time.time()
        
        train_loss, train_acc = train_epoch(cloned_model, trainloader, criterion, optimizer, device)
        test_loss, test_acc = evaluate(cloned_model, testloader, criterion, device)
        
        epoch_time = time.time() - start_time
        
        results['cloned_model']['train_loss'].append(train_loss)
        results['cloned_model']['train_acc'].append(train_acc)
        results['cloned_model']['test_loss'].append(test_loss)
        results['cloned_model']['test_acc'].append(test_acc)
        results['cloned_model']['epoch_times'].append(epoch_time)
        
        print(f"Cloned model - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%, Time: {epoch_time:.2f}s")
        
        # Periodically validate cloning properties
        if (epoch + 1) % validate_cloning_every == 0 or epoch == num_epochs - 1:
            print(f"Validating cloning properties after epoch {epoch+1}...")
            # Get a fresh batch for validation
            val_inputs, val_targets = next(iter(testloader))
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
            
            try:
                # Need to handle MLP differently due to input flattening
                if needs_flatten:
                    inner_model = cloned_model.model
                    val_inputs_flattened = val_inputs.view(val_inputs.size(0), -1)
                    unexplained_vars = test_activation_cloning(
                        reference_model, inner_model, val_inputs_flattened, val_targets, 
                        tolerance=tolerance, check_equality=False, 
                    )
                else:
                    unexplained_vars = test_activation_cloning(
                        reference_model, cloned_model, val_inputs, val_targets, 
                        tolerance=tolerance, check_equality=False, 
                    )
                
                cloning_validation_results['epochs'].append(epoch + 1)
                cloning_validation_results['unexplained_variance'].append(unexplained_vars)
                print(f"Cloning validation passed! Unexplained variance: {sum(unexplained_vars)/len(unexplained_vars):.6f}")
            except Exception as e:
                print(f"Cloning validation failed: {e}")
        
        scheduler.step()
    
    # Save cloned model and cloning validation results
    cloned_model_path = os.path.join(save_path, f"{model_type}_cloned_model.pth")
    torch.save(cloned_model.state_dict(), cloned_model_path)
    print(f"Cloned model saved to {cloned_model_path}")
    
    # Save cloning validation results
    cloning_results_path = os.path.join(save_path, f"{model_type}_cloning_validation.pth")
    torch.save(cloning_validation_results, cloning_results_path)
    
    # Plot cloning validation results
    if cloning_validation_results['epochs']:
        plt.figure(figsize=(10, 6))
        epochs = cloning_validation_results['epochs']
        unexplained_var = cloning_validation_results['unexplained_variance']
        
        # Filter out None values if any validation failed
        valid_points = [(e, v) for e, v in zip(epochs, unexplained_var) if v is not None]
        if valid_points:
            valid_epochs, valid_vars = zip(*valid_points)
            plt.plot(valid_epochs, valid_vars, 'b-o', label='Unexplained Variance')
            plt.axhline(y=tolerance, color='r', linestyle='--', label=f'Tolerance ({tolerance})')
            plt.title(f'Cloning Property Validation - {model_type.upper()}')
            plt.xlabel('Epoch')
            plt.ylabel('Unexplained Variance')
            plt.grid(True, linestyle='--', alpha=0.7)
            plt.legend()
            plt.savefig(os.path.join(save_path, f"{model_type}_cloning_validation.png"))
            plt.close()
    
    # 3. Train from scratch for 2*num_epochs
    print(f"\n{'='*20} Training expanded {model_type.upper()} model from scratch {'='*20}")
    scratch_model_inner = create_model(model_type, expanded=True, activation=activation, 
                                     normalization=normalization, dropout_p=dropout_p)
    scratch_model = ModelWrapper(scratch_model_inner, flatten=needs_flatten).to(device)
    
    optimizer = optim.SGD(scratch_model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2*num_epochs)
    
    for epoch in range(2*num_epochs):
        print(f"\nEpoch {epoch+1}/{2*num_epochs}")
        start_time = time.time()
        
        train_loss, train_acc = train_epoch(scratch_model, trainloader, criterion, optimizer, device)
        test_loss, test_acc = evaluate(scratch_model, testloader, criterion, device)
        
        epoch_time = time.time() - start_time
        
        results['scratch_model']['train_loss'].append(train_loss)
        results['scratch_model']['train_acc'].append(train_acc)
        results['scratch_model']['test_loss'].append(test_loss)
        results['scratch_model']['test_acc'].append(test_acc)
        results['scratch_model']['epoch_times'].append(epoch_time)
        
        print(f"Scratch model - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%, Time: {epoch_time:.2f}s")
        
        scheduler.step()
    
    # Save scratch model
    scratch_model_path = os.path.join(save_path, f"{model_type}_scratch_model.pth")
    torch.save(scratch_model.state_dict(), scratch_model_path)
    print(f"Scratch model saved to {scratch_model_path}")
    
    # Save results
    results_path = os.path.join(save_path, f"{model_type}_results.pth")
    torch.save(results, results_path)
    
    # Plot and save results
    plot_results(results, model_type, num_epochs, save_path)
    
    return results

def plot_results(results, model_type, num_epochs):
    """Plot training and testing curves for all models."""
    plt.figure(figsize=(18, 15))
    
    # Set up x-axis for each model
    epochs_base = list(range(1, num_epochs + 1))
    epochs_cloned = list(range(num_epochs + 1, 2 * num_epochs + 1))
    epochs_scratch = list(range(1, 2 * num_epochs + 1))
    
    # Combined epochs for full timeline view
    epochs_combined = list(range(1, 2 * num_epochs + 1))
    
    # Plot train loss
    plt.subplot(3, 2, 1)
    plt.plot(epochs_base, results['base_model']['train_loss'], 'b-', label='Base Model')
    plt.plot(epochs_cloned, results['cloned_model']['train_loss'], 'r-', label='Cloned Model')
    plt.plot(epochs_scratch, results['scratch_model']['train_loss'], 'g-', label='From Scratch')
    plt.title(f'Training Loss - {model_type.upper()}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    # Plot train accuracy
    plt.subplot(3, 2, 2)
    plt.plot(epochs_base, results['base_model']['train_acc'], 'b-', label='Base Model')
    plt.plot(epochs_cloned, results['cloned_model']['train_acc'], 'r-', label='Cloned Model')
    plt.plot(epochs_scratch, results['scratch_model']['train_acc'], 'g-', label='From Scratch')
    plt.title(f'Training Accuracy - {model_type.upper()}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    # Plot test loss
    plt.subplot(3, 2, 3)
    plt.plot(epochs_base, results['base_model']['test_loss'], 'b-', label='Base Model')
    plt.plot(epochs_cloned, results['cloned_model']['test_loss'], 'r-', label='Cloned Model')
    plt.plot(epochs_scratch, results['scratch_model']['test_loss'], 'g-', label='From Scratch')
    plt.title(f'Test Loss - {model_type.upper()}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    # Plot test accuracy
    plt.subplot(3, 2, 4)
    plt.plot(epochs_base, results['base_model']['test_acc'], 'b-', label='Base Model')
    plt.plot(epochs_cloned, results['cloned_model']['test_acc'], 'r-', label='Cloned Model')
    plt.plot(epochs_scratch, results['scratch_model']['test_acc'], 'g-', label='From Scratch')
    plt.title(f'Test Accuracy - {model_type.upper()}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    # Plot epoch times
    plt.subplot(3, 2, 5)
    plt.plot(epochs_base, results['base_model']['epoch_times'], 'b-', label='Base Model')
    plt.plot(epochs_cloned, results['cloned_model']['epoch_times'], 'r-', label='Cloned Model')
    plt.plot(epochs_scratch, results['scratch_model']['epoch_times'], 'g-', label='From Scratch')
    plt.title(f'Epoch Training Time - {model_type.upper()}')
    plt.xlabel('Epoch')
    plt.ylabel('Time (seconds)')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    # Combined test accuracy plot (base → cloned vs. scratch)
    plt.subplot(3, 2, 6)
    # Combine base and cloned for continuous line
    combined_acc = results['base_model']['test_acc'] + results['cloned_model']['test_acc']
    plt.plot(epochs_combined, combined_acc, 'b-', label='Base → Cloned')
    plt.plot(epochs_scratch, results['scratch_model']['test_acc'], 'g-', label='From Scratch')
    plt.axvline(x=num_epochs, color='r', linestyle='--', label='Cloning Point')
    plt.title(f'Test Accuracy Comparison - {model_type.upper()}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Plot cloning validation results if available
    if 'cloning_validation' in results and results['cloning_validation']['epochs']:
        plt.figure(figsize=(10, 6))
        epochs = results['cloning_validation']['epochs']
        unexplained_var = results['cloning_validation']['unexplained_variance']
        
        plt.plot(epochs, unexplained_var, 'b-o', label='Unexplained Variance')
        plt.title(f'Cloning Property Validation - {model_type.upper()}')
        plt.xlabel('Epoch')
        plt.ylabel('Unexplained Variance')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend()
        plt.show()
    
    # Print summary
    print("\nExperiment Summary:")
    print(f"{'Model Type':15} {'Final Acc':12} {'Best Acc':12} {'Final Loss':12} {'Avg Time/Epoch':15}")
    print("-" * 70)
    
    final_results = {
        'Base Model': {
            'Final Test Acc': results['base_model']['test_acc'][-1],
            'Best Test Acc': max(results['base_model']['test_acc']),
            'Final Test Loss': results['base_model']['test_loss'][-1],
            'Avg Epoch Time': np.mean(results['base_model']['epoch_times'])
        },
        'Cloned Model': {
            'Final Test Acc': results['cloned_model']['test_acc'][-1],
            'Best Test Acc': max(results['cloned_model']['test_acc']),
            'Final Test Loss': results['cloned_model']['test_loss'][-1],
            'Avg Epoch Time': np.mean(results['cloned_model']['epoch_times'])
        },
        'Scratch Model': {
            'Final Test Acc': results['scratch_model']['test_acc'][-1],
            'Best Test Acc': max(results['scratch_model']['test_acc']),
            'Final Test Loss': results['scratch_model']['test_loss'][-1],
            'Avg Epoch Time': np.mean(results['scratch_model']['epoch_times'])
        }
    }
    
    for model, metrics in final_results.items():
        print(f"{model:15} {metrics['Final Test Acc']:12.2f} {metrics['Best Test Acc']:12.2f} "
              f"{metrics['Final Test Loss']:12.4f} {metrics['Avg Epoch Time']:15.2f}")


/home/amir/Codes/NN-dynamic-scaling already in Python path


In [None]:
# Example usage in a notebook
results = run_cloning_experiment(
    model_type='mlp',  # Options: 'mlp', 'cnn', 'resnet', 'vit'
    num_epochs=5,
    activation='relu',
    normalization='batch',
    dropout_p=0.0,
    validate_cloning_every=1,
    tolerance=1e-3
)

Using device: cuda
Files already downloaded and verified


Files already downloaded and verified


Epoch 1/5


Training: 100%|██████████| 391/391 [00:07<00:00, 54.38it/s, loss=1.77, acc=36.2]


Base model - Train Loss: 1.7612, Train Acc: 36.22%, Test Loss: 1.5666, Test Acc: 43.66%, Time: 9.44s

Epoch 2/5


Training: 100%|██████████| 391/391 [00:07<00:00, 54.27it/s, loss=1.62, acc=42.5]


Base model - Train Loss: 1.5917, Train Acc: 42.46%, Test Loss: 1.4822, Test Acc: 46.78%, Time: 9.50s

Epoch 3/5


Training: 100%|██████████| 391/391 [00:07<00:00, 54.59it/s, loss=1.54, acc=45.4]


Base model - Train Loss: 1.5153, Train Acc: 45.41%, Test Loss: 1.4124, Test Acc: 49.16%, Time: 9.46s

Epoch 4/5


Training: 100%|██████████| 391/391 [00:07<00:00, 54.07it/s, loss=1.45, acc=48.1]


Base model - Train Loss: 1.4479, Train Acc: 48.13%, Test Loss: 1.3661, Test Acc: 50.86%, Time: 9.51s

Epoch 5/5


Training:  98%|█████████▊| 385/391 [00:06<00:00, 64.82it/s, loss=1.43, acc=49.4]