# AlexNet Modification

Author: YinTaiChen

In [1]:
import torch.nn as nn
import torchvision.models as models

## Source code of AlexNet

In [2]:
class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

## Override AlexNet

In [3]:
class myAlexNet(AlexNet):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
        )
    
    def forward(self, x):
        x = self.features(x)
        return x

## Download pre-trained AlexNet

In [4]:
alexnet = models.alexnet(pretrained=True)

## Get pre-trained parameters in AlexNet

In [5]:
parameters = alexnet.state_dict()

## Index of pre-trained parameters (as a dictionary)

In [6]:
for p in parameters:
    print(p)

features.0.weight
features.0.bias
features.3.weight
features.3.bias
features.6.weight
features.6.bias
features.8.weight
features.8.bias
features.10.weight
features.10.bias
classifier.1.weight
classifier.1.bias
classifier.4.weight
classifier.4.bias
classifier.6.weight
classifier.6.bias


## Check the first two items in parameters

In [7]:
parameters['features.0.weight']


(0 ,0 ,.,.) = 
  0.1186  0.0941  0.0954  ...   0.0558  0.0216  0.0500
  0.0749  0.0389  0.0530  ...   0.0257 -0.0113  0.0042
  0.0754  0.0388  0.0549  ...   0.0436  0.0102  0.0133
           ...             ⋱             ...          
  0.0932  0.1037  0.0675  ...  -0.2028 -0.1284 -0.1122
  0.0435  0.0649  0.0362  ...  -0.2025 -0.1138 -0.1072
  0.0474  0.0625  0.0248  ...  -0.1184 -0.0956 -0.0839

(0 ,1 ,.,.) = 
 -0.0726 -0.0580 -0.0807  ...  -0.0006 -0.0253  0.0255
 -0.0690 -0.0676 -0.0764  ...  -0.0040 -0.0304  0.0105
 -0.0995 -0.0856 -0.1052  ...  -0.0266 -0.0228  0.0066
           ...             ⋱             ...          
 -0.1512 -0.0887 -0.0967  ...   0.3085  0.1810  0.0843
 -0.1431 -0.0757 -0.0722  ...   0.2042  0.1645  0.0952
 -0.0859 -0.0401 -0.0515  ...   0.1635  0.1482  0.1020

(0 ,2 ,.,.) = 
 -0.0236 -0.0021 -0.0278  ...   0.0399 -0.0071  0.0322
  0.0003  0.0225  0.0089  ...   0.0188 -0.0142  0.0183
  0.0054  0.0294  0.0003  ...   0.0121 -0.0025  0.0084
           ...   

In [8]:
parameters['features.0.bias']


-0.9705
-2.8070
-0.0371
-0.0795
-0.1159
 0.0252
-0.0752
-1.4181
 1.6454
-0.0990
-0.0161
-0.1282
-0.0658
-0.0345
-0.0743
-1.2977
-0.0505
 0.0121
-0.1013
-1.1887
-0.1380
-0.0492
-0.0789
-0.0405
-0.0958
-0.0705
-1.9374
-0.0850
-0.1388
-0.1968
-0.1279
-2.0095
-0.0476
-0.0604
-0.0351
-0.3843
-2.7823
 0.6605
-0.1655
-2.1293
 0.0543
-0.0274
-0.1703
-0.0593
-0.4215
-1.9394
-1.2094
 0.0153
-0.1081
-0.0248
-0.1503
-1.8516
-0.0928
-0.0177
-0.0700
-0.0582
-0.0630
-0.0721
-1.2678
-0.1176
-0.0441
-0.3259
 0.0507
-0.0146
[torch.FloatTensor of size 64]

## Create an instance of overridden AlexNet

In [9]:
mynet = myAlexNet()

## Its state_dict

In [10]:
for p in mynet.state_dict():
    print(p)

features.0.weight
features.0.bias


## Get part of the state_dict of pretrained AlexNet

In [11]:
pretrained_dict = {
    k: v for k, v in parameters.items() if k in mynet.state_dict()
}

## Load it to the overridden AlexNet

In [12]:
mynet.load_state_dict(pretrained_dict)

In [13]:
mynet.state_dict()

OrderedDict([('features.0.weight', 
              (0 ,0 ,.,.) = 
                0.1186  0.0941  0.0954  ...   0.0558  0.0216  0.0500
                0.0749  0.0389  0.0530  ...   0.0257 -0.0113  0.0042
                0.0754  0.0388  0.0549  ...   0.0436  0.0102  0.0133
                         ...             ⋱             ...          
                0.0932  0.1037  0.0675  ...  -0.2028 -0.1284 -0.1122
                0.0435  0.0649  0.0362  ...  -0.2025 -0.1138 -0.1072
                0.0474  0.0625  0.0248  ...  -0.1184 -0.0956 -0.0839
              
              (0 ,1 ,.,.) = 
               -0.0726 -0.0580 -0.0807  ...  -0.0006 -0.0253  0.0255
               -0.0690 -0.0676 -0.0764  ...  -0.0040 -0.0304  0.0105
               -0.0995 -0.0856 -0.1052  ...  -0.0266 -0.0228  0.0066
                         ...             ⋱             ...          
               -0.1512 -0.0887 -0.0967  ...   0.3085  0.1810  0.0843
               -0.1431 -0.0757 -0.0722  ...   0.2042  0.1645  0