In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class BasicConv(nn.Module):
    def __init__(self, input_channel, output_channel, *args, **kwargs):
        super().__init__()
        self.flow = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, *args, **kwargs),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.flow(x)
        return x

In [3]:
class Inception(nn.Module):
    def __init__(self, input_channel,
                b1_11_ch,
                b2_11_ch, b2_33_ch,
                b3_11_ch, b3_55_ch,
                b4_11_ch):
        super().__init__()
        self.b1 = BasicConv(input_channel, b1_11_ch, kernel_size=1, stride=1, padding=0)
        self.b2 = nn.Sequential(
            BasicConv(input_channel, b2_11_ch, kernel_size=1, stride=1, padding=0),
            BasicConv(b2_11_ch, b2_33_ch, kernel_size=3, stride=1, padding=1),
        )
        self.b3 = nn.Sequential(
            BasicConv(input_channel, b3_11_ch, kernel_size=1, stride=1, padding=0),
            BasicConv(b3_11_ch, b3_55_ch, kernel_size=5, stride=1, padding=2),
        )
        self.b4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv(input_channel, b4_11_ch, kernel_size=1, stride=1, padding=0)
        )
    def forward(self, x):
        b1 = self.b1(x)
        b2 = self.b2(x)
        b3 = self.b3(x)
        b4 = self.b4(x)
        x = torch.cat([b1, b2, b3, b4], 1)
        return x

In [4]:
# The model includes fewer Inception modules
class SimpleGoogleNet(nn.Module):
    def __init__(self, input_channel=3, num_classes=10):
        super().__init__()
        self.conv_flow = nn.Sequential(
            BasicConv(input_channel, 64, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            BasicConv(64, 64, kernel_size=1, stride=1, padding=0),
            BasicConv(64, 192, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        self.inception = nn.Sequential(
            Inception(192, 64, 96, 128, 12, 32, 32),
            Inception(256, 384, 192, 384, 48, 128, 128),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes),
        )
    def forward(self, x):
        x = self.conv_flow(x)
        x = self.inception(x)
        x = self.pool(x)
        x = self.classifier(x)
        return x

In [6]:
x = torch.randn(1, 3, 224, 224)
model = SimpleGoogleNet()
model(x)

tensor([[-0.0683, -0.1084, -0.2959,  0.2376,  0.3044, -0.5056,  0.6817, -0.0710,
          0.0541,  0.4644]], grad_fn=<AddmmBackward0>)