# Inception : Going deeper with convolutions
## Deep Dive

### Step 0 : Import libraries

In [None]:
# basic
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# basic-torch
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import Dataset,DataLoader

from torchvision.models import googlenet, GoogLeNet_Weights
from torchvision import datasets



from utils import accuracy_topk

from tqdm import tqdm


### Step1 : model architecture

In [None]:
class BasicConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False)
        self.bn = nn.BatchNorm2d(out_channels,eps=0.001)

    def forward (self,x):
        x= self.conv(x)
        x= self.bn(x)

        # Conv -> bn -> relu
        return F.relu(x)


In [None]:
class InceptionModule(nn.Module):
    def __init__(self,in_channels,out_channels1,out_channels3,out_channels5,mid_channels3,mid_channels5,out_channels_pool):
        super().__init__()
        self.branch1 = BasicConv2d(in_channels,out_channels1,kernel_size=1)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels,mid_channels3,kernel_size=1),
            BasicConv2d(mid_channels3,out_channels3,kernel_size=3,padding=1)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, mid_channels5, kernel_size=1),
            BasicConv2d(mid_channels5, out_channels5, kernel_size=3, padding=1),
        )
        self.branch4=nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            BasicConv2d(in_channels,out_channels_pool,kernel_size=1)
        )
    def forward(self,x):
        out1  = self.branch1(x)
        out2  = self.branch2(x)
        out3  = self.branch3(x)
        out4  = self.branch4(x)
        return torch.concat([out1,out2,out3,out4],dim=1)

In [None]:
class InceptionAux(nn.Module):
    def __init__(self,in_channels=3,out_channels=1000):
        super().__init__()
        self.pool = nn.MaxPool2d((5,5),stride=3,ceil_mode=True) # stride 2?
        self.conv = BasicConv2d(in_channels,out_channels,kernel_size=1,stride=1)
        self.fc1 = nn.Linear(2048,1024,bias=True)
        self.fc2 = nn.Linear(1024,1000,bias=True)
        self.dropout = nn.Dropout(p=0.7, inplace=False)
    def forward(self,x):
        x= self.pool(x)
        x= self.conv(x)
        x= x.view(x.size(0),-1)
        x =self.fc1(x)
        x=F.relu(x)
        x= self.dropout(x)
        return self.fc2(x)

In [None]:
class MyInception(nn.Module):
    def __init__(self,in_channels=3,out_channels=1000):
        super().__init__()
        
        self.conv1 = BasicConv2d(in_channels,64,kernel_size=7,stride=2,padding=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3,stride=2,dilation=1,ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64,kernel_size=1)
        self.conv3 =  BasicConv2d(64, 192,kernel_size=3)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3,stride=2,dilation=1,ceil_mode=True)

        self.inception3a= InceptionModule(
            in_channels=192,out_channels1=64,
            mid_channels3=96,out_channels3=128,
            mid_channels5 = 16,out_channels5 =32,out_channels_pool=32)
        self.inception3b=InceptionModule(
            in_channels=256,out_channels1=128,
            mid_channels3=128,out_channels3=192,
            mid_channels5 = 32,out_channels5 =96,out_channels_pool=64)

        self.maxpool3 = nn.MaxPool2d(kernel_size=3,stride=2,dilation=1,ceil_mode=True)
        
        self.inception4a=InceptionModule(
            in_channels=480,out_channels1=192,
            mid_channels3=96,out_channels3=208,
            mid_channels5 = 16,out_channels5 =48,out_channels_pool=64)
        self.inception4b= InceptionModule(
            in_channels=512,out_channels1=160,
            mid_channels3=112,out_channels3=224,
            mid_channels5 = 24,out_channels5 =64,out_channels_pool=64)
        self.inception4c=  InceptionModule(
            in_channels=512,out_channels1=128,
            mid_channels3=128,out_channels3=256,
            mid_channels5 = 24,out_channels5 =64,out_channels_pool=64)
        self.inception4d=  InceptionModule(
            in_channels=512,out_channels1=112,
            mid_channels3=144,out_channels3=288,
            mid_channels5 = 32,out_channels5 =64,out_channels_pool=64)
        self.inception4e=  InceptionModule(
            in_channels=528,out_channels1=256,
            mid_channels3=160,out_channels3=320,
            mid_channels5 = 32,out_channels5 =128,out_channels_pool=128)

        self.maxpool4 =nn.MaxPool2d(kernel_size=3,stride=2,dilation=1,ceil_mode=True)
        
        self.inception5a= InceptionModule(
            in_channels=832,out_channels1=256,
            mid_channels3=160,out_channels3=320,
            mid_channels5 = 32,out_channels5 =128,out_channels_pool=128)
        self.inception5b= InceptionModule(
            in_channels=832,out_channels1=384,
            mid_channels3=192,out_channels3=384,
            mid_channels5 = 48,out_channels5 =128,out_channels_pool=128)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1)) # 7x7x1024--->1024
        self.dropout = nn.Dropout(p=0.2, inplace=False)
        self.fc = nn.Linear(1024,out_channels,bias=True)

        self.aux1 = InceptionAux(512,128)
        self.aux2 = InceptionAux(528,128)


    def forward(self,x):
        x = self.conv1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        
        x = self.maxpool3(x)

        x = self.inception4a(x)
        out1 = self.aux1(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        out2 = self.aux2(x)
        x = self.inception4e(x)

        x = self.maxpool4(x)
        
        x = self.inception5a(x)
        x = self.inception5b(x)

        x= self.avgpool(x)
        x = x.view(x.size(0),-1)
        x = self.dropout(x)
        return out1,out2,self.fc(x)






In [None]:
# Sanity check
# [batch,channel,H,W]
img = torch.zeros((1,3,224,224))

instance = MyInception(3,1000)
instance(img)[0].shape,instance(img)[1].shape,instance(img)[2].shape

In [None]:
myGoogLeNet = MyInception(3,1000)
weights = GoogLeNet_Weights.IMAGENET1K_V1
tv_model = googlenet(weights=weights, aux_logits=True)
state = tv_model.state_dict()



missing, unexpected = myGoogLeNet.load_state_dict(state, strict=True)
print("missing:", missing)
print("unexpected:", unexpected)

In [None]:

device = "cuda" if torch.cuda.is_available() else "cpu"

weights = GoogLeNet_Weights.IMAGENET1K_V1
val_tfms = weights.transforms()

val_ds = datasets.ImageFolder("./val", transform=val_tfms)
print(len(val_ds))
val_loader = DataLoader(
    val_ds,
    batch_size=256,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
)

@torch.no_grad()
def eval_imagenet(model, loader):
    model.eval()
    top1 = 0.0
    top5 = 0.0
    total = 0

    for x, y in tqdm(loader, desc="ImageNet val"):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        out = model(x)
        logits = out[-1] if isinstance(out, (tuple, list)) else out 

        c1, c5 = accuracy_topk(logits, y, topk=(1, 5))
        top1 += c1.item()
        top5 += c5.item()
        total += y.size(0)

    return 100.0 * top1 / total, 100.0 * top5 / total


myGoogLeNet = MyInception(3,1000)


weights = GoogLeNet_Weights.IMAGENET1K_V1
tv_model = googlenet(weights=weights, aux_logits=True)
state = tv_model.state_dict()

missing, unexpected = myGoogLeNet.load_state_dict(state, strict=True)



t1, t5 = eval_imagenet(myGoogLeNet.to(device), val_loader)
print(f"Top-1: {t1:.3f}% | Top-5: {t5:.3f}%")
