In [2]:
import torch
from torch import nn
from torchinfo import summary

In [4]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0):
        super().__init__()

        self.conv_blk = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.conv_blk(x)
        return x

![](./images/incpetion_v4/inception_v4_stem.png)

In [48]:
class Stem(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            BasicConv2d(3, 32, 3, stride=2),
            BasicConv2d(32, 32, 3),
            BasicConv2d(32, 64, 3, padding=1)
        )

        self.inception1_1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.inception1_2 = BasicConv2d(64, 96, kernel_size=3, stride=2)

        self.inception2_1 = nn.Sequential(
            BasicConv2d(160, 64, kernel_size=1),
            BasicConv2d(64, 96, kernel_size=3)
        )
        self.inception2_2 = nn.Sequential(
            BasicConv2d(160, 64, kernel_size=1),
            BasicConv2d(64, 64, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(64, 64, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(64, 96, kernel_size=3)
        )

        self.inception3_1 = BasicConv2d(192, 192, 3, stride=2)
        self.inception3_2 = nn.MaxPool2d(kernel_size=3, stride=2)
    
    def forward(self, x):
        x = self.conv1(x)
        
        x = [self.inception1_1(x), self.inception1_2(x)]
        x = torch.cat(x, dim=1)

        x = [self.inception2_1(x), self.inception2_2(x)]
        x = torch.cat(x, dim=1)

        x = [self.inception3_1(x), self.inception3_2(x)]
        x = torch.cat(x, dim=1)

        return x

# model = Stem()
# summary(model, (2, 3, 299, 299))

![](./images/incpetion_v4/inception_v4_diagram.png)

In [30]:
class InceptionA(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch1 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channel, 96, 1)
        )

        self.branch2 = BasicConv2d(in_channel, 96, 1)

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, 64, 1),
            BasicConv2d(64, 96, 3, padding=1)
        )

        self.branch4 = nn.Sequential(
            BasicConv2d(in_channel, 64, 1),
            BasicConv2d(64, 96, 3, padding=1),
            BasicConv2d(96, 96, 3, padding=1)
        )
    
    def forward(self, x):
        x = [self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)]
        x = torch.cat(x, dim=1)
        return x

# model = InceptionA(384)
# summary(model, (2, 384, 35, 35))

In [31]:
class ReductionA(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch2 = BasicConv2d(in_channel, 384, 3, stride=2)
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, 192, 1),
            BasicConv2d(192, 224, 3, padding=1),
            BasicConv2d(224, 256, 3, stride=2)
        )
    
    def forward(self, x):
        x = [self.branch1(x), self.branch2(x), self.branch3(x)]
        x = torch.cat(x, dim=1)
        return x
    
# model = ReductionA(384)
# summary(model, (2, 384, 35, 35))

In [38]:
class InceptionB(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch1 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channel, 128, 1)
        )

        self.branch2 = BasicConv2d(in_channel, 384, 1)

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, 192, 1),
            BasicConv2d(192, 224, (1, 7), padding=(0, 3)),
            BasicConv2d(224, 256, (7, 1), padding=(3, 0)),
        )

        self.branch4 = nn.Sequential(
            BasicConv2d(in_channel, 192, 1),
            BasicConv2d(192, 192, (1, 7), padding=(0, 3)),
            BasicConv2d(192, 224, (7, 1), padding=(3, 0)),
            BasicConv2d(224, 224, (1, 7), padding=(0, 3)),
            BasicConv2d(224, 256, (7, 1), padding=(3, 0)),
        )
    
    def forward(self, x):
        x = [self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)]
        x = torch.cat(x, dim=1)
        return x

# model = InceptionB(1024)
# summary(model, (2, 1024, 17, 17))

In [40]:
class ReductionB(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, 192, 1),
            BasicConv2d(192, 192, kernel_size=3, stride=2)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, 256, 1),
            BasicConv2d(256, 256, (1, 7), padding=(0, 3)),
            BasicConv2d(256, 320, (7, 1), padding=(3, 0)),
            BasicConv2d(320, 320, 3, stride=2)
        )
    
    def forward(self, x):
        x = [self.branch1(x), self.branch2(x), self.branch3(x)]
        x = torch.cat(x, dim=1)
        return x

# model = ReductionB(1024)
# summary(model, (2, 1024, 17, 17))

In [45]:
class InceptionC(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch1 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channel, 256, 1)
        )

        self.branch2 = BasicConv2d(in_channel, 256, 1)

        self.branch3_1 = BasicConv2d(in_channel, 384, 1)
        self.branch3_2_1 = BasicConv2d(384, 256, (1, 3), padding=(0, 1))
        self.branch3_2_2 = BasicConv2d(384, 256, (3, 1), padding=(1, 0))

        self.branch4_1 = nn.Sequential(
            BasicConv2d(in_channel, 384, 1),
            BasicConv2d(384, 448, (1, 3), padding=(0, 1)),
            BasicConv2d(448, 512, (3, 1), padding=(1, 0))
        )
        self.branch4_2_1 = BasicConv2d(512, 256, (3, 1), padding=(1, 0))
        self.branch4_2_2 = BasicConv2d(512, 256, (1, 3), padding=(0, 1))

    def forward(self, x):
        output1 = self.branch1(x)
        output2 = self.branch2(x)

        output3 = self.branch3_1(x)
        output3 = [self.branch3_2_1(output3), self.branch3_2_2(output3)]
        output3 = torch.cat(output3, dim=1)

        output4 = self.branch4_1(x)
        output4 = [self.branch4_2_1(output4), self.branch4_2_2(output4)]
        output4 = torch.cat(output4, dim=1)

        output = [output1, output2, output3, output4]
        output = torch.cat(output, dim=1)
        return output

# model = InceptionC(1536)
# summary(model, (2, 1536, 8, 8))

In [59]:
class InceptionV4(nn.Module):
    def __init__(self, inception_num_list, num_classes=1000):
        super().__init__()

        self.Stem = Stem()
        self.InceptionA = nn.Sequential(*[InceptionA(384) for _ in range(inception_num_list[0])])
        self.ReductionA = ReductionA(384)
        self.InceptionB = nn.Sequential(*[InceptionB(1024) for _ in range(inception_num_list[1])])
        self.ReductionB = ReductionB(1024)
        self.InceptionC = nn.Sequential(*[InceptionC(1536) for _ in range(inception_num_list[2])])

        self.GlobalAvgPool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(1536, num_classes)
        )

    def forward(self, x):
        x = self.Stem(x)
        x = self.InceptionA(x)
        x = self.ReductionA(x)
        x = self.InceptionB(x)
        x = self.ReductionB(x)
        x = self.InceptionC(x)

        x = self.GlobalAvgPool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

model = InceptionV4(inception_num_list=[4, 7, 3], num_classes=1000)
summary(model, (2, 3, 299, 299))

Layer (type:depth-idx)                             Output Shape              Param #
InceptionV4                                        [2, 1000]                 --
├─Stem: 1-1                                        [2, 384, 35, 35]          --
│    └─Sequential: 2-1                             [2, 64, 147, 147]         --
│    │    └─BasicConv2d: 3-1                       [2, 32, 149, 149]         960
│    │    └─BasicConv2d: 3-2                       [2, 32, 147, 147]         9,312
│    │    └─BasicConv2d: 3-3                       [2, 64, 147, 147]         18,624
│    └─MaxPool2d: 2-2                              [2, 64, 73, 73]           --
│    └─BasicConv2d: 2-3                            [2, 96, 73, 73]           --
│    │    └─Sequential: 3-4                        [2, 96, 73, 73]           55,584
│    └─Sequential: 2-4                             [2, 96, 71, 71]           --
│    │    └─BasicConv2d: 3-5                       [2, 64, 73, 73]           10,432
│    │    └─BasicCo