In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.conv(x))
    
class InceptionModule(nn.Module):
    def __init__(self, in_channels, f1, f3r, f3, f5r, f5, pool_proj):
        super(InceptionModule, self).__init__()
        self.branch1 = ConvBlock(in_channels, f1, 1)

        self.branch2 = nn.Sequential(
            ConvBlock(in_channels, f3r, 1),
            ConvBlock(f3r, f3, 3, padding=1)
        )

        self.branch3 = nn.Sequential(
            ConvBlock(in_channels, f5r, 1),
            ConvBlock(f5r, f5, 5, padding=2)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            ConvBlock(in_channels, pool_proj, 1)
        )

    def forward(self, x):
        branches = [
            self.branch1(x),
            self.branch2(x),
            self.branch3(x),
            self.branch4(x)
        ]
        return torch.cat(branches, 1)    

In [None]:
class GoogleNet(nn.Module):
    def __init__(self):
        super(GoogleNet, self).__init__()
        self.pre_layers = nn.Sequential(
            ConvBlock(3, 64, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(3, stride=2),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75),
            ConvBlock(64, 64, kernel_size=1),
            ConvBlock(64, 192, kernel_size=3, padding=1),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75),
            nn.MaxPool2d(3, stride=2)
        )

        self.inception_3a = InceptionModule(192, 64, 96, 128, 16, 32, 32)
        self.inception_3b = InceptionModule(256, 128, 128, 192, 32, 96, 64)
        # self.inception_4a = InceptionModule();

    def forward(self, x):
        x = self.pre_layers(x)
        x = self.inception_3a(x)
        x = self.inception_3b(x)



        # x = self.inception_4a(x)
        # x = self.inception_4b(x)
        # x = self.inception_4c(x)
        # x = self.inception_4d(x)
        # x = self.inception_5a(x)
        # x = self.inception_5b(x)

        
        return x


In [None]:
net = GoogleNet()
input_tensor = torch.randn(1, 3, 224, 224)  # Example input tensor
output = net(input_tensor)
print(output.size())
