In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class ShapeTracerCNN(nn.Module):
    def __init__(self):
        super(ShapeTracerCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
            print("Input:", x.shape)
            x = self.pool(F.relu(self.conv1(x)))
            print("After conv1 + pool:", x.shape)
            x = self.pool(F.relu(self.conv2(x)))
            print("After conv2 + pool:", x.shape)
            x = x.view(-1, 32 * 7 * 7)
            print("After flatten:", x.shape)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return x

In [7]:
model = ShapeTracerCNN()

dummy_input = torch.randn(1, 1, 28, 28)
output = model(dummy_input)

Input: torch.Size([1, 1, 28, 28])
After conv1 + pool: torch.Size([1, 16, 14, 14])
After conv2 + pool: torch.Size([1, 32, 7, 7])
After flatten: torch.Size([1, 1568])


In [8]:
pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
Note: you may need to restart the kernel to use updated packages.


In [9]:
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

Input: torch.Size([2, 1, 28, 28])
After conv1 + pool: torch.Size([2, 16, 14, 14])
After conv2 + pool: torch.Size([2, 32, 7, 7])
After flatten: torch.Size([2, 1568])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 28, 28]             160
         MaxPool2d-2           [-1, 16, 14, 14]               0
            Conv2d-3           [-1, 32, 14, 14]           4,640
         MaxPool2d-4             [-1, 32, 7, 7]               0
            Linear-5                  [-1, 128]         200,832
            Linear-6                   [-1, 10]           1,290
Total params: 206,922
Trainable params: 206,922
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.18
Params size (MB): 0.79
Estimated Total Size (MB): 0.97
----------------------------------------------------------------
