In [23]:
import torch.nn as nn
import torch

In [55]:
class HSBlock(nn.Module):
    def __init__(self, w:int, split:int, stride:int=1) -> None:
        super(HSBlock, self).__init__()
        self.split_list = []
        self.last_split = None
        self.w = w
        self.split = split
        self.stride = stride

    def forward(self, x):
        self.last_split = None
        channels = x.shape[1]
        assert channels == self.w*self.split, f'input channels({channels}) is not equal to w({self.w})*split({self.split})'
        self.split_list.append(x[:, 0:self.w, :, :])
        for s in range(1, self.split):
            if self.last_split is None:
                x1, x2 = self._split(x[:, s*self.w:(s+1)*self.w, :, :])
                self.split_list.append(x1)
                self.last_split = x2
            else:
                temp = torch.cat([self.last_split, x[:, s*self.w:(s+1)*self.w, :, :]], dim=1)
                ops = nn.Sequential(
                    nn.Conv2d(temp.shape[1], temp.shape[1], kernel_size=3, padding=1, stride=self.stride),
                    nn.BatchNorm2d(temp.shape[1]),
                    nn.ReLU(inplace=True)
                )
                temp = ops(temp)
                x1, x2 = self._split(temp)
                del temp
                self.split_list.append(x1)
                self.last_split = x2
        self.split_list.append(self.last_split)
        return torch.cat(self.split_list, dim=1)

    def _split(self, x):
        channels = int(x.shape[1]/2)
        return x[:, 0:channels, :, :], x[:, channels:, :, :]

class BottleNeck(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, split:int, stride:int=1) -> None:
        super(BottleNeck, self).__init__()
        self.w = max(2**(split-2), 1)
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, self.w*split, kernel_size=1, stride=stride),
            nn.BatchNorm2d(self.w*split),
            nn.ReLU(inplace=True),
            HSBlock(self.w, split, stride),
            nn.BatchNorm2d(self.w*split),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.w*split, out_channels, kernel_size=1, stride=stride),
            nn.BatchNorm2d(out_channels)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

In [56]:
x = torch.ones(1, 3, 28, 28)
bn = BottleNeck(3, 3, 5)
y = bn(x)
y.shape

torch.Size([1, 3, 28, 28])