In [68]:
import torch
import torch.nn as nn
# Import the utility module and Setup the path
import notebook_utils
notebook_utils.setup_path()


import torch
import torch.nn as nn

from typing import Union, List, Dict, Set

# Replace TypeVar with Union for proper type handling
NormalizationLayer = Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm]

# Define activation function types properly
ActivationFunction = Union[nn.ReLU, nn.Sigmoid, nn.Tanh, nn.SELU, nn.GELU, nn.SiLU, nn.ELU, 
                         nn.LeakyReLU, nn.PReLU, nn.Threshold, nn.Softmax, nn.LogSoftmax, 
                         nn.Softplus, nn.Softmin, nn.Hardsigmoid, nn.Hardswish, nn.Softshrink, 
                         nn.Hardshrink, nn.Softsign, nn.GLU, nn.CELU, nn.Identity]


class CloneAwareFlatten(nn.Module):
    """
    A custom flatten module that ensures duplicated features remain adjacent when flattening
    convolutional feature maps.
    
    When cloning channels in convolutional layers, the standard nn.Flatten would arrange features
    as [a(0,0), a'(0,0), b(0,0), b'(0,0), ...] where features are grouped by spatial position.
    
    This module rearranges to keep all spatial positions of the same channel together:
    [a(0,0), a'(0,0), a(0,1), a'(0,1), ..., b(0,0), b'(0,0), ...] ensuring duplicated
    features remain adjacent.
    """
    def __init__(self, start_dim=1, end_dim=-1):
        super().__init__()
        self.start_dim = start_dim
        self.end_dim = end_dim
        
    def forward(self, x):
        # Standard flattening for non-starting dimensions or non-4D tensors
        if x.dim() != 4 or self.start_dim > 1:
            start_dim = self.start_dim if self.start_dim >= 0 else x.dim() + self.start_dim
            end_dim = self.end_dim if self.end_dim >= 0 else x.dim() + self.end_dim
            
            shape = x.shape
            new_shape = list(shape[:start_dim]) + [-1]
            return x.reshape(*new_shape)
        
        # Special handling for 4D tensors with channel duplication
        batch_size, channels, height, width = x.shape
        
        # If channels are not even, use standard flattening
        if channels % 2 != 0:
            return x.reshape(batch_size, -1)
        
        half_channels = channels // 2
        
        # Step 1: Reshape to separate duplicated channels
        # [batch, channels, h, w] -> [batch, half_channels, 2, h, w]
        x_reshaped = x.view(batch_size, half_channels, 2, height, width)
        
        # Step 2: Permute to get the desired order
        # [batch, half_channels, 2, h, w] -> [batch, half_channels, h, w, 2]
        x_permuted = x_reshaped.permute(0, 1, 3, 4, 2)
        
        # Step 3: Flatten
        # [batch, half_channels, h, w, 2] -> [batch, half_channels * h * w * 2]
        return x_permuted.reshape(batch_size, -1)

def clone_linear(src_module: nn.Linear, cloned_module: nn.Linear):
    # Get module dimensions
    src_in_features = src_module.in_features
    src_out_features = src_module.out_features
    cloned_in_features = cloned_module.in_features
    cloned_out_features = cloned_module.out_features
    
    # Verify expansion factors are valid
    if cloned_in_features % src_in_features != 0 or cloned_out_features % src_out_features != 0:
        raise ValueError(f"Linear module dimensions are not integer multiples: "
                         f"{src_in_features}→{cloned_in_features}, {src_out_features}→{cloned_out_features}")
        
    # Calculate expansion factors
    in_expansion = cloned_in_features // src_in_features
    out_expansion = cloned_out_features // src_out_features
    
    print(f"Cloning Linear module: {src_in_features}→{cloned_in_features}, {src_out_features}→{cloned_out_features}, in expansion: {in_expansion}, out expansion: {out_expansion}")
    
    # Clone the weights with proper scaling
    for i in range(in_expansion):
        for j in range(out_expansion):
            cloned_module.weight.data[j::out_expansion, i::in_expansion] = src_module.weight.data / in_expansion
    
    # Clone the bias if present (no scaling needed for bias)
    if src_module.bias is not None and cloned_module.bias is not None:
        for j in range(out_expansion):
            cloned_module.bias.data[j::out_expansion] = src_module.bias.data
    return cloned_module


def clone_conv1d(src_module: nn.Conv1d, cloned_module: nn.Conv1d):
    # Get module dimensions
    src_in_channels = src_module.in_channels
    src_out_channels = src_module.out_channels
    cloned_in_channels = cloned_module.in_channels
    cloned_out_channels = cloned_module.out_channels
    # Calculate expansion factors
    in_expansion = cloned_in_channels // src_in_channels
    out_expansion = cloned_out_channels // src_out_channels
    
    print(f"Cloning Conv1d module: {src_in_channels}→{cloned_in_channels}, {src_out_channels}→{cloned_out_channels}, in expansion: {in_expansion}, out expansion: {out_expansion}")
    
    # Verify expansion factors are valid
    if cloned_in_channels % src_in_channels != 0 or cloned_out_channels % src_out_channels != 0:
        raise ValueError(f"Conv1d module dimensions are not integer multiples: "
                         f"{src_in_channels}→{cloned_in_channels}, {src_out_channels}→{cloned_out_channels}")
    
    # Clone the weights with proper scaling
    for i in range(in_expansion):
        for j in range(out_expansion):
            cloned_module.weight.data[j::out_expansion, i::in_expansion, :] = src_module.weight.data / in_expansion
    
    # Clone the bias if present (no scaling needed for bias)
    if src_module.bias is not None and cloned_module.bias is not None:
        for j in range(out_expansion):
            cloned_module.bias.data[j::out_expansion] = src_module.bias.data
    return cloned_module

    
def clone_conv2d(src_module: nn.Conv2d, cloned_module: nn.Conv2d):
    # Get module dimensions
    src_in_channels = src_module.in_channels
    src_out_channels = src_module.out_channels
    cloned_in_channels = cloned_module.in_channels
    cloned_out_channels = cloned_module.out_channels
    # Calculate expansion factors
    in_expansion = cloned_in_channels // src_in_channels
    out_expansion = cloned_out_channels // src_out_channels
    
    print(f"Cloning Conv2d module: {src_in_channels}→{cloned_in_channels}, {src_out_channels}→{cloned_out_channels}, in expansion: {in_expansion}, out expansion: {out_expansion}")
    
    # Verify expansion factors are valid
    if cloned_in_channels % src_in_channels != 0 or cloned_out_channels % src_out_channels != 0:
        raise ValueError(f"Conv2d module dimensions are not integer multiples: "
                         f"{src_in_channels}→{cloned_in_channels}, {src_out_channels}→{cloned_out_channels}")
    
    # Clone the weights with proper scaling
    for i in range(in_expansion):
        for j in range(out_expansion):
            cloned_module.weight.data[j::out_expansion, i::in_expansion, :, :] = src_module.weight.data / in_expansion
    
    # Clone the bias if present (no scaling needed for bias)
    if src_module.bias is not None and cloned_module.bias is not None:
        for j in range(out_expansion):
            cloned_module.bias.data[j::out_expansion] = src_module.bias.data
    return cloned_module
    

def clone_normalization(
    src_module: NormalizationLayer, 
    cloned_module: NormalizationLayer,
) -> NormalizationLayer:
    """Clone normalization layer parameters with proper handling of different types."""
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    
    # Check properties that exist for the specific normalization type
    if hasattr(src_module, 'affine') and hasattr(cloned_module, 'affine'):
        assert src_module.affine == cloned_module.affine, "Affine property must match"
    
    # Handle BatchNorm-specific properties
    if isinstance(src_module, (nn.BatchNorm1d, nn.BatchNorm2d)):
        if hasattr(src_module, 'track_running_stats') and hasattr(cloned_module, 'track_running_stats'):
            assert src_module.track_running_stats == cloned_module.track_running_stats, "Track running stats property must match"
    
    # Clone weights and biases
    if hasattr(src_module, 'weight') and src_module.weight is not None and cloned_module.weight is not None:
        expansion = cloned_module.weight.data.shape[0] // src_module.weight.data.shape[0] 
        for i in range(expansion):
            cloned_module.weight.data[i::expansion] = src_module.weight.data
            if hasattr(src_module, 'bias') and src_module.bias is not None and cloned_module.bias is not None:
                cloned_module.bias.data[i::expansion] = src_module.bias.data
    
    # Clone running stats for BatchNorm layers
    if hasattr(src_module, 'running_mean') and src_module.running_mean is not None:
        if hasattr(cloned_module, 'running_mean') and cloned_module.running_mean is not None:
            expansion = cloned_module.running_mean.data.shape[0] // src_module.running_mean.data.shape[0]
            for i in range(expansion):
                cloned_module.running_mean.data[i::expansion] = src_module.running_mean.data
                cloned_module.running_var.data[i::expansion] = src_module.running_var.data
    
    # Clone num_batches_tracked for BatchNorm layers
    if hasattr(src_module, 'num_batches_tracked') and src_module.num_batches_tracked is not None:
        if hasattr(cloned_module, 'num_batches_tracked') and cloned_module.num_batches_tracked is not None:
            cloned_module.num_batches_tracked.data.copy_(src_module.num_batches_tracked.data)
    
    return cloned_module
    
    
def clone_embedding(src_module: nn.Embedding, cloned_module: nn.Embedding):
    # Get module dimensions
    src_num_embeddings = src_module.num_embeddings
    src_embedding_dim = src_module.embedding_dim
    cloned_num_embeddings = cloned_module.num_embeddings
    cloned_embedding_dim = cloned_module.embedding_dim
    
    # Calculate expansion factors
    num_expansion = cloned_num_embeddings // src_num_embeddings
    dim_expansion = cloned_embedding_dim // src_embedding_dim
    
    print(f"Cloning Embedding module: {src_num_embeddings}→{cloned_num_embeddings}, {src_embedding_dim}→{cloned_embedding_dim}, num expansion: {num_expansion}, dim expansion: {dim_expansion}")
    
    # Verify expansion factors are valid
    if cloned_num_embeddings % src_num_embeddings != 0 or cloned_embedding_dim % src_embedding_dim != 0:
        raise ValueError(f"Embedding module dimensions are not integer multiples: "
                         f"{src_num_embeddings}→{cloned_num_embeddings}, {src_embedding_dim}→{cloned_embedding_dim}")
    
    # Clone the weights with proper scaling
    for i in range(num_expansion):
        for j in range(dim_expansion):
            cloned_module.weight.data[j::dim_expansion, i::num_expansion] = src_module.weight.data 
    
    return cloned_module


def clone_activation(src_module: ActivationFunction, cloned_module: ActivationFunction) -> ActivationFunction:
    """Clone activation function parameters, handling configuration parameters properly."""
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    
    # Handle configuration parameters for different activation types
    if isinstance(src_module, nn.LeakyReLU):
        cloned_module.negative_slope = src_module.negative_slope
    
    elif isinstance(src_module, (nn.ELU, nn.CELU)):
        cloned_module.alpha = src_module.alpha
    
    elif isinstance(src_module, nn.Threshold):
        cloned_module.threshold = src_module.threshold
        cloned_module.value = src_module.value
    
    elif isinstance(src_module, (nn.Softmax, nn.LogSoftmax)):
        cloned_module.dim = src_module.dim
    
    elif isinstance(src_module, (nn.Hardshrink, nn.Softshrink)):
        cloned_module.lambd = src_module.lambd
        
    elif isinstance(src_module, nn.GLU):
        cloned_module.dim = src_module.dim
    
    # Handle PReLU specifically (has learnable parameters)
    elif isinstance(src_module, nn.PReLU):
        if src_module.num_parameters == 1 and cloned_module.num_parameters > 1:
            # If source is a single parameter, broadcast to all channels
            cloned_module.weight.data.fill_(src_module.weight.data[0])
        elif src_module.num_parameters > 1 and cloned_module.num_parameters > 1:
            # Channel-wise parameters need proper expansion
            expansion = cloned_module.num_parameters // src_module.num_parameters
            for i in range(expansion):
                cloned_module.weight.data[i::expansion] = src_module.weight.data
        else:
            # Direct copy if dimensions match
            cloned_module.weight.data.copy_(src_module.weight.data)
    
    # Handle other parameterized activation functions if they have weights
    # This is a general catch-all for any other activation function with parameters
    elif hasattr(src_module, 'weight') and hasattr(cloned_module, 'weight'):
        if src_module.weight is not None and cloned_module.weight is not None:
            if cloned_module.weight.data.shape == src_module.weight.data.shape:
                cloned_module.weight.data.copy_(src_module.weight.data)
    
    return cloned_module


def clone_dropout(src_module: nn.Dropout, cloned_module: nn.Dropout):
    """Clone dropout module parameters."""
    assert cloned_module.p == src_module.p, "Dropout probability must match"
    # Print warning if dropout p > 0
    if cloned_module.p > 0:
        print(f"Warning: Dropout probability is set to {cloned_module.p}, cloning is not perfect")
    return cloned_module

def clone_flatten(src_module: nn.Flatten) -> CloneAwareFlatten:
    """
    Clone parameters from a standard Flatten and return a new CloneAwareFlatten.
    
    Args:
        src_module: Source nn.Flatten module
        
    Returns:
        A new CloneAwareFlatten module with the same parameters
    """
    return CloneAwareFlatten(
        start_dim=src_module.start_dim,
        end_dim=src_module.end_dim
    )


def is_parameter_free(module: nn.Module) -> bool:
    """Check if a module has no parameters."""
    return len(list(module.parameters())) == 0


def clone_parameter_free(src_module: nn.Module, cloned_module: nn.Module) -> nn.Module:
    """Clone a parameter-free module."""
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    assert is_parameter_free(src_module), "Source module must be parameter free"
    assert is_parameter_free(cloned_module), "Cloned module must be parameter free"
    
    # For parameter-free modules, there's no need to copy weights
    # Just make sure they're of the same type, which we've already checked
    return cloned_module


# Validation functions

def validate_activation_cloning(src_module: ActivationFunction, cloned_module: ActivationFunction):
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    
    # Validate configuration parameters for different activation types
    if isinstance(src_module, nn.LeakyReLU):
        assert src_module.negative_slope == cloned_module.negative_slope, "LeakyReLU negative_slope does not match"
    
    elif isinstance(src_module, (nn.ELU, nn.CELU)):
        assert src_module.alpha == cloned_module.alpha, "Alpha parameter does not match"
    
    elif isinstance(src_module, nn.Threshold):
        assert src_module.threshold == cloned_module.threshold, "Threshold value does not match"
        assert src_module.value == cloned_module.value, "Replacement value does not match"
    
    elif isinstance(src_module, (nn.Softmax, nn.LogSoftmax)):
        assert src_module.dim == cloned_module.dim, "Dimension parameter does not match"
    
    elif isinstance(src_module, (nn.Hardshrink, nn.Softshrink)):
        assert src_module.lambd == cloned_module.lambd, "Lambda parameter does not match"
        
    elif isinstance(src_module, nn.GLU):
        assert src_module.dim == cloned_module.dim, "Dimension parameter does not match"
    
    # Validate PReLU parameters
    elif isinstance(src_module, nn.PReLU):
        if src_module.num_parameters == 1 and cloned_module.num_parameters > 1:
            # All elements should be equal to the single parameter
            assert torch.all(cloned_module.weight.data == src_module.weight.data[0])
        elif src_module.num_parameters > 1 and cloned_module.num_parameters > 1:
            expansion = cloned_module.num_parameters // src_module.num_parameters
            for i in range(expansion):
                assert torch.allclose(cloned_module.weight.data[i::expansion], src_module.weight.data)
    
    print("Passed all tests")
    return True


def validate_dropout_cloning(src_module: nn.Dropout, cloned_module: nn.Dropout):
    assert cloned_module.p == src_module.p, "Dropout probability must match"
    print("Passed all tests")
    return True


def validate_embedding_cloning(src_module: nn.Embedding, cloned_module: nn.Embedding):
    num_expansion = cloned_module.num_embeddings // src_module.num_embeddings
    dim_expansion = cloned_module.embedding_dim // src_module.embedding_dim
    for j in range(num_expansion):
        for i in range(dim_expansion):
            assert torch.allclose(cloned_module.weight.data[j::num_expansion, i::dim_expansion], src_module.weight.data)
    print("Passed all tests")
    return True


def validate_normalization_cloning(src_module: NormalizationLayer, cloned_module: NormalizationLayer):
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    
    if hasattr(src_module, 'weight') and src_module.weight is not None and hasattr(cloned_module, 'weight'):
        expansion = cloned_module.weight.data.shape[0] // src_module.weight.data.shape[0] 
        for i in range(expansion):
            assert torch.allclose(cloned_module.weight.data[i::expansion], src_module.weight.data)
            
            if hasattr(src_module, 'bias') and src_module.bias is not None and hasattr(cloned_module, 'bias'):
                assert torch.allclose(cloned_module.bias.data[i::expansion], src_module.bias.data)
    
    # Check running stats for BatchNorm layers
    if hasattr(src_module, 'running_mean') and src_module.running_mean is not None:
        if hasattr(cloned_module, 'running_mean') and cloned_module.running_mean is not None:
            expansion = cloned_module.running_mean.data.shape[0] // src_module.running_mean.data.shape[0]
            for i in range(expansion):
                assert torch.allclose(cloned_module.running_mean.data[i::expansion], src_module.running_mean.data)
                assert torch.allclose(cloned_module.running_var.data[i::expansion], src_module.running_var.data)
    
    print("Passed all tests")
    

def validate_linear_cloning(src_module: nn.Linear, cloned_module: nn.Linear):
    in_expansion = cloned_module.in_features // src_module.in_features
    out_expansion = cloned_module.out_features // src_module.out_features
    for j in range(out_expansion):
        for i in range(in_expansion):
            assert torch.allclose(cloned_module.weight.data[j::out_expansion, i::in_expansion], src_module.weight.data/in_expansion)
            assert torch.allclose(cloned_module.bias.data[j::out_expansion], src_module.bias.data)
    print("Passed all tests")
    
    
def validate_conv1d_cloning(src_module: nn.Conv1d, cloned_module: nn.Conv1d):
    in_expansion = cloned_module.in_channels // src_module.in_channels
    out_expansion = cloned_module.out_channels // src_module.out_channels
    for j in range(out_expansion):
        for i in range(in_expansion):
            assert torch.allclose(cloned_module.weight.data[j::out_expansion, i::in_expansion, :], src_module.weight.data/in_expansion)
            assert torch.allclose(cloned_module.bias.data[j::out_expansion], src_module.bias.data)
    print("Passed all tests")
    

def validate_conv2d_cloning(src_module: nn.Conv2d, cloned_module: nn.Conv2d):
    in_expansion = cloned_module.in_channels // src_module.in_channels
    out_expansion = cloned_module.out_channels // src_module.out_channels
    for j in range(out_expansion):
        for i in range(in_expansion):
            assert torch.allclose(cloned_module.weight.data[j::out_expansion, i::in_expansion, :, :], src_module.weight.data/in_expansion)
            assert torch.allclose(cloned_module.bias.data[j::out_expansion], src_module.bias.data)
    print("Passed all tests")



def clone_module(
    src_module: nn.Module, 
    cloned_module: nn.Module,
) -> bool:
    """
    Clone parameters from a source module to a cloned module.
    
    Args:
        src_module: Source module with smaller dimensions
        cloned_module: Target module with larger dimensions
        
    Returns:
        bool: True if cloning was successful, False otherwise
    """
    success = True
    
    # Define normalization and activation types inline for easier checking
    norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm)
    activation_types = (nn.ReLU, nn.Sigmoid, nn.Tanh, nn.SELU, nn.GELU, nn.SiLU, nn.ELU, nn.LeakyReLU, 
                       nn.PReLU, nn.Threshold, nn.Softmax, nn.LogSoftmax, nn.Softplus, nn.Softmin, 
                       nn.Hardsigmoid, nn.Hardswish, nn.Softshrink, nn.Hardshrink, nn.Softsign, 
                       nn.GLU, nn.CELU, nn.Identity)
    
    if isinstance(src_module, nn.Linear):
        clone_linear(src_module, cloned_module)
    elif isinstance(src_module, nn.Conv1d):
        clone_conv1d(src_module, cloned_module)
    elif isinstance(src_module, nn.Conv2d):
        clone_conv2d(src_module, cloned_module)
    elif isinstance(src_module, norm_types):
        clone_normalization(src_module, cloned_module)
    elif isinstance(src_module, nn.Embedding):
        clone_embedding(src_module, cloned_module)
    elif isinstance(src_module, activation_types):
        clone_activation(src_module, cloned_module)
    elif isinstance(src_module, nn.Dropout):
        clone_dropout(src_module, cloned_module)
    elif isinstance(src_module, nn.Flatten):
        pass # Flatten is handled separately
    elif is_parameter_free(src_module) and is_parameter_free(cloned_module):
        clone_parameter_free(src_module, cloned_module)
    else:
        success = False
        print(f"Unsupported module type: {type(src_module)}")
    
    return success



def clone_model(src_model: nn.Module, cloned_model: nn.Module) -> nn.Module:
    """
    Clone parameters from a source model to a cloned model.
    
    Args:
        src_model: Source model with smaller dimensions
        cloned_model: Target model with larger dimensions
        
    Returns:
        cloned_model: The target model with cloned parameters
    """
    for name, module in list(cloned_model.named_modules()):
        if isinstance(module, nn.Flatten):
            parent_name = '.'.join(name.split('.')[:-1])
            module_name = name.split('.')[-1]
            
            # Find parent module to modify
            if parent_name:
                parent = cloned_model.get_submodule(parent_name)
            else:
                parent = cloned_model
                
            # Create and replace with CloneAwareFlatten
            setattr(parent, module_name, CloneAwareFlatten(
                start_dim=module.start_dim,
                end_dim=module.end_dim
            ))
            print(f"Replaced Flatten with CloneAwareFlatten at {name}")
    # Process each module individually
    for name, src_module in src_model.named_modules():
        cloned_module = cloned_model.get_submodule(name)
        print(f"Cloning module {name}")
        clone_module(src_module, cloned_module)
    
    return cloned_model

/home/amir/Codes/NN-dynamic-scaling already in Python path


In [5]:
from src.models import MLP, CNN, ResNet, VisionTransformer

src_model = MLP(input_size=10, output_size=2, hidden_sizes=[64, 32,], activation="relu", dropout_p=0.0)
cloned_model = MLP(input_size=10, output_size=2, hidden_sizes=[64*2, 32*2,], activation="relu", dropout_p=0.0)
cloned_model = clone_model(src_model, cloned_model)

x = torch.randn(32, 10)
y1 = src_model(x)
y2 = cloned_model(x)
print("Output from source model:", y1)
print("Output from cloned model:", y2)

y1-y2


Cloning module 
Unsupported module type: <class 'src.models.mlp.MLP'>
Cloning module layers
Unsupported module type: <class 'torch.nn.modules.container.ModuleDict'>
Cloning module layers.linear_0
Cloning Linear module: 10→10, 64→128, in expansion: 1, out expansion: 2
Cloning module layers.act_0
Cloning module layers.linear_1
Cloning Linear module: 64→128, 32→64, in expansion: 2, out expansion: 2
Cloning module layers.act_1
Cloning module layers.out
Cloning Linear module: 32→64, 2→2, in expansion: 2, out expansion: 1
Output from source model: tensor([[-0.1834,  0.2274],
        [-0.2082,  0.2723],
        [-0.1581,  0.1333],
        [-0.2022,  0.2726],
        [-0.2721,  0.2446],
        [-0.3425,  0.3018],
        [-0.1492,  0.2978],
        [-0.2177,  0.3049],
        [-0.2437,  0.1399],
        [-0.1879,  0.1989],
        [-0.2049,  0.2483],
        [-0.1018,  0.1314],
        [-0.0530,  0.1803],
        [-0.1930,  0.3193],
        [-0.0641,  0.1352],
        [-0.1829,  0.2519],
    

tensor([[-1.4901e-08, -5.9605e-08],
        [ 1.4901e-08, -2.9802e-08],
        [ 1.4901e-08, -1.4901e-08],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -2.9802e-08],
        [ 0.0000e+00, -5.9605e-08],
        [-1.4901e-08,  5.9605e-08],
        [ 2.9802e-08,  0.0000e+00],
        [ 5.9605e-08, -1.4901e-08],
        [ 2.9802e-08,  0.0000e+00],
        [-1.4901e-08,  0.0000e+00],
        [ 7.4506e-09, -2.9802e-08],
        [ 2.2352e-08,  0.0000e+00],
        [ 4.4703e-08,  0.0000e+00],
        [-1.4901e-08,  4.4703e-08],
        [ 0.0000e+00,  2.9802e-08],
        [ 0.0000e+00, -2.9802e-08],
        [-2.9802e-08,  5.9605e-08],
        [-2.9802e-08, -2.9802e-08],
        [ 2.9802e-08, -5.9605e-08],
        [-2.9802e-08, -5.9605e-08],
        [ 2.9802e-08, -7.4506e-09],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  7.4506e-09],
        [ 0.0000e+00,  2.2352e-08],
        [-4.4703e-08, -2.9802e-08],
        [-1.4901e-08, -2.9802e-08],
        [ 0.0000e+00, -1.490

In [82]:
w1, w2 = src_model.layers.linear_0.weight.data, cloned_model.layers.linear_0.weight.data
b1, b2 = src_model.layers.linear_0.bias.data, cloned_model.layers.linear_0.bias.data
w1.shape, w2.shape

torch.allclose(w2[::2,:], w1),torch.allclose(b2[::2], b1)

w1, w2 = src_model.layers.linear_1.weight.data, cloned_model.layers.linear_1.weight.data
b1, b2 = src_model.layers.linear_1.bias.data, cloned_model.layers.linear_1.bias.data

torch.allclose(w2[::2,::2], w1/2),torch.allclose(w2[1::2,::2], w1/2),torch.allclose(w2[::2,1::2], w1/2),torch.allclose(w2[1::2,1::2], w1/2),torch.allclose(b2[::2], b1)

(True, True, True, True, True)

In [98]:
h1, h2 = src_model.layers.linear_0(x), cloned_model.layers.linear_0(x)
h1, h2 = src_model.layers.act_0(h1), cloned_model.layers.act_0(h2)
h1, h2 = src_model.layers.linear_1(h1), cloned_model.layers.linear_1(h2)
h1, h2 = src_model.layers.act_1(h1), cloned_model.layers.act_1(h2)
h1, h2 = src_model.layers.out(h1), cloned_model.layers.out(h2)
h1.shape, h2.shape
if h1.shape[1] != h2.shape[1]:
    assert h2.shape[1] % h1.shape[1] == 0, "Output dimensions are not integer multiples"
    expansion_factor = h2.shape[1] // h1.shape[1]
    for i in range(expansion_factor):
        assert torch.allclose(h2[:,i::expansion_factor], h1, atol=1e-5), f"Output mismatch at expansion factor {i}"
print("All activations are close ")

All activations are close 


In [97]:
w1, w2 = src_model.layers.out.weight.data, cloned_model.layers.out.weight.data
b1, b2 = src_model.layers.out.bias.data, cloned_model.layers.out.bias.data

w1.shape, w2.shape, b1.shape, b2.shape
torch.allclose(w2[:,::2],w1/2),torch.allclose(w2[:,1::2],w1/2), torch.allclose(b2, b1)

(True, True, True)

In [86]:
from src.models import MLP, CNN, ResNet, VisionTransformer

src_model = CNN(in_channels=3, num_classes=2, conv_channels=[64, 128, 256], activation="relu", dropout_p=0.0)
cloned_model = CNN(in_channels=3, num_classes=2, conv_channels=[64*2, 128*2, 256*2], activation="relu", dropout_p=0.0)

cloned_model = clone_model(src_model, cloned_model)

x = torch.randn(32, 3, 32, 32)
y1 = src_model(x)
y2 = cloned_model(x)

torch.allclose(y1, y2,atol=1e-3), y1.shape, y2.shape

Replaced Flatten with CloneAwareFlatten at layers.flatten
Cloning module 
Unsupported module type: <class 'src.models.cnn.CNN'>
Cloning module layers
Unsupported module type: <class 'torch.nn.modules.container.ModuleDict'>
Cloning module layers.conv_0
Cloning Conv2d module: 3→3, 64→128, in expansion: 1, out expansion: 2
Cloning module layers.norm_0
Cloning module layers.act_0
Cloning module layers.pool_0
Cloning module layers.conv_1
Cloning Conv2d module: 64→128, 128→256, in expansion: 2, out expansion: 2
Cloning module layers.norm_1
Cloning module layers.act_1
Cloning module layers.pool_1
Cloning module layers.conv_2
Cloning Conv2d module: 128→256, 256→512, in expansion: 2, out expansion: 2
Cloning module layers.norm_2
Cloning module layers.act_2
Cloning module layers.pool_2
Cloning module layers.flatten
Cloning module layers.fc_0
Cloning Linear module: 4096→8192, 512→512, in expansion: 2, out expansion: 1
Cloning module layers.fc_act_0
Cloning module layers.fc_out
Cloning Linear modu

(True, torch.Size([32, 2]), torch.Size([32, 2]))

In [85]:


w1 = src_model.layers.conv_0.weight.data
w2 = cloned_model.layers.conv_0.weight.data
b1 = src_model.layers.conv_0.bias.data
b2 = cloned_model.layers.conv_0.bias.data

w1.shape, w2.shape, b1.shape, b2.shape
# w2[::2,:,:,:]==w1
torch.allclose(w2[::2,:,:,:], w1),torch.allclose(w2[1::2,:,:,:], w1),torch.allclose(b2[::2], b1)

h1, h2 = src_model.layers.conv_0(x), cloned_model.layers.conv_0(x)
h1, h2 = src_model.layers.norm_0(h1), cloned_model.layers.norm_0(h2)
h1, h2 = src_model.layers.act_0(h1), cloned_model.layers.act_0(h2)
h1, h2 = src_model.layers.pool_0(h1), cloned_model.layers.pool_0(h2)


h1, h2 = src_model.layers.conv_1(h1), cloned_model.layers.conv_1(h2)
h1, h2 = src_model.layers.norm_1(h1), cloned_model.layers.norm_1(h2)
h1, h2 = src_model.layers.act_1(h1), cloned_model.layers.act_1(h2)
h1, h2 = src_model.layers.pool_1(h1), cloned_model.layers.pool_1(h2)

# go to layer 2 
h1, h2 = src_model.layers.conv_2(h1), cloned_model.layers.conv_2(h2)
h1, h2 = src_model.layers.norm_2(h1), cloned_model.layers.norm_2(h2)
h1, h2 = src_model.layers.act_2(h1), cloned_model.layers.act_2(h2)
h1, h2 = src_model.layers.pool_2(h1), cloned_model.layers.pool_2(h2)


h1, h2 = src_model.layers.flatten(h1), cloned_model.layers.flatten(h2)

h1, h2 = src_model.layers.fc_0(h1), cloned_model.layers.fc_0(h2)
h1, h2 = src_model.layers.fc_act_0(h1), cloned_model.layers.fc_act_0(h2) 
h1, h2 = src_model.layers.fc_out(h1), cloned_model.layers.fc_out(h2)

#, torch.allclose(h2[:,1::2],h1, atol=1e-3)
if h1.shape[1] != h2.shape[1]:
    print("Different shape: ")
    print(h1.shape, h2.shape, torch.allclose(h2[:,::2],h1, atol=1e-3), torch.allclose(h2[:,1::2],h1, atol=1e-3))
else:
    print("Same shape: ")
    print(h1.shape, h2.shape, torch.allclose(h2,h1, atol=1e-3))

Same shape: 
torch.Size([32, 2]) torch.Size([32, 2]) True


In [90]:
src_model = ResNet(in_channels=3, num_classes=2, base_channels=64, activation="relu", dropout_p=0.0)
cloned_model = ResNet(in_channels=3, num_classes=2, base_channels=64*2, activation="relu", dropout_p=0.0)

cloned_model = clone_model(src_model, cloned_model)

x = torch.randn(32, 3, 32, 32)
y1 = src_model(x)
y2 = cloned_model(x)

torch.allclose(y1, y2,atol=1e-3), y1.shape, y2.shape

Replaced Flatten with CloneAwareFlatten at layers.flatten
Cloning module 
Unsupported module type: <class 'src.models.resnet.ResNet'>
Cloning module layers
Unsupported module type: <class 'torch.nn.modules.container.ModuleDict'>
Cloning module layers.conv1
Cloning Conv2d module: 3→3, 64→128, in expansion: 1, out expansion: 2
Cloning module layers.bn1
Cloning module layers.activation
Cloning module layers.layer1_block0
Unsupported module type: <class 'src.models.resnet.BasicBlock'>
Cloning module layers.layer1_block0.layers
Unsupported module type: <class 'torch.nn.modules.container.ModuleDict'>
Cloning module layers.layer1_block0.layers.conv1
Cloning Conv2d module: 64→128, 64→128, in expansion: 2, out expansion: 2
Cloning module layers.layer1_block0.layers.bn1
Cloning module layers.layer1_block0.layers.activation
Cloning module layers.layer1_block0.layers.conv2
Cloning Conv2d module: 64→128, 64→128, in expansion: 2, out expansion: 2
Cloning module layers.layer1_block0.layers.bn2
Clonin

(True, torch.Size([32, 2]), torch.Size([32, 2]))

In [96]:
src_model = VisionTransformer(
    in_channels=3, 
    num_classes=2, 
    embed_dim=64, 
    depth=2, 
    drop_rate=0.0,
    attn_drop_rate=0.0,
    activation="relu",)

cloned_model = VisionTransformer(
    in_channels=3, 
    num_classes=2, 
    patch_size=4, 
    embed_dim=64*2, 
    depth=2, 
    drop_rate=0.0,
    attn_drop_rate=0.0,
    activation="relu",)

cloned_model = clone_model(src_model, cloned_model)

x = torch.randn(32, 3, 32, 32)
y1 = src_model(x)
y2 = cloned_model(x)
torch.allclose(y1, y2,atol=1e-3), y1.shape, y2.shape

Cloning module 
Unsupported module type: <class 'src.models.vit.VisionTransformer'>
Cloning module layers
Unsupported module type: <class 'torch.nn.modules.container.ModuleDict'>
Cloning module layers.patch_embed
Unsupported module type: <class 'src.models.vit.PatchEmbedding'>
Cloning module layers.patch_embed.layers
Unsupported module type: <class 'torch.nn.modules.container.ModuleDict'>
Cloning module layers.patch_embed.layers.proj
Cloning Conv2d module: 3→3, 64→128, in expansion: 1, out expansion: 2
Cloning module layers.pos_drop
Cloning module layers.block_0
Unsupported module type: <class 'src.models.vit.TransformerBlock'>
Cloning module layers.block_0.layers
Unsupported module type: <class 'torch.nn.modules.container.ModuleDict'>
Cloning module layers.block_0.layers.norm1
Cloning module layers.block_0.layers.attn
Unsupported module type: <class 'src.models.vit.Attention'>
Cloning module layers.block_0.layers.attn.layers
Unsupported module type: <class 'torch.nn.modules.container.

(False, torch.Size([32, 2]), torch.Size([32, 2]))

In [117]:
src_model.layers.keys()

dict_keys(['patch_embed', 'pos_drop', 'block_0', 'block_1', 'norm', 'head'])

In [171]:
# the ViT forward function: 
    # def forward(self, x):
    #     x = self.layers['patch_embed'](x)
        
    #     B = x.shape[0]
    #     cls_token = self.cls_token.expand(B, -1, -1)
    #     x = torch.cat((cls_token, x), dim=1)
        
    #     x = x + self.pos_embed
    #     x = self.layers['pos_drop'](x)

    #     for i in range(self.depth):
    #         x = self.layers[f'block_{i}'](x)
        
    #     x = self.layers['norm'](x)
    #     x = x[:, 0]  # Use CLS token for classification
    #     x = self.layers['head'](x)
            
    #     return x
cloned_model.cls_token.data[:,:,::2] = src_model.cls_token.data
cloned_model.cls_token.data[:,:,1::2] = src_model.cls_token.data
cloned_model.pos_embed.data[:,:,::2] = src_model.pos_embed.data
cloned_model.pos_embed.data[:,:,1::2] = src_model.pos_embed.data
x = torch.randn(32, 3, 32, 32)
h1, h2 = src_model.layers.patch_embed(x), cloned_model.layers.patch_embed(x)
cls_token1, cls_token2 = src_model.cls_token.expand(h1.shape[0], -1, -1), cloned_model.cls_token.expand(h2.shape[0], -1, -1)
h1, h2 = torch.cat((cls_token1, h1), dim=1), torch.cat((cls_token2, h2), dim=1)
h1, h2 = h1 + src_model.pos_embed, h2 + cloned_model.pos_embed
h1, h2 = src_model.layers.pos_drop(h1), cloned_model.layers.pos_drop(h2)
h1, h2 = src_model.layers.block_0(h1), cloned_model.layers.block_0(h2)
h1, h2 = src_model.layers.block_1(h1), cloned_model.layers.block_1(h2)
h1, h2 = src_model.layers.norm(h1), cloned_model.layers.norm(h2)

h1.shape, h2.shape, torch.allclose(h2[:,:,::2], h1, atol=1e-2), torch.allclose(h2[:,:,1::2], h1, atol=1e-2)

h1, h2 = h1[:, 0], h2[:, 0]

h1.shape, h2.shape, torch.allclose(h2[:,::2], h1, atol=1e-2), torch.allclose(h2[:,1::2], h1, atol=1e-2)

h1, h2 = src_model.layers.head(h1), cloned_model.layers.head(h2)

h1.shape, h2.shape, torch.allclose(h1, h2, atol=1e-2)

(torch.Size([32, 2]), torch.Size([32, 2]), True)

In [169]:
h1 - h2

tensor([[-5.7024e-04,  1.7179e-04],
        [-5.5595e-04,  9.0450e-05],
        [-3.0534e-04, -3.2416e-04],
        [-3.9641e-04, -3.0972e-05],
        [-5.5676e-04, -2.8667e-04],
        [-2.2934e-04, -2.0367e-04],
        [-2.9618e-04,  2.3894e-05],
        [-4.8099e-04,  3.1775e-04],
        [-3.3120e-04, -4.3871e-04],
        [-4.3447e-04, -3.1921e-04],
        [-4.9408e-04,  6.2302e-05],
        [-3.1758e-04,  2.8926e-04],
        [-6.9366e-04, -2.2227e-04],
        [-3.5624e-04, -1.4901e-04],
        [-3.2962e-04, -4.1284e-05],
        [-2.6935e-04, -8.6188e-05],
        [-3.1736e-04, -1.9766e-05],
        [-3.6038e-04,  6.0596e-05],
        [-4.3157e-04, -1.6320e-04],
        [-5.0400e-04, -1.3572e-04],
        [-7.1657e-04, -3.5712e-04],
        [-2.0284e-04, -9.4675e-05],
        [-6.5736e-04,  8.7515e-05],
        [-3.7079e-04, -5.9351e-05],
        [-4.4346e-04, -9.6247e-05],
        [-2.1453e-04,  1.4697e-04],
        [-4.2505e-04, -1.2357e-04],
        [-5.6989e-04,  3.044