In [1]:
import torch
from torch import nn
from torchinfo import summary

In [18]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(num_features=out_channel),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.layer(x)
        return x
    
class InceptionModule(nn.Module):
    def __init__(self, in_channel, ch_1x1, ch_3x3_red, ch_3x3, ch_5x5_red, ch_5x5, pool_proj):
        super().__init__()

        self.branch_1 = BasicConv2d(in_channel, ch_1x1, kernel_size=1)

        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channel, ch_3x3_red, kernel_size=1),
            BasicConv2d(ch_3x3_red, ch_3x3, kernel_size=3, padding=1)
        )

        self.branch_3 = nn.Sequential(
            BasicConv2d(in_channel, ch_5x5_red, kernel_size=1),
            BasicConv2d(ch_5x5_red, ch_5x5, kernel_size=5, padding=2)
        )

        self.branch_4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channel, pool_proj, kernel_size=1)
        )
    
    def forward(self, x):
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)
        branch_3 = self.branch_3(x)
        branch_4 = self.branch_4(x)

        output = [branch_1, branch_2, branch_3, branch_4]
        return torch.cat(output, dim=1)

class InceptionAux(nn.Module):
    def __init__(self, in_channel, num_classes):
        super().__init__()

        self.AvgPool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channel, out_channel=128, kernel_size=1)
        self.fc = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_classes)
        )
    
    def forward(self, x):
        x = self.AvgPool(x)
        x = self.conv(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

In [57]:
class InceptionV1(nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()

        self.conv_1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool_1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv_2a = BasicConv2d(64, 64, kernel_size=1)
        self.conv_2b = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool_2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception_3a = InceptionModule(192, 64, 96, 128, 16, 32, 32)
        self.inception_3b = InceptionModule(256, 128, 128, 192, 32, 96, 64)
        self.maxpool_3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception_4a = InceptionModule(480, 192, 96, 208, 16, 48, 64)
        self.aux1 = InceptionAux(512, num_classes)
        self.inception_4b = InceptionModule(512, 160, 112, 224, 24, 64, 64)
        self.inception_4c = InceptionModule(512, 128, 128, 256, 24, 64, 64)
        self.inception_4d = InceptionModule(512, 112, 144, 288, 32, 64, 64)
        self.aux2 = InceptionAux(528, num_classes)
        self.inception_4e = InceptionModule(528, 256, 160, 320, 32, 128, 128)
        self.maxpool_4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception_5a = InceptionModule(832, 256, 160, 320, 32, 128, 128)
        self.inception_5b = InceptionModule(832, 384, 192, 384, 48, 128, 128)

        self.GlobalAvgPooling = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(1024, num_classes)

        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_uniform_(module.weight)
                nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0, std=0.01)
                nn.init.constant_(module.bias, 0.0)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.maxpool_1(x)

        x = self.conv_2a(x)
        x = self.conv_2b(x)
        x = self.maxpool_2(x)

        x = self.inception_3a(x)
        x = self.inception_3b(x)
        x = self.maxpool_3(x)
    
        x = self.inception_4a(x)
        aux1 = self.aux1(x) if self.training else None
        x = self.inception_4b(x)
        x = self.inception_4c(x)
        x = self.inception_4d(x)
        aux2 = self.aux2(x) if self.training else None

        x = self.inception_4e(x)
        x = self.maxpool_4(x)

        x = self.inception_5a(x)
        x = self.inception_5b(x)

        x = self.GlobalAvgPooling(x)
        x = self.dropout(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        
        return x, aux2, aux1

model = InceptionV1()
summary(model, (32, 3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
InceptionV1                                   [32, 1000]                6,380,240
├─BasicConv2d: 1-1                            [32, 64, 112, 112]        --
│    └─Sequential: 2-1                        [32, 64, 112, 112]        --
│    │    └─Conv2d: 3-1                       [32, 64, 112, 112]        9,472
│    │    └─BatchNorm2d: 3-2                  [32, 64, 112, 112]        128
│    │    └─ReLU: 3-3                         [32, 64, 112, 112]        --
├─MaxPool2d: 1-2                              [32, 64, 56, 56]          --
├─BasicConv2d: 1-3                            [32, 64, 56, 56]          --
│    └─Sequential: 2-2                        [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-4                       [32, 64, 56, 56]          4,160
│    │    └─BatchNorm2d: 3-5                  [32, 64, 56, 56]          128
│    │    └─ReLU: 3-6                         [32, 64, 56, 56]          --
├─Bas

In [58]:
model.train()
pred_y, aux2, aux1 = model(torch.randn(32, 3, 224, 224, device='cuda'))

print(pred_y.shape)
print(aux2.shape)
print(aux1.shape)

torch.Size([32, 1000])
torch.Size([32, 1000])
torch.Size([32, 1000])


In [59]:
model.eval()
pred_y, aux2, aux1 = model(torch.randn(32, 3, 224, 224, device='cuda'))

print(pred_y.shape)
print(aux2)
print(aux1)

torch.Size([32, 1000])
None
None
