<a href="https://colab.research.google.com/github/IANGECHUKI176/deeplearning/blob/main/pytorch/convnets/inceptionv4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi

Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning
> https://arxiv.org/abs/1602.07261

In [208]:
import torch
import torch.nn as nn
from torchsummary import summary

In [209]:
class BasicConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,**kwargs):
        super(BasicConv2d,self).__init__()

        self.conv = nn.Conv2d(in_channels,out_channels,bias= False,**kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace = True)

    def forward(self,x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        return out

In [210]:
BasicConv2d(3,32,kernel_size = 3)

BasicConv2d(
  (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
)

In [211]:
class InceptionStem(nn.Module):
    #"""Figure 3. The schema for stem of the pure Inception-v4 and
    #Inception-ResNet-v2 networks. This is the input part of those
    #networks."""
    def __init__(self,in_channels):
        super(InceptionStem,self).__init__()

        self.conv1 = nn.Sequential(
            BasicConv2d(in_channels,32,kernel_size = 3,stride = 2),
            BasicConv2d(32,32,kernel_size = 3),
            BasicConv2d(32,64,kernel_size = 3,padding = 1)
        )

        self.branch3x3_conv = BasicConv2d(64,96,kernel_size = 3,stride = 2)
        self.branch3x3_pool = nn.MaxPool2d(3,stride = 2)

        self.branch7x7a = nn.Sequential(
            BasicConv2d(160,64,kernel_size = 1),
            BasicConv2d(64,64,kernel_size = (7,1),padding = (3,0)),
            BasicConv2d(64,64,kernel_size = (1,7),padding = (0,3)),
            BasicConv2d(64,96,kernel_size = 3),
        )
        self.branch7x7b = nn.Sequential(
            BasicConv2d(160,64,kernel_size = 1),
            BasicConv2d(64,96,kernel_size = 3)
        )

        self.branchpool_a = nn.MaxPool2d(kernel_size = 3,stride = 2)
        self.branchpool_b = BasicConv2d(192,192,kernel_size = 3,stride = 2)
    def forward(self,x):
        out = self.conv1(x)
        out1 = self.branch3x3_conv(out)
        out2 = self.branch3x3_pool(out)
        cat = torch.cat([out1,out2],1)


        branch7x7b = self.branch7x7b(cat)

        branch7x7a = self.branch7x7a(cat)

        concat2 = torch.cat([branch7x7b,branch7x7a],1)

        branchpool_a = self.branchpool_a(concat2)

        branchpool_b = self.branchpool_b(concat2)

        final = torch.cat([branchpool_a,branchpool_b],1)

        return final

In [212]:
class InceptionA(nn.Module):
    #"""Figure 4. The schema for 35 × 35 grid modules of the pure
    #Inception-v4 network. This is the Inception-A block of Figure 9."""
    def __init__(self,in_channels):
        super(InceptionA,self).__init__()
        self.branch3x3stack = nn.Sequential(
            BasicConv2d(in_channels,64,kernel_size = 1),
            BasicConv2d(64,96,kernel_size = 3,padding = 1),
            BasicConv2d(96,96,kernel_size = 3,padding = 1)
        )

        self.branch3x3 = nn.Sequential(
            BasicConv2d(in_channels,64,kernel_size = 1),
            BasicConv2d(64,96,kernel_size = 3,padding = 1)
        )

        self.branch1x1 = BasicConv2d(in_channels,96,kernel_size = 1)

        self.branchpool = nn.Sequential(
            nn.AvgPool2d(kernel_size = 3,stride = 1,padding = 1),
            BasicConv2d(in_channels,96,kernel_size = 1)
        )
    def forward(self,x):
        branch3x3stack = self.branch3x3stack(x)
        branch3x3 = self.branch3x3(x)
        branch1x1 = self.branch1x1(x)
        branchpool = self.branchpool(x)

        out = [
            branch3x3stack,
            branch3x3,
            branch1x1,
            branchpool
        ]
        out = torch.cat(out,1)
        return out

In [213]:
class ReductionA(nn.Module):
    #"""Figure 7. The schema for 35 × 35 to 17 × 17 reduction module.
    #Different variants of this blocks (with various number of filters)
    #are used in Figure 9, and 15 in each of the new Inception(-v4, - ResNet-v1,
    #-ResNet-v2) variants presented in this paper. The k, l, m, n numbers
    #represent filter bank sizes which can be looked up in Table 1.
    def __init__(self,in_channels,k,l,m,n):
        super(ReductionA,self).__init__()
        self.branch3x3stack = nn.Sequential(
            BasicConv2d(in_channels,k,kernel_size = 1),
            BasicConv2d(k,l,kernel_size = 3,padding = 1),
            BasicConv2d(l,m,kernel_size = 3,stride = 2)
        )
        self.branch3x3 = BasicConv2d(in_channels,n,kernel_size = 3,stride = 2)
        self.branchpool = nn.MaxPool2d(kernel_size = 3,stride= 2)
        self.output_channels = in_channels + n + m
    def forward(self,x):
        branch3x3stack = self.branch3x3stack(x)
        branch3x3 = self.branch3x3(x)

        branchpool = self.branchpool(x)

        out = [
            branch3x3stack,
            branch3x3,
            branchpool
        ]
        out = torch.cat(out,1)
        return out

In [214]:

#"""Figure 5. The schema for 17 × 17 grid modules of the pure Inception-v4 network.
#This is the Inception-B block of Figure 9."""
class InceptionB(nn.Module):
    def __init__(self,in_channels):
        super(InceptionB,self).__init__()

        self.branch7x7stack = nn.Sequential(
            BasicConv2d(in_channels,192,kernel_size= 1),
            BasicConv2d(192,192,kernel_size= (1,7),padding = (0,3)),
            BasicConv2d(192,224,kernel_size= (7,1),padding = (3,0)),
            BasicConv2d(224,224,kernel_size= (1,7),padding = (0,3)),
            BasicConv2d(224,256,kernel_size= (7,1),padding = (3,0)),
        )

        self.branch7x7 = nn.Sequential(
            BasicConv2d(in_channels,192,kernel_size= 1),
            BasicConv2d(192,224,kernel_size= (1,7),padding = (0,3)),
            BasicConv2d(224,256,kernel_size= (7,1),padding = (3,0)),
        )

        self.branch1x1 = BasicConv2d(in_channels,384,kernel_size= 1)

        self.branchpool = nn.Sequential(
            nn.AvgPool2d(kernel_size = 3,stride = 1,padding = 1),
             BasicConv2d(in_channels,128,kernel_size= 1)
        )
    def forward(self,x):
        branch3x3stack = self.branch7x7stack(x)
        branch7x7 = self.branch7x7(x)
        branch1x1 = self.branch1x1(x)
        branchpool = self.branchpool(x)
        out = [
            branch3x3stack,
            branch7x7,
            branch1x1,
            branchpool
        ]
        out = torch.cat(out,1)
        return out

In [215]:
class ReductionB(nn.Module):
     #"""Figure 8. The schema for 17 × 17 to 8 × 8 grid-reduction mod- ule.
    #This is the reduction module used by the pure Inception-v4 network in
    #Figure 9."""
    def __init__(self,in_channels):
        super(ReductionB,self).__init__()

        self.branch7x7stack = nn.Sequential(
            BasicConv2d(in_channels,256,kernel_size = 1),
            BasicConv2d(256,256,kernel_size = (7,1),padding =(3,0)),
            BasicConv2d(256,320,kernel_size = (1,7),padding = (0,3)),
            BasicConv2d(320,320,kernel_size = 3,stride = 2)
        )

        self.branch3x3 = nn.Sequential(
            BasicConv2d(in_channels,192,kernel_size = 1),
            BasicConv2d(192,192,kernel_size = 3,stride = 2)
        )

        self.branchpool = nn.MaxPool2d(kernel_size = 3,stride = 2)
    def forward(self,x):
        branch3x3stack = self.branch7x7stack(x)
        branch3x3 = self.branch3x3(x)
        branchpool = self.branchpool(x)
        return torch.cat([branch3x3stack,branch3x3,branchpool],1)

In [216]:
class InceptionC(nn.Module):
    #"""Figure 6. The schema for 8×8 grid modules of the pure
    #Inceptionv4 network. This is the Inception-C block of Figure 9."""
    def __init__(self,in_channels):
        super(InceptionC,self).__init__()
        self.branch3x3stack = nn.Sequential(
            BasicConv2d(in_channels,256,kernel_size = 1),
            BasicConv2d(256,448,kernel_size = (1,3),padding = (0,1)),
            BasicConv2d(448,512,kernel_size = (3,1),padding = (1,0))

        )
        self.branch3x3stacka = BasicConv2d(512,256,kernel_size = (1,3),padding = (0,1))
        self.branch3x3stackb = BasicConv2d(512,256,kernel_size = (3,1),padding = (1,0))


        self.branch1x1 = BasicConv2d(in_channels,256,kernel_size = 1)

        self.branch3x3 = BasicConv2d(in_channels,384,kernel_size = 1)
        self.branch3x3a = BasicConv2d(384,256,kernel_size = (1,3),padding = (0,1))
        self.branch3x3b = BasicConv2d(384,256,kernel_size = (3,1),padding = (1,0))

        self.branchpool = nn.Sequential(
            nn.AvgPool2d(kernel_size = 3,stride = 1,padding = 1),
            BasicConv2d(in_channels,256,kernel_size = 1)
        )

    def forward(self,x):
        branch3x3stack = self.branch3x3stack(x)

        # branch3x3stack = [
        #     self.branch3x3stacka(branch3x3stack),
        #     self.branch3x3stackb(branch3x3stack)
        # ]
        # branch3x3stack_concat = torch.cat(branch3x3stack,1)
        branch3x3stacka = self.branch3x3stacka(branch3x3stack)
        branch3x3stackb = self.branch3x3stackb(branch3x3stack)

        branch3x3stack_concat = torch.cat([branch3x3stacka,branch3x3stackb],1)

        branch3x3 = self.branch3x3(x)
        branch3x3a = self.branch3x3a(branch3x3)
        branch3x3b = self.branch3x3b(branch3x3)

        brach3x3_concat = torch.cat([branch3x3a,branch3x3b],1)


        branch1x1 = self.branch1x1(x)
        branchpool = self.branchpool(x)

        out = torch.cat([branch3x3stack_concat,brach3x3_concat,branch1x1,branchpool],1)

        return out
       # print("branch3x3stack_concat",branch3x3stack_concat.shape)

In [217]:
class InceptionV4(nn.Module):
    #Figure 9. The overall schema of the Inception-v4 network
    def __init__(self,A, B, C, k=192, l=224, m=256, n=384, num_classes=10):
        super(InceptionV4,self).__init__()

        self.stem = InceptionStem(3)
        self.inception_a = self._generate_inception_module(384,384,A,InceptionA)
        self.reduction_a = ReductionA(384,k,l,m,n)
        output_channels = self.reduction_a.output_channels
        self.inception_b = self._generate_inception_module(output_channels,1024,B,InceptionB)
        self.reduction_b = ReductionB(1024)
        self.inception_c = self._generate_inception_module(1536,1536,C,InceptionC)
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        #keep 0.8
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(1536,num_classes)
    def _generate_inception_module(self,input_channels,output_channels,block_num,block):
        layers = nn.Sequential()
        for l in range(block_num):
            layers.add_module("{}_{}".format(block.__name__, l), block(input_channels))
            input_channels = output_channels

        return layers
    def forward(self,x):
        out = self.stem(x)
        out = self.inception_a(out)
        out = self.reduction_a(out)
        out = self.inception_b(out)
        out = self.reduction_b(out)
        out = self.inception_c(out)
        out = self.pool(out)
        out = self.dropout(out)
        out = out.view(out.size(0),-1)
        out = self.linear(out)
        return out

In [218]:
def inceptionv4():
    return InceptionV4(4, 7, 3)
net = inceptionv4()
summary(net,(3,299,299))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 149, 149]             864
       BatchNorm2d-2         [-1, 32, 149, 149]              64
              ReLU-3         [-1, 32, 149, 149]               0
       BasicConv2d-4         [-1, 32, 149, 149]               0
            Conv2d-5         [-1, 32, 147, 147]           9,216
       BatchNorm2d-6         [-1, 32, 147, 147]              64
              ReLU-7         [-1, 32, 147, 147]               0
       BasicConv2d-8         [-1, 32, 147, 147]               0
            Conv2d-9         [-1, 64, 147, 147]          18,432
      BatchNorm2d-10         [-1, 64, 147, 147]             128
             ReLU-11         [-1, 64, 147, 147]               0
      BasicConv2d-12         [-1, 64, 147, 147]               0
           Conv2d-13           [-1, 96, 73, 73]          55,296
      BatchNorm2d-14           [-1, 96,

Inception-Resnet

In [219]:
class InceptionResnet_A(nn.Module):
    #"""Figure 16. The schema for 35 × 35 grid (Inception-ResNet-A)
    #module of the Inception-ResNet-v2 network."""
    def __init__(self,in_channels):
        super(InceptionResnet_A,self).__init__()

        self.branch3x3stack = nn.Sequential(
            BasicConv2d(in_channels,32,kernel_size = 1),
            BasicConv2d(32,48,kernel_size = 3,padding = 1),
            BasicConv2d(48,64,kernel_size = 3,padding = 1)
        )
        self.branch3x3 = nn.Sequential(
            BasicConv2d(in_channels,32,kernel_size = 1),
            BasicConv2d(32,32,kernel_size = 3,padding = 1)
        )

        self.branch1x1 = BasicConv2d(in_channels,32,kernel_size = 1)

        self.reduction1x1 = BasicConv2d(128,384,kernel_size = 1)
        self.shortcut = BasicConv2d(in_channels,384,kernel_size = 1)
        self.bn = nn.BatchNorm2d(384)
        self.relu = nn.ReLU(inplace = True)

    def forward(self,x):
        branch3x3stack = self.branch3x3stack(x)
        branch3x3 = self.branch3x3(x)
        branch1x1 = self.branch1x1(x)
        all_concat = torch.cat([branch3x3stack,branch3x3,branch1x1],1)
        residual = self.reduction1x1(all_concat)
        shortcut = self.shortcut(x)
        output = self.relu(self.bn(residual + shortcut))
        return output

In [220]:
class InceptionResnet_B(nn.Module):
    #"""Figure 17. The schema for 17 × 17 grid (Inception-ResNet-B) module of
    #the Inception-ResNet-v2 network."""
    def __init__(self,in_channels):
        super(InceptionResnet_B,self).__init__()

        self.branch7x7 = nn.Sequential(
            BasicConv2d(in_channels,128,kernel_size = 1),
            BasicConv2d(128,160,kernel_size = (1,7),padding = (0,3)),
            BasicConv2d(160,192,kernel_size = (7,1),padding = (3,0))
        )

        self.branch1x1 = BasicConv2d(in_channels,192,kernel_size = 1)

        self.reduction1x1 = BasicConv2d(384,1154,kernel_size = 1)
        self.shortcut = BasicConv2d(in_channels,1154,kernel_size = 1)
        self.bn = nn.BatchNorm2d(1154)
        self.relu = nn.ReLU(inplace = True)

    def forward(self,x):
        branch7x7 = self.branch7x7(x)

        branch1x1 = self.branch1x1(x)
        all_concat = torch.cat([branch7x7,branch1x1],1)
        #"""In general we picked some scaling factors between 0.1 and 0.3 to scale the residuals
        #before their being added to the accumulated layer activations (cf. Figure 20)."""
        residual = self.reduction1x1(all_concat) * 0.1
        shortcut = self.shortcut(x)
        output = self.relu(self.bn(residual + shortcut))
        return output

In [221]:
class InceptionResnet_C(nn.Module):
    #Figure 19. The schema for 8×8 grid (Inception-ResNet-C)
    #module of the Inception-ResNet-v2 network."""
    def __init__(self,in_channels):
        super(InceptionResnet_C,self).__init__()

        self.branch3x3 = nn.Sequential(
            BasicConv2d(in_channels,192,kernel_size = 1),
            BasicConv2d(192,224,kernel_size = (1,3),padding = (0,1)),
            BasicConv2d(224,256,kernel_size = (3,1),padding = (1,0))
        )

        self.branch1x1 = BasicConv2d(in_channels,192,kernel_size = 1)

        self.reduction1x1 = BasicConv2d(448,2048,kernel_size = 1)
        self.shortcut = BasicConv2d(in_channels,2048,kernel_size = 1)
        self.bn = nn.BatchNorm2d(2048)
        self.relu = nn.ReLU(inplace = True)

    def forward(self,x):
        branch3x3 = self.branch3x3(x)

        branch1x1 = self.branch1x1(x)
        all_concat = torch.cat([branch3x3,branch1x1],1)
        #"""In general we picked some scaling factors between 0.1 and 0.3 to scale the residuals
        #before their being added to the accumulated layer activations (cf. Figure 20)."""
        residual = self.reduction1x1(all_concat) * 0.1
        shortcut = self.shortcut(x)
        output = self.relu(self.bn(residual + shortcut))
        return output

In [222]:
class InceptionResNetReductionA(nn.Module):
    #"""Figure 7. The schema for 35 × 35 to 17 × 17 reduction module.
    #Different variants of this blocks (with various number of filters)
    #are used in Figure 9, and 15 in each of the new Inception(-v4, - ResNet-v1,
    #-ResNet-v2) variants presented in this paper. The k, l, m, n numbers
    #represent filter bank sizes which can be looked up in Table 1.
    def __init__(self,in_channels,k,l,m,n):
        super(InceptionResNetReductionA,self).__init__()

        self.branch3x3stack = nn.Sequential(
            BasicConv2d(in_channels,k,kernel_size = 1),
            BasicConv2d(k,l,kernel_size = 3,padding = 1),
            BasicConv2d(l,m,kernel_size = 3,stride = 2)
        )

        self.branch3x3 = BasicConv2d(in_channels,n,kernel_size = 3,stride = 2)
        self.branchpool = nn.MaxPool2d(kernel_size = 3,stride= 2)
        self.output_channels = in_channels + n + m
    def forward(self,x):
        branch3x3stack = self.branch3x3stack(x)
        branch3x3 = self.branch3x3(x)

        branchpool = self.branchpool(x)

        out = [
            branch3x3stack,
            branch3x3,
            branchpool
        ]
        out = torch.cat(out,1)
        return out

In [223]:
class InceptionResNetReductionB(nn.Module):
    #"""Figure 18. The schema for 17 × 17 to 8 × 8 grid-reduction module.
    #Reduction-B module used by the wider Inception-ResNet-v1 network in
    #Figure 15."""
    #I believe it was a typo(Inception-ResNet-v1 should be Inception-ResNet-v2)
    def __init__(self,in_channels):
        super(InceptionResNetReductionB,self).__init__()

        self.branchpool =  nn.MaxPool2d(3, stride=2)
        self.branch3x3a = nn.Sequential(
            BasicConv2d(in_channels,256,kernel_size = 1),
            BasicConv2d(256,384,kernel_size = 3,stride = 2)
        )
        self.branch3x3b = nn.Sequential(
            BasicConv2d(in_channels,256,kernel_size = 1),
            BasicConv2d(256,288,kernel_size = 3,stride = 2)
        )

        self.branch3x3stack = nn.Sequential(
            BasicConv2d(in_channels,256,kernel_size = 1),
            BasicConv2d(256,288,kernel_size = 3,padding = 1),
            BasicConv2d(288,320,kernel_size = 3,stride = 2)
        )

        self.reduction1x1 = BasicConv2d(128,384,kernel_size = 1)
        self.shortcut = BasicConv2d(in_channels,384,kernel_size = 1)
        self.bn = nn.BatchNorm2d(384)
        self.relu = nn.ReLU(inplace = True)

    def forward(self,x):
        branchpool = self.branchpool(x)
        branch3x3a = self.branch3x3a(x)
        branch3x3b = self.branch3x3b(x)
        branch3x3stack = self.branch3x3stack(x)


        return torch.cat([branchpool,branch3x3a,branch3x3b,branch3x3stack],1)


In [224]:
class InceptionResNetV2(nn.Module):
    def __init__(self, A, B, C, k=256, l=256, m=384, n=384, n_classes=10):
        super(InceptionResNetV2,self).__init__()
        self.stem = InceptionStem(3)
        self.inception_a = self._generate_inception_module(384,384,A,InceptionResnet_A)
        self.reduction_a = InceptionResNetReductionA(384, k, l, m, n)
        output_channels = self.reduction_a.output_channels
        self.inception_b = self._generate_inception_module(output_channels,1154,B,InceptionResnet_B)
        self.reduction_b = InceptionResNetReductionB(1154)
        self.inception_c = self._generate_inception_module(2146, 2048, C, InceptionResnet_C)

        self.pool = nn.AdaptiveAvgPool2d((1,1))
        #keep 0.8
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(2048,n_classes)
    def forward(self,x):
        out = self.stem(x)
        out = self.inception_a(out)
        out = self.reduction_a(out)
        out = self.inception_b(out)
        out = self.reduction_b(out)
        out = self.inception_c(out)
        out = self.pool(out)
        out = self.dropout(out)
        out = out.view(out.size(0),-1)
        out = self.linear(out)
        return out
    @staticmethod
    def _generate_inception_module(input_channels, output_channels, block_num, block):

        layers = nn.Sequential()
        for l in range(block_num):
            layers.add_module("{}_{}".format(block.__name__, l), block(input_channels))
            input_channels = output_channels

        return layers

In [225]:
def inception_resnet_v2():
    return InceptionResNetV2(5, 10, 5)

resnet = inception_resnet_v2()
summary(resnet,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 111, 111]             864
       BatchNorm2d-2         [-1, 32, 111, 111]              64
              ReLU-3         [-1, 32, 111, 111]               0
       BasicConv2d-4         [-1, 32, 111, 111]               0
            Conv2d-5         [-1, 32, 109, 109]           9,216
       BatchNorm2d-6         [-1, 32, 109, 109]              64
              ReLU-7         [-1, 32, 109, 109]               0
       BasicConv2d-8         [-1, 32, 109, 109]               0
            Conv2d-9         [-1, 64, 109, 109]          18,432
      BatchNorm2d-10         [-1, 64, 109, 109]             128
             ReLU-11         [-1, 64, 109, 109]               0
      BasicConv2d-12         [-1, 64, 109, 109]               0
           Conv2d-13           [-1, 96, 54, 54]          55,296
      BatchNorm2d-14           [-1, 96,