In [1]:
import torch.nn as nn
import torch
import torch.utils.model_zoo as model_zoo
from torch.hub import load_state_dict_from_url
import re
from collections import OrderedDict
from utils import save_net,load_net
from itertools import islice 
model_url = "https://download.pytorch.org/models/vgg16-397923af.pth"

In [2]:
def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

In [3]:
class CSRNet(nn.Module):
    def __init__(self, load_weights=False):
        super(CSRNet, self).__init__()
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat  = [512, 512, 512,256,128,64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        if not load_weights:
            self._initialize_weights()
            mod =  torch.hub.load_state_dict_from_url(model_url)
        
            print ("lengt front end",len(self.frontend.state_dict().items()))
            print("length model vgg",len(mod.items())) 
                        
            i = 1
            for k, v in mod.items():
                print ("type k",type(k))
                print ("type v",type(v))
                print ("k=", k)
                #k = re.sub('group\d+\.', '', k)
                k = re.sub('(features.|classfier.)', '', k)
                print ("k.sub =",k)
                v = v.data
                print ("type v.data",type(v))
               
                if i in range(len(self.frontend.state_dict().items())):

                    print ("tensor v",v)
                    print ("shape tensor v",v.shape)
                    print ("shape front end",self.frontend.state_dict()[k].shape)
                    self.frontend.state_dict()[k].copy_(v) 
                    print ("tensor %d",k,self.frontend.state_dict()[k])
                    
                i+=1
                    
            print ("lengt front end",len(self.frontend.state_dict().items()))
            print("model items",self.frontend.state_dict().items())
           
    def forward(self,x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
   

In [4]:
model = CSRNet()

lengt front end 20
length model vgg 32
type k <class 'str'>
type v <class 'torch.Tensor'>
k= features.0.weight
k.sub = 0.weight
type v.data <class 'torch.Tensor'>
tensor v tensor([[[[-5.5373e-01,  1.4270e-01,  5.2896e-01],
          [-5.8312e-01,  3.5655e-01,  7.6566e-01],
          [-6.9022e-01, -4.8019e-02,  4.8409e-01]],

         [[ 1.7548e-01,  9.8630e-03, -8.1413e-02],
          [ 4.4089e-02, -7.0323e-02, -2.6035e-01],
          [ 1.3239e-01, -1.7279e-01, -1.3226e-01]],

         [[ 3.1303e-01, -1.6591e-01, -4.2752e-01],
          [ 4.7519e-01, -8.2677e-02, -4.8700e-01],
          [ 6.3203e-01,  1.9308e-02, -2.7753e-01]]],


        [[[ 2.3254e-01,  1.2666e-01,  1.8605e-01],
          [-4.2805e-01, -2.4349e-01,  2.4628e-01],
          [-2.5066e-01,  1.4177e-01, -5.4864e-03]],

         [[-1.4076e-01, -2.1903e-01,  1.5041e-01],
          [-8.4127e-01, -3.5176e-01,  5.6398e-01],
          [-2.4194e-01,  5.1928e-01,  5.3915e-01]],

         [[-3.1432e-01, -3.7048e-01, -1.3094e-01],


tensor %d 14.weight tensor([[[[-0.0155,  0.0034, -0.0208],
          [-0.0154,  0.0106, -0.0012],
          [-0.0046,  0.0178, -0.0036]],

         [[-0.0206,  0.0066, -0.0181],
          [ 0.0116, -0.0038, -0.0226],
          [ 0.0356, -0.0056, -0.0036]],

         [[ 0.0001, -0.0145,  0.0259],
          [ 0.0346,  0.0227,  0.0142],
          [ 0.0486,  0.0350,  0.0236]],

         ...,

         [[ 0.0004, -0.0091, -0.0099],
          [-0.0138,  0.0159,  0.0132],
          [-0.0383,  0.0032,  0.0217]],

         [[-0.0250, -0.0385, -0.0223],
          [-0.0498, -0.0494, -0.0421],
          [-0.0433, -0.0373, -0.0203]],

         [[ 0.0202,  0.0256,  0.0071],
          [ 0.0002,  0.0084,  0.0243],
          [ 0.0096,  0.0068,  0.0299]]],


        [[[ 0.0035, -0.0297, -0.0180],
          [-0.0055, -0.0111, -0.0139],
          [-0.0057, -0.0050,  0.0180]],

         [[ 0.0206,  0.0164,  0.0047],
          [ 0.0073, -0.0121, -0.0166],
          [-0.0284, -0.0446, -0.0417]],

         [[

In [5]:
x = torch.rand((1,3,255,255))

In [6]:
model(x).shape

torch.Size([1, 1, 31, 31])