In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### nn.Module
- Base class for neural network modules

In [10]:
class BasicClassifier(nn.Module):
    def __init__(self, in_ch, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=32, kernel_size=(3, 3))
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
        self.bn2 = nn.BatchNorm2d(64)

        self.fc1 = nn.Linear(in_features=28*28*64, out_features=1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = x.view(x.shape[0], -1) #flatten tensor
        x = self.fc1(x)
        x = self.fc2(x)

        return x

Problems with the above approach
- If we need to add layers, we need to write them in the __init__ method along with corresponding forward function
- Also, if we are using the same set of (Conv, BatchNorm, Relu) layers in another Module, we need to write them again

### NN.Sequential
- Let's reduce the number of lines in the forward method by wrapping the layers in a nn.Sequential container

In [11]:
class BasicClassifier(nn.Module):
    def __init__(self, in_ch, num_classes):
        super().__init__()
        self.conv_block_1 = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=32, kernel_size=(3, 3)), 
                                         nn.BatchNorm2d(32), 
                                         nn.ReLU())
        self.conv_block_2 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3)), 
                                          nn.BatchNorm2d(64), 
                                         nn.ReLU())
        self.fc_block = nn.Sequential(nn.Linear(in_features=28*28*64, out_features=1024), 
                                     nn.Linear(1024, num_classes))


    def forward(self, x):
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)

        x = x.view(x.shape[0], -1) #flatten tensor
        x = self.fc_block(x)
        return x


Problems with the above approach
- conv_block_1 and conv_block_2 have the same layers. We can write a function to return a nn.Sequential to further simplify

In [12]:
def conv_bn(in_ch, out_ch, *args, **kwargs):
    return nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=out_ch, *args, **kwargs), 
                        nn.BatchNorm2d(out_ch), 
                        nn.ReLU())

In [13]:
class BasicClassifier(nn.Module):
    def __init__(self, in_ch, num_classes):
        super().__init__()
        self.conv_block_1 = conv_bn(3, 32, kernel_size=(3, 3))
        self.conv_block_2 = conv_bn(32, 64, kernel_size=(3, 3))
        self.fc_block = nn.Sequential(nn.Linear(in_features=28*28*64, out_features=1024), 
                                     nn.Linear(1024, num_classes))


    def forward(self, x):
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)

        x = x.view(x.shape[0], -1) #flatten tensor
        x = self.fc_block(x)
        return x

- Merge the two conv blocks to a single sequential container

In [14]:
class BasicClassifier(nn.Module):
    def __init__(self, in_ch, num_classes):
        super().__init__()
        self.encoder = nn.Sequential(conv_bn(3, 32, kernel_size=(3, 3)), 
                                    conv_bn(32, 64, kernel_size=(3, 3)))
        self.fc_block = nn.Sequential(nn.Linear(in_features=28*28*64, out_features=1024), 
                                     nn.Linear(1024, num_classes))


    def forward(self, x):
        x = self.encoder(x)

        x = x.view(x.shape[0], -1) #flatten tensor
        x = self.fc_block(x)
        return x

- We have hardcoded the number of in and out channels for the conv layers. Can we make it better?

In [15]:
class BasicClassifier(nn.Module):
    def __init__(self, in_ch, num_classes):
        super().__init__()
        self.enc_sizes = [in_ch, 32, 64]
        conv_blocks = [conv_bn(in_c, out_c, kernel_size=(3, 3)) for in_c, out_c in zip(self.enc_sizes, self.enc_sizes[1:])]
        self.encoder = nn.Sequential(*conv_blocks)
        self.fc_block = nn.Sequential(nn.Linear(in_features=28*28*64, out_features=1024), 
                                     nn.Linear(1024, num_classes))


    def forward(self, x):
        x = self.encoder(x)

        x = x.view(x.shape[0], -1) #flatten tensor
        x = self.fc_block(x)
        return x

- We can pass the