# October 4 - Weight saving and loading mechanism

## Goal : Undertand how the weight saving and loading mechanism works in the case of modular architectures

In [10]:
import torch
from torch import nn

### Using the state_dict for a EresNet-34 VAE

In [2]:
state_dict_path = "/home/akajal/WatChMaL/VAE/dumps/20190929_184731/ENet_best.pth"

# Open the file and read the state dict
with open(state_dict_path, "rb") as f:
    checkpoint = torch.load(f)


In [4]:
print(list(checkpoint.keys()))

['global_step', 'state_dict']


### Read the values stored in the saved state_dict

In [5]:
global_step = checkpoint["global_step"]
state_dict = checkpoint["state_dict"]

In [6]:
print(global_step)

838631


In [7]:
print(state_dict)

OrderedDict([('encoder.conv1.weight', tensor([[[[-5.2940e-01]],

         [[-5.1946e-01]],

         [[-4.6545e-01]],

         [[-1.5377e-01]],

         [[-1.9471e-01]],

         [[-6.9736e-03]],

         [[-6.5698e-02]],

         [[ 8.3288e-01]],

         [[-8.7191e-02]],

         [[-1.3223e-01]],

         [[ 7.3852e-02]],

         [[-3.0275e-02]],

         [[ 3.7372e-01]],

         [[-2.1696e-01]],

         [[ 2.9207e-02]],

         [[ 4.7525e-01]],

         [[-4.6410e-02]],

         [[-1.8610e-01]],

         [[-2.5266e-01]]],


        [[[ 3.3186e-01]],

         [[-5.2408e-02]],

         [[-5.7941e-02]],

         [[ 4.3265e-02]],

         [[ 3.9217e-01]],

         [[ 2.5440e-01]],

         [[ 1.2374e-01]],

         [[ 1.9465e-01]],

         [[ 3.9171e-01]],

         [[ 3.0919e-01]],

         [[ 4.1020e-01]],

         [[ 6.7400e-01]],

         [[-2.0577e-02]],

         [[-1.8177e-01]],

         [[ 4.8402e-02]],

         [[ 2.2106e-01]],

         [[ 3.5

In [8]:
print(list(state_dict.keys()))

['encoder.conv1.weight', 'encoder.bn1.weight', 'encoder.bn1.bias', 'encoder.bn1.running_mean', 'encoder.bn1.running_var', 'encoder.bn1.num_batches_tracked', 'encoder.conv2.weight', 'encoder.bn2.weight', 'encoder.bn2.bias', 'encoder.bn2.running_mean', 'encoder.bn2.running_var', 'encoder.bn2.num_batches_tracked', 'encoder.layer0.conv1.weight', 'encoder.layer0.bn1.weight', 'encoder.layer0.bn1.bias', 'encoder.layer0.bn1.running_mean', 'encoder.layer0.bn1.running_var', 'encoder.layer0.bn1.num_batches_tracked', 'encoder.layer0.conv2.weight', 'encoder.layer0.bn2.weight', 'encoder.layer0.bn2.bias', 'encoder.layer0.bn2.running_mean', 'encoder.layer0.bn2.running_var', 'encoder.layer0.bn2.num_batches_tracked', 'encoder.layer1.0.conv1.weight', 'encoder.layer1.0.bn1.weight', 'encoder.layer1.0.bn1.bias', 'encoder.layer1.0.bn1.running_mean', 'encoder.layer1.0.bn1.running_var', 'encoder.layer1.0.bn1.num_batches_tracked', 'encoder.layer1.0.conv2.weight', 'encoder.layer1.0.bn2.weight', 'encoder.layer1.0

## Exploring the dictionary structure of nn.Module

In [11]:
# Latent Classifier
class Classifier(nn.Module):
    
    # Initializer
    def __init__(self, num_latent_dims=64, num_classes=3):
        super(Classifier, self).__init__()
        
        # Activation functions
        self.relu = nn.ReLU()
        
        # Classifier fully connected layers
        self.cl_fc1 = nn.Linear(num_latent_dims, int(num_latent_dims/2))
        self.cl_fc2 = nn.Linear(int(num_latent_dims/2), int(num_latent_dims/4))
        self.cl_fc3 = nn.Linear(int(num_latent_dims/4), int(num_latent_dims/8))
        self.cl_fc4 = nn.Linear(int(num_latent_dims/8), num_classes)
        
    # Forward
    def forward(self, X):
        
        # Fully-connected layers
        x = self.relu(self.cl_fc1(X))
        x = self.relu(self.cl_fc2(x))
        x = self.relu(self.cl_fc3(x))
        x = self.cl_fc4(x)
        
        return x

In [12]:
# Modular model
class ClassifierNet(nn.Module):
    
    # Initializer
    def __init__(self, num_latent_dims=64, num_classes=3):
        super(ClassifierNet, self).__init__()
        
        # Activation functions
        self.relu = nn.ReLU()
        
        # Add the module to this network
        self.classifier_1 = Classifier(num_latent_dims, num_classes)
        self.classifier_2 = Classifier(num_latent_dims, num_classes)
        
    # Forward pass
    def forward(self, X):
        
        # Pass the input through the classifiers
        x_1 = self.classifier_1(X)
        x_2 = self.classifier_2(X)
        
        return x_1, x_2

In [20]:
clnet = ClassifierNet()
print(list(clnet._modules.keys()))

['relu', 'classifier_1', 'classifier_2']


In [17]:
print(getattr(clnet,'classifier_1'))

Classifier(
  (relu): ReLU()
  (cl_fc1): Linear(in_features=64, out_features=32, bias=True)
  (cl_fc2): Linear(in_features=32, out_features=16, bias=True)
  (cl_fc3): Linear(in_features=16, out_features=8, bias=True)
  (cl_fc4): Linear(in_features=8, out_features=3, bias=True)
)


In [18]:
clnet_cl_1 = getattr(clnet, 'classifier_1')

In [19]:
print(clnet_cl_1.state_dict())

OrderedDict([('cl_fc1.weight', tensor([[ 0.0833, -0.0425,  0.1111,  ..., -0.0045,  0.0240, -0.0502],
        [ 0.0166, -0.0689, -0.0919,  ...,  0.1160,  0.0180, -0.1179],
        [ 0.0971, -0.1095, -0.1014,  ..., -0.0326,  0.1230, -0.0528],
        ...,
        [ 0.0431, -0.0536, -0.0607,  ..., -0.0416, -0.0963,  0.0955],
        [ 0.0013, -0.0647,  0.0186,  ...,  0.0542, -0.0490,  0.0222],
        [ 0.0363, -0.0425, -0.1097,  ...,  0.0014, -0.0285, -0.0102]])), ('cl_fc1.bias', tensor([ 0.0143, -0.0295,  0.0343, -0.0345,  0.0207,  0.0044,  0.0852,  0.0417,
         0.0376,  0.1068,  0.1030, -0.1175, -0.0568, -0.0204, -0.0471, -0.0108,
        -0.0843,  0.0345, -0.1212,  0.0070,  0.1219,  0.1044,  0.0059, -0.1041,
        -0.0813, -0.1129, -0.0590, -0.0871,  0.0264, -0.0855,  0.0473,  0.0013])), ('cl_fc2.weight', tensor([[ 0.1510,  0.0725,  0.0351,  0.0638,  0.0118, -0.1110,  0.0677,  0.1767,
         -0.1668,  0.1076,  0.0448, -0.0178,  0.0614,  0.0572,  0.0748, -0.0482,
         -0.07

In [21]:
model_params = clnet

In [22]:
modules = list(model_params._modules.keys())
state_dict = {module:getattr(model_params, module).state_dict() for module in modules}
torch.save(state_dict, "clnet_weights")

In [23]:
# Open the file and read the state dict
with open("clnet_weights", "rb") as f:
    checkpoint = torch.load(f)

In [25]:
print(list(checkpoint.keys()))

['relu', 'classifier_1', 'classifier_2']


In [27]:
getattr(clnet, 'classifier_1').load_state_dict(checkpoint['classifier_1'], strict=False)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [28]:
getattr(clnet, 'classifier_2').load_state_dict(checkpoint['classifier_1'], strict=False)

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [31]:
getattr(clnet, 'classifier_1').load_state_dict(checkpoint['relu'], strict=True)

RuntimeError: Error(s) in loading state_dict for Classifier:
	Missing key(s) in state_dict: "cl_fc1.weight", "cl_fc1.bias", "cl_fc2.weight", "cl_fc2.bias", "cl_fc3.weight", "cl_fc3.bias", "cl_fc4.weight", "cl_fc4.bias". 

In [30]:
getattr(clnet, 'cl')

AttributeError: 'ClassifierNet' object has no attribute 'cl'