In [2]:
import torch

def check_weights_same(model1, model2):
    """
    Check if the weights of two PyTorch models are exactly the same.

    Args:
    model1: The first PyTorch model.
    model2: The second PyTorch model.

    Returns:
    bool: True if all weights are the same, False otherwise.
    """
    for (param1, param2) in zip(model1.parameters(), model2.parameters()):
        if param1.data.ne(param2.data).sum() > 0:
            return False
    return True

# Example usage:
# model1 = YourModelClass(*args, **kwargs)
# model2 = YourModelClass(*args, **kwargs)
# are_same = check_weights_same(model1, model2)
# print("Models have the same weights:", are_same)


In [3]:
import torch
from ofa.imagenet_classification.elastic_nn.networks import OFAResNets
from CNN_Pruning_Engine.Models.Resnet import ResNet101

model1 = OFAResNets(
        n_classes=100,
        bn_param=(0.1, 1e-5),
        dropout_rate=0.1,
        depth_list=1,
        expand_ratio_list=1,
        width_mult_list=1.0, 
    )
model2 = ResNet101(100)

weight = torch.load("weights/Model@ResNet101_ACC@79.89.pt")["state_dict"]
model2.load_state_dict(weight)

<All keys matched successfully>

In [4]:
model1.input_stem[0].conv.conv.load_state_dict(model2.conv1.state_dict())
model1.input_stem[0].bn.bn.load_state_dict(model2.bn1.state_dict()) 

layersx = [
            model2.layer1,
            model2.layer2,
            model2.layer3,
            model2.layer4
        ]
model2_blocks = []
for layers in layersx:
    for layer in layers:
        model2_blocks.append(layer)



for i,block in enumerate(model1.blocks,0):
    model1.blocks[i].conv1.conv.conv.load_state_dict(model2_blocks[i].conv1.state_dict())
    model1.blocks[i].conv1.bn.bn.load_state_dict(model2_blocks[i].bn1.state_dict())

    model1.blocks[i].conv2.conv.conv.load_state_dict(model2_blocks[i].conv2.state_dict())
    model1.blocks[i].conv2.bn.bn.load_state_dict(model2_blocks[i].bn2.state_dict())

    model1.blocks[i].conv3.conv.conv.load_state_dict(model2_blocks[i].conv3.state_dict())
    model1.blocks[i].conv3.bn.bn.load_state_dict(model2_blocks[i].bn3.state_dict())

    if len(model2_blocks[i].shortcut) > 0:
        model1.blocks[i].downsample.conv.conv.load_state_dict(model2_blocks[i].shortcut[0].state_dict())
        model1.blocks[i].downsample.bn.bn.load_state_dict(model2_blocks[i].shortcut[1].state_dict())
model1.classifier.linear.linear.load_state_dict(model2.linear.state_dict())

<All keys matched successfully>

In [5]:
torch.save({"state_dict":model1.state_dict()},"ResNet101OFA_ACC@79.89.pt")

In [6]:
import torch
from ofa.imagenet_classification.elastic_nn.networks import OFAMobileNetV2
from CNN_Pruning_Engine.Models.Mobilenetv2 import MobileNetV2
model1 = OFAMobileNetV2(
        num_classes=100,
        bn_param=(0.1, 1e-5),
        dropout_rate=0.1,
        depth_list=1,
        expand_ratio_list=1,
        width_mult_list=1.0, 
    )
model2 = MobileNetV2(num_classes=100)

weight = torch.load("weights/Model@Mobilenetv2_ACC@79.32.pt")
model2.load_state_dict(weight)


<All keys matched successfully>

In [7]:
model1.conv1.load_state_dict(model2.conv1.state_dict())
model1.bn1.load_state_dict(model2.bn1.state_dict())
model1.conv2.load_state_dict(model2.conv2.state_dict())
model1.bn2.load_state_dict(model2.bn2.state_dict())
model1.linear.load_state_dict(model2.linear.state_dict())


for layer_idx in range(len(model1.layers)):
    model1.layers[layer_idx].inverted_bottleneck.conv.conv.load_state_dict(model2.layers[layer_idx].conv1.state_dict())
    model1.layers[layer_idx].inverted_bottleneck.bn.bn.load_state_dict(model2.layers[layer_idx].bn1.state_dict())

    model1.layers[layer_idx].depth_conv.conv.conv.load_state_dict(model2.layers[layer_idx].conv2.state_dict())
    model1.layers[layer_idx].depth_conv.bn.bn.load_state_dict(model2.layers[layer_idx].bn2.state_dict())

    model1.layers[layer_idx].point_linear.conv.conv.load_state_dict(model2.layers[layer_idx].conv3.state_dict())
    model1.layers[layer_idx].point_linear.bn.bn.load_state_dict(model2.layers[layer_idx].bn3.state_dict())
    if len(model2.layers[layer_idx].shortcut) > 0:
        model1.layers[layer_idx].shortcut[0].load_state_dict(model2.layers[layer_idx].shortcut[0].state_dict())
        model1.layers[layer_idx].shortcut[1].load_state_dict(model2.layers[layer_idx].shortcut[1].state_dict())

In [8]:
check_weights_same(model1,model2)

True

In [9]:
torch.save({"state_dict":model1.state_dict()},"MobilenetV2OFA_ACC@79.32.pt")