In [10]:
import torch
import torch.nn as nn

class AlexNet(nn.Module):
    
    def __init__(self):
        super(AlexNet, self).__init__()
        
        self.conv_unit = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4,  padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
            nn.Conv2d(48, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
            nn.Conv2d(128, 192, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        )
        
        self.fc_unit = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Linear(2048, 5)
        )
        
        tmp = torch.randn(2, 3, 224, 224)
        out = self.conv_unit(tmp)
        
        print('conv out:', out.shape)
        
        self._initialize_weights()
        
        
    def forward(self, x):
        batchsz = x.size(0)
        x = self.conv_unit(x)
        x = x.view(batchsz, 128 * 6 * 6)
        logits = self.fc_unit(x)
        
        return logits
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

if __name__ == "__main__":
    tmp = torch.randn(2, 3, 224, 224)
    net = AlexNet()
    out = net(tmp)
    print('alex out:', out.shape)

conv out: torch.Size([2, 128, 6, 6])
alex out: torch.Size([2, 5])
