In [7]:
from torchvision import models
from torchsummary import summary
import torch
import torch.nn as nn


In [8]:
vgg = models.vgg16()
summary(vgg,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256,

In [9]:
class VGG(nn.Module):
    def __init__(self, num_classes=1000):
        super(VGG, self).__init__()
        self.features = self._make_layers([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'])
        self.classifier = nn.Sequential(
            nn.Linear(in_features = 512 * 7 * 7, out_features = 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
    def forward(self, x):
        x = self.features(x)
        # x = x.view(x.size(0), -1)
        x = torch.flatten(x,1)
        x = self.classifier(x)
        return x
    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU()]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


In [10]:
vgg_customized = VGG(num_classes=1000)
# print(vgg_model)
summary(vgg_customized,(3,224,224))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256,

In [12]:
vgg_customized.features[0].weight = vgg_customized.features[2].weight
# if vgg_customized.features[0].weight == vgg_customized.features[2].weight:
#     print("Weight is Shifted.")
print(vgg_customized.features[0].weight)

Parameter containing:
tensor([[[[ 9.0325e-03,  1.1660e-02, -2.1144e-02],
          [ 3.8024e-03, -2.1520e-02,  1.1370e-02],
          [-3.3302e-03,  2.9250e-02, -1.3387e-02]],

         [[ 2.5487e-03, -3.0706e-02, -4.2310e-03],
          [ 3.1586e-02,  1.8315e-02,  1.2440e-03],
          [-1.0033e-02,  1.2105e-02, -8.7955e-03]],

         [[-2.5549e-02,  2.8382e-02,  2.6996e-03],
          [-3.0011e-03, -2.2791e-02,  1.9533e-02],
          [-1.9263e-02, -2.7983e-02,  2.9161e-02]],

         ...,

         [[ 8.1424e-03, -3.4856e-02,  3.1137e-03],
          [ 2.2051e-02, -4.6729e-03, -6.6014e-03],
          [ 2.1069e-02, -2.2440e-02,  3.3568e-02]],

         [[-3.7910e-02, -1.2924e-02,  9.1662e-03],
          [-2.5973e-02, -2.8424e-02,  6.2381e-03],
          [-2.7338e-02,  3.9031e-02, -3.1442e-03]],

         [[-3.3275e-02, -3.5623e-02, -3.5867e-02],
          [ 3.9003e-02,  2.4669e-02, -3.0423e-02],
          [ 7.8718e-03, -2.5453e-02, -3.6895e-03]]],


        [[[-1.5978e-02, -3.1128