In [1]:
import os, sys
import torch
import torchvision
import torch.nn as nn

### Setting block

In [10]:
### Here we take InceptionResnet-A block as an example ###
# Reference: https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/inceptionresnetv2.py
class Basic_conv2d_ReLU(nn.Module):
    def __init__(self, in_channels, out_channels, inplace, **kwarg):
        super(Basic_conv2d_ReLU, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwarg)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=inplace)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        
        return out
    
class Basic_conv2d_LeakyReLU(nn.Module):
    def __init__(self, in_channels, out_channels, inplace, **kwarg):
        super(Basic_conv2d_LeakyReLU, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwarg)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.leaky_relu(out)
        
        return out
    
"""A block"""
class InceptionResnetA_Block(nn.Module):
    def __init__(self, in_channels, out_channels, scale=1.0):
        super(InceptionResnetA_Block, self).__init__()
        
        self.scale = scale
        
        self.branch1 = Basic_conv2d_LeakyReLU(in_channels=in_channels,
                                              out_channels=32,
                                              inplace=False,
                                              kernel_size=(1,1), 
                                              stride=(1,1),
                                              padding=(0,0))
        
        self.branch2 = nn.Sequential(
            Basic_conv2d_LeakyReLU(in_channels, 32, kernel_size=(1,1), inplace=False, stride=1, padding=0),
            Basic_conv2d_LeakyReLU(32, 32, kernel_size=(3,3), inplace=False, stride=1, padding=1)
        )
        
        self.branch3 = nn.Sequential(
            Basic_conv2d_LeakyReLU(in_channels, 32, kernel_size=(1,1), inplace=False, stride=1, padding=0),
            Basic_conv2d_LeakyReLU(32, 48, kernel_size=(3,3), inplace=False, stride=1, padding=1),
            Basic_conv2d_LeakyReLU(48, 64, kernel_size=(3,3), inplace=False, stride=1, padding=1)
        )
        
        self.linear = nn.Conv2d(32+32+64, out_channels, kernel_size=(1,1), stride=1, padding=0)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=False)
        
    def forward(self, x):
        x_shortcut = x
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        
        out = torch.cat((x1, x2, x3), dim=1)
        out = self.linear(out)
        
        out = out*self.scale + x_shortcut
        out = self.leaky_relu(out)
        
        return out

### Testing on Mnist

In [4]:
mnist_trainset = torchvision.datasets.MNIST(root='./data',
                                            train=True,
                                            download=True,
                                            transform=torchvision.transforms.ToTensor())
mnist_loader = torch.utils.data.DataLoader(dataset=mnist_trainset, batch_size=64)
x, y = mnist_loader.__iter__().__next__()

In [11]:
model = InceptionResnetA_Block(in_channels=1, out_channels=384)
model

InceptionResnetA_Block(
  (branch1): Basic_conv2d_LeakyReLU(
    (conv): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (leaky_relu): LeakyReLU(negative_slope=0.2)
  )
  (branch2): Sequential(
    (0): Basic_conv2d_LeakyReLU(
      (conv): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
    (1): Basic_conv2d_LeakyReLU(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
  )
  (branch3): Sequential(
    (0): Basic_conv2d_LeakyReLU(
      (conv): Conv2d(1, 32, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_s

In [13]:
out = model(x)
out.size()

torch.Size([64, 384, 28, 28])

### Reference:
1. [Inception_V3:Paper](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf)
2. [Inception_V3: Review Article](https://medium.com/@sh.tsang/review-inception-v3-1st-runner-up-image-classification-in-ilsvrc-2015-17915421f77c)
3. [Inception_V4 and Inception_ResNet: Paper](https://arxiv.org/pdf/1602.07261.pdf)
4. [Inception_V4 and Inception_ResNet: Review Article-1](https://towardsdatascience.com/review-inception-v4-evolved-from-googlenet-merged-with-resnet-idea-image-classification-5e8c339d18bc)
5. [Inception_V4 and Inception_ResNet: Review Article-2](https://zhuanlan.zhihu.com/p/32888084)
6. [InceptionResNet_V2: Review Article](https://lizonghang.github.io/2018/05/23/Inception-ResNet-v2/)
7. [InceptionResNet_V2: Pytorch Tutorial](https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/inceptionresnetv2.py)