In [1]:
import torch 
import torch.nn as nn

In [9]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=0)
        self.c2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2)
        self.c3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.c4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.c5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)

        self.fc1 = nn.Linear(6*6*256, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 1000)

        self.localnorm = nn.LocalResponseNorm(size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)

        self.init_weight()

    def forward(self, x):
        # x shape: [batch, 3, 227, 227]
        x = self.relu(self.c1(x))
        # x shape: [batch, 96, 55, 55]
        x = self.maxpool(self.localnorm(x))
        # x shape: [batch, 96, 27, 27]
        x = self.relu(self.c2(x))
        # x shape: [batch, 256, 27, 27]
        x = self.maxpool(self.localnorm(x))
        # x shape: [batch, 256, 13, 13]
        x = self.relu(self.c3(x))
        # x shape: [batch, 384, 13, 13]
        x = self.relu(self.c4(x))
        # x shape: [batch, 384, 13, 13]
        x = self.maxpool(self.relu(self.c5(x)))
        # x shape: [batch, 256, 6, 6]
        x = torch.flatten(x,1)
        # x shape: [batch, 256*6*6]
        x = self.relu(self.dropout(self.fc1(x)))
        # x shape: [batch, 4096]
        x = self.relu(self.dropout(self.fc2(x)))
        # x shape: [batch, 4096]
        x = self.fc3(x)
        # x shape: [batch, classes]
        return x
    
    def init_weight(self):
        bias = [1,3,4,5,6,7]
        for i,layer in enumerate(self.modules()):
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                if i in bias:
                    nn.init.constant_(layer.bias, 1)
                else:
                    nn.init.constant_(layer.bias, 0)
                
                nn.init.normal_(layer.weight, mean=0, std=0.01)

In [11]:
net = AlexNet()
print(net(torch.rand(1,3,227,227)).shape)

torch.Size([1, 96, 55, 55])
torch.Size([1, 1000])


In [12]:
print(net)

AlexNet(
  (c1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
  (c2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (c3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (c4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (c5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=9216, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=4096, bias=True)
  (fc3): Linear(in_features=4096, out_features=1000, bias=True)
  (localnorm): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=1.0)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU(inplace=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
