# 1. Blocks

## 1.1. ChannelSeBlock

In [1]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class ChannelSeBlock(nn.Module):
    def __init__(
        self,
        in_channels, squeeze_channels,
        activation_layer=nn.ReLU,
        conv_block=ops.Conv2dNormActivation,
        pool_block=nn.AdaptiveAvgPool2d,
    ):
        """ChannelSeBlock is a Squeeze-And-Excitation block.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.pool = pool_block(1)
        self.conv1 = conv_block(
            in_channels, squeeze_channels, 1,
            activation_layer=activation_layer,
            norm_layer=None, bias=False,
        )
        self.conv2 = conv_block(
            squeeze_channels, in_channels, 1,
            norm_layer=None, activation_layer=None, bias=False,
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        inp = x

        # 1. (...) -> (..., in_channels, 1, 1)
        x = self.pool(x)

        # 2. -> (..., squeeze_channels, 1, 1)
        x = self.conv1(x)

        # 3. -> (..., in_channels, 1, 1)
        x = self.conv2(x)

        # 4. Scale
        x = self.sigmoid(x)
        x = inp * x

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = ChannelSeBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

## 1.2. SpatialSeBlock

In [2]:
import torch
import torch.nn as nn

from torchinfo import summary

class SpatialSeBlock(nn.Module):
    def __init__(
        self,
        in_channels, squeeze_channels,
        conv_block=ops.Conv2dNormActivation,
        **kwargs,
    ):
        """SpatialSeBlock is a spatial Squeeze-And-Excitation.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.conv = conv_block(
            in_channels, 1, 1,
            norm_layer=None, activation_layer=None, bias=None,
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 1. Save input -> (..., in_channels, H, W)
        inp = x

        # 2. -> (..., 1, H, W)
        x = self.conv(x)

        # 3. Scale
        x = self.sigmoid(x)
        x = inp * x

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = SpatialSeBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

## 1.3. SpatialChannelSeBlock

In [3]:
import torch
import torch.nn as nn

from torchinfo import summary

class SpatialChannelSeBlock(nn.Module):
    def __init__(self, in_channels, squeeze_channels, **kwargs):
        """SpatialChannelSeBlock is a Squeeze-And-Excitation block
        which combines spatial and channel squeeze-and-excitation.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.spatial_se = SpatialSeBlock(
            in_channels=in_channels, squeeze_channels=squeeze_channels,
            **kwargs,
        )
        self.channel_se = ChannelSeBlock(
            in_channels=in_channels, squeeze_channels=squeeze_channels,
            **kwargs,
        )

    def forward(self, x):
        # 1. Spatial scaling
        spatial_x = self.spatial_se(x)

        # 2. Channel scaling
        channel_x = self.channel_se(x)

        # 3. Combine spatial and channel scaling results
        x = spatial_x + channel_x

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = SpatialChannelSeBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

## 1.4. MBConv

In [4]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class MBConv(nn.Module):
    def __init__(
        self,
        in_channels, out_channels,
        bottleneck=4, kernel_size=3, stride=1, padding='same', squeeze_ratio=4,
        activation_layer=nn.SiLU,
        se_block=ChannelSeBlock,
        conv_block=ops.Conv2dNormActivation,
        pool_block=nn.AdaptiveAvgPool2d,
        **kwargs
    ):
        """MBConv is a MBConv block with SEBlock.

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottleneck (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
        """
        super().__init__()
        self.residual = (in_channels == out_channels and stride == 1)

        modules = nn.ModuleList()
        bottleneck_size = int(in_channels*bottleneck)

        # 1. (..., in_channels, ...) -> (..., bottleneck_size, ...)
        if in_channels != bottleneck_size:
            modules.append(
                conv_block(
                    in_channels, bottleneck_size,
                    kernel_size=1, stride=1, padding='same',
                    activation_layer=activation_layer,
                    **kwargs
                )
            )
        
        # 2. (..., bottleneck_size, ...) -> (..., bottleneck_size, ...)
        modules.append(
            conv_block(
                bottleneck_size, bottleneck_size,
                kernel_size=kernel_size, stride=stride, padding=padding,
                groups=bottleneck_size, # Depthwise Convolution
                activation_layer=activation_layer,
                **kwargs
            )
        )

        # 3. Squeeze and excitation block
        squeeze_channels = max(1, in_channels // squeeze_ratio)
        modules.append(
            se_block(
                bottleneck_size, squeeze_channels,
                activation_layer=activation_layer,
                conv_block=conv_block,
                pool_block=pool_block,
            )
        )

        # 4. (..., bottleneck_size, ...) -> (..., out_channels, ...)
        modules.append(
            conv_block(
                bottleneck_size, out_channels,
                kernel_size=1, stride=1, padding='same',
                activation_layer=nn.Identity,
                **kwargs
            )
        )
        
        self.block = nn.Sequential(*modules)
    
    def forward(self, x):
        inp = x
        x = self.block(x)

        if self.residual:
            x = x + inp

        return x

# in_channels, out_channels = 32, 4
# kernel_size = 3
# batch_size = 1
# l = MBConv(
#     in_channels, out_channels,
#     kernel_size=kernel_size, bottleneck=1,
# )
# x = torch.randn(batch_size, in_channels, 20, 20)
# out = l(x)
# x.shape, out.shape

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

# 2. FinNet

##  2.1. FinNetConfig

In [5]:
import math
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class FinNetConfig(object):
    def __init__(
        self,
        # Efficient B4 configuration
        width_mult=1.4, depth_mult=1.8, dropout=0.4, last_channels=1280,

        # Block configuration
        kernel_1=3, kernel_2=5,

        # Blocks
        se_block=ChannelSeBlock,
        conv_block=ops.Conv2dNormActivation,
        pool_block=nn.AdaptiveAvgPool2d,

        # Fin configuration
        fins=4,
        fin_depth=4,
        fin_output=False,
    ):
        self.width_mult = width_mult
        self.depth_mult =depth_mult

        self.dropout = dropout
        self.last_channels = last_channels

        # Conv type 1
        self.kernel_1 = kernel_1

        # Conv type 2
        self.kernel_2 = kernel_2

        # Blocks
        self.se_block = se_block
        self.conv_block = conv_block
        self.pool_block = pool_block

        # Fin configuration
        self.fins = fins
        self.fin_depth = fin_depth
        self.fin_output = fin_output

        # MBConv configs
        self.block_configs = [
            # (in_channels, out_channels, bottleneck, squeeze_channels, kernel, padding, stride, layers)
            (32, 16, 1, None, kernel_1, 'same', 1, 1),
            (16, 24, 6, None, 3, 1, 2, 2),
            (24, 40, 6, None, 5, 2, 2, 2),
            (40, 80, 6, None, 3, 1, 2, 3),
            (80, 112, 6, None, kernel_2, 'same', 1, 3),
            (112, 192, 6, None, 5, 2, 2, 4),
            (192, 320, 6, None, kernel_1, 'same', 1, 1),
        ]

    def adjust_channels(self, channels):
        return self.round_to(channels*self.width_mult)
    
    def adjust_depth(self, num_layers):
        return int(math.ceil(num_layers*self.depth_mult))
    
    @staticmethod
    def round_to(v, multiple=8):
        return int(multiple * round(v / multiple))
    
    @staticmethod
    def get_squeeze_ratio(channels, squeeze_channels, default=4):
        return channels//squeeze_channels if squeeze_channels else default
    
    def _block(self, args, **kwargs):
        in_channels, out_channels, bottleneck, squeeze_channels, kernel, padding, stride, layers = args

        # 1. Update in_channels and out_channels based on the width_mult
        in_channels = self.adjust_channels(in_channels)
        out_channels = self.adjust_channels(out_channels)

        # 2. Update layers based on depth_mult
        layers = self.adjust_depth(layers)
        
        block = nn.Sequential(
            MBConv(
                in_channels, out_channels,
                bottleneck=bottleneck, squeeze_ratio=self.get_squeeze_ratio(in_channels, squeeze_channels),
                kernel_size=kernel, stride=stride, padding=padding,
                se_block=self.se_block,
                conv_block=self.conv_block,
                pool_block=self.pool_block,
                **kwargs
            ),
            *map(
                lambda _: MBConv(
                    out_channels, out_channels,
                    bottleneck=bottleneck, squeeze_ratio=self.get_squeeze_ratio(out_channels, squeeze_channels),
                    kernel_size=kernel, stride=1, padding='same',
                    se_block=self.se_block,
                    conv_block=self.conv_block,
                    pool_block=self.pool_block,
                    **kwargs
                ),
                range(layers - 1)
            )
        )
        return block
    
    def make_blocks(self, **kwargs):
        # 1. Create 'fins' fin modules until fin_depth.
        fin_modules = nn.ModuleList(map(
            lambda _: nn.Sequential(
                *map(
                    lambda b_config: self._block(b_config,  **kwargs),
                    self.block_configs[:self.fin_depth]
                )
            ),
            range(self.fins)
        ))

        # 2. Create fin extension
        extension = nn.Sequential(
            *map(
                lambda b_config: self._block(b_config,  **kwargs),
                self.block_configs[self.fin_depth:]
            )
        )
        return fin_modules, extension

# fin_config = FinNetConfig(
#     # se_block=ChannelSeBlock,
#     # conv_block=Conv1dNormActivation,
#     # pool_block=nn.AdaptiveAvgPool1d
    
# )
# block_config = fin_config.block_configs[0]
# fin_modules, extension = fin_config.make_blocks()

# x = torch.randn(1, 48, 224, 224)
# print(summary(
#     model=fin_modules[0], 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"],
# ))

## 2.2. FinNet

In [10]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class FinNet(nn.Module):
    def __init__(
        self,
        in_channels, out_channels, config,
        activation_layer=nn.SiLU,
        **kwargs
    ):
        """ReflectionNet

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottle_factor (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
            phi (float, optional): The compound scaling coefficient. Defaults to 1.
        """
        super().__init__()
        self.in_channels = in_channels // config.fins
        self.config = config
        
        in_mb_channels = config.adjust_channels(32)
        out_mb_channels = config.adjust_channels(320)

        self.fins, self.extension = config.make_blocks()
        self.fin_init_convs = nn.ModuleList(map(
            lambda _: config.conv_block(
                self.in_channels, in_mb_channels,
                kernel_size=3, padding=1, stride=2,
                activation_layer=activation_layer,
                **kwargs,
            ),
            range(config.fins)
        ))

        self.pool = config.pool_block(1)
        self.classifier = nn.Linear(out_mb_channels, out_channels)
    
    def forward(self, x):
        # 1. Split x into fins
        xs = torch.split(x, self.in_channels, dim=1)

        # 2. Apply top-level conv
        xs = list(map(
            lambda item: item[1](item[0]),
            zip(xs, self.fin_init_convs)
        ))

        # 3. Run fin modules
        xs = list(map(
            lambda item: item[1](item[0]), 
            zip(xs, self.fins)
        ))

        # 4. Combine fin outputs
        x = torch.stack(xs, dim=1)
        x = x.mean(dim=1)

        # 5. Apply extension
        x = self.extension(x)
        x = self.pool(x)

        # 6. Classification head
        x = x.flatten(1)
        x = self.classifier(x)

        return x, xs if self.config.fin_output else x

# in_channels, out_channels = 4, 6
in_channels, out_channels = 16, 6
fin_config = FinNetConfig()
model = FinNet(
    in_channels,
    out_channels,
    fin_config,
)

x = torch.randn(1, in_channels, 128, 128)

print(summary(
    model=model, 
    input_data=x,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
))

Layer (type (var_name))                                           Input Shape          Output Shape         Param #              Trainable
FinNet (FinNet)                                                   [1, 16, 128, 128]    [1, 6]               --                   True
├─ModuleList (fin_init_convs)                                     --                   --                   --                   True
│    └─Conv2dNormActivation (0)                                   [1, 4, 128, 128]     [1, 48, 64, 64]      --                   True
│    │    └─Conv2d (0)                                            [1, 4, 128, 128]     [1, 48, 64, 64]      1,728                True
│    │    └─BatchNorm2d (1)                                       [1, 48, 64, 64]      [1, 48, 64, 64]      96                   True
│    │    └─SiLU (2)                                              [1, 48, 64, 64]      [1, 48, 64, 64]      --                   --
│    └─Conv2dNormActivation (1)                            

In [7]:
class Props(dict):
    def __init__(self, *args, **kwargs):
        super(Props, self).__init__(*args, **kwargs)
        self.__dict__ = self
    
    def __getattribute__(self, name):
        try:
            return super(Props, self).__getattribute__(name)
        except AttributeError:
            return None

model_config = Props(x=2)
model_config.x, model_config.y

(2, None)

In [16]:
import torch.nn as nn

class HBAFinNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
    
        ref_config = FinNetConfig(
            se_block=SpatialChannelSeBlock,

            # B0
            width_mult=1., depth_mult=1.,

            # Fin configuration
            fin_output=config.fin_output,
        )
        self.finnet = FinNet(
            config.in_channels,
            config.out_channels,
            ref_config,
        )
    
    def forward(self, x):
        return self.finnet(x)

config = Props(
    in_channels=16,
    out_channels=6,
    fin_output=True,
)

model = HBAFinNet(config)

x = torch.randn(1, config.in_channels, 256, 256)
# print(summary(
#     model=model, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"],
# ))
output = model(x)