In [3]:
import torch
import torch.nn as nn
# Import the utility module and Setup the path
import notebook_utils
notebook_utils.setup_path()

from src.models.mlp import MLP

Added /home/amir/Codes/NN-dynamic-scaling to Python path


In [5]:


def clone_model_parameters(src_model, cloned_model):
    """
    Clone parameters from a smaller model to a larger model using a module-based approach.
    
    For linear layers, weights are scaled by 1/n (where n is the input expansion factor)
    to ensure equivalent functionality after cloning.
    
    Args:
        src_model: Source model with smaller dimensions
        cloned_model: Target model with larger dimensions
        
    Returns:
        cloned_model: The target model with cloned parameters
    """
    # First verify model structures
    src_modules = {name: module for name, module in src_model.named_modules() if isinstance(module, nn.Linear)}
    cloned_modules = {name: module for name, module in cloned_model.named_modules() if isinstance(module, nn.Linear)}
    
    # Check if modules match
    if set(src_modules.keys()) != set(cloned_modules.keys()):
        raise ValueError("Source and cloned models have different module structures")
    
    # Process each module individually
    for name, src_module in src_modules.items():
        cloned_module = cloned_modules[name]
        
        # Get module dimensions
        src_in_features = src_module.in_features
        src_out_features = src_module.out_features
        cloned_in_features = cloned_module.in_features
        cloned_out_features = cloned_module.out_features
        
        # Calculate expansion factors
        in_expansion = cloned_in_features // src_in_features
        out_expansion = cloned_out_features // src_out_features
        
        print(f"Cloning module {name}: {src_in_features}→{cloned_in_features}, {src_out_features}→{cloned_out_features}, in expansion: {in_expansion}, out expansion: {out_expansion}")
        
        
        # Verify expansion factors are valid
        if cloned_in_features % src_in_features != 0 or cloned_out_features % src_out_features != 0:
            raise ValueError(f"Module {name} dimensions are not integer multiples: "
                             f"{src_in_features}→{cloned_in_features}, {src_out_features}→{cloned_out_features}")
        
        # Clone the weights with proper scaling
        for i in range(in_expansion):
            for j in range(out_expansion):
                cloned_module.weight.data[j::out_expansion, i::in_expansion] = src_module.weight.data / in_expansion
    
        
        # Clone the bias if present (no scaling needed for bias)
        if src_module.bias is not None and cloned_module.bias is not None:
            # cloned the bias vector 
            for j in range(out_expansion):
                cloned_module.bias.data[j::out_expansion] = src_module.bias.data
    
    # For non-linear modules (if any), copy parameters without scaling
    for name, param in src_model.named_parameters():
        # Skip parameters that belong to linear layers (already handled)
        if any(module_name in name for module_name in src_modules.keys()):
            continue
        
        if name in cloned_model.state_dict():
            cloned_param = cloned_model.state_dict()[name]
            src_shape = torch.tensor(param.shape)
            cloned_shape = torch.tensor(cloned_param.shape)
            
            # If shapes match, directly copy
            if tuple(src_shape) == tuple(cloned_shape):
                cloned_param.copy_(param)
                print(f"Parameter {name} copied directly (dimensions match)")
            else:
                # For other parameters that need expansion but no scaling
                non_matching_dims = src_shape != cloned_shape
                
                # Create indices for blockwise expansion
                indices = []
                expansion_info = []
                
                for i, (s1, s2) in enumerate(zip(src_shape, cloned_shape)):
                    if s1 == s2:
                        indices.append(torch.arange(s2))
                    else:
                        expansion_factor = s2 // s1
                        expansion_info.append(f"dim {i}: {expansion_factor}x")
                        indices.append(torch.div(torch.arange(s2), expansion_factor, rounding_mode='floor'))
                
                # Create the grid and copy
                grid = torch.meshgrid(*indices, indexing='ij')
                cloned_param.copy_(param[grid])
                
                if expansion_info:
                    print(f"Parameter {name} cloned with blockwise expansion: {', '.join(expansion_info)}")
    
    return cloned_model
    
    
def validate_model_cloning(src_model, cloned_model):
    passed = 0
    total = 0
    for name, module in src_model.named_modules():
        if isinstance(module, nn.Linear):
            # print(f"source module {name}: {module.in_features}→{module.out_features},  cloned module {cloned_model.get_submodule(name).in_features}→{cloned_model.get_submodule(name).out_features}")
            module2 = cloned_model.get_submodule(name)
            in_expansion = module2.in_features // module.in_features
            out_expansion = module2.out_features // module.out_features
            # print(f"Expansion factors (outxin): {out_expansion}x{in_expansion}")
            for j in range(out_expansion):
                for i in range(in_expansion):
                    passed += torch.allclose(module2.weight.data[j::out_expansion, i::in_expansion], module.weight.data/in_expansion)
                    passed += torch.allclose(module2.bias.data[j::out_expansion], module.bias.data)
                    total += 2
    
    print(f"Passed {passed} out of {total} tests")
    return passed == total  

def validate_activation_clonign(src_model, cloned_model):
    from src.utils.monitor import NetworkMonitor

    src_monitor = NetworkMonitor(src_model, )
    cloned_monitor = NetworkMonitor(cloned_model, )
    src_monitor.register_hooks()
    cloned_monitor.register_hooks()


    d = src_model.input_size
    x = torch.randn(10, d)
    src_model(x)
    cloned_model(x)

    acts, acts2 = src_monitor.get_latest_activations(), cloned_monitor.get_latest_activations()

    passed = 0
    total = 0
    for key in acts.keys():
        a1, a2 = acts[key], acts2[key]
        if a1.shape[1] != a2.shape[1]:
            diffs = (a1 - a2[:,::2])
        else: #only for the last layer
            diffs = (a1 - a2)
        # print(f"Diff for {key}: {diffs.abs().max().item()}")
        passed += diffs.abs().max().item() < 1e-5
        total += 1
    print(f"Passed {passed} out of {total} tests")
    return passed == total

# src_model = MLP(input_size=10, output_size=2, hidden_sizes=[64, 32, 16], activation="relu", dropout_p=0.0)
# cloned_model = MLP(input_size=10, output_size=2, hidden_sizes=[64*2, 32*2, 16*2], activation="relu", dropout_p=0.0)

# cloned_model = clone_model_parameters(src_model, cloned_model)
# test1(src_model, cloned_model)
# test2(src_model, cloned_model)


# Example usage
if __name__ == "__main__":
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Create source and target models with random weights
    src_model = MLP(input_size=10, output_size=2, hidden_sizes=[64, 32, 16], activation="relu", dropout_p=0.0)
    cloned_model = MLP(input_size=10, output_size=2, hidden_sizes=[64*2, 32*2, 16*2], activation="relu", dropout_p=0.0)
    
    # Clone the parameters
    cloned_model = clone_model_parameters(src_model, cloned_model)
    
    # Test the cloning with a functional test
    success = validate_model_cloning(src_model, cloned_model)
    print("\nModel cloning test:", "PASSED" if success else "FAILED")

    
    
    success = validate_activation_clonign(src_model, cloned_model)
    print("\nActivation cloning test:", "PASSED" if success else "FAILED")
    

Cloning module layers.linear_0: 10→10, 64→128, in expansion: 1, out expansion: 2
Cloning module layers.linear_1: 64→128, 32→64, in expansion: 2, out expansion: 2
Cloning module layers.linear_2: 32→64, 16→32, in expansion: 2, out expansion: 2
Cloning module layers.out: 16→32, 2→2, in expansion: 2, out expansion: 1
Passed 24 out of 24 tests

Model cloning test: PASSED
Passed 7 out of 7 tests

Activation cloning test: PASSED


In [86]:
from src.models import CNN
model = CNN()
model.layers.conv_0.weight.data.shape
for k,v in model.layers.norm_2.named_parameters():
    print(k,v.shape, v[:5])
model.layers.norm_2

weight torch.Size([256]) tensor([1., 1., 1., 1., 1.], grad_fn=<SliceBackward0>)
bias torch.Size([256]) tensor([0., 0., 0., 0., 0.], grad_fn=<SliceBackward0>)


BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [110]:
from src.models import MLP, CNN, ResNet, VisionTransformer

model = VisionTransformer()
model

VisionTransformer(
  (layers): ModuleDict(
    (patch_embed): PatchEmbedding(
      (layers): ModuleDict(
        (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      )
    )
    (pos_drop): Dropout(p=0.1, inplace=False)
    (block_0): TransformerBlock(
      (layers): ModuleDict(
        (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (layers): ModuleDict(
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_drop): Dropout(p=0.1, inplace=False)
          )
        )
        (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (mlp): TransformerMLP(
          (layers): ModuleDict(
            (fc1): Linear(in_features=192, out_features=768, bias=True)
            (act): GELU(approximate='none')
            (drop1): Dropout(p=0.1, inplace=False)

In [None]:
import torch
import torch.nn as nn

from typing import Union, List, Dict, Set

# Replace TypeVar with Union for proper type handling
NormalizationLayer = Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm]

# Define activation function types properly
ActivationFunction = Union[nn.ReLU, nn.Sigmoid, nn.Tanh, nn.SELU, nn.GELU, nn.SiLU, nn.ELU, 
                         nn.LeakyReLU, nn.PReLU, nn.Threshold, nn.Softmax, nn.LogSoftmax, 
                         nn.Softplus, nn.Softmin, nn.Hardsigmoid, nn.Hardswish, nn.Softshrink, 
                         nn.Hardshrink, nn.Softsign, nn.GLU, nn.CELU, nn.Identity]

def clone_linear(src_module: nn.Linear, cloned_module: nn.Linear):
    # Get module dimensions
    src_in_features = src_module.in_features
    src_out_features = src_module.out_features
    cloned_in_features = cloned_module.in_features
    cloned_out_features = cloned_module.out_features
    
    # Verify expansion factors are valid
    if cloned_in_features % src_in_features != 0 or cloned_out_features % src_out_features != 0:
        raise ValueError(f"Linear module dimensions are not integer multiples: "
                         f"{src_in_features}→{cloned_in_features}, {src_out_features}→{cloned_out_features}")
        
    # Calculate expansion factors
    in_expansion = cloned_in_features // src_in_features
    out_expansion = cloned_out_features // src_out_features
    
    print(f"Cloning Linear module: {src_in_features}→{cloned_in_features}, {src_out_features}→{cloned_out_features}, in expansion: {in_expansion}, out expansion: {out_expansion}")
    
    # Clone the weights with proper scaling
    for i in range(in_expansion):
        for j in range(out_expansion):
            cloned_module.weight.data[j::out_expansion, i::in_expansion] = src_module.weight.data / in_expansion
    
    # Clone the bias if present (no scaling needed for bias)
    if src_module.bias is not None and cloned_module.bias is not None:
        for j in range(out_expansion):
            cloned_module.bias.data[j::out_expansion] = src_module.bias.data
    return cloned_module


def clone_conv1d(src_module: nn.Conv1d, cloned_module: nn.Conv1d):
    # Get module dimensions
    src_in_channels = src_module.in_channels
    src_out_channels = src_module.out_channels
    cloned_in_channels = cloned_module.in_channels
    cloned_out_channels = cloned_module.out_channels
    # Calculate expansion factors
    in_expansion = cloned_in_channels // src_in_channels
    out_expansion = cloned_out_channels // src_out_channels
    
    print(f"Cloning Conv1d module: {src_in_channels}→{cloned_in_channels}, {src_out_channels}→{cloned_out_channels}, in expansion: {in_expansion}, out expansion: {out_expansion}")
    
    # Verify expansion factors are valid
    if cloned_in_channels % src_in_channels != 0 or cloned_out_channels % src_out_channels != 0:
        raise ValueError(f"Conv1d module dimensions are not integer multiples: "
                         f"{src_in_channels}→{cloned_in_channels}, {src_out_channels}→{cloned_out_channels}")
    
    # Clone the weights with proper scaling
    for i in range(in_expansion):
        for j in range(out_expansion):
            cloned_module.weight.data[j::out_expansion, i::in_expansion, :] = src_module.weight.data / in_expansion
    
    # Clone the bias if present (no scaling needed for bias)
    if src_module.bias is not None and cloned_module.bias is not None:
        for j in range(out_expansion):
            cloned_module.bias.data[j::out_expansion] = src_module.bias.data
    return cloned_module

    
def clone_conv2d(src_module: nn.Conv2d, cloned_module: nn.Conv2d):
    # Get module dimensions
    src_in_channels = src_module.in_channels
    src_out_channels = src_module.out_channels
    cloned_in_channels = cloned_module.in_channels
    cloned_out_channels = cloned_module.out_channels
    # Calculate expansion factors
    in_expansion = cloned_in_channels // src_in_channels
    out_expansion = cloned_out_channels // src_out_channels
    
    print(f"Cloning Conv2d module: {src_in_channels}→{cloned_in_channels}, {src_out_channels}→{cloned_out_channels}, in expansion: {in_expansion}, out expansion: {out_expansion}")
    
    # Verify expansion factors are valid
    if cloned_in_channels % src_in_channels != 0 or cloned_out_channels % src_out_channels != 0:
        raise ValueError(f"Conv2d module dimensions are not integer multiples: "
                         f"{src_in_channels}→{cloned_in_channels}, {src_out_channels}→{cloned_out_channels}")
    
    # Clone the weights with proper scaling
    for i in range(in_expansion):
        for j in range(out_expansion):
            cloned_module.weight.data[j::out_expansion, i::in_expansion, :, :] = src_module.weight.data / in_expansion
    
    # Clone the bias if present (no scaling needed for bias)
    if src_module.bias is not None and cloned_module.bias is not None:
        for j in range(out_expansion):
            cloned_module.bias.data[j::out_expansion] = src_module.bias.data
    return cloned_module
    

def clone_normalization(
    src_module: NormalizationLayer, 
    cloned_module: NormalizationLayer,
) -> NormalizationLayer:
    """Clone normalization layer parameters with proper handling of different types."""
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    
    # Check properties that exist for the specific normalization type
    if hasattr(src_module, 'affine') and hasattr(cloned_module, 'affine'):
        assert src_module.affine == cloned_module.affine, "Affine property must match"
    
    # Handle BatchNorm-specific properties
    if isinstance(src_module, (nn.BatchNorm1d, nn.BatchNorm2d)):
        if hasattr(src_module, 'track_running_stats') and hasattr(cloned_module, 'track_running_stats'):
            assert src_module.track_running_stats == cloned_module.track_running_stats, "Track running stats property must match"
    
    # Clone weights and biases
    if hasattr(src_module, 'weight') and src_module.weight is not None and cloned_module.weight is not None:
        expansion = cloned_module.weight.data.shape[0] // src_module.weight.data.shape[0] 
        for i in range(expansion):
            cloned_module.weight.data[i::expansion] = src_module.weight.data
            if hasattr(src_module, 'bias') and src_module.bias is not None and cloned_module.bias is not None:
                cloned_module.bias.data[i::expansion] = src_module.bias.data
    
    # Clone running stats for BatchNorm layers
    if hasattr(src_module, 'running_mean') and src_module.running_mean is not None:
        if hasattr(cloned_module, 'running_mean') and cloned_module.running_mean is not None:
            expansion = cloned_module.running_mean.data.shape[0] // src_module.running_mean.data.shape[0]
            for i in range(expansion):
                cloned_module.running_mean.data[i::expansion] = src_module.running_mean.data
                cloned_module.running_var.data[i::expansion] = src_module.running_var.data
    
    # Clone num_batches_tracked for BatchNorm layers
    if hasattr(src_module, 'num_batches_tracked') and src_module.num_batches_tracked is not None:
        if hasattr(cloned_module, 'num_batches_tracked') and cloned_module.num_batches_tracked is not None:
            cloned_module.num_batches_tracked.data.copy_(src_module.num_batches_tracked.data)
    
    return cloned_module
    
    
def clone_embedding(src_module: nn.Embedding, cloned_module: nn.Embedding):
    # Get module dimensions
    src_num_embeddings = src_module.num_embeddings
    src_embedding_dim = src_module.embedding_dim
    cloned_num_embeddings = cloned_module.num_embeddings
    cloned_embedding_dim = cloned_module.embedding_dim
    
    # Calculate expansion factors
    num_expansion = cloned_num_embeddings // src_num_embeddings
    dim_expansion = cloned_embedding_dim // src_embedding_dim
    
    print(f"Cloning Embedding module: {src_num_embeddings}→{cloned_num_embeddings}, {src_embedding_dim}→{cloned_embedding_dim}, num expansion: {num_expansion}, dim expansion: {dim_expansion}")
    
    # Verify expansion factors are valid
    if cloned_num_embeddings % src_num_embeddings != 0 or cloned_embedding_dim % src_embedding_dim != 0:
        raise ValueError(f"Embedding module dimensions are not integer multiples: "
                         f"{src_num_embeddings}→{cloned_num_embeddings}, {src_embedding_dim}→{cloned_embedding_dim}")
    
    # Clone the weights with proper scaling
    for i in range(num_expansion):
        for j in range(dim_expansion):
            cloned_module.weight.data[j::dim_expansion, i::num_expansion] = src_module.weight.data 
    
    return cloned_module


def clone_activation(src_module: ActivationFunction, cloned_module: ActivationFunction) -> ActivationFunction:
    """Clone activation function parameters, handling configuration parameters properly."""
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    
    # Handle configuration parameters for different activation types
    if isinstance(src_module, nn.LeakyReLU):
        cloned_module.negative_slope = src_module.negative_slope
    
    elif isinstance(src_module, (nn.ELU, nn.CELU)):
        cloned_module.alpha = src_module.alpha
    
    elif isinstance(src_module, nn.Threshold):
        cloned_module.threshold = src_module.threshold
        cloned_module.value = src_module.value
    
    elif isinstance(src_module, (nn.Softmax, nn.LogSoftmax)):
        cloned_module.dim = src_module.dim
    
    elif isinstance(src_module, (nn.Hardshrink, nn.Softshrink)):
        cloned_module.lambd = src_module.lambd
        
    elif isinstance(src_module, nn.GLU):
        cloned_module.dim = src_module.dim
    
    # Handle PReLU specifically (has learnable parameters)
    elif isinstance(src_module, nn.PReLU):
        if src_module.num_parameters == 1 and cloned_module.num_parameters > 1:
            # If source is a single parameter, broadcast to all channels
            cloned_module.weight.data.fill_(src_module.weight.data[0])
        elif src_module.num_parameters > 1 and cloned_module.num_parameters > 1:
            # Channel-wise parameters need proper expansion
            expansion = cloned_module.num_parameters // src_module.num_parameters
            for i in range(expansion):
                cloned_module.weight.data[i::expansion] = src_module.weight.data
        else:
            # Direct copy if dimensions match
            cloned_module.weight.data.copy_(src_module.weight.data)
    
    # Handle other parameterized activation functions if they have weights
    # This is a general catch-all for any other activation function with parameters
    elif hasattr(src_module, 'weight') and hasattr(cloned_module, 'weight'):
        if src_module.weight is not None and cloned_module.weight is not None:
            if cloned_module.weight.data.shape == src_module.weight.data.shape:
                cloned_module.weight.data.copy_(src_module.weight.data)
    
    return cloned_module


def clone_dropout(src_module: nn.Dropout, cloned_module: nn.Dropout):
    """Clone dropout module parameters."""
    assert cloned_module.p == src_module.p, "Dropout probability must match"
    # Print warning if dropout p > 0
    if cloned_module.p > 0:
        print(f"Warning: Dropout probability is set to {cloned_module.p}, cloning is not perfect")
    return cloned_module


def is_parameter_free(module: nn.Module) -> bool:
    """Check if a module has no parameters."""
    return len(list(module.parameters())) == 0


def clone_parameter_free(src_module: nn.Module, cloned_module: nn.Module) -> nn.Module:
    """Clone a parameter-free module."""
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    assert is_parameter_free(src_module), "Source module must be parameter free"
    assert is_parameter_free(cloned_module), "Cloned module must be parameter free"
    
    # For parameter-free modules, there's no need to copy weights
    # Just make sure they're of the same type, which we've already checked
    return cloned_module


# Validation functions

def validate_activation_cloning(src_module: ActivationFunction, cloned_module: ActivationFunction):
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    
    # Validate configuration parameters for different activation types
    if isinstance(src_module, nn.LeakyReLU):
        assert src_module.negative_slope == cloned_module.negative_slope, "LeakyReLU negative_slope does not match"
    
    elif isinstance(src_module, (nn.ELU, nn.CELU)):
        assert src_module.alpha == cloned_module.alpha, "Alpha parameter does not match"
    
    elif isinstance(src_module, nn.Threshold):
        assert src_module.threshold == cloned_module.threshold, "Threshold value does not match"
        assert src_module.value == cloned_module.value, "Replacement value does not match"
    
    elif isinstance(src_module, (nn.Softmax, nn.LogSoftmax)):
        assert src_module.dim == cloned_module.dim, "Dimension parameter does not match"
    
    elif isinstance(src_module, (nn.Hardshrink, nn.Softshrink)):
        assert src_module.lambd == cloned_module.lambd, "Lambda parameter does not match"
        
    elif isinstance(src_module, nn.GLU):
        assert src_module.dim == cloned_module.dim, "Dimension parameter does not match"
    
    # Validate PReLU parameters
    elif isinstance(src_module, nn.PReLU):
        if src_module.num_parameters == 1 and cloned_module.num_parameters > 1:
            # All elements should be equal to the single parameter
            assert torch.all(cloned_module.weight.data == src_module.weight.data[0])
        elif src_module.num_parameters > 1 and cloned_module.num_parameters > 1:
            expansion = cloned_module.num_parameters // src_module.num_parameters
            for i in range(expansion):
                assert torch.allclose(cloned_module.weight.data[i::expansion], src_module.weight.data)
    
    print("Passed all tests")
    return True


def validate_dropout_cloning(src_module: nn.Dropout, cloned_module: nn.Dropout):
    assert cloned_module.p == src_module.p, "Dropout probability must match"
    print("Passed all tests")
    return True


def validate_embedding_cloning(src_module: nn.Embedding, cloned_module: nn.Embedding):
    num_expansion = cloned_module.num_embeddings // src_module.num_embeddings
    dim_expansion = cloned_module.embedding_dim // src_module.embedding_dim
    for j in range(num_expansion):
        for i in range(dim_expansion):
            assert torch.allclose(cloned_module.weight.data[j::num_expansion, i::dim_expansion], src_module.weight.data)
    print("Passed all tests")
    return True


def validate_normalization_cloning(src_module: NormalizationLayer, cloned_module: NormalizationLayer):
    assert isinstance(cloned_module, type(src_module)), "Cloned module must be of the same type as source module"
    
    if hasattr(src_module, 'weight') and src_module.weight is not None and hasattr(cloned_module, 'weight'):
        expansion = cloned_module.weight.data.shape[0] // src_module.weight.data.shape[0] 
        for i in range(expansion):
            assert torch.allclose(cloned_module.weight.data[i::expansion], src_module.weight.data)
            
            if hasattr(src_module, 'bias') and src_module.bias is not None and hasattr(cloned_module, 'bias'):
                assert torch.allclose(cloned_module.bias.data[i::expansion], src_module.bias.data)
    
    # Check running stats for BatchNorm layers
    if hasattr(src_module, 'running_mean') and src_module.running_mean is not None:
        if hasattr(cloned_module, 'running_mean') and cloned_module.running_mean is not None:
            expansion = cloned_module.running_mean.data.shape[0] // src_module.running_mean.data.shape[0]
            for i in range(expansion):
                assert torch.allclose(cloned_module.running_mean.data[i::expansion], src_module.running_mean.data)
                assert torch.allclose(cloned_module.running_var.data[i::expansion], src_module.running_var.data)
    
    print("Passed all tests")
    

def validate_linear_cloning(src_module: nn.Linear, cloned_module: nn.Linear):
    in_expansion = cloned_module.in_features // src_module.in_features
    out_expansion = cloned_module.out_features // src_module.out_features
    for j in range(out_expansion):
        for i in range(in_expansion):
            assert torch.allclose(cloned_module.weight.data[j::out_expansion, i::in_expansion], src_module.weight.data/in_expansion)
            assert torch.allclose(cloned_module.bias.data[j::out_expansion], src_module.bias.data)
    print("Passed all tests")
    
    
def validate_conv1d_cloning(src_module: nn.Conv1d, cloned_module: nn.Conv1d):
    in_expansion = cloned_module.in_channels // src_module.in_channels
    out_expansion = cloned_module.out_channels // src_module.out_channels
    for j in range(out_expansion):
        for i in range(in_expansion):
            assert torch.allclose(cloned_module.weight.data[j::out_expansion, i::in_expansion, :], src_module.weight.data/in_expansion)
            assert torch.allclose(cloned_module.bias.data[j::out_expansion], src_module.bias.data)
    print("Passed all tests")
    

def validate_conv2d_cloning(src_module: nn.Conv2d, cloned_module: nn.Conv2d):
    in_expansion = cloned_module.in_channels // src_module.in_channels
    out_expansion = cloned_module.out_channels // src_module.out_channels
    for j in range(out_expansion):
        for i in range(in_expansion):
            assert torch.allclose(cloned_module.weight.data[j::out_expansion, i::in_expansion, :, :], src_module.weight.data/in_expansion)
            assert torch.allclose(cloned_module.bias.data[j::out_expansion], src_module.bias.data)
    print("Passed all tests")


def clone_module(
    src_module: nn.Module, 
    cloned_module: nn.Module,
) -> bool:
    """
    Clone parameters from a source module to a cloned module.
    
    Args:
        src_module: Source module with smaller dimensions
        cloned_module: Target module with larger dimensions
        
    Returns:
        bool: True if cloning was successful, False otherwise
    """
    success = True
    
    # Define normalization and activation types inline for easier checking
    norm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm)
    activation_types = (nn.ReLU, nn.Sigmoid, nn.Tanh, nn.SELU, nn.GELU, nn.SiLU, nn.ELU, nn.LeakyReLU, 
                       nn.PReLU, nn.Threshold, nn.Softmax, nn.LogSoftmax, nn.Softplus, nn.Softmin, 
                       nn.Hardsigmoid, nn.Hardswish, nn.Softshrink, nn.Hardshrink, nn.Softsign, 
                       nn.GLU, nn.CELU, nn.Identity)
    
    if isinstance(src_module, nn.Linear):
        clone_linear(src_module, cloned_module)
    elif isinstance(src_module, nn.Conv1d):
        clone_conv1d(src_module, cloned_module)
    elif isinstance(src_module, nn.Conv2d):
        clone_conv2d(src_module, cloned_module)
    elif isinstance(src_module, norm_types):
        clone_normalization(src_module, cloned_module)
    elif isinstance(src_module, nn.Embedding):
        clone_embedding(src_module, cloned_module)
    elif isinstance(src_module, activation_types):
        clone_activation(src_module, cloned_module)
    elif isinstance(src_module, nn.Dropout):
        clone_dropout(src_module, cloned_module)
    elif is_parameter_free(src_module) and is_parameter_free(cloned_module):
        clone_parameter_free(src_module, cloned_module)
    else:
        success = False
        print(f"Unsupported module type: {type(src_module)}")
    
    return success


def clone_model(src_model: nn.Module, cloned_model: nn.Module) -> nn.Module:
    """
    Clone parameters from a source model to a cloned model.
    
    Args:
        src_model: Source model with smaller dimensions
        cloned_model: Target model with larger dimensions
        
    Returns:
        cloned_model: The target model with cloned parameters
    """
    # Keep track of processed modules to avoid double-processing
    processed_modules = set()
    
    # Get direct children modules only, to avoid processing nested modules multiple times
    src_direct_children = {name: module for name, module in src_model.named_children()}
    
    for name, src_child in src_direct_children.items():
        try:
            cloned_child = cloned_model.get_submodule(name)
            print(f"Cloning module {name}")
            
            # Clone the current module
            success = clone_module(src_child, cloned_child)
            if not success:
                print(f"Warning: Failed to clone module {name} directly")
                
                # If direct cloning failed, try recursively cloning its children
                recursive_success = clone_model(src_child, cloned_child)
                
        except AttributeError:
            print(f"Warning: Could not find matching module for {name} in the cloned model")
            
    return cloned_model


def clone_model(src_model: nn.Module, cloned_model: nn.Module) -> nn.Module:
    """
    Clone parameters from a source model to a cloned model.
    
    Args:
        src_model: Source model with smaller dimensions
        cloned_model: Target model with larger dimensions
        
    Returns:
        cloned_model: The target model with cloned parameters
    """
    
    # Process each module individually
    for name, src_module in src_model.named_modules():
        cloned_module = cloned_model.get_submodule(name)
        print(f"Cloning module {name}")
        clone_module(src_module, cloned_module)
    
    return cloned_model

In [63]:
from src.models import MLP, CNN, ResNet, VisionTransformer

src_model = MLP(input_size=10, output_size=2, hidden_sizes=[64, 32,], activation="relu", dropout_p=0.0)
cloned_model = MLP(input_size=10, output_size=2, hidden_sizes=[64*2, 32*2,], activation="relu", dropout_p=0.0)
cloned_model = clone_model(src_model, cloned_model)

x = torch.randn(32, 10)
y1 = src_model(x)
y2 = cloned_model(x)
print("Output from source model:", y1)
print("Output from cloned model:", y2)


Cloning module 
Unsupported module type: <class 'src.models.mlp.MLP'>
Cloning module layers
Unsupported module type: <class 'torch.nn.modules.container.ModuleDict'>
Cloning module layers.linear_0
Cloning Linear module: 10→10, 64→128, in expansion: 1, out expansion: 2
Cloning module layers.act_0
Cloning module layers.linear_1
Cloning Linear module: 64→128, 32→64, in expansion: 2, out expansion: 2
Cloning module layers.act_1
Cloning module layers.out
Cloning Linear module: 32→64, 2→2, in expansion: 2, out expansion: 1
Output from source model: tensor([[-0.3410, -0.0723],
        [-0.3435, -0.1194],
        [-0.2754, -0.0096],
        [-0.2427, -0.0972],
        [-0.2637, -0.0710],
        [-0.1828, -0.1549],
        [-0.3514, -0.0417],
        [-0.2227, -0.0717],
        [-0.2692, -0.0477],
        [-0.1762, -0.0253],
        [-0.2311, -0.0928],
        [-0.1939, -0.0128],
        [-0.2574, -0.1233],
        [-0.3244, -0.2309],
        [-0.1472, -0.0440],
        [-0.2799, -0.0553],
    

In [82]:
w1, w2 = src_model.layers.linear_0.weight.data, cloned_model.layers.linear_0.weight.data
b1, b2 = src_model.layers.linear_0.bias.data, cloned_model.layers.linear_0.bias.data
w1.shape, w2.shape

torch.allclose(w2[::2,:], w1),torch.allclose(b2[::2], b1)

w1, w2 = src_model.layers.linear_1.weight.data, cloned_model.layers.linear_1.weight.data
b1, b2 = src_model.layers.linear_1.bias.data, cloned_model.layers.linear_1.bias.data

torch.allclose(w2[::2,::2], w1/2),torch.allclose(w2[1::2,::2], w1/2),torch.allclose(w2[::2,1::2], w1/2),torch.allclose(w2[1::2,1::2], w1/2),torch.allclose(b2[::2], b1)

(True, True, True, True, True)

In [98]:
h1, h2 = src_model.layers.linear_0(x), cloned_model.layers.linear_0(x)
h1, h2 = src_model.layers.act_0(h1), cloned_model.layers.act_0(h2)
h1, h2 = src_model.layers.linear_1(h1), cloned_model.layers.linear_1(h2)
h1, h2 = src_model.layers.act_1(h1), cloned_model.layers.act_1(h2)
h1, h2 = src_model.layers.out(h1), cloned_model.layers.out(h2)
h1.shape, h2.shape
if h1.shape[1] != h2.shape[1]:
    assert h2.shape[1] % h1.shape[1] == 0, "Output dimensions are not integer multiples"
    expansion_factor = h2.shape[1] // h1.shape[1]
    for i in range(expansion_factor):
        assert torch.allclose(h2[:,i::expansion_factor], h1, atol=1e-5), f"Output mismatch at expansion factor {i}"
print("All activations are close ")

All activations are close 


In [97]:
w1, w2 = src_model.layers.out.weight.data, cloned_model.layers.out.weight.data
b1, b2 = src_model.layers.out.bias.data, cloned_model.layers.out.bias.data

w1.shape, w2.shape, b1.shape, b2.shape
torch.allclose(w2[:,::2],w1/2),torch.allclose(w2[:,1::2],w1/2), torch.allclose(b2, b1)

(True, True, True)

tensor([[-5.9605e-08,  0.0000e+00,  5.9605e-08,  2.9802e-08],
        [ 0.0000e+00,  0.0000e+00,  5.9605e-08,  0.0000e+00],
        [ 2.9802e-08,  0.0000e+00,  0.0000e+00,  2.9802e-08],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-5.9605e-08,  0.0000e+00,  1.4901e-08,  0.0000e+00]],
       grad_fn=<SubBackward0>)