# Custom Inception Model from scratch

Author: Tharindu Yakkala

References:
Going Deeper with Convolutions
https://arxiv.org/abs/1409.4842

In [54]:
import torch
from torch import nn
import torchvision

In [90]:
"""
Building Blocks of the Inception Layer
"""

### 1x1 Convolution block -> 3x3 Conv
class Conv1x1_3x3(nn.Module):
    def __init__(self, input_channels:int , mid_channels: int, output_channels:int) -> torch.Tensor:
        """Instantiate a Convblock that takes image through a 1x1 Conv
        and reduced channels to mid_channels, then through a 3x3 Conv2d
        block which outputs 'output_channels'

        Args:
            input_channels (int): Image imput channels.
            mid_channels (int): mid channels for the 1x1 conv.
            output_channels (int): out channels for the 3x3 conv.
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, output_channels, x, y)
        """
        super().__init__()
        self.conv_block3x = nn.Sequential(
            nn.Conv2d(in_channels=input_channels, out_channels=mid_channels, kernel_size=1),
            nn.Conv2d(in_channels=mid_channels, out_channels=output_channels, kernel_size=3, padding=1)
        )
    
    def forward(self, x):
        return self.conv_block3x(x)
    

### 1x1 Convolution block -> 5x5 Conv
class Conv1x1_5x5(nn.Module):
    def __init__(self, input_channels: int, mid_channels: int, output_channels: int) -> torch.Tensor:
        """Instiantiate a convblock, it takes an input image of input_channels
        into a 1x1 conv and reduces channels to 'mid_channels' then into a
        5x5 conv that outputs 'output_channels'

        Args:
            input_channels (int): Image input channels.
            mid_channels (int): 1x1 conv output channels.
            output_channels (int): 5x5 conv output channels.

        Returns:
            torch.Tensor: Conv block output of shape (batch_size, output_channels, x, y).
        """
        super().__init__()
        self.conv_block5x = nn.Sequential(
            nn.Conv2d(in_channels=input_channels, out_channels=mid_channels, kernel_size=1),
            nn.Conv2d(in_channels=mid_channels, out_channels=output_channels, kernel_size=5, padding=2)
        )
    
    def forward(self, x):
        return self.conv_block5x(x)

class Pool3x3_Conv1x1(nn.Module):
    def __init__(self, input_channels: int, output_channels:int) -> torch.Tensor:
        """Instantiate a convblock that takes takes image through a MaxPool 3x3 and 
        maintains same image shape, then into a 1x1 conv block that outputs 'output_channels'

        Args:
            input_channels (int): Image input channels.
            output_channels (int): Image output channels.

        Returns:
            torch.Tensor: Tensor of shape (batch_size, output_channels, x, y)
        """
        super().__init__()
        self.pool_conv = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=1, stride=1, padding=0)
        )
    
    def forward(self, x):
        return self.pool_conv(x)
        
class Inception_block(nn.Module):
    def __init__(self, input_channels: int, x1_out: int, x3_mid: int, x3_out: int, x5_mid: int, x5_out: int, pool_out: int) -> torch.Tensor:
        """Instantiate an inception block that contains 4 neural networks in one.
        Args:
            input_channels (int): Input channels of image.
            x1_out (int): Output channels of the 1x1 conv (block1).
            x3_mid (int): Mid channels of the 1x1conv_3x3conv (block2).
            x3_out (int): Output channels of the 1x1conv_3x3conv (block2).
            x5_mid (int): Mid channels of the 1x1Conv_5x5Conv (block3).
            x5_out (int): Output channels of the 1x1Conv_5x5Conv (block3).
            pool_out (int): Output channels of the 3x3Pool_1x1Conv (block4).

        Returns:
            torch.Tensor: Tensor of shape (batch_size, x1_out + x3_out + x5_out + pool_out, x, y)
        """
        super().__init__()
        self.block1 = nn.Conv2d(in_channels=input_channels, out_channels=x1_out, kernel_size=1, stride=1, padding=0)
        self.block2 = Conv1x1_3x3(input_channels=input_channels, mid_channels=x3_mid, output_channels=x3_out)
        self.block3 = Conv1x1_5x5(input_channels=input_channels, mid_channels=x5_mid, output_channels=x5_out)
        self.block4 = Pool3x3_Conv1x1(input_channels=input_channels, output_channels=pool_out)
    
    def forward(self, x):
        block1_out = self.block1(x)
        block2_out = self.block2(x)
        block3_out = self.block3(x)
        block4_out = self.block4(x)
        
        # contatination on channel dim, bring the inner neural nets and combine them.
        return torch.concat([block1_out, block2_out, block3_out, block4_out], dim=1)
        

In [89]:
"""
The mini-inception model
"""

class inception_mini(nn.Module):
    def __init__(self, image_channels: int) -> torch.Tensor:
        """Instantiate a custom inception model for binary classification.

        Returns:
            torch.Tensor: Tensor of shape (batch_size, 1), output logits.
        """
        super().__init__()
        
        # first conv layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=image_channels, out_channels=192,  kernel_size=3, stride=1),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
        )
        
        # first inception layer, consisting of 4 layers in one.
        self.inception1 = Inception_block(input_channels=192,
                                          x1_out=64,
                                          x3_mid=96,
                                          x3_out=128,
                                          x5_mid=16,
                                          x5_out=32,
                                          pool_out=32)
        
        # second inception layer, also 4 layers in one.
        self.inception2 = Inception_block(input_channels=256,
            x1_out=64,
            x3_mid=96,
            x3_out=128,
            x5_mid=16,
            x5_out=32,
            pool_out=32
        ) 
        
        # adding bottleneck layer to further optimize speed
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=64, kernel_size=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
        )
        
        self.adaptive = nn.Sequential(
              nn.AdaptiveAvgPool2d((1,1)),
              nn.Flatten(start_dim=1)
        )
             
        self.out = nn.Sequential(
            nn.Linear(in_features=64, out_features=10)
        )
    
    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.inception1(x1)
        x1 = self.inception2(x1)
        x1 = self.bottleneck(x1)
        x1 = self.adaptive(x1)
        return self.out(x1)

In [112]:
model = inception_mini(image_channels=1) #channels=1 for MNIST

In [92]:
try:
    from torchinfo import summary
except:
    !pip install torchinfo
    from torchinfo import summary

In [93]:
summary(model, input_size=(32, 1, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
inception_mini                           [32, 10]                  --
├─Sequential: 1-1                        [32, 192, 13, 13]         --
│    └─Conv2d: 2-1                       [32, 192, 26, 26]         1,920
│    └─BatchNorm2d: 2-2                  [32, 192, 26, 26]         384
│    └─ReLU: 2-3                         [32, 192, 26, 26]         --
│    └─MaxPool2d: 2-4                    [32, 192, 13, 13]         --
├─Inception_block: 1-2                   [32, 256, 13, 13]         --
│    └─Conv2d: 2-5                       [32, 64, 13, 13]          12,352
│    └─Conv1x1_3x3: 2-6                  [32, 128, 13, 13]         --
│    │    └─Sequential: 3-1              [32, 128, 13, 13]         129,248
│    └─Conv1x1_5x5: 2-7                  [32, 32, 13, 13]          --
│    │    └─Sequential: 3-2              [32, 32, 13, 13]          15,920
│    └─Pool3x3_Conv1x1: 2-8              [32, 32, 13, 13]          -

In [98]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms

In [113]:
transform = transforms.Compose([transforms.ToTensor()])
train = MNIST(root="./data_mnist", train=True, download=True, transform=transform)
test = MNIST(root="./data_mnist", train=False, download=True, transform=transform)


In [114]:
train_dataloader = DataLoader(train, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test, batch_size=32, shuffle=True)

In [127]:
# test one batch before any training
get_batch = next(iter(train_dataloader))
try:
    with torch.inference_mode():
        pred_logits = model(get_batch[0])
        pred_proba = torch.softmax(pred_logits, dim=1)
        pred = torch.argmax(pred_proba, dim=1)
        print(pred)
except Exception as e:
    print(e,"\nInput batch didn't pass through model correctly")

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5])
