In [1]:
import torch
from torch import nn

In [5]:
class Resnet(nn.Module):
    def __init__(self, block_counts=[2, 2, 2, 2], channels=[64, 96, 128, 256],
                 se=True, squeeze_ratio=4, groups=1, d_modification=True):
        super().__init__()
        
        self.initial_conv = ConvBlock(in_channels=3, out_channels=channels[0], ...)
        
        blocks = []
        for j, count in enumerate(block_counts):
            for i, num_blocks in enumerate(range(count)):
                stride = 2 if i == 0 else 1
                bottleneck_ratio = A if i == 0 else B
                in_channels = (channels[j-1] if j > 0 else channels[0]) if i == 0 else channels[j]
                
                block = ResBlock(in_channels=in_channels, out_channels=channels,
                                 se=se, groups=groups, d_modification=d_modification,
                                 squeeze_ratio=squeeze_ratio)
                blocks.append(block)

        self.network = nn.Sequential(*blocks)
        
        # super().__init__(*block)
        
        
    def forward(self, x):
        return self.network(x)

    
    
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 bottleneck=False, bottleneck_ratio=4,
                 se=True, reduction_ratio=4, groups=1, d_modification=True):
        #
        if bottleneck:
            main_flow = [
                ConvBlock(kernel_size=1, ),                             #conv -> bn -> relu
                ConvBlock(..., kernel_size=3, groups=groups),           #conv -> bn -> relu
                ConvBlock(..., kernel_size=1, post_activation=False),   #conv -> bn
            ]
        else:
            main_flow = [
                ConvBlock(..., kernel_size=3, groups=groups),           #conv -> bn -> relu
                ConvBlock(..., kernel_size=3, post_activation=False),   #conv -> bn
            ]

        if se:
            main_flow.append(SEBlock(reduction_ratio=reduction_ratio))

        self.main_flow = nn.Sequential(*main_flow)
        
        #
        if in_channels != out_channels or stride != 1:
            if d_modification:
                skip_flow = [nn.Identity() if stride != 2 else nn.AvgPool2d(kernel_size=stride, stride=stride)]
                skip_flow.append(ConvBlock(..., post_activation=False))
                skip_flow = nn.Sequential(*skip_flow)
            else:
                skip_flow = ConvBlock(kernel_size=1, stride=stride, post_activation=False,
                                      in_channels=in_channels, out_channels=out_channels)
        else:
            skip_flow = nn.Identity()
        
        self.skip_flow = skip_flow
        
    def forward(self, x):
        return self.main_flow(x) + self.skip_flow(x)

    
class ConvBlock(nn.Sequential):
    def __init__(self, kernel_size, in_channels, out_channels, stride=1, groups=1, post_activation=True):
        super().__init__([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                      stride=stride, padding=kernel_size // 2, bias=False, groups=groups),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU() if post_activation else nn.Identity(),
        ])


class SEBlock(nn.Sequential):
    def __init__(self, in_channels, squeeze_ratio=4):
        self.in_channels = in_channels
        super().__init__([
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            nn.Flatten(),
            nn.Linear(in_features=in_channels, out_features=in_channels // squeeze_ratio),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(in_features=in_channels // squeeze_ratio, out_features=in_channels),
            nn.Sigmoid(),
        ])
        
    def forward(self, x):
        return self(x) * x