In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class Inception(nn.Module):
    def __init__(self, in_channels, out_1x1, out_3x3_reduce, out_3x3, out_5x5_reduce, out_5x5, out_pool_proj):
        super(Inception, self).__init__()
        
        # 1x1 conv branch
        self.branch1x1= nn.Conv2d(in_channels, out_1x1, kernel_size=1)
        
        # 1x1 conv -> 3x3 conv branch
        self.branch3x3_1= nn.Conv2d(in_channels, out_3x3_reduce, kernel_size= 1)
        self.branch3x3_2= nn.Conv2d(out_3x3_reduce, out_3x3, kernel_size=3, padding=1)
        
        # 1x1 conv -> 5x5 conv branch
        self.branch5x5_1= nn.Conv2d(in_channels, out_5x5_reduce, kernel_size=1)
        self.branch5x5_2= nn.Conv2d(out_5x5_reduce, out_5x5, kernel_size= 5, padding= 2)
        
        # 3x3 pool -> 1x1 conv branch
        self.branch_pool= nn.Conv2d(in_channels, out_pool_proj, kernel_size=1)
        
        
    def forward(self, x):
        branch1x1= self.branch1x1(x)
        
        branch3x3= self.branch3x3_1(x)
        branch3x3= self.branch3x3_2(branch3x3)
        
        branch5x5= self.branch5x5_1(x)
        branch5x5= self.branch5x5_2(branch5x5)
        
        branch_pool= F.max_pool2d(x, kernel_size= 3, stride= 1, padding= 1)
        branch_pool= self.branch_pool(branch_pool)
        
        outputs= [branch1x1, branch3x3, branch5x5, branch_pool]
        return torch.cat(outputs, 1)

In [11]:
class GoogleNet(nn.Module):
    def __init__(self, num_classes= 1000):
        super(GoogleNet, self).__init__()
        
        self.conv1= nn.Conv2d(3, 64, kernel_size= 7, stride=2, padding= 3)
        self.maxpool1= nn.MaxPool2d(kernel_size= 3, stride= 2, padding= 1)
        self.conv2= nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1)
        self.maxpool2= nn.MaxPool2d(kernel_size= 3, stride=2, padding=1)
        
        self.Inception3a= Inception(192, 64, 96, 128, 16, 32, 32)
        self.Inception3b= Inception(256, 128, 128, 192, 32, 96, 64)
        
        self.maxpool3= nn.MaxPool2d(kernel_size= 3, stride= 2, padding=1)
        
        self.Inception4a= Inception(480, 192, 96, 208, 16, 48, 64)
        self.Inception4b= Inception(512, 160, 112, 224, 24, 64, 64)
        self.Inception4c= Inception(512, 128, 128, 256, 24, 64, 64)
        self.Inception4d= Inception(512, 112, 144, 288, 32, 64, 64)
        self.Inception4e= Inception(528, 256, 160, 320, 32, 128, 128)
        
        self.maxpool4= nn.MaxPool2d(kernel_size= 3, stride= 2, padding= 0)
        
        self.Inception5a= Inception(832, 256, 160, 320, 32, 128, 128)
        self.Inception5b= Inception(832, 384, 192, 384, 48, 128, 128)
        
        self.avgpool= nn.AdaptiveAvgPool2d((1, 1))
        self.dropout= nn.Dropout(0.4)
        self.fc= nn.Linear(1024, num_classes)
    
    def forward(self, x):
        x= self.conv1(x)
        x= self.maxpool1(x)
        x= self.conv2(x)
        x= self.maxpool2(x)
        
        x= self.Inception3a(x)
        x= self.Inception3b(x)
        x= self.maxpool3(x)
        
        x= self.Inception4a(x)
        x= self.Inception4b(x)
        x= self.Inception4c(x)
        x= self.Inception4d(x)
        x= self.Inception4e(x)
        x= self.maxpool4(x)
        
        x= self.Inception5a(x)
        x= self.Inception5b(x)
        
        x= self.avgpool(x)
        x= torch.flatten(x, 1)
        x= self.dropout(x)
        x= self.fc(x)
        
        return x
        

In [12]:
model= GoogleNet(num_classes= 1000)
print(model)

#Testing with dummay data

data= torch.randn(1, 3, 224, 224)
print(model(data).shape)

GoogleNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (Inception3a): Inception(
    (branch1x1): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
    (branch3x3_1): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
    (branch3x3_2): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (branch5x5_1): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))
    (branch5x5_2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (branch_pool): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (Inception3b): Inception(
    (branch1x1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
    (branch3x3_1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
    (bra