In [None]:
# import necessary libraries
import torch
import torch.nn as nn

In [2]:
# We use MaxPooling after 1,2,5 blocks,
# ReLU after every conv or fc layer,
# Dropout with p=0.5 before 1,2 fc layers,
# Response-normalization layers follow the first and second convolutional layers
# To see more, read the paper https://proceedings.neurips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf
class AlexNet(nn.Module):
    def __init__(self,in_channels,num_classes):
        super(AlexNet,self).__init__()
        self.net = nn.Sequential(
            # first block
            nn.Conv2d(in_channels,96,kernel_size=11,stride=4,padding=2),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=(2,2),stride=2),
            # second block
            nn.Conv2d(96,256,kernel_size=5,stride=1,padding=2),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=(2,2),stride=2),
            # third block
            nn.Conv2d(256,384,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            # fourth block
            nn.Conv2d(384,384,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            # fifth block
            nn.Conv2d(384,256,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2),stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(6*6*256,4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,num_classes),
            nn.ReLU(),
            nn.Linear(num_classes,num_classes)
        )

    def forward(self,x):
        x = self.net(x)
        x = x.reshape(x.shape[0],-1)
        x = self.classifier(x)
        return x

In [3]:
# create AlexNet
def alexnet(in_channels=3,num_classes=1000):
    return AlexNet(in_channels,num_classes)

In [4]:
# test the net architecture
def test():
    x = torch.rand(2,3,224,224)
    net = alexnet()
    y = net(x)
    print(y.shape)

In [5]:
test()

torch.Size([2, 1000])
