In [3]:
import torch
import torch.nn as nn 

In [4]:
class AlexNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        # Convolutional Layers
        self.conv1 = nn.Conv2d(in_channels, 96, kernel_size=11, stride=4, padding=2) #(96x55x55)
        self.conv2 = nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2) #(55x)
        self.conv3 = nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1)

        # Fully Connected Layers
        self.fc1 = nn.Linear(256 * 6 * 6, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, num_classes)

        self.flatten = nn.Flatten()
        self.maxpooling = nn.MaxPool2d(kernel_size=3, stride=2)
        self.norm = nn.LocalResponseNorm(size=5, k=2)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.maxpooling(self.norm(self.relu(self.conv1(x)))) #(96x27x27)
        x = self.maxpooling(self.norm(self.relu(self.conv2(x)))) #(256x13x13)
        x = self.relu(self.conv3(x)) #(384x13x13)
        x = self.relu(self.conv4(x)) #(384x13x13)
        x = self.maxpooling(self.relu(self.conv5(x))) #(256x6x6)
        x = self.flatten(x)  #(9216)
        x = self.dropout(self.relu(self.fc1(x))) #(4096)
        x = self.dropout(self.relu(self.fc2(x))) #(4096)
        return self.fc3(x) #(B, num_classes)

In [5]:
model = AlexNet(3, 100)
x = torch.randn(64, 3, 224, 224)
print(model(x).shape)

torch.Size([64, 100])


In [6]:
if __name__ == "__main__":
    model = AlexNet(in_channels=3, num_classes=100)
    print(model)

AlexNet(
  (conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
  (conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=9216, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=4096, bias=True)
  (fc3): Linear(in_features=4096, out_features=100, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (maxpooling): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (norm): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=2)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
)


In [7]:
from torchsummary import summary
model = AlexNet(3, 100)
summary(model,(3,227,227))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 56, 56]          34,944
              ReLU-2           [-1, 96, 56, 56]               0
 LocalResponseNorm-3           [-1, 96, 56, 56]               0
         MaxPool2d-4           [-1, 96, 27, 27]               0
            Conv2d-5          [-1, 256, 27, 27]         614,656
              ReLU-6          [-1, 256, 27, 27]               0
 LocalResponseNorm-7          [-1, 256, 27, 27]               0
         MaxPool2d-8          [-1, 256, 13, 13]               0
            Conv2d-9          [-1, 384, 13, 13]         885,120
             ReLU-10          [-1, 384, 13, 13]               0
           Conv2d-11          [-1, 384, 13, 13]       1,327,488
             ReLU-12          [-1, 384, 13, 13]               0
           Conv2d-13          [-1, 256, 13, 13]         884,992
             ReLU-14          [-1, 256,