In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [4]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels, eps=.001)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

In [6]:
class InceptionV2(nn.Module):
    def __init__(self):
        super(InceptionV2, self).__init__()
        self.branch_1x1 = BasicConv2d(192, 96, 1)
        self.branch_2x3 = nn.Sequential(
            BasicConv2d(192, 48, 1),
            BasicConv2d(48, 64, 3, 1)
        )
        self.branch_3x3 = nn.Sequential(
            BasicConv2d(192, 64, 1),
            BasicConv2d(64, 96, 3, 1),
            BasicConv2d(96, 96, 3, 1)
        )
        self.branch_pooling = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
            BasicConv2d(192, 64, 1)
        )
        
    def forward(self, x):
        b1 = self.branch_1x1(x)
        b2 = self.branch_2x3(x)
        b3 = self.branch_3x3(x)
        b4 = self.branch_pooling(x)
        out = torch.cat((b1, b2, b3, b4), dim=1)
        return out

In [7]:
inception = InceptionV2()

In [10]:
data = torch.randn(1, 192, 32, 32)

In [11]:
out = inception(data)
out.shape

torch.Size([1, 320, 32, 32])