# **GoogleNet**
此份程式碼會介紹如何使用 tf.keras 的方式建構 GoogleNet 的模型架構，以及訓練的方式。

<img src="https://hackmd.io/_uploads/rkWu7ywIp.png" high=800/>
- [source paper](https://arxiv.org/abs/1409.4842)

## 匯入套件

In [None]:
# PyTorch 相關套件
import torch
import torch.nn as nn

## GoogleNet Architecture

<img src="https://hackmd.io/_uploads/HJT6mkwI6.png" width=1000/>

- [source paper](https://arxiv.org/abs/1409.4842)

In [None]:
class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        return self.act(x)


class InceptionBlock(nn.Module):

    def __init__(self, in_channels, filters_1x1, filters_3x3_reduce,
                 filters_3x3, filters_5x5_reduce, filters_5x5,
                 filters_pooling):
        super().__init__()
        # 1x1 Convolution
        self.path1 = BasicConv2d(in_channels,
                                 filters_1x1,
                                 kernel_size=1,
                                 padding='same')
        self.path2 = nn.Sequential(
            BasicConv2d(in_channels,
                        filters_3x3_reduce,
                        kernel_size=1,
                        padding='same'),
            BasicConv2d(filters_3x3_reduce,
                        filters_3x3,
                        kernel_size=3,
                        padding='same'))
        self.path3 = nn.Sequential(
            BasicConv2d(in_channels,
                        filters_5x5_reduce,
                        kernel_size=1,
                        padding='same'),
            BasicConv2d(filters_5x5_reduce,
                        filters_5x5,
                        kernel_size=5,
                        padding='same'))
        self.path4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels,
                        filters_pooling,
                        kernel_size=1,
                        padding='same'))

    def forward(self, x):
        p1 = self.path1(x)
        p2 = self.path2(x)
        p3 = self.path3(x)
        p4 = self.path4(x)
        return torch.cat((p1, p2, p3, p4), dim=1)

In [None]:
class AuxiliaryClassifier(nn.Module):

    def __init__(self, in_channels, num_classes, dropout=0.7):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d((4, 4))
        self.conv = BasicConv2d(in_channels,
                                128,
                                kernel_size=1,
                                padding='same')
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1024), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(in_features=1024, out_features=num_classes))

    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [None]:
class GoogLeNet(nn.Module):

    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = InceptionBlock(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = InceptionBlock(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = InceptionBlock(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = InceptionBlock(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = InceptionBlock(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = InceptionBlock(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = InceptionBlock(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = InceptionBlock(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = InceptionBlock(832, 384, 192, 384, 48, 128, 128)

        self.aux1 = AuxiliaryClassifier(512, num_classes)
        self.aux2 = AuxiliaryClassifier(528, num_classes)

        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                        nn.Flatten(), nn.Dropout(0.4),
                                        nn.Linear(1024, num_classes))

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        aux1 = self.aux1(x)

        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        aux2 = self.aux2(x)

        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.classifier(x)
        return x, aux1, aux2

In [None]:
model = GoogLeNet(10)

inputs = torch.randn(1, 3, 224, 224)
outputs = model(inputs)
print(outputs[0].shape, outputs[1].shape, outputs[2].shape)

In [None]:
# fake dataset for 224 images
train_images = torch.randn(10, 3, 224, 224)
train_labels = torch.randint(0, 10, (10, ))
train_dataset = torch.utils.data.TensorDataset(train_images, train_labels)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=2,
                                           shuffle=True)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for x, y in train_loader:
    optimizer.zero_grad()
    outputs = model(x)
    # mulpitle outputs and loss
    loss1 = 1.0 * loss_fn(outputs[0], y)
    loss2 = 0.3 * loss_fn(outputs[1], y)
    loss3 = 0.3 * loss_fn(outputs[2], y)
    loss = loss1 + loss2 + loss3
    loss.backward()
    optimizer.step()