In [23]:
import torch
import torch.nn as nn

In [24]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [25]:
class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool):
        super().__init__()

        self.branch1 = ConvBlock(in_channels, out_1x1, kernel_size=1)
        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, red_3x3, kernel_size=1),
            ConvBlock(red_3x3, out_3x3, kernel_size=3, padding=1),
        )
        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, red_5x5, kernel_size=1),
            ConvBlock(red_5x5, out_5x5, kernel_size=5, padding=2),
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channels, out_1x1pool, kernel_size=1),
        )
        
    def forward(self, x):
        # N x Filters x 28 x 28
        return torch.cat([self.branch1(x),
                          self.branch2(x),
                          self.branch3(x),
                          self.branch4(x),
                          ], 1)

In [26]:
class Inception(nn.Module):
    def __init__(self, in_channels=3, num_classes=1000):
        super().__init__()
        self.conv1 = ConvBlock(in_channels=in_channels, out_channels=64,
                               kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))

        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = ConvBlock(64, 192, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool
        self.Inception3a = InceptionBlock(192, 64, 96, 128, 16, 32, 32)
        self.Inception3b = InceptionBlock(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.Inception4a = InceptionBlock(480, 192, 96, 208, 16, 48, 64)
        self.Inception4b = InceptionBlock(512, 160, 112, 224, 24, 64, 64)
        self.Inception4c = InceptionBlock(512, 128, 128, 256, 24, 64, 64)
        self.Inception4d = InceptionBlock(512, 112, 144, 288, 32, 64, 64)
        self.Inception4e = InceptionBlock(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.Inception5a = InceptionBlock(832, 256, 160, 320, 32, 128, 128)
        self.Inception5b = InceptionBlock(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=2)
        self.dropout = nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.maxpool2(x)
        
        x = self.Inception3a(x)
        x = self.Inception3b(x)
        x = self.maxpool3(x)
        
        x = self.Inception4a(x)
        x = self.Inception4b(x)
        x = self.Inception4c(x)
        x = self.Inception4d(x)
        x = self.Inception4e(x)
        x = self.maxpool4(x)
        
        x = self.Inception5a(x)
        x = self.Inception5b(x)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x
        
        

In [28]:
x = torch.randn(3, 3, 244, 244)
model = Inception()
model(x)

tensor([[ 0.0323,  0.0296,  0.0362,  ...,  0.2423, -0.1679,  0.3208],
        [ 0.2622,  0.0755,  0.2153,  ...,  0.3396,  0.0836,  0.2530],
        [ 0.1184,  0.0323,  0.0711,  ...,  0.3488, -0.1086,  0.3295]],
       grad_fn=<AddmmBackward0>)