In [2]:
import os, sys
import torch
import torchvision
import torch.nn as nn

### Setting Inception block

In [3]:
class basic_conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, inplace, **kwarg):
        super(basic_conv2d, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwarg)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.leaky_relu(out)
        
        return out

class inception_block(nn.Module):
    def __init__(self, input_depth):
        super(inception_block, self).__init__()
        self.branch1, self.branch2, self.branch3, self.branch4 = self.make_branches(input_depth=input_depth)
        
        
        
    def forward(self, x):
        branch1_out = self.branch1(x)
        branch2_out = self.branch2(x)
        branch3_out = self.branch3(x)
        branch4_out = self.branch4(x)
        
        out = [branch1_out, branch2_out, branch3_out, branch4_out]
        
        return torch.cat(tensors=out, dim=1)
        
    def make_branches(self, input_depth):
        branch1 = basic_conv2d(in_channels=input_depth, out_channels=64, inplace=False, kernel_size=(1,1), padding=(0,0))
        
        branch2_1x1 = basic_conv2d(in_channels=input_depth, out_channels=96, inplace=False, kernel_size=(1,1), padding=(0,0))
        branch2_3x3 = basic_conv2d(in_channels=96, out_channels=128, inplace=False, kernel_size=(3,3), padding=(1,1))
        
        branch3_1x1 = basic_conv2d(in_channels=input_depth, out_channels=16, inplace=False, kernel_size=(1,1), padding=(0,0))
        branch3_5x5 = basic_conv2d(in_channels=16, out_channels=32, inplace=False, kernel_size=(5,5), padding=(2,2))
        
        branch4_pool = nn.MaxPool2d(kernel_size=(3,3), stride=(1,1), padding=(1,1))
        branch4_1x1 = basic_conv2d(in_channels=input_depth, out_channels=32, inplace=False, kernel_size=(1,1), padding=(0,0))
        
        branch1 = nn.Sequential(branch1)
        branch2 = nn.Sequential(branch2_1x1, branch2_3x3)
        branch3 = nn.Sequential(branch3_1x1, branch3_5x5)
        branch4 = nn.Sequential(branch4_pool, branch4_1x1)
        
        return branch1, branch2, branch3, branch4

### Testing on Mnist

In [4]:
mnist_trainset = torchvision.datasets.MNIST(root='./data',
                                            train=True,
                                            download=True,
                                            transform=torchvision.transforms.ToTensor())
mnist_loader = torch.utils.data.DataLoader(dataset=mnist_trainset, batch_size=64)
x, y = mnist_loader.__iter__().__next__()

In [5]:
model = inception_block(input_depth=1)
model

inception_block(
  (branch1): Sequential(
    (0): basic_conv2d(
      (conv): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
  )
  (branch2): Sequential(
    (0): basic_conv2d(
      (conv): Conv2d(1, 96, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
    (1): basic_conv2d(
      (conv): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
  )
  (branch3): Sequential(
    (0): basic_conv2d(
      (conv): Conv2d(1, 16, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 

In [6]:
out = model(x)
out.size()

torch.Size([64, 256, 28, 28])