# AlexNet Extraction Algorithm

Author: YinTaiChen

In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.autograd import Variable

from PIL import Image

In [2]:
model = torchvision.models.alexnet(pretrained=True)

In [3]:
model_filters = [x for x in model.state_dict().items()]

In [4]:
conv_1 = nn.Conv2d(3, 64, 11)
conv_2 = nn.Conv2d(64, 192, 5)
conv_3 = nn.Conv2d(192, 384, 3)
conv_4 = nn.Conv2d(384, 256, 3)
conv_5 = nn.Conv2d(256, 256, 3)

In [5]:
names = ["conv1", "conv2", "conv3", "conv4", "conv5"]

In [6]:
channels = [64, 192, 384, 256, 256]

In [7]:
submodel_1 = nn.Sequential(conv_1, nn.ReLU(inplace=True))

In [8]:
submodel_2 = nn.Sequential(conv_1,
    nn.ReLU(inplace=True),
    nn.MaxPool2d(3, 2, 1),                       
    conv_2)

In [9]:
submodel_3 = nn.Sequential(conv_1,
    nn.ReLU(inplace=True),                       
    nn.MaxPool2d(3, 2, 1),
    conv_2,
    nn.ReLU(inplace=True),                       
    nn.MaxPool2d(3, 2, 1),
    conv_3,
    nn.ReLU(inplace=True))

In [11]:
submodel_4 = nn.Sequential(conv_1,
    nn.ReLU(inplace=True),                       
    nn.MaxPool2d(3, 2, 1),
    conv_2,
    nn.ReLU(inplace=True),                       
    nn.MaxPool2d(3, 2, 1),  
    conv_3,
    nn.ReLU(inplace=True),                      
    conv_4,
    nn.ReLU(inplace=True))

In [12]:
submodel_5 = nn.Sequential(conv_1,
    nn.ReLU(inplace=True),                       
    nn.MaxPool2d(3, 2, 1),
    conv_2,
    nn.ReLU(inplace=True),                      
    nn.MaxPool2d(3, 2, 1),  
    conv_3,
    nn.ReLU(inplace=True),                      
    conv_4,
    nn.ReLU(inplace=True),
    conv_5,
    nn.ReLU(inplace=True))

In [13]:
submodels = [submodel_1, submodel_2, submodel_3, submodel_4, submodel_5]

In [14]:
conv1_filters = [x for x in submodel_1.state_dict().items()]
conv2_filters = [x for x in submodel_2.state_dict().items()]
conv3_filters = [x for x in submodel_3.state_dict().items()]
conv4_filters = [x for x in submodel_4.state_dict().items()]
conv5_filters = [x for x in submodel_5.state_dict().items()]

In [15]:
d_1 = {}
for i in range(2):
    d_1[conv1_filters[i][0]] = model_filters[i][1]

d_2 = {}
for i in range(4):
    d_2[conv2_filters[i][0]] = model_filters[i][1]

d_3 = {}
for i in range(6):
    d_3[conv3_filters[i][0]] = model_filters[i][1]

d_4 = {}
for i in range(8):
    d_4[conv4_filters[i][0]] = model_filters[i][1]
    
d_5 = {}
for i in range(10):
    d_5[conv5_filters[i][0]] = model_filters[i][1]

In [16]:
submodel_1.load_state_dict(d_1)
submodel_2.load_state_dict(d_2)
submodel_3.load_state_dict(d_3)
submodel_4.load_state_dict(d_4)
submodel_5.load_state_dict(d_5)

In [17]:
img = Image.open("Lenna.png")

In [18]:
loader = transforms.ToTensor()

In [19]:
unloader = transforms.ToPILImage()

In [20]:
input = loader(img)
input = Variable(input.unsqueeze(0))

In [21]:
for i, submodel in enumerate(submodels):
    output = submodel(input)
    feature_maps = [x for x in output.data[0]]
    ch = channels[i]
    for j in range(ch-2):
        image_tensor = torch.stack(feature_maps[j:j+3], 0)
        to_save = unloader(image_tensor)
        to_save.save('feature_maps/feature_map_'+names[i]+'_'+str(j)+'.jpg')