In [1]:
import torch
from torch import nn
from torchvision import models

In [2]:
class SeparaConv(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels):
        super().__init__()
        self.depth = nn.Conv2d(in_channels,
                               in_channels,
                               kernel_size=3,
                               padding=1,
                               stride=1,
                               groups=in_channels)
        self.point =nn.Conv2d(in_channels,
                              out_channels,
                              kernel_size=1,
                              padding=0,
                              stride=1)
    def forward(self,x):
        out =self.depth(x)
        out = self.point(out)
        return out
    
class SeparaBlock(nn.Module):
    def __init__(self,
                 in_channels,
                out_channels,
                downsample=True,
                first_layer=False):
        super().__init__()
        if first_layer:
            self.layer1=SeparaConv(in_channels,out_channels)
        else:
            self.layer1=nn.Sequential(nn.ReLU(),
                                    SeparaConv(in_channels,out_channels))
        
        self.layer2=nn.Sequential(nn.ReLU(),
                                SeparaConv(out_channels,out_channels))
        self.size = nn.ModuleList()
        if downsample:
            self.layer3 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
            self.size.append(nn.Conv2d(in_channels,out_channels,kernel_size=1,padding=0,stride=2))
        else:
            self.layer3= nn.Sequential(nn.ReLU(),
                                SeparaConv(out_channels,out_channels))
    def forward(self,x):
        residual = x
        out =self.layer1(x)
        out = self.layer2(out)
        out =self.layer3(out)
        for layer in self.size:
            residual  =layer(x)
        return out+residual

In [3]:
class Entry(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(
                nn.Conv2d(3,32,kernel_size=3,stride=2,padding=1),
                 nn.ReLU(),
                 nn.Conv2d(32,64,kernel_size=3),
                nn.ReLU()
                                 )
        self.conv = nn.Conv2d(3,32,kernel_size=3,stride=2,padding=1)
        self.block1 = SeparaBlock(64,128,first_layer=True)
        self.block2 = SeparaBlock(128,256)
        self.block3 =SeparaBlock(256,728)
    def forward(self,x):
        #out = self.conv(x)
        out =self.stem(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        return out

In [6]:
class MiddleBlock(nn.Module):
    def __init__(self,in_channels,nums):
        super().__init__()
        self.blocks = nn.ModuleList()
        for i in range(nums):
            self.blocks.append(SeparaBlock(728,728,downsample=False))
        
    def forward(self,x):
        out = x
        for layer in self.blocks:
            out = layer(out)
        return out
class ExitBlock(nn.Module):
    def __init__(self,n_classes):
        super().__init__()
        self.n_classes = n_classes
        self.block1 = SeparaBlock(728,1024)
        
        self.block2 = nn.Sequential(
          SeparaConv(1024,1536),
          nn.ReLU(inplace=True),
          SeparaConv(1536,2048),
          nn.ReLU(inplace=True),
          nn.AdaptiveAvgPool2d(1),
          nn.Conv2d(2048,self.n_classes,kernel_size=1,padding=0,stride=1) 
        )
    def forward(self,x):
        out =self.block1(x)
        out = self.block2(out)
        return out
class MyXcep(nn.Module):
    
    def __init__(self,n_classes):
        super().__init__()
        self.entry = Entry()
        self.mid = MiddleBlock(728,8)
        self.ext = ExitBlock(n_classes)
    def forward(self,x):
        out = self.entry(x)
        out = self.mid(out)
        out = self.ext(out)
        return out

In [7]:
test=MyXcep(1000)
x = torch.randn(1,3,299,299)
test.eval()
torch.onnx.export(test,x,'myxc.onnx')
