In [2]:
import torch
import math
from torch import nn, autograd, stack, cat
from torch.autograd import Variable
import torch.utils.model_zoo as model_zoo
vgg_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

In [26]:
class partialVGG(nn.Module):

    def __init__(self, target_layer):
        super(partialVGG, self).__init__()
        self.target_layer = target_layer
        self.features = self.make_layers()
        self._initialize_weights()

    def forward(self, x):
        for layer in self.features:
            x = layer(x)
        return x
    
    def make_layers(self, batch_norm=False):
        layers = []
        in_channels = 3
        for v in [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
            if(len(layers) >= self.target_layer+1):
                return nn.ModuleList(layers)
        return nn.ModuleList(layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
                
def get_partialVGG(target_layer):
    model = partialVGG(target_layer)
    
    pretrained_dict = model_zoo.load_url(vgg_urls['vgg16'])

    model_dict = model.state_dict()

    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict) 
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    return model , VGG_features_sizes(target_layer)

In [27]:
def VGG_features_sizes(target_layer):
    im_size = 224
    filters = 64
    num_features = []
    layers = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
    for layer in layers:
        if layer == 'M':
            im_size //= 2
            num_features.append(im_size * im_size * filters)
            #print("{}\t{} : {} features".format(len(num_features)-1,'MAXP',num_features[len(num_features)-1]))
        else:
            filters = layer
            num_features.append(im_size * im_size * filters) # conv
            #print("{}\t{} : {} features".format(len(num_features)-1,'CONV',num_features[len(num_features)-1]))
            num_features.append(im_size * im_size * filters) # relu
            #print("{}\t{} : {} features".format(len(num_features)-1,'RELU',num_features[len(num_features)-1]))
    return num_features[target_layer]

In [28]:
#VGG_features_sizes()

In [29]:
#vgg, features = get_partialVGG(23)

0	CONV : 3211264 features
1	RELU : 3211264 features
2	CONV : 3211264 features
3	RELU : 3211264 features
4	MAXP : 802816 features
5	CONV : 1605632 features
6	RELU : 1605632 features
7	CONV : 1605632 features
8	RELU : 1605632 features
9	MAXP : 401408 features
10	CONV : 802816 features
11	RELU : 802816 features
12	CONV : 802816 features
13	RELU : 802816 features
14	CONV : 802816 features
15	RELU : 802816 features
16	MAXP : 200704 features
17	CONV : 401408 features
18	RELU : 401408 features
19	CONV : 401408 features
20	RELU : 401408 features
21	CONV : 401408 features
22	RELU : 401408 features
23	MAXP : 100352 features
24	CONV : 100352 features
25	RELU : 100352 features
26	CONV : 100352 features
27	RELU : 100352 features
28	CONV : 100352 features
29	RELU : 100352 features
30	MAXP : 25088 features
