## AlexNet in PyTorch

Importing the necessary libraries

In [1]:
import torch
import torch.nn as nn
from torchsummary import summary

Defining the AlexNet class with its respective layers

In [2]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 96, 11, 4), # Conv2d-1: 96 kernel(11x11) + stride of 4
            nn.ReLU(),
            nn.LocalResponseNorm(5, 0.0001, 0.75, 2), #  Local Response Normalization-1: with k of 2
            nn.MaxPool2d(3, 2), # MaxPool2d-1: 1 kernel(3x3) + stride of 2
            nn.Conv2d(96, 256, 5, padding=2), # Conv2d-2: 256 kernel(5x5) + padding of 2
            nn.ReLU(),
            nn.LocalResponseNorm(5, 0.0001, 0.75, 2), #  Local Response Normalization-2: with k of 2
            nn.MaxPool2d(3, 2), # MaxPool2d-2: 1 kernel(3x3) + stride of 2
            nn.Conv2d(256, 384, 3, padding=1), # Conv2d-3: 384 kernel(3x3) + padding of 1
            nn.ReLU(),
            nn.Conv2d(384, 384, 3, padding=1), # Conv2d-4: 384 kernel(3x3) + padding of 1
            nn.ReLU(),
            nn.Conv2d(384, 256, 3, padding=1), # Conv2d-5: 256 kernel(3x3) + padding of 1
            nn.ReLU(),
            nn.MaxPool2d(3, 2), # MaxPool2d-3: 1 kernel(3x3) + stride of 2
        )
        self.mlp = nn.Sequential(
            nn.Dropout(), # Dropout-1: regularization
            nn.Linear(256*6*6, 4096), # Linear-1: input = 256(number of channels) * 6(height) * 6(width)
            nn.ReLU(),
            nn.Dropout(), # Dropout-2: regularization
            nn.Linear(4096, 4096), # Linear-2: input = 4096
            nn.ReLU(),
            nn.Linear(4096, num_classes) # Linear-3: num_classes = 1000 ImageNet classes
        )

        self._initialize_bias()
    
    def _initialize_bias(self): # initialize bias of the second, fourth, and fifth convolutional layers to 1 as in the original AlexNet paper.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        nn.init.constant_(self.layers[4].bias, 1)
        nn.init.constant_(self.layers[8].bias, 1)
        nn.init.constant_(self.layers[10].bias, 1)
    
    def forward(self, t): # t = input tensor
        t = self.layers(t) # pass through cnn layers
        t = t.view(t.size(0), -1) # flatten the tensor for mlp
        t = self.mlp(t) # passing through mlp layers
        return t

Creating an instance of the AlexNet model, then created a random input tensor with a batch size of 10 and shape 3x227x227, corresponding to 10 images of size 227x227 with 3 color channels. We pass this tensor through the model and print the output shape, which should be 10x1000, corresponding to the batch size and the number of classes.

In [3]:
alexnet = AlexNet()

input_tensor = torch.randn(10, 3, 227, 227)

output_tensor = alexnet(input_tensor)

print(output_tensor.shape)

torch.Size([10, 1000])


Using `torchsummary` to get a detailed view of the model architecture

In [4]:
summary(alexnet, (3, 227, 227))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 55, 55]          34,944
              ReLU-2           [-1, 96, 55, 55]               0
 LocalResponseNorm-3           [-1, 96, 55, 55]               0
         MaxPool2d-4           [-1, 96, 27, 27]               0
            Conv2d-5          [-1, 256, 27, 27]         614,656
              ReLU-6          [-1, 256, 27, 27]               0
 LocalResponseNorm-7          [-1, 256, 27, 27]               0
         MaxPool2d-8          [-1, 256, 13, 13]               0
            Conv2d-9          [-1, 384, 13, 13]         885,120
             ReLU-10          [-1, 384, 13, 13]               0
           Conv2d-11          [-1, 384, 13, 13]       1,327,488
             ReLU-12          [-1, 384, 13, 13]               0
           Conv2d-13          [-1, 256, 13, 13]         884,992
             ReLU-14          [-1, 256,