# LeNet

![](img/lenet.svg)

![](img/lenet.png)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6 , kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        
        # fc = Full Connection
        self.fc1 = nn.Linear(in_features=16*5*5, out_features=120)
        self.fc2 = nn.Linear(in_features=120   , out_features=84)
        self.fc3 = nn.Linear(in_features=86    , out_features=10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, (2, 2))
        
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2) # equiv to (2, 2) as above
        
        x = x.view(-1, self.count_flatten_features(x)) # flatten
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
    
    def count_flatten_features(self, X):
        # https://discuss.pytorch.org/t/understand-nn-module/8416/5
        # 去除掉 batch 维,例如25x3x32x32 -> 3x32x32
        
        # .shape 是 .size() 的别名
        # https://github.com/pytorch/pytorch/issues/5544
        size = X.size()[1:]
        count_features = 1
        for features in size:
            count_features *= features
        return count_features
    

net = LeNet()
print(net)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=86, out_features=10, bias=True)
)


In [3]:
params = list(net.parameters()) # 模型的可学习参数
print('params_len =', len(params))
print('param0_size =', params[0].size())

params_len = 10
param0_size = torch.Size([6, 1, 5, 5])
