In [16]:
import warnings
warnings.filterwarnings("ignore")

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch

name = './../model/LabVGG16.pth'
dictname = name[:-4] + '_dict' + name[-4:]
print(dictname)

net = torch.load(name,map_location=torch.device('cpu'))
dictnet = net.state_dict()
torch.save(dictnet, dictname)
count=0
for key in dictnet.keys():
    print("{}. {}   {}".format(count,key,dictnet[key].shape))
    count+=1

./../model/LabVGG16_dict.pth
0. conv1.0.weight   torch.Size([64, 3, 3, 3])
1. conv1.0.bias   torch.Size([64])
2. conv1.1.weight   torch.Size([64])
3. conv1.1.bias   torch.Size([64])
4. conv1.1.running_mean   torch.Size([64])
5. conv1.1.running_var   torch.Size([64])
6. conv1.1.num_batches_tracked   torch.Size([])
7. conv1.3.weight   torch.Size([64, 64, 3, 3])
8. conv1.3.bias   torch.Size([64])
9. pool1.0.weight   torch.Size([64])
10. pool1.0.bias   torch.Size([64])
11. pool1.0.running_mean   torch.Size([64])
12. pool1.0.running_var   torch.Size([64])
13. pool1.0.num_batches_tracked   torch.Size([])
14. conv2.0.weight   torch.Size([128, 64, 3, 3])
15. conv2.0.bias   torch.Size([128])
16. conv2.1.weight   torch.Size([128])
17. conv2.1.bias   torch.Size([128])
18. conv2.1.running_mean   torch.Size([128])
19. conv2.1.running_var   torch.Size([128])
20. conv2.1.num_batches_tracked   torch.Size([])
21. conv2.3.weight   torch.Size([128, 128, 3, 3])
22. conv2.3.bias   torch.Size([128])
23. poo

LAB VGG16: Trying to import up to pool2

In [21]:
import torch
import torch.nn as nn

def l2normalize(v, eps = 1e-12):
    return v / (v.norm() + eps)


class LabVGG16(nn.Module):
    def __init__(self, in_dim = 3, num_classes = 1000):
        super(LabVGG16, self).__init__()
        # feature extraction part
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels = in_dim, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        )
        self.pool1 = nn.Sequential(
            nn.ReLU(inplace = False),
            nn.MaxPool2d(kernel_size = 2, stride = 2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1)
        )
        self.pool2 = nn.Sequential(
            nn.ReLU(inplace = False),
            nn.MaxPool2d(kernel_size = 2, stride = 2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),
            #nn.ReLU(inplace = True),
            #nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)
        )
        self.pool3 = nn.Sequential(
            nn.ReLU(inplace = False),
            nn.MaxPool2d(kernel_size = 2, stride = 2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1),
            #nn.ReLU(inplace = True),
            #nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1)
        )
        self.pool4 = nn.Sequential(
            nn.ReLU(inplace = False),
            nn.MaxPool2d(kernel_size = 2, stride = 2)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1),
            #nn.ReLU(inplace = True),
            #nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 1, padding = 1)
        )
        """
        self.pool5 = nn.Sequential(
            nn.ReLU(inplace = False),
            nn.AdaptiveAvgPool2d((7, 7))
        )
        # classification part
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace = True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace = True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        ) """

    def forward(self, x):                                   # shape: [B, 3, 224, 224]
        conv1 = self.conv1(x)                               # shape: [B, 64, 224, 224]
        pool1 = self.pool1(conv1)                           # shape: [B, 64, 112, 112]
        conv2 = self.conv2(pool1)                           # shape: [B, 128, 112, 112]
        pool2 = self.pool2(conv2)                           # shape: [B, 128, 56, 56]
        conv3 = self.conv3(pool2)                           # shape: [B, 256, 56, 56]
        pool3 = self.pool3(conv3)                           # shape: [B, 256, 28, 28]
        conv4 = self.conv4(pool3)                           # shape: [B, 512, 28, 28]
        pool4 = self.pool4(conv4)                           # shape: [B, 512, 14, 14]
        conv5 = self.conv5(pool4)                           # shape: [B, 512, 14, 14]
        """ pool5 = self.pool5(conv5)                           # shape: [B, 512, 7, 7]
        pool5 = pool5.view(x.size(0), -1)                   # shape: [B, 512 * 7 * 7]
        x = self.classifier(pool5)                          # shape: [B, 1000]
        return x """
        return conv5

net= LabVGG16()
print(net)

LabVGG16(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (pool1): Sequential(
    (0): ReLU()
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (pool2): Sequential(
    (0): ReLU()
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (pool3): Sequential(
    (0): ReLU()
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode

In [8]:
# Load the state dictionary
state_dict = torch.load("./../model/LabVGG16_dict.pth", map_location=torch.device('cpu'))

In [15]:
count=0
for key in state_dict.keys():
    print("{}. {}   {}".format(count,key,state_dict[key].shape))
    count+=1

0. conv1.0.weight   torch.Size([64, 3, 3, 3])
1. conv1.0.bias   torch.Size([64])
2. conv1.1.weight   torch.Size([64])
3. conv1.1.bias   torch.Size([64])
4. conv1.1.running_mean   torch.Size([64])
5. conv1.1.running_var   torch.Size([64])
6. conv1.1.num_batches_tracked   torch.Size([])
7. conv1.3.weight   torch.Size([64, 64, 3, 3])
8. conv1.3.bias   torch.Size([64])
9. pool1.0.weight   torch.Size([64])
10. pool1.0.bias   torch.Size([64])
11. pool1.0.running_mean   torch.Size([64])
12. pool1.0.running_var   torch.Size([64])
13. pool1.0.num_batches_tracked   torch.Size([])
14. conv2.0.weight   torch.Size([128, 64, 3, 3])
15. conv2.0.bias   torch.Size([128])
16. conv2.1.weight   torch.Size([128])
17. conv2.1.bias   torch.Size([128])
18. conv2.1.running_mean   torch.Size([128])
19. conv2.1.running_var   torch.Size([128])
20. conv2.1.num_batches_tracked   torch.Size([])
21. conv2.3.weight   torch.Size([128, 128, 3, 3])
22. conv2.3.bias   torch.Size([128])
23. pool2.0.weight   torch.Size([128

In [24]:
first_x_keys = list(state_dict.keys())[:86]

extracted_dict = {key: state_dict[key] for key in first_x_keys if not any(["conv3.4" in key,"conv3.6" in key, "conv4.4" in key,"conv4.6" in key, "conv5.4" in key,"conv5.6" in key])}
count=0
for key in extracted_dict.keys():
    print("{}. {},      {}".format(count,key,state_dict[key].shape))
    count+=1

0. conv1.0.weight,      torch.Size([64, 3, 3, 3])
1. conv1.0.bias,      torch.Size([64])
2. conv1.1.weight,      torch.Size([64])
3. conv1.1.bias,      torch.Size([64])
4. conv1.1.running_mean,      torch.Size([64])
5. conv1.1.running_var,      torch.Size([64])
6. conv1.1.num_batches_tracked,      torch.Size([])
7. conv1.3.weight,      torch.Size([64, 64, 3, 3])
8. conv1.3.bias,      torch.Size([64])
9. pool1.0.weight,      torch.Size([64])
10. pool1.0.bias,      torch.Size([64])
11. pool1.0.running_mean,      torch.Size([64])
12. pool1.0.running_var,      torch.Size([64])
13. pool1.0.num_batches_tracked,      torch.Size([])
14. conv2.0.weight,      torch.Size([128, 64, 3, 3])
15. conv2.0.bias,      torch.Size([128])
16. conv2.1.weight,      torch.Size([128])
17. conv2.1.bias,      torch.Size([128])
18. conv2.1.running_mean,      torch.Size([128])
19. conv2.1.running_var,      torch.Size([128])
20. conv2.1.num_batches_tracked,      torch.Size([])
21. conv2.3.weight,      torch.Size([12

In [25]:
net.load_state_dict(extracted_dict, strict=False)

_IncompatibleKeys(missing_keys=['conv1.2.weight', 'conv1.2.bias', 'conv2.2.weight', 'conv2.2.bias', 'conv3.2.weight', 'conv3.2.bias', 'conv4.2.weight', 'conv4.2.bias', 'conv5.2.weight', 'conv5.2.bias'], unexpected_keys=['conv1.3.weight', 'conv1.3.bias', 'conv1.1.weight', 'conv1.1.bias', 'conv1.1.running_mean', 'conv1.1.running_var', 'conv1.1.num_batches_tracked', 'pool1.0.weight', 'pool1.0.bias', 'pool1.0.running_mean', 'pool1.0.running_var', 'pool1.0.num_batches_tracked', 'conv2.3.weight', 'conv2.3.bias', 'conv2.1.weight', 'conv2.1.bias', 'conv2.1.running_mean', 'conv2.1.running_var', 'conv2.1.num_batches_tracked', 'pool2.0.weight', 'pool2.0.bias', 'pool2.0.running_mean', 'pool2.0.running_var', 'pool2.0.num_batches_tracked', 'conv3.3.weight', 'conv3.3.bias', 'conv3.1.weight', 'conv3.1.bias', 'conv3.1.running_mean', 'conv3.1.running_var', 'conv3.1.num_batches_tracked', 'pool3.0.weight', 'pool3.0.bias', 'pool3.0.running_mean', 'pool3.0.running_var', 'pool3.0.num_batches_tracked', 'conv4

Already working: Load entire dictionary

In [8]:
# To load entire state dictionary
net.load_state_dict(torch.load("./../model/LabVGG16_dict.pth",map_location=torch.device('cpu')))

RuntimeError: Error(s) in loading state_dict for LabVGG16:
	Missing key(s) in state_dict: "conv1.2.weight", "conv1.2.bias", "conv2.2.weight", "conv2.2.bias". 
	Unexpected key(s) in state_dict: "conv3.0.weight", "conv3.0.bias", "conv3.1.weight", "conv3.1.bias", "conv3.1.running_mean", "conv3.1.running_var", "conv3.1.num_batches_tracked", "conv3.3.weight", "conv3.3.bias", "conv3.4.weight", "conv3.4.bias", "conv3.4.running_mean", "conv3.4.running_var", "conv3.4.num_batches_tracked", "conv3.6.weight", "conv3.6.bias", "pool3.0.weight", "pool3.0.bias", "pool3.0.running_mean", "pool3.0.running_var", "pool3.0.num_batches_tracked", "conv4.0.weight", "conv4.0.bias", "conv4.1.weight", "conv4.1.bias", "conv4.1.running_mean", "conv4.1.running_var", "conv4.1.num_batches_tracked", "conv4.3.weight", "conv4.3.bias", "conv4.4.weight", "conv4.4.bias", "conv4.4.running_mean", "conv4.4.running_var", "conv4.4.num_batches_tracked", "conv4.6.weight", "conv4.6.bias", "pool4.0.weight", "pool4.0.bias", "pool4.0.running_mean", "pool4.0.running_var", "pool4.0.num_batches_tracked", "conv5.0.weight", "conv5.0.bias", "conv5.1.weight", "conv5.1.bias", "conv5.1.running_mean", "conv5.1.running_var", "conv5.1.num_batches_tracked", "conv5.3.weight", "conv5.3.bias", "conv5.4.weight", "conv5.4.bias", "conv5.4.running_mean", "conv5.4.running_var", "conv5.4.num_batches_tracked", "conv5.6.weight", "conv5.6.bias", "pool5.0.weight", "pool5.0.bias", "pool5.0.running_mean", "pool5.0.running_var", "pool5.0.num_batches_tracked", "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias", "conv1.3.weight", "conv1.3.bias", "conv1.1.weight", "conv1.1.bias", "conv1.1.running_mean", "conv1.1.running_var", "conv1.1.num_batches_tracked", "pool1.0.weight", "pool1.0.bias", "pool1.0.running_mean", "pool1.0.running_var", "pool1.0.num_batches_tracked", "conv2.3.weight", "conv2.3.bias", "conv2.1.weight", "conv2.1.bias", "conv2.1.running_mean", "conv2.1.running_var", "conv2.1.num_batches_tracked", "pool2.0.weight", "pool2.0.bias", "pool2.0.running_mean", "pool2.0.running_var", "pool2.0.num_batches_tracked". 