In [2]:
import torch
from torch import nn
from torchinfo import summary

In [3]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0):
        super().__init__()

        self.conv_blk = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv_blk(x)
        return x

---
Inception-A 모듈
- Branch 1:
	- 1x1 합성곱: 64채널
	- 3x3 합성곱: 96채널
	- 3x3 합성곱: 96채널
- Branch 2:
	- 1x1 합성곱: 48채널
	- 3x3 합성곱: 64채널
- Branch 3:
	- 평균 풀링(Average Pooling)
	- 1x1 합성곱: 64채널
- Branch 4:
	- 1x1 합성곱: 64채널

In [4]:
class InceptionA(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channel, 64, kernel_size=1),
            BasicConv2d(64, 96, kernel_size=3, padding=1),
            BasicConv2d(96, 96, kernel_size=3, padding=1),
        )

        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channel, 48, kernel_size=1),
            BasicConv2d(48, 64, kernel_size=3, padding=1)
        )

        self.branch_3 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channel, 64, kernel_size=1)
        )

        self.branch_4 = BasicConv2d(in_channel, 64, kernel_size=1)
        
    def forward(self, x):
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)
        branch_3 = self.branch_3(x)
        branch_4 = self.branch_4(x)

        output = torch.cat([branch_1, branch_2, branch_3, branch_4], dim=1)
        return output

model = InceptionA(288)
summary(model, (1, 288, 35, 35))

Layer (type:depth-idx)                   Output Shape              Param #
InceptionA                               [1, 288, 35, 35]          --
├─Sequential: 1-1                        [1, 96, 35, 35]           --
│    └─BasicConv2d: 2-1                  [1, 64, 35, 35]           --
│    │    └─Sequential: 3-1              [1, 64, 35, 35]           18,624
│    └─BasicConv2d: 2-2                  [1, 96, 35, 35]           --
│    │    └─Sequential: 3-2              [1, 96, 35, 35]           55,584
│    └─BasicConv2d: 2-3                  [1, 96, 35, 35]           --
│    │    └─Sequential: 3-3              [1, 96, 35, 35]           83,232
├─Sequential: 1-2                        [1, 64, 35, 35]           --
│    └─BasicConv2d: 2-4                  [1, 48, 35, 35]           --
│    │    └─Sequential: 3-4              [1, 48, 35, 35]           13,968
│    └─BasicConv2d: 2-5                  [1, 64, 35, 35]           --
│    │    └─Sequential: 3-5              [1, 64, 35, 35]           27

---
Inception-B 모듈
- Branch 1:
	- 1x1 합성곱: (128, 160, 160, 192) 채널
	- 1x7 합성곱: (128, 160, 160, 192) 채널
	- 7x1 합성곱: (128, 160, 160, 192) 채널
	- 1x7 합성곱: (128, 160, 160, 192) 채널
	- 7x1 합성곱: 192채널
- Branch 2:
	- 1x1 합성곱: (128, 160, 160, 192) 채널
	- 1x7 합성곱: (128, 160, 160, 192) 채널
	- 7x1 합성곱: 192채널
- Branch 3:
	- 평균 풀링(Average Pooling)
	- 1x1 합성곱: 192채널
- Branch 4:
	- 1x1 합성곱: 192채널

In [5]:
class InceptionB(nn.Module):
    def __init__(self, in_channel, channel_7):
        super().__init__()

        self.branch_1 = nn.Sequential(
            BasicConv2d(in_channel, channel_7, kernel_size=1),
            BasicConv2d(channel_7, channel_7, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(channel_7, channel_7, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(channel_7, channel_7, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(channel_7, 192, kernel_size=(7, 1), padding=(3, 0))
        )

        self.branch_2 = nn.Sequential(
            BasicConv2d(in_channel, channel_7, kernel_size=1),
            BasicConv2d(channel_7, channel_7, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(channel_7, 192, kernel_size=(7, 1), padding=(3, 0))
        )

        self.branch_3 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channel, 192, kernel_size=1)
        )

        self.branch_4 = BasicConv2d(in_channel, 192, kernel_size=1)
        
    def forward(self, x):
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)
        branch_3 = self.branch_3(x)
        branch_4 = self.branch_4(x)

        output = torch.cat([branch_1, branch_2, branch_3, branch_4], dim=1)
        return output

model = InceptionB(768, 128)
summary(model, (1, 768, 17, 17))


Layer (type:depth-idx)                   Output Shape              Param #
InceptionB                               [1, 768, 17, 17]          --
├─Sequential: 1-1                        [1, 192, 17, 17]          --
│    └─BasicConv2d: 2-1                  [1, 128, 17, 17]          --
│    │    └─Sequential: 3-1              [1, 128, 17, 17]          98,688
│    └─BasicConv2d: 2-2                  [1, 128, 17, 17]          --
│    │    └─Sequential: 3-2              [1, 128, 17, 17]          115,072
│    └─BasicConv2d: 2-3                  [1, 128, 17, 17]          --
│    │    └─Sequential: 3-3              [1, 128, 17, 17]          115,072
│    └─BasicConv2d: 2-4                  [1, 128, 17, 17]          --
│    │    └─Sequential: 3-4              [1, 128, 17, 17]          115,072
│    └─BasicConv2d: 2-5                  [1, 192, 17, 17]          --
│    │    └─Sequential: 3-5              [1, 192, 17, 17]          172,608
├─Sequential: 1-2                        [1, 192, 17, 17]    

---
Inception-C 모듈
- Branch 1:
	- 1x1 합성곱: 448채널
	- 3x3 합성곱: 384채널
	- 두 개의 병렬 3x3 합성곱:
		- 첫 번째 1x3 합성곱: 384채널
		- 두 번째 3x1 합성곱: 384채널
- Branch 2:
	- 1x1 합성곱: 384채널
	- 두 개의 병렬 3x3 합성곱:
		- 첫 번째 1x3 합성곱: 384채널
		- 두 번째 3x1 합성곱: 384채널
- Branch 3:
	- 평균 풀링(Average Pooling)
	- 1x1 합성곱: 192채널
- Branch 4:
	- 1x1 합성곱: 320채널

In [6]:
class InceptionC(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch_1_1 = nn.Sequential(
            BasicConv2d(in_channel, 448, kernel_size=1),
            BasicConv2d(448, 384, kernel_size=3, padding=1),
        )
        self.branch_1_2_1 = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch_1_2_2 = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
            

        self.branch_2_1 = BasicConv2d(in_channel, 384, kernel_size=1)
        self.branch_2_2_1 = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch_2_2_2 = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_3 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channel, 192, kernel_size=1)
        )

        self.branch_4 = BasicConv2d(in_channel, 320, kernel_size=1)
        
    def forward(self, x):
        branch_1 = self.branch_1_1(x)
        branch_1_1 = self.branch_1_2_1(branch_1)
        branch_1_2 = self.branch_1_2_2(branch_1)

        branch_2 = self.branch_2_1(x)
        branch_2_1 = self.branch_2_2_1(branch_2)
        branch_2_2 = self.branch_2_2_2(branch_2)

        branch_3 = self.branch_3(x)
        branch_4 = self.branch_4(x)

        output = torch.cat([branch_1_1, branch_1_2, branch_2_1, branch_2_2, branch_3, branch_4], dim=1)
        return output

model = InceptionC(1280)
summary(model, (1, 1280, 8, 8))

Layer (type:depth-idx)                   Output Shape              Param #
InceptionC                               [1, 2048, 8, 8]           --
├─Sequential: 1-1                        [1, 384, 8, 8]            --
│    └─BasicConv2d: 2-1                  [1, 448, 8, 8]            --
│    │    └─Sequential: 3-1              [1, 448, 8, 8]            574,784
│    └─BasicConv2d: 2-2                  [1, 384, 8, 8]            --
│    │    └─Sequential: 3-2              [1, 384, 8, 8]            1,549,440
├─BasicConv2d: 1-2                       [1, 384, 8, 8]            --
│    └─Sequential: 2-3                   [1, 384, 8, 8]            --
│    │    └─Conv2d: 3-3                  [1, 384, 8, 8]            442,752
│    │    └─BatchNorm2d: 3-4             [1, 384, 8, 8]            768
│    │    └─ReLU: 3-5                    [1, 384, 8, 8]            --
├─BasicConv2d: 1-3                       [1, 384, 8, 8]            --
│    └─Sequential: 2-4                   [1, 384, 8, 8]            

---
Reduction-A 모듈
- Branch 1:
	- 1x1 합성곱: 64 채널
	- 3x3 합성곱: 96 채널
	- 3x3 합성곱 (stride=2): 96 채널
- Branch 2:
	- 1x1 합성곱: 64 채널
	- 3x3 합성곱 (stride=2): 384 채널
- Branch 3:
	- 최대 풀링 (Max Pooling, stride=2)

In [7]:
class ReductionA(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch_1 = nn.Sequential(
           BasicConv2d(in_channel, 64, kernel_size=1),
           BasicConv2d(64, 96, kernel_size=3, padding=1),
           BasicConv2d(96, 96, kernel_size=3, stride=2)
        )

        self.branch_2 = nn.Sequential(
           BasicConv2d(in_channel, 64, kernel_size=1),
           BasicConv2d(64, 384, kernel_size=3, stride=2)
        )

        self.branch_3 = nn.MaxPool2d(kernel_size=3, stride=2)
        
    def forward(self, x):
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)
        branch_3 = self.branch_3(x)
       

        output = torch.cat([branch_1, branch_2, branch_3], dim=1)
        return output

model = ReductionA(288)
summary(model, (1, 288, 35, 35))

Layer (type:depth-idx)                   Output Shape              Param #
ReductionA                               [1, 768, 17, 17]          --
├─Sequential: 1-1                        [1, 96, 17, 17]           --
│    └─BasicConv2d: 2-1                  [1, 64, 35, 35]           --
│    │    └─Sequential: 3-1              [1, 64, 35, 35]           18,624
│    └─BasicConv2d: 2-2                  [1, 96, 35, 35]           --
│    │    └─Sequential: 3-2              [1, 96, 35, 35]           55,584
│    └─BasicConv2d: 2-3                  [1, 96, 17, 17]           --
│    │    └─Sequential: 3-3              [1, 96, 17, 17]           83,232
├─Sequential: 1-2                        [1, 384, 17, 17]          --
│    └─BasicConv2d: 2-4                  [1, 64, 35, 35]           --
│    │    └─Sequential: 3-4              [1, 64, 35, 35]           18,624
│    └─BasicConv2d: 2-5                  [1, 384, 17, 17]          --
│    │    └─Sequential: 3-5              [1, 384, 17, 17]          22

---
Reduction-B 모듈
- Branch 1:
	- 1x1 합성곱: 192 채널
	- 3x3 합성곱: 192 채널
	- 3x3 합성곱 (stride=2): 192 채널
- Branch 2:
	- 1x1 합성곱: 192 채널
	- 3x3 합성곱 (stride=2): 320 채널
- Branch 3:
	- 최대 풀링 (Max Pooling, stride=2)

In [8]:
class ReductionB(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        self.branch_1 = nn.Sequential(
           BasicConv2d(in_channel, 192, kernel_size=1),
           BasicConv2d(192, 192, kernel_size=3, padding=1),
           BasicConv2d(192, 192, kernel_size=3, stride=2)
        )

        self.branch_2 = nn.Sequential(
           BasicConv2d(in_channel, 192, kernel_size=1),
           BasicConv2d(192, 320, kernel_size=3, stride=2)
        )

        self.branch_3 = nn.MaxPool2d(kernel_size=3, stride=2)
        
    def forward(self, x):
        branch_1 = self.branch_1(x)
        branch_2 = self.branch_2(x)
        branch_3 = self.branch_3(x)
       

        output = torch.cat([branch_1, branch_2, branch_3], dim=1)
        return output

model = ReductionB(768)
summary(model, (1, 768, 17, 17))

Layer (type:depth-idx)                   Output Shape              Param #
ReductionB                               [1, 1280, 8, 8]           --
├─Sequential: 1-1                        [1, 192, 8, 8]            --
│    └─BasicConv2d: 2-1                  [1, 192, 17, 17]          --
│    │    └─Sequential: 3-1              [1, 192, 17, 17]          148,032
│    └─BasicConv2d: 2-2                  [1, 192, 17, 17]          --
│    │    └─Sequential: 3-2              [1, 192, 17, 17]          332,352
│    └─BasicConv2d: 2-3                  [1, 192, 8, 8]            --
│    │    └─Sequential: 3-3              [1, 192, 8, 8]            332,352
├─Sequential: 1-2                        [1, 320, 8, 8]            --
│    └─BasicConv2d: 2-4                  [1, 192, 17, 17]          --
│    │    └─Sequential: 3-4              [1, 192, 17, 17]          148,032
│    └─BasicConv2d: 2-5                  [1, 320, 8, 8]            --
│    │    └─Sequential: 3-5              [1, 320, 8, 8]          

In [34]:
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.avgpool1 = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv1 = BasicConv2d(in_channels, 128, kernel_size=1)
        self.fc = nn.Sequential(
            nn.Linear(128*5*5, 1024),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.avgpool1(x)
        x = self.conv1(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [35]:
class InceptionV3(nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()

        self.conv_blk_1 = nn.Sequential(
            BasicConv2d(3, 32, kernel_size=3, stride=2),
            BasicConv2d(32, 32, kernel_size=3),
            BasicConv2d(32, 64, kernel_size=3, padding=1)
        )

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)

        self.conv_blk_2 = nn.Sequential(
            BasicConv2d(64, 80, kernel_size=3),
            BasicConv2d(80, 192, kernel_size=3, stride=2),
            BasicConv2d(192, 288, kernel_size=3, padding=1)
        )

        self.inception_a = nn.Sequential(
            InceptionA(288),
            InceptionA(288)
        )
        self.reduction_a = ReductionA(288)

        self.inception_b = nn.Sequential(
            InceptionB(768, 128),
            InceptionB(768, 160),
            InceptionB(768, 160),
            InceptionB(768, 192)
        )
        self.inception_aux = InceptionAux(768, num_classes)
        self.reduction_b = ReductionB(768)
        
        self.inception_c = nn.Sequential(
           InceptionC(1280),
           InceptionC(2048)
        )

        self.GlobalAvgPool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=0.3)
        self.fc = nn.Linear(2048, num_classes)

        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.01)
                nn.init.constant_(module.bias, 0.0)

    def forward(self, x):
        x = self.conv_blk_1(x)
        x = self.maxpool(x)
        x = self.conv_blk_2(x)

        x = self.inception_a(x)
        x = self.reduction_a(x)

        x = self.inception_b(x)
        aux = self.inception_aux(x) if self.training else None
        x = self.reduction_b(x)

        x = self.inception_c(x)

        x = self.GlobalAvgPool(x)
        x = self.dropout(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x, aux
    
model = InceptionV3(num_classes=1000)
summary(model, (1, 3, 299, 299))

Layer (type:depth-idx)                             Output Shape              Param #
InceptionV3                                        [1, 1000]                 4,401,512
├─Sequential: 1-1                                  [1, 64, 147, 147]         --
│    └─BasicConv2d: 2-1                            [1, 32, 149, 149]         --
│    │    └─Sequential: 3-1                        [1, 32, 149, 149]         960
│    └─BasicConv2d: 2-2                            [1, 32, 147, 147]         --
│    │    └─Sequential: 3-2                        [1, 32, 147, 147]         9,312
│    └─BasicConv2d: 2-3                            [1, 64, 147, 147]         --
│    │    └─Sequential: 3-3                        [1, 64, 147, 147]         18,624
├─MaxPool2d: 1-2                                   [1, 64, 73, 73]           --
├─Sequential: 1-3                                  [1, 288, 35, 35]          --
│    └─BasicConv2d: 2-4                            [1, 80, 71, 71]           --
│    │    └─Sequenti

In [36]:
model.train()
x = torch.randn(1, 3, 299, 299).to('cuda')
pred_y, aux = model(x)

print(pred_y.shape, aux.shape)

torch.Size([1, 1000]) torch.Size([1, 1000])


In [37]:
model.eval()
x = torch.randn(1, 3, 299, 299).to('cuda')
pred_y, aux = model(x)

print(pred_y.shape, aux)

torch.Size([1, 1000]) None
