<a href="https://colab.research.google.com/github/Muhammad-ali-aren/GoogLeNet-Inception-v1-/blob/main/GoogleLeNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
from torchsummary import summary

In [3]:
class Inception(nn.Module):
    def __init__(self, in_channels, c1, c3_reduce, c3, c5_reduce, c5, pool_proj):
        super(Inception, self).__init__()
        self.branch1 = nn.Conv2d(in_channels, c1, kernel_size=1) # c1 x 28 x 28

        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, c3_reduce, kernel_size=1), # c3_reduce x 28 x 28
            nn.ReLU(inplace=True),
            nn.Conv2d(c3_reduce, c3, kernel_size=3, padding=1), # c3 x 28 x28
            nn.ReLU(inplace=True)
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, c5_reduce, kernel_size=1), # c5_reduce x 28 x28
            nn.ReLU(inplace=True),
            nn.Conv2d(c5_reduce, c5, kernel_size=5, padding=2), # c5 x 28 x 28
            nn.ReLU(inplace=True)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1), # in_channels x 28 x 28
            nn.Conv2d(in_channels, pool_proj, kernel_size=1), #pool_proj x 28 x 28
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x4 = self.branch4(x)
        # print(f'x1: {x1.shape} | x2: {x2.shape} | x3: {x3.shape} | x4: {x4.shape}')
        return torch.cat([x1, x2, x3, x4], 1)


In [56]:
class InceptionAux(nn.Module):
    def __init__(self,in_channels,output_classes):
        super(InceptionAux,self).__init__()
        self.avg_pool = nn.AvgPool2d(kernel_size=5,stride=3)
        self.conv = nn.Conv2d(in_channels=in_channels,out_channels=128,kernel_size=1)
        self.fc1 = nn.Linear(4*4*128,1024)
        self.dropout = nn.Dropout(0.7)
        self.fc2 = nn.Linear(1024, output_classes)
    def forward(self,input):
        x = self.avg_pool(input)
        x = self.conv(x)
        x = torch.relu(x)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [52]:
class GoogleLeNet(nn.Module):
    def __init__(self,training_mode=True):
        super().__init__()
        self.initial_layers = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3),# 64 x 112 x 112
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1), # 64 x 56 x 56
            nn.Conv2d(in_channels=64,out_channels=192,kernel_size=3,stride=1,padding=1),# 192 x 56 x 56
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1), # 192 x 28 x 28
        )
        self.training_mode = training_mode
        self.inception3a = Inception(in_channels=192,c1 = 64, c3_reduce=96, c3=128, c5_reduce=16, c5=32, pool_proj=32) # 256 x 28 x 28
        self.inception3b = Inception(in_channels=256,c1 = 128, c3_reduce=128, c3=192, c5_reduce=32, c5=96, pool_proj=64) # 480 x 28 x 28
        self.maxpool_A3b = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) # 480 x 14 x 14
        self.inception4a = Inception(in_channels=480,c1 = 192, c3_reduce=96, c3=208, c5_reduce=16, c5=48, pool_proj=64) # 512 x 14 x 14
        self.Axu1 = InceptionAux(in_channels=512,output_classes=1000)
        self.inception4b = Inception(in_channels=512,c1 = 160, c3_reduce=112, c3=224, c5_reduce=24, c5=64, pool_proj=64) # 512 x 14 x 14
        self.inception4c = Inception(in_channels=512,c1 = 128, c3_reduce=128, c3=256, c5_reduce=24, c5=64, pool_proj=64) # 512 x 14 x 14
        self.inception4d = Inception(in_channels=512,c1 = 112, c3_reduce=144, c3=288, c5_reduce=32, c5=64, pool_proj=64) # 528 x 14 x 14
        self.Axu2 = InceptionAux(in_channels=528,output_classes=1000)
        self.inception4e = Inception(in_channels=528,c1 = 256, c3_reduce=160, c3=320, c5_reduce=32, c5=128, pool_proj=128) # 832 x 14 x 14
        self.maxpool_A4e = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) # 832 x 7 x 7
        self.inception5a = Inception(in_channels=832,c1 = 256, c3_reduce=160, c3=320, c5_reduce=32, c5=128, pool_proj=128) # 832 x 7 x 7
        self.inception5b = Inception(in_channels=832,c1 = 384, c3_reduce=192, c3=384, c5_reduce=48, c5=128, pool_proj=128) # 1024 x 7 x 7
        self.avg_pool = nn.AvgPool2d(kernel_size=7,stride=1) # 1024 x 1 x 1
        self.dropout = nn.Dropout(0.4)
        self.linear = nn.Linear(1*1*1024,1000)
    def forward(self,input):
        out_initial = self.initial_layers(input)
        out1 = self.inception3a(out_initial)
        out2 = self.inception3b(out1)
        out3 = self.maxpool_A3b(out2)
        out4 = self.inception4a(out3)
        aux1 = self.Axu1(out4) if self.training else None
        out5 = self.inception4b(out4)
        out6 = self.inception4c(out5)
        out7 = self.inception4d(out6)
        aux2 = self.Axu2(out7) if self.training else None
        out8 = self.inception4e(out7)
        out9 = self.maxpool_A4e(out8)
        out10 = self.inception5a(out9)
        out11 = self.inception5b(out10)
        out12 = self.avg_pool(out11)
        out13 = self.dropout(out12)
        out13 = out13.view(out13.size(0), -1)
        out = self.linear(out13)
        if self.training_mode:
            return out,aux1,aux2
        else:
            return out

In [62]:
model = GoogleLeNet()
summary(model,input_size=(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
         MaxPool2d-2           [-1, 64, 56, 56]               0
            Conv2d-3          [-1, 192, 56, 56]         110,784
         MaxPool2d-4          [-1, 192, 28, 28]               0
            Conv2d-5           [-1, 64, 28, 28]          12,352
            Conv2d-6           [-1, 96, 28, 28]          18,528
              ReLU-7           [-1, 96, 28, 28]               0
            Conv2d-8          [-1, 128, 28, 28]         110,720
              ReLU-9          [-1, 128, 28, 28]               0
           Conv2d-10           [-1, 16, 28, 28]           3,088
             ReLU-11           [-1, 16, 28, 28]               0
           Conv2d-12           [-1, 32, 28, 28]          12,832
             ReLU-13           [-1, 32, 28, 28]               0
        MaxPool2d-14          [-1, 192,

In [63]:

x = torch.randn(1,3,224,224)
out, aux1, aux2 = model(x)
print(aux1.shape,aux2.shape,out.shape)

torch.Size([1, 1000]) torch.Size([1, 1000]) torch.Size([1, 1000])
