<a href="https://colab.research.google.com/github/IANGECHUKI176/deeplearning/blob/main/pytorch/convnets/inceptionv3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna

Rethinking the Inception Architecture for Computer Vision
> https://arxiv.org/abs/1512.00567v3

In [None]:
import torch
import torch.nn as nn
from torchsummary import summary

In [None]:
class BasicConv2d(nn.Module):
    def __init__(self,ch_in,ch_out,**kwargs):
        super(BasicConv2d,self).__init__()
        self.conv = nn.Conv2d(ch_in,ch_out,bias = False,**kwargs)
        self.bn = nn.BatchNorm2d(ch_out)
        self.relu = nn.ReLU(inplace = True)
    def forward(self,x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        return out

In [None]:
class InceptionA(nn.Module):
    def __init__(self,input_channel,pool_features):
        super(InceptionA,self).__init__()
        self.branch1x1 = BasicConv2d(input_channel,64,kernel_size = 1)
        self.branch5x5 = nn.Sequential(
            BasicConv2d(input_channel,48,kernel_size = 1),
            BasicConv2d(48,64,kernel_size = 5,padding = 2)
        )
        self.branch3x3 = nn.Sequential(
            BasicConv2d(input_channel,64,kernel_size = 1),
            BasicConv2d(64,96,kernel_size = 3,padding = 1),
            BasicConv2d(96,96,kernel_size = 3,padding = 1),
        )
        self.branch_pool = nn.Sequential(
            nn.AvgPool2d(kernel_size = 3,stride = 1,padding = 1),
            BasicConv2d(input_channel,pool_features,kernel_size = 3,padding = 1)
        )
    def forward(self,x):
        #x -> 1x1(same)
        branch1x1 = self.branch1x1(x)

        #x -> 1x1 -> 5x5(same)
        branch5x5 = self.branch5x5(x)

        #x -> 1x1 -> 3x3 -> 3x3(same)
        branch3x3 = self.branch3x3(x)

        #x -> pool -> 1x1(same)
        branch_pool = self.branch_pool(x)

        outputs = [branch1x1,branch5x5,branch3x3,branch_pool]
        return torch.cat(outputs,1)

Downsampling

In [None]:
class ReductionA(nn.Module):
    def __init__(self,in_channels,**kwargs):
        super(ReductionA,self).__init__()
        self.branch3x3 = BasicConv2d(in_channels,384,kernel_size = 3,stride = 2)

        self.branch3x3_stack = nn.Sequential(
            BasicConv2d(in_channels,64,kernel_size = 1),
            BasicConv2d(64,96,kernel_size = 3,padding = 1),
            BasicConv2d(96,96,kernel_size = 3,stride = 2)
        )

        self.branch_pool = nn.MaxPool2d(kernel_size = 3,stride = 2)

    def forward(self,x):
        #x - > 3x3(downsample)
        branch3x3 = self.branch3x3(x)

        #x -> 3x3 -> 3x3(downsample)
        branch3x3_stack = self.branch3x3_stack(x)

        #x -> avgpool(downsample)
        branch_pool = self.branch_pool(x)

         #"""We can use two parallel stride 2 blocks: P and C. P is a pooling
        #layer (either average or maximum pooling) the activation, both of
        #them are stride 2 the filter banks of which are concatenated as in
        #figure 10."""
        return torch.cat([branch3x3,branch3x3_stack,branch_pool],1)

In [None]:
#Factorizing Convolutions with Large Filter Size
class InceptionB(nn.Module):
    def __init__(self,in_channels,channels_7x7):
        super(InceptionB,self).__init__()

        self.branch1x1 = BasicConv2d(in_channels,192,kernel_size = 1)

        c7 = channels_7x7
        #In theory, we could go even further and argue that one can replace any n × n
        #convolution by a 1 × n convolution followed by a n × 1 convolution and the
        #computational cost saving increases dramatically as n grows (see figure 6).
        self.branch7x7 = nn.Sequential(
            BasicConv2d(in_channels,c7,kernel_size = 1),
            BasicConv2d(c7,c7,kernel_size = (7,1),padding = (3,0)),
            BasicConv2d(c7,192,kernel_size = (1,7),padding = (0,3))
        )
        self.branch7x7stack = nn.Sequential(
            BasicConv2d(in_channels,c7,kernel_size = 1),
            BasicConv2d(c7,c7,kernel_size = (7,1),padding = (3,0)),
            BasicConv2d(c7,c7,kernel_size =(1,7),padding = (0,3)),
            BasicConv2d(c7,c7,kernel_size = (7,1),padding = (3,0)),
            BasicConv2d(c7,192,kernel_size = (1,7),padding = (0,3))
        )
        self.branch_pool = nn.Sequential(
            nn.AvgPool2d(kernel_size = 3,stride = 1,padding = 1),
            BasicConv2d(in_channels,192,kernel_size = 1)
        )
    def forward(self,x):
        #x -> 1x1(same)
        branch1x1 = self.branch1x1(x)

        #x -> 1layer 1*7 and 7*1 (same)
        branch7x7 = self.branch7x7(x)

        #x-> 2layer 1*7 and 7*1(same)
        branch7x7stack = self.branch7x7stack(x)

        #x-> avgpool (same)
        branch_pool = self.branch_pool(x)

        return torch.cat([branch1x1,branch7x7,branch7x7stack,branch_pool],1)

In [None]:
class ReductionB(nn.Module):
    def __init__(self,in_channels):
        super(ReductionB,self).__init__()
        self.branch3x3 = nn.Sequential(
            BasicConv2d(in_channels,192,kernel_size = 1),
            BasicConv2d(192,320,kernel_size = 3,stride = 2)
        )
        self.branch7x7 = nn.Sequential(
            BasicConv2d(in_channels,192,kernel_size = 1),
            BasicConv2d(192,192,kernel_size = (1,7),padding = (0,3)),
            BasicConv2d(192,192,kernel_size = (7,1),padding = (3,0)),
            BasicConv2d(192,192,kernel_size = 3,stride = 2)
        )
        self.branch_pool = nn.AvgPool2d(kernel_size = 3,stride = 2)

    def forward(self,x):
        #x -> 1x1 -> 3x3(downsample)
        branch3x3 = self.branch3x3(x)

        #x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample)
        branch7x7 = self.branch7x7(x)

        #x -> avgpool (downsample)
        branch_pool = self.branch_pool(x)
        return torch.cat([branch3x3,branch7x7,branch_pool],1)

In [None]:
#same
class InceptionC(nn.Module):
    def __init__(self,in_channels):
        super(InceptionC,self).__init__()
        #x -> 1x1 (same)
        self.branch1x1 = BasicConv2d(in_channels,320,kernel_size = 1)

        # x -> 1x1 -> 3x1
        # x -> 1x1 -> 1x3
        # concatenate(3x1, 1x3)
        #"""7. Inception modules with expanded the filter bank outputs.
        #This architecture is used on the coarsest (8 × 8) grids to promote
        #high dimensional representations, as suggested by principle
        #2 of Section 2."""
        self.branch3x3_1 = BasicConv2d(in_channels,384,kernel_size = 1)
        self.branch3x3_2a = BasicConv2d(384,384,kernel_size = (1,3),padding = (0,1))
        self.branch3x3_2b = BasicConv2d(384,384,kernel_size = (3,1),padding = (1,0))

        # x -> 1x1 -> 3x3 -> 1x3
        # x -> 1x1 -> 3x3 -> 3x1
        #concatenate(1x3, 3x1)
        self.branch3x3stack_1 = BasicConv2d(in_channels,448,kernel_size = 1)
        self.branch3x3stack_2 = BasicConv2d(448,384,kernel_size = 3,padding = 1)
        self.branch3x3stack_3a = BasicConv2d(384,384,kernel_size = (1,3),padding = (0,1))
        self.branch3x3stack_3b = BasicConv2d(384,384,kernel_size = (3,1),padding = (1,0))

        self.branch_pool = nn.Sequential(
            nn.AvgPool2d(kernel_size=3,stride = 1,padding = 1),
            BasicConv2d(in_channels,192,kernel_size = 1)
        )
    def forward(self,x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)

        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3)
        ]

        branch3x3 = torch.cat(branch3x3,1)

        branch3x3stack = self.branch3x3stack_1(x)

        branch3x3stack=self.branch3x3stack_2(branch3x3stack)

        branch3x3stack= [
            self.branch3x3stack_3a(branch3x3stack),
            self.branch3x3stack_3b(branch3x3stack)
        ]
        branch3x3stack = torch.cat(branch3x3stack,1)

        branch_pool = self.branch_pool(x)

        outputs = [branch1x1,branch3x3,branch3x3stack,branch_pool]
        return torch.cat(outputs,1)

In [None]:
from torch.nn.modules.pooling import MaxPool2d
class InceptionV3(nn.Module):
    def __init__(self,n_classes = 100):
        super(InceptionV3,self).__init__()

        self.stem = nn.Sequential(
            BasicConv2d(3,32,kernel_size = 3,padding = 1),
            BasicConv2d(32,32,kernel_size = 3,padding = 1),
            BasicConv2d(32,64,kernel_size = 3,padding = 1),
            nn.MaxPool2d(kernel_size = 3,stride = 2),
            BasicConv2d(64,80,kernel_size = 1),
            BasicConv2d(80,192,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 3,stride = 2),
            #BasicConv2d(192,288,kernel_size = 3)
        )

        #naive inception module
        self.inception_a = nn.Sequential(
            InceptionA(192,pool_features=32),
            InceptionA(256,pool_features=64),
            InceptionA(288,pool_features=64),
        )
        #downsample
        self.reduction_a = ReductionA(288)

        self.inception_b = nn.Sequential(
            InceptionB(768,channels_7x7=128),
            InceptionB(768,channels_7x7=160),
            InceptionB(768,channels_7x7=160),
            InceptionB(768,channels_7x7=192)
        )
        #downsample
        self.reduction_b = ReductionB(768)

        self.inception_c = nn.Sequential(
            InceptionC(1280),
            InceptionC(2048)
        )
        self.adapt_pool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout()
        self.linear = nn.Linear(2048,n_classes)
    def forward(self,x):
        out = self.stem(x)
        out = self.inception_a(out)
        out = self.reduction_a(out)
        out = self.inception_b(out)

        out = self.reduction_b(out)
        out = self.inception_c(out)
        out = self.adapt_pool(out)
        out = out.view(out.size(0),-1)
        out = self.linear(out)
        return out

In [None]:
net = InceptionV3(10)
summary(net,(3,229,229))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 229, 229]             864
       BatchNorm2d-2         [-1, 32, 229, 229]              64
              ReLU-3         [-1, 32, 229, 229]               0
       BasicConv2d-4         [-1, 32, 229, 229]               0
            Conv2d-5         [-1, 32, 229, 229]           9,216
       BatchNorm2d-6         [-1, 32, 229, 229]              64
              ReLU-7         [-1, 32, 229, 229]               0
       BasicConv2d-8         [-1, 32, 229, 229]               0
            Conv2d-9         [-1, 64, 229, 229]          18,432
      BatchNorm2d-10         [-1, 64, 229, 229]             128
             ReLU-11         [-1, 64, 229, 229]               0
      BasicConv2d-12         [-1, 64, 229, 229]               0
        MaxPool2d-13         [-1, 64, 114, 114]               0
           Conv2d-14         [-1, 80, 1