In [10]:
import torch.nn as nn
import torch

In [7]:
class ReLUConvBN(nn.Module):
    def __init__(self,in_ch,out_ch,kernel_size,stride,padding,affine=True):
        super(ReLUConvBN,self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch,out_ch,kernel_size,stride=stride,padding=padding,bias=False),
            nn.BatchNorm2d(out_ch,affine=affine)
        )
    def forward(self,x):
        return self.op(x)
    

In [8]:
class SepConv(nn.Module):
    def __init__(self,in_ch,out_ch,kernel_size,stride,padding,affine=True):
        super(SepConv,self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch,in_ch,kernel_size,stride=stride,padding=padding,groups=in_ch,bias=False),
            nn.Conv2d(in_ch,in_ch,kernel_size=1,padding=0,bias=False),
            nn.BatchNorm2d(in_ch,affine),
            
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch,in_ch,kernel_size=kernel_size,stride=1,padding=padding,groups=in_ch,bias=False),
            nn.Conv2d(in_ch,out_ch,kernel_size=1,padding=0,bias=False),
            nn.BatchNorm2d(out_ch,affine),
        )
    def forward(self,x):
        return self.op(x)

In [9]:
class DilConv(nn.Module):
    def __init__(self,in_ch,out_ch,kernel_size,stride,padding,dilation,affine=True):
        super(DilConv,self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch,in_ch,kernel_size,stride, padding, dilation,bias=False),
            nn.Conv2d(in_ch,out_ch,kernel_size=1,padding=0),
            nn.BatchNorm2d(out_ch),
        )
    def forward(self,x):
        return self.op(x)

In [ ]:
class FactorizedReduce(nn.Module):
    def __init__(self,in_ch,out_ch,affine=True):
        super(FactorizedReduce,self).__init__()
        assert out_ch%2 ==0
        self.conv1 = nn.Conv2d(in_ch,out_ch//2 ,kernel_size=1,stride = 2,padding=0,bias=False)
        self.conv2 = nn.Conv2d(in_ch,out_ch//2 ,kernel_size=1,stride = 2,padding=0,bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm2d(out_ch,affine)
        
    def forward(self,x):
        x = self.relu(x)
        out = torch.cat([self.conv1(x),self.conv2(x[:,:,1:,1:])],dim=1)
        out = self.bn(out)
        return out