In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Definition of the LeNet Image

class LeNetSeq(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(1,6,5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(6,16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Flatten(),

            # Fully connected layer
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84, 10),
            nn.ReLU(),
        )




LeNetModell = LeNetSeq()
print(LeNetModell)
        




LeNetSeq(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=400, out_features=120, bias=True)
  (8): ReLU()
  (9): Linear(in_features=120, out_features=84, bias=True)
  (10): ReLU()
  (11): Linear(in_features=84, out_features=10, bias=True)
  (12): ReLU()
)


In [9]:
from torchinfo import summary

summary(LeNetModell, input_size = (16,1,32,32))

Layer (type:depth-idx)                   Output Shape              Param #
LeNetSeq                                 [16, 10]                  --
├─Conv2d: 1-1                            [16, 6, 28, 28]           156
├─ReLU: 1-2                              [16, 6, 28, 28]           --
├─MaxPool2d: 1-3                         [16, 6, 14, 14]           --
├─Conv2d: 1-4                            [16, 16, 10, 10]          2,416
├─ReLU: 1-5                              [16, 16, 10, 10]          --
├─MaxPool2d: 1-6                         [16, 16, 5, 5]            --
├─Flatten: 1-7                           [16, 400]                 --
├─Linear: 1-8                            [16, 120]                 48,120
├─ReLU: 1-9                              [16, 120]                 --
├─Linear: 1-10                           [16, 84]                  10,164
├─ReLU: 1-11                             [16, 84]                  --
├─Linear: 1-12                           [16, 10]                  850
├─