In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional, Tuple

In [3]:
class Inceptionv3(nn.Module):
    def __init__(self, input_planes, n_channels1x1, n_channels3x3red, n_channels3x3, n_channels5x5red, n_channels5x5, pooling_planes):
        super(Inceptionv3, self).__init__()

        ## 1x1 합성곱
        self.block1 = nn.Sequential(
            nn.Conv2d(input_planes, n_channels1x1, kernel_size=1),
            nn.BatchNorm2d(n_channels1x1),
            nn.ReLU(True)
        )

        ## 1x1 합성곱 -> 3x3 합성곱
        self.block2 = nn.Sequential(
            nn.Conv2d(input_planes, n_channels3x3red, kernel_size=1), ## 채널 수는 정하기 나름(?), 정해져야하는 것은 kernel 수
            nn.BatchNorm2d(n_channels3x3red), 
            nn.ReLU(True),

            nn.Conv2d(n_channels3x3red, n_channels3x3, kernel_size=3, padding=1),
            nn.BatchNorm2d(n_channels3x3),
            nn.ReLU(True)
        )

        ## 1x1 합성곱 -> 5x5 합성곱
        self.block3 = nn.Sequential(
            nn.Conv2d(input_planes, n_channels5x5red, kernel_size=1),
            nn.BatchNorm2d(n_channels5x5red), 
            nn.ReLU(True),

            nn.Conv2d(n_channels5x5red, n_channels5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(n_channels5x5), 
            nn.ReLU(True),

            nn.Conv2d(n_channels5x5, n_channels5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(n_channels5x5), 
            nn.ReLU(True),
        )

        ## 3x3 pooling -> 1x1 합성곱
        self.block4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),

            nn.Conv2d(input_planes, pooling_planes, kernel_size=1),
            nn.BatchNorm2d(pooling_planes), 
            nn.ReLU(True)
        )
    
    def forward(self, x):
        op1 = self.block1(x)
        op2 = self.block2(x)
        op3 = self.block3(x)
        op4 = self.block4(x)

        return torch.cat([op1, op2, op3, op4], 1)

In [4]:
class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

In [5]:
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes, dropout):
        super().__init__()
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

        self.dropout = nn.Dropout(p = dropout)

    def forward(self, x):

        x = F.adaptive_avg_pool2d(x, (4, 4))
        x = self.conv(x)
        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x, inplace = True)
        x = self.dropout(x)
        x = self.fc2(x)

        return x

In [7]:
class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, dropout=0.2, dropout_aux=0.7):
        
        super(GoogLeNet, self).__init__()
        
        self.stem = nn.Sequential(
            BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), ## ceil mode -> output shape을 결정할 때 ceil 연산 사용
            BasicConv2d(64, 64, kernel_size=1),
            BasicConv2d(64, 192, kernel_size=3, stride = 1, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
        )

        self.im1 = Inceptionv3(192, 64, 96, 128, 16, 32, 32)
        self.im2 = Inceptionv3(256, 128, 128, 192, 32, 96, 64)
        self.im3 = Inceptionv3(480, 192, 96, 208, 16, 48, 64)
        self.im4 = Inceptionv3(512, 160, 112, 224, 24, 64, 64)
        self.im5 = Inceptionv3(512, 128, 128, 256, 24, 64, 64)
        self.im6 = Inceptionv3(512, 112, 144, 288, 32, 64, 64)
        self.im7 = Inceptionv3(528, 256, 160, 320, 32, 128, 128)
        self.im8 = Inceptionv3(832, 256, 160, 320, 32, 128, 128)
        self.im9 = Inceptionv3(832, 384, 192, 384, 48, 128, 128)

        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(1024, num_classes)
        
        if aux_logits:
            self.aux1 = InceptionAux(512, num_classes, dropout=dropout_aux)
            self.aux2 = InceptionAux(528, num_classes, dropout=dropout_aux)
        else:
            self.aux1 = None  
            self.aux2 = None 

    def forward(self, x):
        op = self.stem(x)
        op = self.im1(op) ## 3a
        op = self.im2(op) ## 3b
        op = nn.MaxPool2d(3, stride= 2, ceil_mode = True)

        op = self.im3(op) ## 4a
        
        aux1: Optional[Tensor] = None
        if self.aux1 is not None:
            if self.training:
                aux1 = self.aux1(x)

        op = self.im4(op) ## 4b
        op = self.im5(op) ## 4c
        op = self.im6(op) ## 4d

        aux2: Optional[Tensor] = None
        if self.aux2 is not None:
            if self.training:
                aux2 = self.aux2(x)

        op = self.im7(op) ## 4e
        op = nn.MaxPool2d(2, stride= 2, ceil_mode = True) ## ? size 3 아님?
        
        op = self.im8(op) ## 5a
        op = self.im9(op) ## 5b

        op = F.adaptive_avg_pool2d(op, (1, 1))
        op = torch.flatten(op, 1)

        op = self.dropout(op)
        op = self.fc(op)

        return x, aux2, aux1
