## Reference :
[Going Deeper with Convolutions](https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable

In [None]:
class Inception(nn.Module):
    
    def __init__(self, in_channels, n1x1, n3x3pre, n3x3, n5x5pre, n5x5, pool_features):
        super(Inception, self).__init__()
            
        # 1x1 branch
        self.branch1x1 = BasicConv2d(in_channels, n1x1, kernel_size=1)

        # 3x3 branch : through 1x1 conv to reduce parameters
        self.branch3x3 = nn.Sequential(
            BasicConv2d(in_channels, n3x3pre, kernel_size=1),
            BasicConv2d(n3x3pre, n3x3, kernel_size=3, padding=1),
        )

        # 5x5 branch 
        self.branch5x5 = nn.Sequential(
            BasicConv2d(in_channels, n5x5pre, kernel_size=1),
            BasicConv2d(n5x5pre, n5x5, kernel_size=5, padding=2),
        )

        # MaxPool branch
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_features, kernel_size=1),
        )
        
    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3(x)

        branch5x5 = self.branch5x5(x)

        branch_pool = self.branch_pool(x)

        outputs = [branch1x1, branch3x3, branch5x5, branch_pool]
        # catenate all branch in channel dimension
        return torch.cat(outputs, 1)

class GoogLeNet(nn.Module):
    
        def __init__(self):
            super(GoogLeNet, self).__init__()
            self.conv1 = BasicConv2d(3, 32, kernel_size=3, padding=1)
            # Orignal version has a maxpool but cifar10 image size too small
            self.conv2 = BasicConv2d(32, 192, kernel_size=3, padding=1)
            
            self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
            self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
            
            self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
            
            self.inception4a = Inception(480, 192,  96, 208, 16,  48,  64)
            self.inception4b = Inception(512, 160, 112, 224, 24,  64,  64)
            self.inception4c = Inception(512, 128, 128, 256, 24,  64,  64)
            self.inception4e = Inception(512, 112, 144, 288, 32,  64,  64)
            self.inception4d = Inception(528, 256, 160, 320, 32, 128, 128)
            
            self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
            
            self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
            self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
            
            self.avgpool = nn.AvgPool2d(8, stride=1)
            self.drop = nn.Dropout()
            self.linear = nn.Linear(1024, 10)
            
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = init.kaiming_normal(m.weight, mode = 'fan_out')
                    if m.bias is not None:
                        init.constant(m.bias, 0)
                elif isinstance(m, nn.BatchNorm2d):
                    init.constant(m.weight, 1)
                    init.constant(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    init.normal(m.weight, std=1e-2)
                    if m.bias is not None:
                        init.constant(m.bias, 0)

        def forward(self, x):
            output = self.conv1(x)
            output = self.conv2(output)
            output = self.inception3a(output)
            output = self.inception3b(output)
            output = self.maxpool(output)
            output = self.inception4a(output)
            output = self.inception4b(output)
            output = self.inception4c(output)
            output = self.inception4e(output)
            output = self.inception4d(output)
            output = self.maxpool(output)
            output = self.inception5a(output)
            output = self.inception5b(output)
            output = self.avgpool(output)
            output = self.drop(output)
            output = output.view(output.size(0), -1)
            output = self.linear(output)
            return output
            
class BasicConv2d(nn.Module):
    
    def __init__(self, in_channels,out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    
    

# net = GoogLeNet()
# print(net)
# x = torch.randn(64,3,32,32)
# y = net(Variable(x))
# print(y.size())