/home/amir/Codes/NN-dynamic-scaling already in Python path


Added /home/amir/Codes/NN-dynamic-scaling to Python path


In [None]:
import torch
import torch.nn as nn
# Import the utility module and Setup the path
import notebook_utils
notebook_utils.setup_path()

from src.models.mlp import MLP

def clone_model_parameters(src_model, cloned_model):
    """
    Clone parameters from a smaller model to a larger model using a module-based approach.
    
    For linear layers, weights are scaled by 1/n (where n is the input expansion factor)
    to ensure equivalent functionality after cloning.
    
    Args:
        src_model: Source model with smaller dimensions
        cloned_model: Target model with larger dimensions
        
    Returns:
        cloned_model: The target model with cloned parameters
    """
    # First verify model structures
    src_modules = {name: module for name, module in src_model.named_modules() if isinstance(module, nn.Linear)}
    cloned_modules = {name: module for name, module in cloned_model.named_modules() if isinstance(module, nn.Linear)}
    
    # Check if modules match
    if set(src_modules.keys()) != set(cloned_modules.keys()):
        raise ValueError("Source and cloned models have different module structures")
    
    # Process each module individually
    for name, src_module in src_modules.items():
        cloned_module = cloned_modules[name]
        
        # Get module dimensions
        src_in_features = src_module.in_features
        src_out_features = src_module.out_features
        cloned_in_features = cloned_module.in_features
        cloned_out_features = cloned_module.out_features
        
        # Calculate expansion factors
        in_expansion = cloned_in_features // src_in_features
        out_expansion = cloned_out_features // src_out_features
        
        print(f"Cloning module {name}: {src_in_features}→{cloned_in_features}, {src_out_features}→{cloned_out_features}, in expansion: {in_expansion}, out expansion: {out_expansion}")
        
        
        # Verify expansion factors are valid
        if cloned_in_features % src_in_features != 0 or cloned_out_features % src_out_features != 0:
            raise ValueError(f"Module {name} dimensions are not integer multiples: "
                             f"{src_in_features}→{cloned_in_features}, {src_out_features}→{cloned_out_features}")
        
        # Clone the weights with proper scaling
        for i in range(in_expansion):
            for j in range(out_expansion):
                cloned_module.weight.data[j::out_expansion, i::in_expansion] = src_module.weight.data / in_expansion
    
        
        # Clone the bias if present (no scaling needed for bias)
        if src_module.bias is not None and cloned_module.bias is not None:
            # cloned the bias vector 
            for j in range(out_expansion):
                cloned_module.bias.data[j::out_expansion] = src_module.bias.data
    
    # For non-linear modules (if any), copy parameters without scaling
    for name, param in src_model.named_parameters():
        # Skip parameters that belong to linear layers (already handled)
        if any(module_name in name for module_name in src_modules.keys()):
            continue
        
        if name in cloned_model.state_dict():
            cloned_param = cloned_model.state_dict()[name]
            src_shape = torch.tensor(param.shape)
            cloned_shape = torch.tensor(cloned_param.shape)
            
            # If shapes match, directly copy
            if tuple(src_shape) == tuple(cloned_shape):
                cloned_param.copy_(param)
                print(f"Parameter {name} copied directly (dimensions match)")
            else:
                # For other parameters that need expansion but no scaling
                non_matching_dims = src_shape != cloned_shape
                
                # Create indices for blockwise expansion
                indices = []
                expansion_info = []
                
                for i, (s1, s2) in enumerate(zip(src_shape, cloned_shape)):
                    if s1 == s2:
                        indices.append(torch.arange(s2))
                    else:
                        expansion_factor = s2 // s1
                        expansion_info.append(f"dim {i}: {expansion_factor}x")
                        indices.append(torch.div(torch.arange(s2), expansion_factor, rounding_mode='floor'))
                
                # Create the grid and copy
                grid = torch.meshgrid(*indices, indexing='ij')
                cloned_param.copy_(param[grid])
                
                if expansion_info:
                    print(f"Parameter {name} cloned with blockwise expansion: {', '.join(expansion_info)}")
    
    return cloned_model
    
    
def test_model_cloning(src_model, cloned_model):
    passed = 0
    total = 0
    for name, module in src_model.named_modules():
        if isinstance(module, nn.Linear):
            # print(f"source module {name}: {module.in_features}→{module.out_features},  cloned module {cloned_model.get_submodule(name).in_features}→{cloned_model.get_submodule(name).out_features}")
            module2 = cloned_model.get_submodule(name)
            in_expansion = module2.in_features // module.in_features
            out_expansion = module2.out_features // module.out_features
            # print(f"Expansion factors (outxin): {out_expansion}x{in_expansion}")
            for j in range(out_expansion):
                for i in range(in_expansion):
                    passed += torch.allclose(module2.weight.data[j::out_expansion, i::in_expansion], module.weight.data/in_expansion)
                    passed += torch.allclose(module2.bias.data[j::out_expansion], module.bias.data)
                    total += 2
    
    print(f"Passed {passed} out of {total} tests")
    return passed == total  

def test_activation_clonign(src_model, cloned_model):
    from src.utils.monitor import NetworkMonitor

    src_monitor = NetworkMonitor(src_model, )
    cloned_monitor = NetworkMonitor(cloned_model, )
    src_monitor.register_hooks()
    cloned_monitor.register_hooks()


    d = src_model.input_size
    x = torch.randn(10, d)
    src_model(x)
    cloned_model(x)

    acts, acts2 = src_monitor.get_latest_activations(), cloned_monitor.get_latest_activations()

    passed = 0
    total = 0
    for key in acts.keys():
        a1, a2 = acts[key], acts2[key]
        if a1.shape[1] != a2.shape[1]:
            diffs = (a1 - a2[:,::2])
        else: #only for the last layer
            diffs = (a1 - a2)
        # print(f"Diff for {key}: {diffs.abs().max().item()}")
        passed += diffs.abs().max().item() < 1e-5
        total += 1
    print(f"Passed {passed} out of {total} tests")
    return passed == total

# src_model = MLP(input_size=10, output_size=2, hidden_sizes=[64, 32, 16], activation="relu", dropout_p=0.0)
# cloned_model = MLP(input_size=10, output_size=2, hidden_sizes=[64*2, 32*2, 16*2], activation="relu", dropout_p=0.0)

# cloned_model = clone_model_parameters(src_model, cloned_model)
# test1(src_model, cloned_model)
# test2(src_model, cloned_model)


# Example usage
if __name__ == "__main__":
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Create source and target models with random weights
    src_model = MLP(input_size=10, output_size=2, hidden_sizes=[64, 32, 16], activation="relu", dropout_p=0.0)
    cloned_model = MLP(input_size=10, output_size=2, hidden_sizes=[64*2, 32*2, 16*2], activation="relu", dropout_p=0.0)
    
    # Clone the parameters
    cloned_model = clone_model_parameters(src_model, cloned_model)
    
    # Test the cloning with a functional test
    success = test_model_cloning(src_model, cloned_model)
    print("\nModel cloning test:", "PASSED" if success else "FAILED")

    
    
    success = test_activation_clonign(src_model, cloned_model)
    print("\nActivation cloning test:", "PASSED" if success else "FAILED")
    

/home/amir/Codes/NN-dynamic-scaling already in Python path
Cloning module layers.linear_0: 10→10, 64→128, in expansion: 1, out expansion: 2
Cloning module layers.linear_1: 64→128, 32→64, in expansion: 2, out expansion: 2
Cloning module layers.linear_2: 32→64, 16→32, in expansion: 2, out expansion: 2
Cloning module layers.out: 16→32, 2→2, in expansion: 2, out expansion: 1
Passed 24 out of 24 tests

Model cloning test: PASSED
Passed 7 out of 7 tests

Activation cloning test: PASSED


In [86]:
from src.models import CNN
model = CNN()
model.layers.conv_0.weight.data.shape
for k,v in model.layers.norm_2.named_parameters():
    print(k,v.shape, v[:5])
model.layers.norm_2

weight torch.Size([256]) tensor([1., 1., 1., 1., 1.], grad_fn=<SliceBackward0>)
bias torch.Size([256]) tensor([0., 0., 0., 0., 0.], grad_fn=<SliceBackward0>)


BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [96]:
from src.models import MLP, CNN, ResNet, VisionTransformer

model = ResNet()
model

ResNet(
  (layers): ModuleDict(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU()
    (layer1_block0): BasicBlock(
      (layers): ModuleDict(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer1_block1): BasicBlock(
      (layers): ModuleDict(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU(

In [100]:
l = nn.BatchNorm1d(10)
for k,v in l.named_parameters():
    print(k,v.shape, v[:5])

weight torch.Size([10]) tensor([1., 1., 1., 1., 1.], grad_fn=<SliceBackward0>)
bias torch.Size([10]) tensor([0., 0., 0., 0., 0.], grad_fn=<SliceBackward0>)


Cloning module layers.linear_0: 10→10, 64→128, in expansion: 1, out expansion: 2
Cloning module layers.linear_1: 64→128, 32→64, in expansion: 2, out expansion: 2
Cloning module layers.linear_2: 32→64, 16→32, in expansion: 2, out expansion: 2
Cloning module layers.out: 16→32, 2→2, in expansion: 2, out expansion: 1
Diff for layers.linear_0: 0.0
Diff for layers.act_0: 0.0
Diff for layers.linear_1: 1.7881393432617188e-07
Diff for layers.act_1: 1.7881393432617188e-07
Diff for layers.linear_2: 5.960464477539063e-08
Diff for layers.act_2: 5.960464477539063e-08
Diff for layers.out: 0.08214077353477478


In [61]:
src_model.layers.linear_0.bias.data[0], cloned_model.layers.linear_0.bias.data[0]

(tensor(-0.0097), tensor(-0.2224))

In [37]:
for key in acts.keys():
    if acts[key].shape[0] != acts2[key].shape[0]:
        print(f"case 1: Diff max for {key}: {(acts[key] - acts2[key][::2]).abs().max().item()}")
    elif acts[key].shape[1] != acts2[key].shape[1] and acts[key].shape[0] == acts2[key].shape[0]:
        print(f"case 2: Diff max for {key}: {(acts[key] - acts2[key][:,::2]).abs().max().item()}")
    elif acts[key].shape[1] != acts2[key].shape[1] and acts[key].shape[0] != acts2[key].shape[0]:
        print(f"case 2: Diff max for {key}: {(acts[key] - acts2[key][::2,::2]).abs().max().item()}")
    else:
        diffs = acts[key]- acts2[key]
        print (f"case 3: Diff max for {key}: {diffs.abs().max().item()}")    
    

case 2: Diff max for layers.linear_0: 2.2091572284698486
case 2: Diff max for layers.act_0: 1.2357994318008423
case 2: Diff max for layers.linear_1: 1.0886716842651367
case 2: Diff max for layers.act_1: 0.828102171421051
case 2: Diff max for layers.linear_2: 0.52223801612854
case 2: Diff max for layers.act_2: 0.35224783420562744
case 3: Diff max for layers.out: 0.4210308790206909


In [24]:
src_model.layers.linear_1.weight.data- cloned_model.layers.linear_1.weight.data[::2,::2]/2

tensor([[ 0.0235, -0.0760,  0.0396,  ...,  0.0529,  0.0249, -0.0878],
        [-0.0603,  0.0927,  0.0358,  ..., -0.0336,  0.0566, -0.0045],
        [-0.0821, -0.0516, -0.0679,  ...,  0.0044,  0.0684,  0.0292],
        ...,
        [-0.0473,  0.0655,  0.0621,  ..., -0.0619, -0.0801,  0.0128],
        [-0.0593,  0.0851, -0.0510,  ...,  0.0862, -0.0646, -0.0052],
        [-0.0688,  0.0260,  0.0799,  ...,  0.0811,  0.0772,  0.0347]])

In [None]:
lin = nn.Linear(2, 4)
lin2 = nn.Linear(4, 12)
m, n = torch.tensor(lin2.weight.data.shape)//torch.tensor(lin.weight.data.shape)
m, n = m.item(), n.item()
for j in range(m):
    lin2.bias.data[j::m] = lin.bias.data[:]
    for i in range(n):
        lin2.weight.data[j::m, i::n] = lin.weight.data / n

x = torch.randn(5, 2)
x2 = torch.randn(5, 2*n)
for j in range(n):
    x2[:,j::n] = x
x2[:,::2] == x, x2[:,1::2] == x # this holdds 

y = lin(x)
y2 = lin2(x2)

y2[:,::m] - y

# lin2.weight.data[::2,::2] = lin.weight.data

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-5.9605e-08,  2.9802e-08,  0.0000e+00, -2.9802e-08],
        [ 1.4901e-08,  2.9802e-08,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],
       grad_fn=<SubBackward0>)

In [80]:
y.shape, x.shape, y2.shape, x2.shape

(torch.Size([5, 4]),
 torch.Size([5, 2]),
 torch.Size([5, 12]),
 torch.Size([5, 4]))

In [None]:
y2[:,::m] - y


tensor([[-5.9605e-08,  0.0000e+00,  5.9605e-08,  2.9802e-08],
        [ 0.0000e+00,  0.0000e+00,  5.9605e-08,  0.0000e+00],
        [ 2.9802e-08,  0.0000e+00,  0.0000e+00,  2.9802e-08],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-5.9605e-08,  0.0000e+00,  1.4901e-08,  0.0000e+00]],
       grad_fn=<SubBackward0>)