In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import glob
import PIL
from PIL import Image
from torch.utils import data as D
from torch.utils.data.sampler import SubsetRandomSampler
import random

In [2]:
batch_size = 16
validation_ratio = 0.1
random_seed = 10

In [3]:
class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout, kernel_size, stride, padding, bias=False):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel_size, stride=stride, padding=padding, groups=nin, bias=bias)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

In [4]:
class _DSPPB(nn.Module):
    def __init__(self, channel):
        super(_DSPPB, self).__init__()
        
        self.conv1 = nn.Conv2d(channel, channel, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1, dilation=1)
        self.conv4 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=2, dilation=2)
        self.pool = nn.AvgPool2d(kernel_size=3, stride=3)
        
        self.conv = nn.Conv2d(5*channel, channel, 1, bias=False)
        self.bn = nn.BatchNorm2d(channel)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x4 = self.conv4(x)
        x5 = self.pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)
        #x = x1 + x2 + x3 + x4 + x5
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

In [130]:
class Xception(nn.Module):
    def __init__(self, input_channel, num_classes=10):
        super(Xception, self).__init__()
        
        # Entry Flow
        self.conv1 = nn.Conv2d(input_channel, 16, kernel_size=3, dilation=1, padding=1)
        self.conv2 = nn.Conv2d(input_channel, 16, kernel_size=3, dilation=2, padding=2)
        self.conv3 = nn.Conv2d(input_channel, 16, kernel_size=3, dilation=3, padding=3)
        
        self.entry_flow_1 = nn.Sequential(
            nn.ReLU(True),
            nn.Conv2d(48, 48, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(True)
        )
        
        self.entry_flow_2 = nn.Sequential(
            depthwise_separable_conv(48, 96, 3, 1, 1),
            nn.ReLU(True),            
            depthwise_separable_conv(96, 96, 3, 1, 1),
            nn.ReLU(True), 
            depthwise_separable_conv(96, 96, 3, 2, 1)
        )
        
        self.entry_flow_2_residual = nn.Conv2d(48, 96, kernel_size=1, stride=2, padding=0)
        
        self.entry_flow_3 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(96, 192, 3, 1, 1),
            nn.ReLU(True),
            depthwise_separable_conv(192, 192, 3, 1, 1),
            nn.ReLU(True),
            depthwise_separable_conv(192, 192, 3, 2, 1)
        )
        
        self.entry_flow_3_residual = nn.Conv2d(96, 192, kernel_size=1, stride=2, padding=0)
        
        self.entry_flow_4 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(192, 768, 3, 1, 1),
            nn.ReLU(True),
            depthwise_separable_conv(768, 768, 3, 1, 1),
            nn.ReLU(True),
            depthwise_separable_conv(768, 768, 3, 2, 1)
        )
        
        self.entry_flow_4_residual = nn.Conv2d(192, 768, kernel_size=1, stride=2, padding=0)
        
        # Middle Flow
        self.middle_flow = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(768, 768, 3, 1, 1),
            
            nn.ReLU(True),
            depthwise_separable_conv(768, 768, 3, 1, 1),
            
            nn.ReLU(True),
            depthwise_separable_conv(768, 768, 3, 1, 1)
        )
        
        # Exit Flow
        self.exit_flow_1 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(768, 768, 3, 1, 1),
            
            nn.ReLU(True),
            depthwise_separable_conv(768, 768, 3, 1, 1),
            
            nn.ReLU(True),
            depthwise_separable_conv(768, 768, 3, 2, 1)
        )
        self.exit_flow_1_residual = nn.Conv2d(768, 768, kernel_size=1, stride=2, padding=0)
        
        self.exit_flow_2 = nn.Sequential(
            nn.ReLU(True),
            depthwise_separable_conv(768, 768, 3, 1, 1),
            nn.ReLU(True),
            
            depthwise_separable_conv(768, 768, 3, 1, 1),
            nn.ReLU(True)
        )
        
        self.dilated_block4 = _DSPPB(768)
        self.dilated_block3 = _DSPPB(768)
        self.dilated_block2 = _DSPPB(192)
        self.dilated_block1 = _DSPPB(96)
        self.dilated_block0 = _DSPPB(48)
        
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  
        
        self.upsample2= nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(768, 192, 1) 
        )
        self.upsample3= nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(192, 96, 1) 
        )
        self.upsample4= nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(96, 48, 1) 
        )
        self.upsample5= nn.Sequential(
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),
            nn.Conv2d(192, 48, 1) 
        )
        
        self.conv4 = nn.Conv2d(2*768, 768, 1)        
        self.conv5 = nn.Conv2d(2*192, 192, 1)
        self.conv6 = nn.Conv2d(2*96, 96, 1)
        self.conv7 = nn.Conv2d(2*48, 48, 1)
        
        self.conv_last = nn.Conv2d(3*48, num_classes, 1)
        
    def forward(self, x):
        x1=self.conv1(x)
        x2=self.conv2(x)        
        x3=self.conv3(x)        
        x=torch.cat([x1, x2, x3], dim=1)
        
        entry_out1 = self.entry_flow_1(x)
        stage0=entry_out1
        
        entry_out2 = self.entry_flow_2(entry_out1) + self.entry_flow_2_residual(entry_out1)
        stage1=entry_out2
        
        entry_out3 = self.entry_flow_3(entry_out2) + self.entry_flow_3_residual(entry_out2)
        stage2=entry_out3
        
        entry_out = self.entry_flow_4(entry_out3) + self.entry_flow_4_residual(entry_out3)
        
        middle_out = self.middle_flow(entry_out) + entry_out
        stage3=middle_out

        exit_out1 = self.exit_flow_1(middle_out) + self.exit_flow_1_residual(middle_out)
        exit_out2 = self.exit_flow_2(exit_out1)
        stage4=exit_out2
        
        x = self.dilated_block4(stage4) #16*16*768
        x = self.upsample1(x)           #32*32*768 
        y = self.dilated_block3(stage3) #32*32*768  
        x = torch.cat([y, x], dim=1)
        x = self.conv4(x)               #32*32*768
        x = self.upsample2(x)   
        y = self.dilated_block2(stage2) #64*64*192        
        x = torch.cat([y, x], dim=1)
        x6 = self.conv5(x)              #64*64*192 
        
        x = self.upsample3(x6) #128*128*96  
        y = self.dilated_block1(stage1)   
       
        x = torch.cat([y, x], dim=1)
        x7 = self.conv6(x) #128*128*96  
            
        x = self.upsample4(x7)          #256*256*48         
        y = self.dilated_block0(stage0)        
        x = torch.cat([y, x], dim=1)
        x8 = self.conv7(x)              #256*256*48
        
        x9=self.upsample5(x6)           #256*256*48
        x10=self.upsample4(x7)          #256*256*48
        x=torch.cat([x8, x9, x10], dim=1)
        
        out = self.conv_last(x)        
        return torch.sigmoid(out)

In [128]:
def summary(model):
    params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"total param: {params}\ntrainable param: {trainable_params}")

In [129]:
model = Xception(3, 1)
#print(model)
summary(model) 
input = torch.rand([3, 256, 256])
input = input[None, ...]
print(f"input shape {input.shape}")
model.eval()
output = model(input)
print(f"output shape {output.shape}")

total param: 49016545
trainable param: 49016545
input shape torch.Size([1, 3, 256, 256])
output shape torch.Size([1, 1, 256, 256])
