In [1]:
from IPython.display import display, HTML

display(HTML(data="""
<style>
    div#notebook-container    { width: 95%; }
    div#menubar-container     { width: 65%; }
    div#maintoolbar-container { width: 99%; }
</style>
"""))

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch.nn as nn
import torch
from torchvision.models import resnet34
from torchlego.models.resnet import ResNetBasicBlock
from torchlego.utils import ModuleTransfer, Tracker
from torchsummary import summary


In [4]:
resnet34_my = nn.Sequential(
    nn.Conv2d(3,
                      64,
                      kernel_size=7,
                      stride=2,
                      padding=3,
                      bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
    ResNetBasicBlock(64, 64),
    ResNetBasicBlock(64, 64),
    ResNetBasicBlock(64, 64),
    ResNetBasicBlock(64, 128),
    ResNetBasicBlock(128, 128),
    ResNetBasicBlock(128, 128),
    ResNetBasicBlock(128, 128),
    ResNetBasicBlock(128, 256),
    ResNetBasicBlock(256, 256),
    ResNetBasicBlock(256, 256),
    ResNetBasicBlock(256, 256),
    ResNetBasicBlock(256, 256),
    ResNetBasicBlock(256, 256),
    ResNetBasicBlock(256, 512),
    ResNetBasicBlock(512, 512),
    ResNetBasicBlock(512, 512),
    nn.AdaptiveAvgPool2d(output_size=(1, 1)),
    nn.Flatten(),
    nn.Linear(512, 1000)
    
)

In [5]:
src_model = resnet34(True).eval().cpu()
dest_model = resnet34_my.eval()

x = torch.zeros((1, 3, 224, 244))

src_tr = Tracker(src_model)
dest_tr = Tracker(dest_model)

src_operations = src_tr(x).parametrized
dest_operations = dest_tr(x).parametrized


In [6]:
# for src_op, dest_op in zip(src_operations, dest_operations):
# #   
#     print('-------')
#     print(src_op, '|', dest_op)
#     if 'weight' in dest_op.state_dict().keys():
#         assert dest_op.state_dict()['weight'].sum() != src_op.state_dict()['weight'].sum()


In [7]:
import time 

for src_op, dest_op in zip(src_operations, dest_operations):
#     print(f'Transfered from={src_op} to={dest_op}')
    print('-------')
    print(src_op, '|', dest_op)
    dest_op.load_state_dict(src_op.state_dict())
    if 'weight' in dest_op.state_dict().keys():
        assert dest_op.state_dict()['weight'].sum() == src_op.state_dict()['weight'].sum()


-------
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
-------
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
-------
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
-------
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
-------
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
-------
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [8]:
for src_op, dest_op in zip(src_operations, dest_operations):
    print('-------')
    print(src_op, '|', dest_op)
    if 'weight' in src_op.state_dict().keys():
            print(dest_op.state_dict()['weight'].sum(), dest_op)
            print(src_op.state_dict()['weight'].sum(), src_op)
            assert dest_op.state_dict()['weight'].sum() == src_op.state_dict()['weight'].sum()

-------
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
tensor(-1.3706) Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
tensor(-1.3706) Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
-------
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) | BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
tensor(16.8200) BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
tensor(16.8200) BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
-------
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) | Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(-82.0177) Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(-82.0177) Conv2d(64,

In [9]:
src_model(x)

tensor([[-5.7581e-01, -2.4327e-01, -2.2948e+00, -1.9053e+00, -2.4730e+00,
         -2.1047e-01, -3.5750e+00, -1.7740e+00, -1.9271e+00, -9.9620e-01,
          1.6648e+00, -4.0727e-01, -1.0489e+00, -1.4315e+00, -1.5815e+00,
         -7.8157e-01, -7.4532e-01, -1.7407e+00, -2.1929e+00, -1.5676e+00,
         -1.9657e+00,  6.2797e-01, -9.6306e-01, -6.2646e-01, -1.6638e+00,
         -2.4488e+00, -1.5220e+00, -1.8061e+00, -2.9531e-01, -9.5689e-01,
         -4.1994e+00, -1.6529e+00, -2.1386e+00, -1.8823e+00, -1.0101e+00,
         -3.7073e+00, -1.9570e+00, -3.5193e+00,  5.8777e-01, -1.9691e+00,
         -7.9854e-01, -1.2132e+00, -3.8929e-01,  1.2133e+00, -9.5485e-01,
         -4.0284e-01, -1.1763e+00, -8.9208e-01, -1.9620e+00, -3.0837e+00,
         -2.4837e+00,  8.9620e-01, -7.8899e-01, -1.2444e+00, -1.4835e+00,
         -2.8464e+00, -2.3781e+00, -3.0220e+00, -3.1995e+00,  1.3249e+00,
         -7.2586e-01, -2.5133e+00,  1.6638e-01,  3.1705e-01, -7.1929e-02,
         -5.3640e-01,  1.0295e-01, -2.

In [10]:
dest_model(x)

tensor([[-5.7581e-01, -2.4327e-01, -2.2948e+00, -1.9053e+00, -2.4730e+00,
         -2.1047e-01, -3.5750e+00, -1.7740e+00, -1.9271e+00, -9.9620e-01,
          1.6648e+00, -4.0727e-01, -1.0489e+00, -1.4315e+00, -1.5815e+00,
         -7.8157e-01, -7.4532e-01, -1.7407e+00, -2.1929e+00, -1.5676e+00,
         -1.9657e+00,  6.2797e-01, -9.6306e-01, -6.2646e-01, -1.6638e+00,
         -2.4488e+00, -1.5220e+00, -1.8061e+00, -2.9531e-01, -9.5689e-01,
         -4.1994e+00, -1.6529e+00, -2.1386e+00, -1.8823e+00, -1.0101e+00,
         -3.7073e+00, -1.9570e+00, -3.5193e+00,  5.8777e-01, -1.9691e+00,
         -7.9854e-01, -1.2132e+00, -3.8929e-01,  1.2133e+00, -9.5485e-01,
         -4.0284e-01, -1.1763e+00, -8.9208e-01, -1.9620e+00, -3.0837e+00,
         -2.4837e+00,  8.9620e-01, -7.8899e-01, -1.2444e+00, -1.4835e+00,
         -2.8464e+00, -2.3781e+00, -3.0220e+00, -3.1995e+00,  1.3249e+00,
         -7.2586e-01, -2.5133e+00,  1.6638e-01,  3.1705e-01, -7.1929e-02,
         -5.3640e-01,  1.0295e-01, -2.