## Inception 
![](./imgs/inception.svg)

In [5]:
from torch import nn
from torch.nn import functional as F
import torch
class Inception(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.net_0=nn.Sequential(
            nn.Conv2d(in_channels,out_channels[0],kernel_size=1),
            nn.ReLU()
        )
        self.net_1=nn.Sequential(
            nn.Conv2d(in_channels,out_channels[1][0],kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(out_channels[1][0],out_channels[1][1],kernel_size=3,padding=1),
            nn.ReLU(),
        )
        self.net_2=nn.Sequential(
            nn.Conv2d(in_channels,out_channels[2][0],kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(out_channels[2][0],out_channels[2][1],kernel_size=5,padding=2),
            nn.ReLU(),
        )
        self.net_3=nn.Sequential(
            nn.MaxPool2d(kernel_size=3,padding=1,stride=1),
            nn.Conv2d(in_channels,out_channels[3],kernel_size=1),
            nn.ReLU(),
        )        
        
        
    
    
    def forward(self,X):
        X0=self.net_0(X)
        X1=self.net_1(X)
        X2=self.net_2(X)
        X3=self.net_3(X)
        return torch.cat((X0,X1,X2,X3),dim=1)
        


![](./imgs/inception-full.svg)

In [6]:
class GoogleNet(nn.Module):
    def __init__(self):
        super().__init__()
        net_0=nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        net_1=nn.Sequential(
            
            nn.Conv2d(64, 64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        net_2=nn.Sequential(
            Inception(192,[64, (96, 128), (16, 32), 32]),
            Inception(256,[128, (128, 192), (32, 96), 64]),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        net_3=nn.Sequential(
            Inception(480, [192, (96, 208), (16, 48), 64]),
            Inception(512, [160, (112, 224), (24, 64), 64]),
            Inception(512, [128, (128, 256), (24, 64), 64]),
            Inception(512, [112, (144, 288), (32, 64), 64]),
            Inception(528, [256, (160, 320), (32, 128), 128]),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        net_4=nn.Sequential(
            Inception(832,[ 256, (160, 320), (32, 128), 128]),
            Inception(832,[ 384, (192, 384), (48, 128), 128]),
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten()
        )
        self.net=nn.Sequential(
            net_0,net_1,net_2,net_3,net_4,nn.Linear(1024, 10)
        )
        
    def forward(self,X):
        return self.net(X)

    





In [7]:
X = torch.rand(size=(1, 1, 96, 96))

In [8]:
net=GoogleNet()
net(X)

torch.Size([1, 192, 12, 12])
torch.Size([1, 64, 12, 12]) torch.Size([1, 128, 12, 12]) torch.Size([1, 32, 12, 12]) torch.Size([1, 32, 12, 12])
torch.Size([1, 256, 12, 12])
torch.Size([1, 128, 12, 12]) torch.Size([1, 192, 12, 12]) torch.Size([1, 96, 12, 12]) torch.Size([1, 64, 12, 12])
torch.Size([1, 480, 6, 6])
torch.Size([1, 192, 6, 6]) torch.Size([1, 208, 6, 6]) torch.Size([1, 48, 6, 6]) torch.Size([1, 64, 6, 6])
torch.Size([1, 512, 6, 6])
torch.Size([1, 160, 6, 6]) torch.Size([1, 224, 6, 6]) torch.Size([1, 64, 6, 6]) torch.Size([1, 64, 6, 6])
torch.Size([1, 512, 6, 6])
torch.Size([1, 128, 6, 6]) torch.Size([1, 256, 6, 6]) torch.Size([1, 64, 6, 6]) torch.Size([1, 64, 6, 6])
torch.Size([1, 512, 6, 6])
torch.Size([1, 112, 6, 6]) torch.Size([1, 288, 6, 6]) torch.Size([1, 64, 6, 6]) torch.Size([1, 64, 6, 6])
torch.Size([1, 528, 6, 6])
torch.Size([1, 256, 6, 6]) torch.Size([1, 320, 6, 6]) torch.Size([1, 128, 6, 6]) torch.Size([1, 128, 6, 6])
torch.Size([1, 832, 3, 3])
torch.Size([1, 256, 3

tensor([[-0.0393,  0.0269,  0.0090,  0.0067, -0.0145,  0.0080, -0.0187, -0.0218,
          0.0191, -0.0196]], grad_fn=<AddmmBackward0>)