### This notebook illustrates the implementation of `Inception`  model with Pytorch.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

### The inception model is a big model, thus we need to create sub-blocks that will allow us to take a modular approach to implementing it.

### The `ConvBlock` module is a `convolutional` layer followed by `batch normalization.` A `ReLU` activation is applied after the batchnorm.

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        return F.relu(self.bn(self.conv(x)))

## Define the Inception block. 
![inception model](../inputs/inception_building_block.png)
***
### The basic idea behind `Inception` is that we create multiple convolutional branches, each with a different kernel.

### Please refer to the _README.md_ file to get the visual representation of an Inception block as well as the whole model archirecture.

### The standard kernel sizes are 3:
- 1 by 1
- 3 by 3
- 5 by 5

In [3]:
class InceptionBlock(nn.Module):
    def __init__(self, 
                 in_channels, 
                 out_1x1,
                 red_3x3,
                 out_3x3, 
                 red_5x5,
                 out_5x5,
                 out_pool):
        super(InceptionBlock, self).__init__()
        self.branch1 = ConvBlock(in_channels, out_1x1, kernel_size = 1)
        
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, red_3x3, kernel_size = 1, padding = 0),
            ConvBlock(red_3x3, out_3x3, kernel_size = 3, padding = 1))
        
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, red_5x5, kernel_size = 1),
            ConvBlock(red_5x5, out_5x5, kernel_size = 5, padding = 2))
        
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size = 3, padding = 1, stride = 1),
            ConvBlock(in_channels, out_pool, kernel_size = 1))
            
    
    def forward(self, x):
            branches = (self.branch1, self.branch2, self.branch3, self.branch4)
            return torch.cat([branch(x) for branch in branches], 1)

## Add an auxiliary classifier.

In [4]:
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.dropout = nn.Dropout(p = 0.7)
        self.pool = nn.AvgPool2d(kernel_size = 5, stride = 3)
        self.conv = ConvBlock(in_channels, 128, kernel_size = 1)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

### We have abstracted many of the building blocks of the model as `ConvBlock` or `InceptioBlock`.

In [5]:
class InceptionV1(nn.Module):
    def __init__(self, aux_logits = True, num_classes = 1_000):
        super(InceptionV1, self).__init__()
        self.aux_logits = aux_logits
        self.conv1 = ConvBlock(
            in_channels = 3,
            out_channels = 64, 
            kernel_size = (7, 7),
            stride = (2, 2),
            padding = (3, 3)
        )
        
        self.conv2 = ConvBlock(64, 192, kernel_size = 3, stride = 1, padding = 1)
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.inception3a = InceptionBlock(in_channels = 192, out_1x1 = 64, red_3x3 = 96, out_3x3 = 128, red_5x5 = 16, out_5x5 = 32, out_pool = 32)
        self.inception3b = InceptionBlock(in_channels = 256, out_1x1 = 128, red_3x3 = 128, out_3x3 = 192, red_5x5 = 32, out_5x5 = 96, out_pool = 64)
        
        self.inception4a = InceptionBlock(in_channels = 480, out_1x1 = 192, red_3x3 = 96, out_3x3 = 208, red_5x5 = 16, out_5x5 = 48, out_pool = 64)
        self.inception4b = InceptionBlock(in_channels = 512, out_1x1 = 160, red_3x3 = 112, out_3x3 = 224, red_5x5 = 24, out_5x5 = 64, out_pool = 64)
        self.inception4c = InceptionBlock(in_channels = 512, out_1x1 = 128, red_3x3 = 128, out_3x3 = 256, red_5x5 = 24, out_5x5 = 64, out_pool = 64)
        self.inception4d = InceptionBlock(in_channels = 512, out_1x1 = 112, red_3x3 = 144, out_3x3 = 288, red_5x5 = 32, out_5x5 = 64, out_pool = 64)
        self.inception4e = InceptionBlock(in_channels = 528, out_1x1 = 256, red_3x3 = 160, out_3x3 = 320, red_5x5 = 32, out_5x5 = 128, out_pool = 128)
        
        self.inception5a = InceptionBlock(in_channels = 832, out_1x1 = 256, red_3x3 = 160, out_3x3 = 320, red_5x5 = 32, out_5x5 = 128, out_pool = 128)
        self.inception5b = InceptionBlock(in_channels = 832, out_1x1 = 384, red_3x3 = 192, out_3x3 = 384, red_5x5 = 48, out_5x5 = 128, out_pool = 128)
            
        self.avgpool = nn.AvgPool2d(kernel_size = 7, stride = 1)
        self.dropout = nn.Dropout(p = 0.4)
        self.fc = nn.Linear(1024, num_classes)
        
        
        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)
        else:
            self.aux1 = self.aux2 = None
            
            
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        
        x = self.conv2(x)
        x = self.maxpool(x)
        
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool(x)
        
        x = self.inception4a(x)
        
        if self.aux_logits and self.training:
            aux1 = self.aux1(x)
        
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        
        
        if self.aux_logits and self.training:
            aux2 = self.aux2(x)
            
        x = self.inception4e(x)
        x = self.maxpool(x)
        
        x = self.inception5a(x)
        x = self.inception5b(x)
        x = self.avgpool(x)
        
        x = x.reshape(x.shape[0], -1)
        x = self.dropout(x)
        x = self.fc(x)
        
        if self.aux_logits and self.training:
            return aux1, aux2, x
        
        return x

In [6]:
model = InceptionV1()

summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
       BatchNorm2d-2         [-1, 64, 112, 112]             128
         ConvBlock-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5          [-1, 192, 56, 56]         110,784
       BatchNorm2d-6          [-1, 192, 56, 56]             384
         ConvBlock-7          [-1, 192, 56, 56]               0
         MaxPool2d-8          [-1, 192, 28, 28]               0
            Conv2d-9           [-1, 64, 28, 28]          12,352
      BatchNorm2d-10           [-1, 64, 28, 28]             128
        ConvBlock-11           [-1, 64, 28, 28]               0
           Conv2d-12           [-1, 96, 28, 28]          18,528
      BatchNorm2d-13           [-1, 96, 28, 28]             192
        ConvBlock-14           [-1, 96,

In [7]:
test_input = torch.randn(2, 3, 224, 224)

aux1, aux2, output = model(test_input)
print(output.shape)

torch.Size([2, 1000])
