<a href="https://colab.research.google.com/github/IANGECHUKI176/deeplearning/blob/main/pytorch/convnets/Xception.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


François Chollet

Xception: Deep Learning with Depthwise Separable Convolutions

[original_paper](https://arxiv.org/abs/1610.02357)

In [None]:
import torch
import torch.nn as nn
from torchsummary import summary

In [None]:
class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,**kwargs):
        super(SeparableConv2d,self).__init__()
        self.depthwise = nn.Conv2d(in_channels,in_channels,kernel_size,groups = in_channels,bias = False,**kwargs)
        self.pointwise = nn.Conv2d(in_channels,out_channels,kernel_size = 1,bias = False)
    def forward(self,x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

In [None]:
#blk = SeparableConv2d(3,32,3)

In [None]:
class EntryFlow(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3,32,3,padding = 1,bias = False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace = True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32,64,3,padding = 1,bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace = True)
        )
        self.conv3_residual = nn.Sequential(
            SeparableConv2d(64,128,3,padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace= True),
            SeparableConv2d(128,128,3,padding = 1),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(3,stride = 2,padding = 1)
        )
        self.conv3_shortcut = nn.Sequential(
            nn.Conv2d(64,128,1,stride = 2,bias = False),
            nn.BatchNorm2d(128)
        )
        self.conv4_residual = nn.Sequential(
            nn.ReLU(inplace = True),
            SeparableConv2d(128,256,3,padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace = True),
            SeparableConv2d(256,256,3,padding = 1),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(3,stride = 2,padding = 1)
        )
        self.conv4_shortcut = nn.Sequential(
            nn.Conv2d(128,256,1,stride = 2,bias = False),
            nn.BatchNorm2d(256)
        )
        self.conv5_residual = nn.Sequential(
            nn.ReLU(inplace = True),
            SeparableConv2d(256,728,3,padding = 1),
            nn.BatchNorm2d(728),
            SeparableConv2d(728,728,3,padding = 1),
            nn.BatchNorm2d(728),
            nn.MaxPool2d(3,stride = 2,padding = 1)
        )
        self.conv5_shortcut = nn.Sequential(
            nn.Conv2d(256,728,1,stride = 2,bias = False),
            nn.BatchNorm2d(728),
        )
    def forward(self,x):
        out = self.conv1(x)
        out = self.conv2(out)
        residual = self.conv3_residual(out)
        shortcut = self.conv3_shortcut(out)
        out = residual + shortcut
        residual = self.conv4_residual(out)
        shortcut = self.conv4_shortcut(out)
        out = residual + shortcut
        residual = self.conv5_residual(out)
        shortcut = self.conv5_shortcut(out)
        out = residual + shortcut
        return out

In [None]:
#blk1 = EntryFlow()
#summary(blk1,(3,224,224))

In [None]:
class MiddleFlowBlock(nn.Module):
    def __init__(self):
        super(MiddleFlowBlock,self).__init__()
        self.shortcut = nn.Sequential()

        self.conv1 = nn.Sequential(
            nn.ReLU(inplace = True),
            SeparableConv2d(728,728,3,padding = 1),
            nn.BatchNorm2d(728)
        )
        self.conv2 = nn.Sequential(
            nn.ReLU(inplace = True),
            SeparableConv2d(728,728,3,padding = 1),
            nn.BatchNorm2d(728)
        )
        self.conv3 = nn.Sequential(
            nn.ReLU(inplace = True),
            SeparableConv2d(728,728,3,padding = 1),
            nn.BatchNorm2d(728)
        )
    def forward(self,x):
        residual = self.conv1(x)
        residual = self.conv2(residual)
        residual = self.conv3(residual)

        shortcut = self.shortcut(x)

        return shortcut + residual



In [None]:
class MiddleFlow(nn.Module):
    def __init__(self,block):
        super(MiddleFlow,self).__init__()
        blocks = []
        for i in range(8):
            blocks.append(block())
        self.blocks = nn.Sequential(*blocks)
    def forward(self,x):
        x = self.blocks(x)
        return x

In [None]:
class ExitFlow(nn.Module):
    def __init__(self):
        super(ExitFlow,self).__init__()

        self.residual = nn.Sequential(
            nn.ReLU(inplace = True),
            SeparableConv2d(728,728,3,padding = 1),
            nn.BatchNorm2d(728),
            nn.ReLU(inplace = True),
            SeparableConv2d(728,1024,3,padding = 1),
            nn.BatchNorm2d(1024),
            nn.MaxPool2d(kernel_size = 3,stride = 2,padding =1)
        )
        self.shortcut = nn.Sequential(
            nn.Conv2d(728,1024,1,stride = 2),
            nn.BatchNorm2d(1024)
        )
        self.conv = nn.Sequential(
            SeparableConv2d(1024,1536,3,padding = 1),
            nn.BatchNorm2d(1536),
            nn.ReLU(inplace = True),
            SeparableConv2d(1536,2048,3,padding = 1),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace = True)
        )
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
    def forward(self,x):
        residual = self.residual(x)
        shortcut = self.shortcut(x)
        out = residual + shortcut
        out = self.conv(out)
        out = self.avg_pool(out)
        return out

In [None]:
class Xception(nn.Module):
    def __init__(self,n_classes=10):
        super(Xception,self).__init__()
        self.entry_flow = EntryFlow()
        self.middle_flow = MiddleFlow(MiddleFlowBlock)
        self.exit_flow = ExitFlow()

        self.fc = nn.Linear(2048,n_classes)
    def forward(self,x):
        out = self.entry_flow(x)
        out = self.middle_flow(out)
        out = self.exit_flow(out)
        out = out.view(out.size(0),-1)
        out = self.fc(out)

In [None]:
blk = Xception(10)
summary(blk,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 224, 224]             864
       BatchNorm2d-2         [-1, 32, 224, 224]              64
              ReLU-3         [-1, 32, 224, 224]               0
            Conv2d-4         [-1, 64, 224, 224]          18,432
       BatchNorm2d-5         [-1, 64, 224, 224]             128
              ReLU-6         [-1, 64, 224, 224]               0
            Conv2d-7         [-1, 64, 224, 224]             576
            Conv2d-8        [-1, 128, 224, 224]           8,192
   SeparableConv2d-9        [-1, 128, 224, 224]               0
      BatchNorm2d-10        [-1, 128, 224, 224]             256
             ReLU-11        [-1, 128, 224, 224]               0
           Conv2d-12        [-1, 128, 224, 224]           1,152
           Conv2d-13        [-1, 128, 224, 224]          16,384
  SeparableConv2d-14        [-1, 128, 2