In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.modules

In [11]:
class AlexNet(nn.Module) : 
    def __init__(self, number_of_classes):
        super(AlexNet, self).__init__()
        self.convnet = nn.Sequential(

            nn.Conv2d(3, 96, 11, 4, 0),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, k=2),
            nn.MaxPool2d((3, 3), 2),

            nn.Conv2d(96, 256, 5, 1, 2),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, k=2), 
            nn.MaxPool2d((3, 3), 2), 

            nn.Conv2d(256, 384, 3, 1, 1),
            nn.ReLU(), 

            nn.Conv2d(384, 384, 3, 1, 1),
            nn.ReLU(), 

            nn.Conv2d(384, 256, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d((3, 3), 2)
        )

        self.fcnet = nn.Sequential(

            nn.Dropout(0.5), 
            nn.Linear(6*6*256, 4096), ## 이전에 maxpooling을 했기 대문에 6x6으로 바뀜
            nn.ReLU(), 

            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(), 

            nn.Linear(4096, number_of_classes)
        )
    
    def forward(self, x):
        x = self.convnet(x)
        x = torch.flatten(x, 1)
        x = self.fcnet(x)
        return x

alexnet = AlexNet(1000)
print(alexnet)

AlexNet(
  (convnet): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): ReLU()
    (2): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=2)
    (3): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): ReLU()
    (6): LocalResponseNorm(5, alpha=0.0001, beta=0.75, k=2)
    (7): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fcnet): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=Tru