In [20]:
# Conv2d(input_channels, output_channels, kernel_size, 
#        stride = 1, padding = 0, dilation = 1, groups = 1, bias = True)
# input_channels : input image의 채널 수. rgb이미지라면 3이 되겠다.
# output_channels : conv에 의해 생성된 channel의 수
# kernel_size : conv_kernel의 크기 (filter라고도 한다.)

# MaxPool2d(kernel_size, stride = None, padding = 0, dilation = 1)

import torch
import torch.nn as nn

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self._build_net()
        
    def _build_net(self):
        self.feature = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size = 11, stride = 4, padding = 2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            
            nn.Conv2d(96, 256, kernel_size = 5, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            
            nn.Conv2d(256, 384, kernel_size = 3, padding = 1),
            nn.ReLU(),
            
            nn.Conv2d(384, 384, kernel_size = 3, padding = 1),
            nn.ReLU(),
            
            nn.Conv2d(384, 256, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2)
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((13, 13))
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 13 * 13, 4096),
            nn.ReLU(),
            
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),

            nn.Linear(4096, 1000),
        )

    def forward(self, x):
        out = self.feature(x)
        out = self.avgpool(x)
        out = x.view(x.size(0), 256 * 13 * 13)
        out = self.classifier(x)
        
        return x