In [1]:
import torch
from torch import nn


bn_eps = 0.0001
bn_momentum = 0.03
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=None, acti="leaky"):
        super().__init__()
        if padding is None:
            padding = kernel_size // 2
            
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels, eps=bn_eps, momentum=bn_momentum)
        if acti == "relu":
            self.acti = nn.ReLU(inplace=True)
        elif acti == "leaky":
            self.acti = nn.LeakyReLU(0.1015625, inplace=True)
        # elif acti == "mish":
        #     self.acti = Mish()
            
        self.fused = False
            
    def forward(self, x):
        if not self.training and self.fused:
            return self.acti(self.fused_conv[0](x))
        else :
            return self.acti(self.bn(self.conv(x)))

class Focus(nn.Module):
    def __init__(self):
        super().__init__()
        # self.conv = Conv(4 * in_channels, out_channels, kernel_size)
        # self.out_channels = out_channels
        
    def forward(self, x):
        # print(f'in foucus : {x.shape}')
        concat = torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1)
        return concat

In [2]:
class SplitSpatial(nn.Module):
    def __init__(self,in_ch):
        super(SplitSpatial, self).__init__()
        self.conv1 = nn.Conv2d(in_ch, in_ch, kernel_size=2, stride=2, bias=False).requires_grad_(False)  
        self.conv2 = nn.Conv2d(in_ch, in_ch, kernel_size=2, stride=2, bias=False).requires_grad_(False)  
        self.conv3 = nn.Conv2d(in_ch, in_ch, kernel_size=2, stride=2, bias=False).requires_grad_(False)  
        self.conv4 = nn.Conv2d(in_ch, in_ch, kernel_size=2, stride=2, bias=False).requires_grad_(False)

        with torch.no_grad():
            wts1 = torch.zeros(in_ch, in_ch, 2,2)
            wts2 = torch.zeros(in_ch, in_ch, 2,2)
            wts3 = torch.zeros(in_ch, in_ch, 2,2)
            wts4 = torch.zeros(in_ch, in_ch, 2,2)
            for i in range(in_ch):
                wts1[i, i, 0, 0] = 1
                wts2[i, i, 1, 0] = 1
                wts3[i, i, 0, 1] = 1
                wts4[i, i, 1, 1] = 1

            self.conv1.weight.copy_(wts1)
            self.conv2.weight.copy_(wts2)
            self.conv3.weight.copy_(wts3)
            self.conv4.weight.copy_(wts4)
            
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x4 = self.conv4(x)
        return torch.cat((x1, x2, x3, x4), dim=1)
    

In [3]:
torch.manual_seed(42)
tens = torch.rand(1,3,320,320)
default = Focus()
own = SplitSpatial(in_ch=3)

def_out = default(tens)
own_out = own(tens)
def_out.shape, own_out.shape, torch.equal(def_out, own_out)

(torch.Size([1, 12, 160, 160]), torch.Size([1, 12, 160, 160]), True)