In [1]:
import torch
from torch import nn

## Advantages of ConvLayer

In [22]:
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.l1 = nn.Linear(8*8*1, 64)
        self.l2 = nn.Linear(64, 10)
    
    def forward(self, x):
        return self.l2(self.l1(x))

In [23]:
class ConvModel(nn.Module):
    def __init__(self):
        super(ConvModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=4,
                               kernel_size=2,
                               stride=2)
        self.l1 = nn.Linear(64, 10)
        
    def forward(self, x):
        return self.l1(self.conv1(x).view(-1, 64))

In [24]:
model1 = SimpleMLP()

In [25]:
model2 = ConvModel()

In [26]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [27]:
print ("Fully-connected model parameters:", count_parameters(model1))
print ("Convolutional model parameters:", count_parameters(model2))

Fully-connected model parameters: 4810
Convolutional model parameters: 670


## Going deeper

In [28]:
class ConvModel(nn.Module):
    def __init__(self):
        super(ConvModel, self).__init__()
        # 8*8*1
        self.conv1 = nn.Conv2d(1, 4, 2, stride=1)
        # Returns 7x7x4
        self.conv2 = nn.Conv2d(4, 2, 2, stride=1)
        # Returns 6*6*2
        self.conv3 = nn.Conv2d(2, 2, 3)
        # Returns 4*4*2
        
        self.l1 = nn.Linear(32, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return self.l1(x.view(-1, 32))

In [29]:
print ("Fully-connected model parameters:", count_parameters(model1))
print ("Convolutional model parameters:", count_parameters(model2))

Fully-connected model parameters: 4810
Convolutional model parameters: 670


In [32]:
model3 = ConvModel()

In [33]:
count_parameters(model3)

422

In [34]:
def calculate_output(input_shape, kernel_size, padding, stride):
    return ((input_shape + 2*padding - kernel_size) / stride) +1

In [38]:
calculate_output(8, 2, 1, 1)

9.0