<a href="https://colab.research.google.com/github/ajw1587/Pytorch_Study/blob/main/23_GoogLeNet_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
!pip install torch
!pip install torchvision



In [34]:
import torch
import torch.nn as nn

In [35]:
class conv_block(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(conv_block, self).__init__()
    self.relu = nn.ReLU()
    self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
    self.batchnorm = nn.BatchNorm2d(out_channels)

  def forward(self, x):
    return self.relu(self.batchnorm(self.conv(x)))

In [36]:
class Inception_block(nn.Module):
  def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool):
    super(Inception_block, self).__init__()

    self.branch1 = conv_block(in_channels, out_1x1, kernel_size = 1)

    self.branch2 = nn.Sequential(
        conv_block(in_channels, red_3x3, kernel_size = 1),
        conv_block(red_3x3, out_3x3, kernel_size = 3, stride = 1, padding = 1)
    )

    self.branch3 = nn.Sequential(
        conv_block(in_channels, red_5x5, kernel_size = 1),
        conv_block(red_5x5, out_5x5, kernel_size = 5, stride = 1, padding = 2)
    )

    self.branch4 = nn.Sequential(
        nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
        conv_block(in_channels, out_1x1pool, kernel_size = 1)
    )


  def forward(self, x):
    return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], 1)

In [37]:
class InceptionAux(nn.Module):
  def __init__(self, in_channels, num_classes):
    super(InceptionAux, self).__init__()

    self.conv = nn.Sequential(
        nn.AvgPool2d(kernel_size = 5, stride = 3),
        conv_block(in_channels, 128, kernel_size = 1)
    )

    self.fc = nn.Sequential(
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(1024, num_classes)
    )

  def forward(self, x):
    x = self.conv(x)
    x = x.view(x.shape[0], -1)
    x = self.fc(x)
    return x

In [38]:
class GoogLeNet(nn.Module):
  def __init__(self, in_channels = 3, num_classes = 1000, aux_logits = True, init_weights = True):
    super(GoogLeNet, self).__init__()
    assert aux_logits == True or aux_logits == False
    self.aux_logits = aux_logits

    self.conv1 = conv_block(in_channels = in_channels, out_channels = 64, kernel_size = (7, 7),
                            stride = (2,2), padding = (3, 3))
    
    self.maxpool1 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
    self.conv2 = conv_block(64, 192, kernel_size = 3, stride = 1, padding = 1)
    self.maxpool2 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
    
    # In this order: in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool
    self.inception3a = Inception_block(192, 64, 96, 128, 16, 32, 32)
    self.inception3b = Inception_block(256, 128, 128, 192, 32, 96, 64)
    self.maxpool3 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)

    self.inception4a = Inception_block(480, 192, 96, 208, 16, 48, 64)

    # auxiliary classifier
    self.inception4b = Inception_block(512, 160, 112, 224, 24, 64, 64)
    self.inception4c = Inception_block(512, 128, 128, 256, 24, 64, 64)
    self.inception4d = Inception_block(512, 112, 144, 288, 32, 64, 64)

    # auxiliary classifier
    self.inception4e = Inception_block(528, 256, 160, 320, 32, 128, 128)
    self.maxpool4 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
    self.inception5a = Inception_block(832, 256, 160, 320, 32, 128, 128)
    self.inception5b = Inception_block(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(kernel_size = 7, stride = 1)
    self.dropout = nn.Dropout(p = 0.4)
    self.fc1 = nn.Linear(1024, 1000)

    if self.aux_logits:
      self.aux1 = InceptionAux(512, num_classes)
      self.aux2 = InceptionAux(528, num_classes)
    else:
      self.aux1 = self.aux2 = None

    if init_weights:
      self._initialize_weights()

  
  def forward(self, x):
    x = self.conv1(x)
    x = self.maxpool1(x)
    x = self.conv2(x)
    x = self.maxpool2(x)

    x = self.inception3a(x)
    x = self.inception3b(x)
    x = self.maxpool3(x)

    x = self.inception4a(x)

    if self.aux_logits and self.training:
      aux1 = self.aux1(x)

    x = self.inception4b(x)
    x = self.inception4c(x)
    x = self.inception4d(x)

    if self.aux_logits and self.training:
      aux2 = self.aux2(x)

    x = self.inception4e(x)
    x = self.maxpool4(x)

    x = self.inception5a(x)
    x = self.inception5b(x)
    x = self.avgpool(x)
    x = x.reshape(x.shape[0], -1)
    x = self.dropout(x)
    x = self.fc1(x)
    return x


  # define weight initialization function
  def _initialize_weights(self):
    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
          nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
          nn.init.constant_(m.weight, 1)
          nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
          nn.init.normal_(m.weight, 0, 0.01)
          nn.init.constant_(m.bias, 0)

In [39]:
if __name__ == '__main__':
  x = torch.randn(3, 3, 224, 224)
  model = GoogLeNet()
  print(model(x).shape)

torch.Size([3, 1000])
