In [None]:
# Conv2d(input_channels, output_channels, kernel_size, 
#        stride = 1, padding = 0, dilation = 1, groups = 1, bias = True)
# input_channels : input image의 채널 수. rgb이미지라면 3이 되겠다.
# output_channels : conv에 의해 생성된 channel의 수
# kernel_size : conv_kernel의 크기 (filter라고도 한다.)

# MaxPool2d(kernel_size, stride = None, padding = 0, dilation = 1)

import torch.nn as nn

class AlexNet(nn.Module, num_classes = 10):
    def __init__(self):
        super(AlexNet, self).__init__()
        self._build_net()
        
    def _build_net(self):
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 48, kernel_size = 11, stride = 4, padding = 2),
            nn.ReLU(inplace = True),
            nn.MaxPool2d( )
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(48, 48, kernel_size = 5, stride = 4, padding = 2),
            nn.ReLU(inplace = True),
            nn.MaxPool2d()
        )
        
        self.layer3 = nn.Sequential(
            nn.Conv2d(48, 256, kernel_size = 3, stride = 4, padding = 2),
            nn.ReLU(inplace = True),
            nn.MaxPool2d()
        )
        
        self.layer4 = nn.Sequential(
            nn.Conv2d(256, 192, kernel_size = 3, stride = 4, padding = 2),
            nn.ReLU(inplace = True),
            nn.MaxPool2d()
        )
        
        self.layer5 = nn.Sequential(
            nn.Conv2d(192, 192, kernel_size = 3, stride = 4, padding = 2),
            nn.ReLU(inplace = True),
            nn.MaxPool2d()
        )
        
        self.lin1 = nn.Linear(256 * 13 * 13, 4096, bias = True)
        nn.init.xavier_uniform(self.lin1)
        self.layer6 = nn.Sequential(
            self.lin1,
            nn.ReLU(inplace = True),
            nn.Dropout(0.5)
        )
        
        self.lin2 = nn.Linear(4096, 4096, bias = True)
        nn.init.xavier_uniform(self.lin2)
        self.layer7 = nn.Sequential(
            self.lin2,
            nn.ReLU(inplace = True),
            nn.Dropout(0.5)
        )
        
        self.layer8 = nn.Sequential(
            nn.Linear(4096, num_classes),
            nn.Softmax()
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(x)
        out = self.layer3(x)
        out = self.layer4(x)
        out = self.layer5(x)
        