# Inception Net (PyTorch Implementation)
https://arxiv.org/pdf/1409.4842.pdf

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

![Screenshot%202022-05-19%20at%2012.12.12%20PM.png](attachment:Screenshot%202022-05-19%20at%2012.12.12%20PM.png)

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                              kernel_size=kernel_size, stride=stride, padding=padding)
        self.bnorm = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.conv(x)
        x = self.bnorm(x)
        x = self.relu(x)
        return x

In [3]:
class InceptionBlock(nn.Module):
    def __init__(self, in_channels, no_1x1, no_3x3_red, no_3x3, no_5x5_red, no_5x5, pool_proj):
        '''
            in_channels: The number of channels in the input of the current layer
            no_1x1: The number of filters for the 1x1 convolutions
            no_3x3_red: The number of filters for the 1x1 convolutions before the 3x3 convolutions
            no_3x3: The number of filters in the 3x3 convolutions
            no_5x5_red: The number of filters for the 1x1 convolution before the 5x5 convolutions
            no_5x5: The number of filters in the 5x5 convolutions
            pool_proj: The 1x1 after the maxpooling
        '''
        super(InceptionBlock, self).__init__()
        # Branch 1
        self.conv1x1 = ConvBlock(in_channels=in_channels, out_channels=no_1x1, kernel_size=(1, 1), 
                                 stride=(1,1), padding=(0, 0))
        
        # Branch 2
        self.conv3x3 = nn.Sequential(
            ConvBlock(in_channels=in_channels, out_channels=no_3x3_red, kernel_size=(1, 1), stride=(1, 1), 
                      padding=(0, 0)),
            ConvBlock(in_channels=no_3x3_red, out_channels=no_3x3, kernel_size=(3, 3), stride=(1, 1), 
                      padding=(1, 1))
        )
        
        # Branch 3
        self.conv5x5 = nn.Sequential(
            ConvBlock(in_channels=in_channels, out_channels=no_5x5_red, kernel_size=(1, 1), stride=(1, 1), 
                      padding=(0, 0)),
            ConvBlock(in_channels=no_5x5_red, out_channels=no_5x5, kernel_size=(5, 5), stride=(1, 1), 
                      padding=(2, 2))
        )
        
        # Branch 4
        self.maxpool = nn.Sequential(
            nn.MaxPool2d(kernel_size=(3, 3), stride=1, padding=1),
            ConvBlock(in_channels=in_channels, out_channels=pool_proj, kernel_size=(1, 1), stride=(1, 1), 
                      padding=(0, 0))
        )
    
    def forward(self, x):
        x1 = F.relu(self.conv1x1(x))
        print(f'    branch 1: {x1.shape}')
        x2 = F.relu(self.conv3x3(x))
        print(f'    branch 2: {x2.shape}')
        x3 = F.relu(self.conv5x5(x))
        print(f'    branch 3: {x3.shape}')
        x4 = F.relu(self.maxpool(x))
        print(f'    branch 4: {x4.shape}')
        y = torch.cat([x1, x2, x3, x4], axis=1)
        print(f'    concat: {y.shape}')
        return y

![Screenshot%202022-05-19%20at%2012.12.31%20PM.png](attachment:Screenshot%202022-05-19%20at%2012.12.31%20PM.png)

In [4]:
class GoogLeNet(nn.Module):
    def __init__(self, in_channels=3, out_classes=1000):
        '''
            Input shape: (in_channels, 224, 224)
        '''
        super(GoogLeNet, self).__init__()
        self.conv1 = ConvBlock(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), 
                               padding=(3, 3))
        self.maxpool1 = nn.MaxPool2d(kernel_size=(3, 3), stride=2, padding=1)
        self.conv2 = ConvBlock(in_channels=64, out_channels=192, kernel_size=(3, 3), stride=(1, 1), 
                               padding=(1, 1))
        self.maxpool2 = nn.MaxPool2d(kernel_size=(3, 3), stride=2, padding=1)
        self.inception3a = InceptionBlock(in_channels=192, no_1x1=64, no_3x3_red=96, no_3x3=128, 
                                          no_5x5_red=16, no_5x5=32, pool_proj=32)
        self.inception3b = InceptionBlock(in_channels=256, no_1x1=128, no_3x3_red=128, no_3x3=192, 
                                          no_5x5_red=32, no_5x5=96, pool_proj=64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=(3, 3), stride=2, padding=1)
        self.inception4a = InceptionBlock(in_channels=480, no_1x1=192, no_3x3_red=96, no_3x3=208, 
                                          no_5x5_red=16, no_5x5=48, pool_proj=64)
        self.inception4b = InceptionBlock(in_channels=512, no_1x1=160, no_3x3_red=112, no_3x3=224, 
                                          no_5x5_red=24, no_5x5=64, pool_proj=64)
        self.inception4c = InceptionBlock(in_channels=512, no_1x1=128, no_3x3_red=128, no_3x3=256, 
                                          no_5x5_red=24, no_5x5=64, pool_proj=64)
        self.inception4d = InceptionBlock(in_channels=512, no_1x1=112, no_3x3_red=144, no_3x3=288, 
                                          no_5x5_red=32, no_5x5=64, pool_proj=64)
        self.inception4e = InceptionBlock(in_channels=528, no_1x1=256, no_3x3_red=160, no_3x3=320, 
                                          no_5x5_red=32, no_5x5=128, pool_proj=128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=(3, 3), stride=2, padding=1)
        self.inception5a = InceptionBlock(in_channels=832, no_1x1=256, no_3x3_red=160, no_3x3=320, 
                                          no_5x5_red=32, no_5x5=128, pool_proj=128)
        self.inception5b = InceptionBlock(in_channels=832, no_1x1=384, no_3x3_red=192, no_3x3=384, 
                                          no_5x5_red=48, no_5x5=128, pool_proj=128)
        self.avgpool = nn.AvgPool2d(kernel_size=(7, 7), stride=1)
        self.dropout = nn.Dropout(p=0.4)
        self.linear = nn.Linear(in_features=1024, out_features=1000)
    
    def forward(self, x):
        print(f'x: {x.shape}')
        x = self.conv1(x)
        print(f'- x after conv1: {x.shape}')
        x = self.maxpool1(x)
        print(f'- x after maxpool1: {x.shape}')
        x = self.conv2(x)
        print(f'- x after conv2: {x.shape}')
        x = self.maxpool2(x)
        print(f'- x after maxpool1: {x.shape}')
        x = self.inception3a(x)
        print(f'- x after inception3a: {x.shape}')
        x = self.inception3b(x)
        print(f'- x after inception3b: {x.shape}')
        x = self.maxpool3(x)
        print(f'- x after maxpool3: {x.shape}')
        x = self.inception4a(x)
        print(f'- x after inception4a: {x.shape}')
        x = self.inception4b(x)
        print(f'- x after inception4b: {x.shape}')
        x = self.inception4c(x)
        print(f'- x after inception4c: {x.shape}')
        x = self.inception4d(x)
        print(f'- x after inception4d: {x.shape}')
        x = self.inception4e(x)
        print(f'- x after inception4e: {x.shape}')
        x = self.maxpool4(x)
        print(f'- x after maxpool4: {x.shape}')
        x = self.inception5a(x)
        print(f'- x after inception5a: {x.shape}')
        x = self.inception5b(x)
        print(f'- x after inception5b: {x.shape}')
        x = self.avgpool(x)
        print(f'- x after avgpool: {x.shape}')
        x = self.dropout(x)
        print(f'- x after dropout: {x.shape}')
        x = x.view(-1, 1024)
        print(f'- x after reshape: {x.shape}')
        x = self.linear(x)
        print(f'- x after linear: {x.shape}')
        return x

In [5]:
X = torch.randn((10, 3, 224, 224))
X.shape

torch.Size([10, 3, 224, 224])

In [6]:
model = GoogLeNet()
y = model.forward(X)

x: torch.Size([10, 3, 224, 224])
- x after conv1: torch.Size([10, 64, 112, 112])
- x after maxpool1: torch.Size([10, 64, 56, 56])
- x after conv2: torch.Size([10, 192, 56, 56])
- x after maxpool1: torch.Size([10, 192, 28, 28])
    branch 1: torch.Size([10, 64, 28, 28])
    branch 2: torch.Size([10, 128, 28, 28])
    branch 3: torch.Size([10, 32, 28, 28])
    branch 4: torch.Size([10, 32, 28, 28])
    concat: torch.Size([10, 256, 28, 28])
- x after inception3a: torch.Size([10, 256, 28, 28])
    branch 1: torch.Size([10, 128, 28, 28])
    branch 2: torch.Size([10, 192, 28, 28])
    branch 3: torch.Size([10, 96, 28, 28])
    branch 4: torch.Size([10, 64, 28, 28])
    concat: torch.Size([10, 480, 28, 28])
- x after inception3b: torch.Size([10, 480, 28, 28])
- x after maxpool3: torch.Size([10, 480, 14, 14])
    branch 1: torch.Size([10, 192, 14, 14])
    branch 2: torch.Size([10, 208, 14, 14])
    branch 3: torch.Size([10, 48, 14, 14])
    branch 4: torch.Size([10, 64, 14, 14])
    concat: 