In [1]:
import torch
from torch import nn
from torchinfo import summary

In [27]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0):
        super().__init__()

        self.conv_blk = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.conv_blk(x)
        return x

In [39]:
class Stem(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            BasicConv2d(3, 32, 3, stride=2),
            BasicConv2d(32, 32, 3),
            BasicConv2d(32, 64, 3, padding=1)
        )

        self.inception1_1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.inception1_2 = BasicConv2d(64, 96, kernel_size=3, stride=2)

        self.inception2_1 = nn.Sequential(
            BasicConv2d(160, 64, kernel_size=1),
            BasicConv2d(64, 96, kernel_size=3)
        )
        self.inception2_2 = nn.Sequential(
            BasicConv2d(160, 64, kernel_size=1),
            BasicConv2d(64, 64, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(64, 64, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(64, 96, kernel_size=3)
        )

        self.inception3_1 = BasicConv2d(192, 192, 3, stride=2)
        self.inception3_2 = nn.MaxPool2d(kernel_size=3, stride=2)
    
    def forward(self, x):
        x = self.conv1(x)
        
        x = [self.inception1_1(x), self.inception1_2(x)]
        x = torch.cat(x, dim=1)

        x = [self.inception2_1(x), self.inception2_2(x)]
        x = torch.cat(x, dim=1)

        x = [self.inception3_1(x), self.inception3_2(x)]
        x = torch.cat(x, dim=1)

        return x

# model = Stem()
# summary(model, (2, 3, 299, 299))

# input_names = ['Input']
# output_names = ['Output']

# x = torch.randn(1, 3, 299, 299).to('cuda')
# torch.onnx.export(model, x, 'Stem.onnx', input_names=input_names, output_names=output_names, training=torch.onnx.TrainingMode.TRAINING)

In [71]:
class InceptionResnetA(nn.Module):
    def __init__(self, in_channel=384):
        super().__init__()

        self.branch1 = BasicConv2d(in_channel, 32, 1)
        
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, 32, 1),
            BasicConv2d(32, 32, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, 32, 1),
            BasicConv2d(32, 48, 3, padding=1),
            BasicConv2d(48, 64, 3, padding=1)
        )

        self.branch_merge = nn.Sequential(
            nn.Conv2d(128, 384, 1, bias=False),
            nn.BatchNorm2d(384)
        )

        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channel, 384, 1, bias=False),
            nn.BatchNorm2d(384)
        )

        self.relu = nn.ReLU()

    def forward(self, x):
        residual = [self.branch1(x), self.branch2(x), self.branch3(x)]
        residual = torch.cat(residual, dim=1)
        residual = self.branch_merge(residual)

        shortcut = self.shortcut(x)
        output = self.relu(shortcut + residual)
        
        return output

# model = InceptionResnetA(in_channel=384)
# summary(model, (1, 384, 35, 35))

In [72]:
class ReductionA(nn.Module):
    def __init__(self, in_channel=384):
        super().__init__()

        self.branch1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch2 = BasicConv2d(in_channel, 384, 3, stride=2)
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, 256, 1),
            BasicConv2d(256, 256, 3, padding=1),
            BasicConv2d(256, 384, 3, stride=2)
        )
    
    def forward(self, x):
        output = [self.branch1(x), self.branch2(x), self.branch3(x)]
        output = torch.cat(output, dim=1)
        return output
    
# model = ReductionA()
# summary(model, (1, 384, 35, 35))

In [73]:
class InceptionResnetB(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch1 = BasicConv2d(in_channel, 192, 1)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, 128, 1),
            BasicConv2d(128, 160, (1, 7), padding=(0, 3)),
            BasicConv2d(160, 192, (7, 1), padding=(3, 0))
        )
        self.branch_merge = nn.Sequential(
            nn.Conv2d(384, 1154, 1),
            nn.BatchNorm2d(1154)
        )

        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channel, 1154, 1),
            nn.BatchNorm2d(1154)
        )
        self.relu = nn.ReLU()
    
    def forward(self, x):
        residual = [self.branch1(x), self.branch2(x)]
        residual = torch.cat(residual, dim=1)
        residual = self.branch_merge(residual)

        shortcut = self.shortcut(x)
        output = residual + shortcut

        return output

# model = InceptionResnetB(1152)
# summary(model, (2, 1152, 17, 17))

# model = InceptionResnetB(1154)
# summary(model, (2, 1154, 17, 17))

In [74]:
class ReductionB(nn.Module):
    def __init__(self, in_channel=1154):
        super().__init__()

        self.branch1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, 256, 1),
            BasicConv2d(256, 384, 3, stride=2)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, 256, 1),
            BasicConv2d(256, 288, 3, stride=2)
        )
        self.branch4 = nn.Sequential(
            BasicConv2d(in_channel, 256, 1),
            BasicConv2d(256, 288, 3, padding=1),
            BasicConv2d(288, 320, 3, stride=2)
        )
    
    def forward(self, x):
        output = [self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)]
        output = torch.cat(output, dim=1)
        return output

# model = ReductionB()
# summary(model, (2, 1154, 17, 17))

In [75]:
class InceptionResnetC(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch1 = BasicConv2d(in_channel, 192, 1)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, 192, 1),
            BasicConv2d(192, 224, (1, 3), padding=(0, 1)),
            BasicConv2d(224, 256, (3, 1), padding=(1, 0))
        )
        self.branch_merge = nn.Sequential(
            nn.Conv2d(448, 2048, 1),
            nn.BatchNorm2d(2048)
        )

        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channel, 2048, 1),
            nn.BatchNorm2d(2048)
        )
    
    def forward(self, x):
        residual = [self.branch1(x), self.branch2(x)]
        residual = torch.cat(residual, dim=1)
        residual = self.branch_merge(residual)

        shortcut = self.shortcut(x)
        output = residual + shortcut
        
        return output
    
# model = InceptionResnetC(2146)
# summary(model, (2, 2146, 8, 8))

# model = InceptionResnetC(2048)
# summary(model, (2, 2048, 8, 8))

In [80]:
class ResnetV2(nn.Module):
    def __init__(self, A, B, C, num_class=1000):
        super().__init__()

        self.Stem = Stem()
        
        self.InceptionResnetA = nn.Sequential(*[InceptionResnetA() for _ in range(A)])
        self.ReductionA = ReductionA()

        self.InceptionResnetB = nn.Sequential(InceptionResnetB(1152), *[InceptionResnetB(1154) for _ in range(B-1)])
        self.ReductionB = ReductionB()

        self.InceptionResnetC = nn.Sequential(InceptionResnetC(2146), *[InceptionResnetC(2048) for _ in range(C-1)])
        self.GlobalAvgPool = nn.AdaptiveAvgPool2d((1, 1))

        self.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(2048, num_class)
        )
    
    def forward(self, x):
        x = self.Stem(x)

        x = self.InceptionResnetA(x)
        x = self.ReductionA(x)

        x = self.InceptionResnetB(x)
        x = self.ReductionB(x)

        x = self.InceptionResnetC(x)
        x = self.GlobalAvgPool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

model = ResnetV2(A=5, B=10, C=5, num_class=1000)
summary(model, (2, 3, 299, 299))

Layer (type:depth-idx)                             Output Shape              Param #
ResnetV2                                           [2, 1000]                 --
├─Stem: 1-1                                        [2, 384, 35, 35]          --
│    └─Sequential: 2-1                             [2, 64, 147, 147]         --
│    │    └─BasicConv2d: 3-1                       [2, 32, 149, 149]         928
│    │    └─BasicConv2d: 3-2                       [2, 32, 147, 147]         9,280
│    │    └─BasicConv2d: 3-3                       [2, 64, 147, 147]         18,560
│    └─MaxPool2d: 2-2                              [2, 64, 73, 73]           --
│    └─BasicConv2d: 2-3                            [2, 96, 73, 73]           --
│    │    └─Sequential: 3-4                        [2, 96, 73, 73]           55,488
│    └─Sequential: 2-4                             [2, 96, 71, 71]           --
│    │    └─BasicConv2d: 3-5                       [2, 64, 73, 73]           10,368
│    │    └─BasicCo

In [82]:
input_names = ['Input']
output_names = ['Output']

x = torch.randn(1, 3, 299 ,299).to('cuda')
torch.onnx.export(model, x, 'ResnetV2.onnx', input_names=input_names, output_names=output_names, training=torch.onnx.TrainingMode.TRAINING)