In [None]:
import tensorflow as tf
from functools import partial

In [None]:
tf.keras.backend.clear_session()

In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import torchvision
torch.use_deterministic_algorithms(True)
from tqdm import tqdm
import torch.optim as optim

import os
import time
import copy
from torchvision import datasets, models, transforms

# Tensorflow Implementation

In [None]:
ConvBlock = partial(tf.keras.layers.Conv2D, kernel_size=3, strides=1,
                        padding="same", activation = "relu",
                        use_bias=False)

class InceptionBlock(tf.keras.layers.Layer):
    def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool, strides=1, activation="relu", **kwargs): # in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool
        super().__init__(**kwargs)
        self.activation = tf.keras.activations.get(activation)
        self.concat = tf.keras.layers.Concatenate

        self.branch1_layers = [
            ConvBlock(filters = out_1x1, kernel_size = 1)
            ]

        self.branch2_layers = [
            ConvBlock(filters = red_3x3, kernel_size=1),
            ConvBlock(filters = out_3x3, kernel_size=(3, 3), padding = "same")
            ]

        self.branch3_layers = [
            ConvBlock(filters = red_5x5, kernel_size=1),
            ConvBlock(filters = out_5x5, kernel_size=5, padding = "same")
        ]

        self.branch4_layers = [
            tf.keras.layers.MaxPool2D(pool_size = 3, strides = 1, padding = "same"),
            ConvBlock(filters = out_1x1pool, kernel_size=1, padding = "same")
        ]

        self.main_layers = [self.branch1_layers, self.branch2_layers, self.branch3_layers, self.branch4_layers]

    def call(self, inputs):
        outputs = []
        for branch in self.main_layers:
          Z = inputs
          for layer in branch:
            Z = layer(Z)
          outputs.append(Z)

        return self.concat()(outputs)

In [None]:
class Inception_Aux(tf.keras.layers.Layer):
    def __init__(self, filters, num_classes, strides=1, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.activation = tf.keras.activations.get(activation)

        self.main_layers = [
            tf.keras.layers.AveragePooling2D(pool_size=5, strides=3),
            ConvBlock(filters = 128, kernel_size=1),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1024),
            self.activation,
            tf.keras.layers.Dropout(rate = 0.7),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Dense(num_classes)
        ]

    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)

        return self.activation(Z)

In [None]:
def createGoogleNet(training, aux_logits=True, num_classes=1000):

  input_ = tf.keras.layers.Input(shape = (224, 224, 3))
  conv1 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 7, strides = 2, padding = "same")(input_)
  maxpool1 = tf.keras.layers.MaxPool2D(pool_size = 3, strides=2, padding="same")(conv1)
  conv2 = tf.keras.layers.Conv2D(filters = 192, kernel_size = 3, strides = 1, padding = "same")(maxpool1)
  maxpool2 = tf.keras.layers.MaxPool2D(pool_size = 3, strides=2, padding="same")(conv2)

  inception3a = InceptionBlock(192, 64, 96, 128, 16, 32, 32)(maxpool2)
  inception3b = InceptionBlock(256, 128, 128, 192, 32, 96, 64)(inception3a)
  maxpool3 = tf.keras.layers.MaxPool2D(pool_size = 3, strides=2, padding="same")(inception3b)

  inception4a = InceptionBlock(480, 192, 96, 208, 16, 48, 64)(maxpool3)

  if aux_logits and training:
    aux1 = Inception_Aux(512, num_classes)(inception4a)

  inception4b = InceptionBlock(512, 160, 112, 224, 24, 64, 64)(inception4a)
  inception4c = InceptionBlock(512, 128, 128, 256, 24, 64, 64)(inception4b)
  inception4d = InceptionBlock(512, 112, 144, 288, 32, 64, 64)(inception4c)

  if aux_logits and training:
      aux2 = Inception_Aux(528, num_classes)(inception4d)

  inception4e = InceptionBlock(528, 256, 160, 320, 32, 128, 128)(inception4d)
  maxpool4 = tf.keras.layers.MaxPool2D(pool_size =3, strides=2, padding= "same")(inception4e)

  inception5a = InceptionBlock(832, 256, 160, 320, 32, 128, 128)(maxpool4)
  inception5b = InceptionBlock(832, 384, 192, 384, 48, 128, 128)(inception5a)

  avgpool = tf.keras.layers.AveragePooling2D(pool_size=7, strides=1)(inception5b)
  dropout = tf.keras.layers.Dropout(rate = 0.4)(avgpool)
  fc1 = tf.keras.layers.Dense(num_classes)(dropout)


  if aux_logits and training:
    model = tf.keras.Model(inputs = (input_), outputs = [fc1, aux1, aux2])
    return model
  else:
    model = tf.keras.Model(inputs = (input_), outputs = [fc1])
    return model


In [None]:
googleNet_Tf = createGoogleNet(False)
googleNet_Tf.summary()

In [None]:
x = tf.random.uniform((1, 224, 224, 3))
googleNet_Tf(x).shape

TensorShape([1, 1, 1, 1000])

# Pytorch Implementation
#### Batch Norms are not in the main architecture

In [None]:
class conv_block(nn.Module):
  def __init__(self, in_channels, out_channels, **kwargs):
    super(conv_block, self).__init__()
    self.relu = nn.ReLU()
    self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
    self.batchnorm = nn.BatchNorm2d(out_channels)

  def forward(self, x):
    return self.relu(self.batchnorm(self.conv(x)))



In [None]:
class Inception_block(nn.Module):
    def __init__(
        self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool
    ):
        super(Inception_block, self).__init__()
        self.branch1 = conv_block(in_channels, out_1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            conv_block(in_channels, red_3x3, kernel_size=1),
            conv_block(red_3x3, out_3x3, kernel_size=(3, 3), padding=1),
        )

        self.branch3 = nn.Sequential(
            conv_block(in_channels, red_5x5, kernel_size=1),
            conv_block(red_5x5, out_5x5, kernel_size=5, padding=2),
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            conv_block(in_channels, out_1x1pool, kernel_size=1),
        )

    def forward(self, x):
        return torch.cat(
            [self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], 1
        )

In [None]:
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.7)
        self.pool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = conv_block(in_channels, 128, kernel_size=1)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        x = x.reshape(x.shape[0], -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
class GoogLeNet(nn.Module):
    def __init__(self, aux_logits=True, num_classes=1000):
        super(GoogLeNet, self).__init__()
        assert aux_logits == True or aux_logits == False
        self.aux_logits = aux_logits

        # Write in_channels, etc, all explicit in self.conv1, rest will write to
        # make everything as compact as possible, kernel_size=3 instead of (3,3)
        self.conv1 = conv_block(
            in_channels=3,
            out_channels=64,
            kernel_size=7,
            stride=2,
            padding=3,
        )

        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = conv_block(64, 192, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # In this order: in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool
        self.inception3a = Inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception_block(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception4a = Inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception_block(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception5a = Inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception_block(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.dropout = nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(1024, num_classes)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)
        else:
            self.aux1 = self.aux2 = None

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)

        # Auxiliary Softmax classifier 1
        if self.aux_logits and self.training:
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)

        # Auxiliary Softmax classifier 2
        if self.aux_logits and self.training:
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        x = self.maxpool4(x)
        x = self.inception5a(x)
        x = self.inception5b(x)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.dropout(x)
        x = self.fc1(x)

        if self.aux_logits and self.training:
            return aux1, aux2, x
        else:
            return x

In [None]:
BATCH_SIZE = 5
x = torch.randn(BATCH_SIZE, 3, 224, 224)
model = GoogLeNet(aux_logits=True, num_classes=1000)
print(model(x)[2].shape)
assert model(x)[2].shape == torch.Size([BATCH_SIZE, 1000])

torch.Size([5, 1000])
