In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable

In [2]:
def set_trainable(model, requires_grad):
	for param in model.parameters():
		param.requires_grad = requires_grad

In [None]:
class VGG16(nn.Module):
    """ Architecture defination for VGG """
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(3, 64, 3, same_padding=True),
                                  nn.Conv2d(64, 64, 3, same_padding=True),
                                  nn.MaxPool2d((2)))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, 3, same_padding=True),
                                  nn.Conv2d(128, 128, 3, same_padding=True),
                                  nn.MaxPool2d((2)))
        
        # we will not train cnv1 and conv2 layer 
        set_trainable(self.conv1, requires_grad=False)
        set_trainable(self.conv2, requires_grad=False)

        self.conv3 = nn.Sequential(nn.Conv2d(64, 128, 3, same_padding=True),
                          nn.Conv2d(128, 128, 3, same_padding=True),
                          nn.MaxPool2d((2)))
        
        self.conv4 = nn.Sequential(nn.Conv2d(128, 256, 3, same_padding=True),
                  nn.Conv2d(256, 256, 3, same_padding=True),
                  nn.MaxPool2d((2)))
        self.conv5 = nn.Sequential(nn.Conv2d(256, 512, 3, same_padding=True),
          nn.Conv2d(512, 512, 3, same_padding=True),
          nn.MaxPool2d((2)))
        
    def forward(self, im_data):
        x = self.conv1(im_data)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x
    
    def load_from_npz(self, params):
        d = self.state_dict()
        for name, val in d.items():
            i,j = int(name[4]), int(name[6]) + 1
            ptype = 'weights' if name[-1] == 't' else biases 
            key = 'conv{}_{}/{}:0'.format(i,j, ptype)
            param = torch.from_numpy(params[key])
            if ptype == 'weights':
                param  = param.permute(3, 2, 0, 1)
            val.copy_(param)