In [1]:
import torch
import torch.nn as nn

In [2]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        
        self.conv1 = nn.Sequential(
                nn.Conv2d(
                    in_channels = 3, #输入通道数
                    out_channels = 96, #输出通道数
                    kernel_size = 11,
                    stride = 4,
                    padding = 2
                ),
                nn.ReLU(),
                nn.MaxPool2d(
                    kernel_size = 3,
                    stride = 2
                )
        )
        self.conv2 = nn.Sequential(
                nn.Conv2d(
                    in_channels = 96, #输入通道数
                    out_channels = 256, #输出通道数
                    kernel_size = 5,
                    padding = 2
                    ),
                nn.ReLU(),
                nn.MaxPool2d(
                    kernel_size = 3,
                    stride = 2
                )
        )
        self.conv3 = nn.Sequential(
                nn.Conv2d(
                    in_channels = 256, #输入通道数
                    out_channels = 384, #输出通道数
                    kernel_size = 3,
                    padding = 1
                ),
                nn.ReLU()
        )
        self.conv4 = nn.Sequential(
                nn.Conv2d(
                    in_channels = 384, #输入通道数
                    out_channels = 384, #输出通道数
                    kernel_size = 3,
                    padding = 1
                ),
                nn.ReLU()
            
        )
        self.conv5 = nn.Sequential(
                nn.Conv2d(
                    in_channels = 384, #输入通道数
                    out_channels = 256, #输出通道数
                    kernel_size = 3,
                    padding = 1
                ),
                nn.ReLU(),
                nn.MaxPool2d(
                    kernel_size = 3,
                    stride = 2
                )
        )
        
        self.classification = nn.Sequential(
                nn.Dropout(p = 0.5),
                nn.Linear(in_features = 6*6*256, out_features = 4096),
                nn.ReLU(),
                nn.Dropout(p = 0.5),
                nn.Linear(in_features = 4096, out_features = 4096),
                nn.ReLU(),
                nn.Linear(in_features = 4096, out_features = 1000)
                
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = x.view(x.size(0),256*6*6) # 把x从6x6x256转成256*6*6x1
        x = self.classification(x)
        return x
 

In [3]:
if __name__=='__main__':
    model = AlexNet();
    #print(model)
    
    input = torch.randn(1, 3, 224, 224)
    out1 = model(input)
    print(out1.shape)
    

torch.Size([1, 1000])
