In [94]:
from nn.blocks.residuals import ResidualAdd
import torch.nn as nn
from collections import OrderedDict
import torch
from torch import Tensor

In [48]:
class Residual(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Identity()
        self.shortcut = nn.Identity()
        self.res_func = None
    def forward(self, x: Tensor) -> Tensor:
        res = x
        if self.shortcut is not None:
            res = self.shortcut(res)
        x = self.block(x)
        if self.res_func is not None:
            x = self.res_func(x, res)
        return x

In [51]:

class ResNetBasicBlock(Residual):
    def __init__(self, in_features, out_features, activation: nn.Module = nn.ReLU()):
        super().__init__()
        self.block = nn.Sequential(OrderedDict(
            {'conv1': nn.Conv2d(in_features, out_features, kernel_size=3, stride=2),
             'bn1': nn.BatchNorm2d(out_features),
             'act1': activation,
             'conv1': nn.Conv2d(out_features, out_features, kernel_size=3),
             'bn1': nn.BatchNorm2d(out_features)}))
        
        self.shortcut = nn.Sequential(OrderedDict({
            'conv': nn.Conv2d(in_features, out_features, kernel_size=1, stride=2, bias=False),
            'bn': nn.BatchNorm2d(out_features)
        }))


In [165]:

class ResNetBasicBlock(nn.Module):
    def __init__(self, in_features, out_features, activation: nn.Module = nn.ReLU(inplace=True)):
        super().__init__()
        self.in_features, self.out_features = in_features, out_features
        self.block = ResidualAdd(
            nn.Sequential(OrderedDict(
            {'conv1': nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1, bias=False),
             'bn1': nn.BatchNorm2d(out_features),
             'act1': activation,
             'conv2': nn.Conv2d(out_features, out_features, kernel_size=3, padding=1, bias=False),
             'bn2': nn.BatchNorm2d(out_features),
            })),
             shortcut=nn.Sequential(OrderedDict({
            'conv': nn.Conv2d(in_features, out_features, kernel_size=1, stride=2, bias=False),
            'bn': nn.BatchNorm2d(out_features) if self.should_apply_shortcut else None
        }))
        )
        self.act = activation
        
    def forward(self, x):
        x = self.block(x)
        return x
    
    @property
    def should_apply_shortcut(self):
        return self.in_features != self.out_features


In [166]:
class SEModule(nn.Module):
    def __init__(self, block: nn.Module, features: int, reduction=16):
        super().__init__()
        self.block = block
        self.se = SELayer(features)

    def forward(self, x):
        x = self.block(x)
        x = self.se(x)
        return x

In [167]:
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [168]:
res = ResNetBasicBlock(32, 64)
se_res = SEModule(res, features=64)

In [160]:
se_res(torch.rand(1, 32,40,40))

tensor([[[[0.0000, 0.0000, 0.6738,  ..., 0.0000, 0.3328, 0.0000],
          [0.0000, 0.0889, 0.0000,  ..., 0.0000, 0.0000, 0.5379],
          [0.0000, 0.1901, 0.3449,  ..., 0.0000, 0.6459, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.7397, 0.0000, 0.0000],
          [0.0000, 0.0000, 1.1589,  ..., 0.0000, 0.0000, 0.0182],
          [0.0411, 0.5858, 0.0000,  ..., 0.4667, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.7064, 0.0000, 0.0000],
          [0.0000, 1.2824, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.3845,  ..., 0.4785, 0.4482, 0.0000],
          ...,
          [0.0000, 0.6119, 1.6757,  ..., 0.0000, 0.0000, 0.3173],
          [0.0000, 0.0000, 0.0000,  ..., 0.0094, 0.0000, 1.1220],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0109, 0.0000]],

         [[0.0000, 0.8985, 0.0000,  ..., 0.5922, 0.0000, 0.5688],
          [0.1740, 0.7287, 0.0000,  ..., 0.0000, 0.4529, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0