<a href="https://colab.research.google.com/github/Parinita-Jain/Applied_CV_using_DL/blob/main/4_2_Building_your_own_Inception_block_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# importing torch library
import torch
import torch.nn as nn

![alt text](https://drive.google.com/uc?id=1TFsgpU85odiPUff0gomm5POKaXOUonwr)

In [None]:
# defining the inception module 
class Inception(nn.Module):
    def __init__(self, in_channels):
        super(Inception, self).__init__()
        
        # defining the first 1X1 branch of inception module
        self.branch1 = nn.Conv2d(in_channels, 16, kernel_size=1)

        # defining the second 3X3 branch with dimensionality reduction module
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=1),
            nn.Conv2d(16, 24, kernel_size=3, padding=1)
        )

        # defining the third 5X5 branch with dimensionality reduction module
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=1),
            nn.Conv2d(16, 24, kernel_size=5, padding=2)
        )

        # defining the fourth max pooling branch with dimensionality reduction module
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            # ceil_mode: https://stackoverflow.com/questions/59906456/in-pytorchs-maxpool2d-is-padding-added-depending-on-ceil-mode
            nn.Conv2d(in_channels, 3, kernel_size=1)
        )

    # defining the forward pass
    def _forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        # storing the output in a list
        outputs = [branch1, branch2, branch3, branch4]
        return outputs

    def forward(self, x):
        outputs = self._forward(x)
        # returning the concatenated output
        return torch.cat(outputs, 1)

In [None]:
# summary of the defined inception module
Inception(in_channels=3)

Inception(
  (branch1): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
  (branch2): Sequential(
    (0): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(16, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (branch3): Sequential(
    (0): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(16, 24, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
  (branch4): Sequential(
    (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
    (1): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [None]:
# initializing a random image
random_image = torch.randn(1, 3, 224, 224)
random_image.shape

torch.Size([1, 3, 224, 224])

In [None]:
# passing the image to the inception block and generating the output
output = Inception(in_channels=random_image.shape[1]).forward(random_image)

In [None]:
# shape of the output
output.shape

torch.Size([1, 67, 224, 224])

In [None]:
# generating the branch-wise output
branch_output = Inception(in_channels=random_image.shape[1])._forward(random_image)

In [None]:
# length of branch_output
len(branch_output)

4

In [None]:
# shape of output from the first 1 X 1 branch
branch_output[0].shape

torch.Size([1, 16, 224, 224])

In [None]:
# shape of output from the second 3 X 3 brach
branch_output[1].shape

torch.Size([1, 24, 224, 224])

In [None]:
# shape of output from the third 5 X 5 brach
branch_output[2].shape

torch.Size([1, 24, 224, 224])

In [None]:
# shape of output from the fourth max pool brach
branch_output[3].shape

torch.Size([1, 3, 224, 224])