In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = torch.tanh(self.conv1(x))
        x = self.pool1(x)
        x = torch.tanh(self.conv2(x))
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = self.fc3(x)
        return x


In [2]:
model = LeNet5()
x = torch.randn(4, 1, 32, 32)
y = model(x)
print('output shape:', y.shape)

output shape: torch.Size([4, 10])


In [3]:
def shape_trace(model, x):
    print("input:", x.shape)
    x = model.conv1(x); print("after conv1:", x.shape)
    x = torch.tanh(x);  print("after tanh1:", x.shape)
    x = model.pool1(x); print("after pool1:", x.shape)
    x = model.conv2(x); print("after conv2:", x.shape)
    x = torch.tanh(x);  print("after tanh2:", x.shape)
    x = model.pool2(x); print("after pool2:", x.shape)
    x = torch.flatten(x, 1); print("after flatten:", x.shape)
    x = model.fc1(x); print("after fc1:", x.shape)
    x = torch.tanh(x); print("after tanh3:", x.shape)
    x = model.fc2(x); print("after fc2:", x.shape)
    x = torch.tanh(x); print("after tanh4:", x.shape)
    x = model.fc3(x); print("after fc3:", x.shape)
    return x

model = LeNet5()
x = torch.randn(2, 1, 32, 32)
_ = shape_trace(model, x)


input: torch.Size([2, 1, 32, 32])
after conv1: torch.Size([2, 6, 28, 28])
after tanh1: torch.Size([2, 6, 28, 28])
after pool1: torch.Size([2, 6, 14, 14])
after conv2: torch.Size([2, 16, 10, 10])
after tanh2: torch.Size([2, 16, 10, 10])
after pool2: torch.Size([2, 16, 5, 5])
after flatten: torch.Size([2, 400])
after fc1: torch.Size([2, 120])
after tanh3: torch.Size([2, 120])
after fc2: torch.Size([2, 84])
after tanh4: torch.Size([2, 84])
after fc3: torch.Size([2, 10])
