# 1. Blocks

## 1.1. BottleNeck

In [51]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class BottleNeck(nn.Module):
    def __init__(
        self,
        in_channels, out_channels,
        bottleneck=4, kernel_size=3, stride=1, padding='same',
        activation_layer=nn.ReLU,
        **kwargs
    ):
        """Bottleneck block with an option to behave as a linear bottle-neck block.

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottleneck (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
            act (nn.Module, optional): It activation layer applied to the block output.
                If it is set to the identity layer, it becomes a linear bottleneck block.
                Defaults to nn.ReLU.
        """
        super().__init__()
        bottleneck_size = int(in_channels*bottleneck)

        self.block = ResidualAdd(
            nn.Sequential(
                # (..., in_channels, ...) -> (..., bottleneck_size, ...)
                ops.Conv2dNormActivation(
                    in_channels, bottleneck_size,
                    kernel_size=1, stride=1, padding='same',
                    **kwargs
                ),

                # (..., bottleneck_size, ...) -> (..., bottleneck_size, ...)
                ops.Conv2dNormActivation(
                    bottleneck_size, bottleneck_size,
                    kernel_size=kernel_size, stride=stride, padding=padding,
                    **kwargs
                ),

                # (..., bottleneck_size, ...) -> (..., out_channels, ...)
                ops.Conv2dNormActivation(
                    bottleneck_size, out_channels,
                    kernel_size=1, stride=1, padding='same',
                    activation_layer=nn.Identity,
                    **kwargs
                ),
            ),
            # (..., out_channels, ...) -> (..., in_channels, ...)
            shortcut=ConvBnReLU2d(
                in_channels, out_channels,
                kernel_size=1, stride=stride,
                **kwargs
            ) if in_channels != out_channels else None
        )
        self.act = activation_layer()
    
    def forward(self, x):
        return self.act(self.block(x))

# in_channels, out_channels = 2, 4
# kernel_size = 3
# H, W = 20, 40
# batch_size = 1
# l = BottleNeck(
#     in_channels, out_channels,
#     kernel_size=kernel_size, stride=2, padding=1,
# )
# x = torch.randn(batch_size, in_channels, H, W)
# out = l(x)
# x.shape, out.shape

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
BottleNeck (BottleNeck)                  [1, 2, 20, 40]       [1, 4, 10, 20]       --                   True
├─ResidualAdd (block)                    [1, 2, 20, 40]       [1, 4, 10, 20]       --                   True
│    └─Sequential (block)                [1, 2, 20, 40]       [1, 4, 10, 20]       --                   True
│    │    └─Conv2dNormActivation (0)     [1, 2, 20, 40]       [1, 8, 20, 40]       32                   True
│    │    └─Conv2dNormActivation (1)     [1, 8, 20, 40]       [1, 8, 10, 20]       592                  True
│    │    └─Conv2dNormActivation (2)     [1, 8, 10, 20]       [1, 4, 10, 20]       40                   True
│    └─ConvBnReLU2d (shortcut)           [1, 2, 20, 40]       [1, 4, 10, 20]       --                   True
│    │    └─Sequential (block)           [1, 2, 20, 40]       [1, 4, 10, 20]       20                   True
├─ReLU (act)  

## 1.4. MobileNet Block

In [50]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class MobileNetBlock(nn.Module):
    def __init__(
        self,
        in_channels, out_channels,
        bottleneck=4, kernel_size=3, stride=1, padding='same',
        activation_layer=nn.Identity,
        **kwargs
    ):
        """MobileNetBlock is a bottleneck block which applies residual connections
        if in_features == out_features. It uses a bottle_factor of 4.

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottleneck (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
            activation_layer (nn.Module, optional): It activation layer applied to the block output.
                If it is set to the identity layer, it becomes a linear bottleneck block.
                Defaults to nn.ReLU.
        """
        super().__init__()
        bottleneck_size = int(in_channels*bottleneck)

        residualOrNot = ResidualAdd if in_channels == out_channels else nn.Sequential
        self.block = (
            residualOrNot(
                nn.Sequential(
                    # (..., in_channels, ...) -> (..., bottleneck_size, ...)
                    ops.Conv2dNormActivation(
                        in_channels, bottleneck_size,
                        kernel_size=1, stride=1, padding='same',
                        **kwargs
                    ),

                    # (..., bottleneck_size, ...) -> (..., bottleneck_size, ...)
                    ops.Conv2dNormActivation(
                        bottleneck_size, bottleneck_size,
                        kernel_size=kernel_size, stride=stride, padding=padding,
                        **kwargs
                    ),

                    # (..., bottleneck_size, ...) -> (..., out_channels, ...)
                    ops.Conv2dNormActivation(
                        bottleneck_size, out_channels,
                        kernel_size=1, stride=1, padding='same',
                        activation_layer=nn.Identity,
                        **kwargs
                    ),
                ),
            )
        )
        self.act = activation_layer()
    
    def forward(self, x):
        return self.act(self.block(x))

# in_channels, out_channels = 2, 4
# kernel_size = 3
# H, W = 20, 40
# batch_size = 1
# l = MobileNetBlock(
#     in_channels, out_channels,
#     kernel_size=kernel_size,
# )
# x = torch.randn(batch_size, in_channels, H, W)
# out = l(x)
# x.shape, out.shape

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
MobileNetBlock (MobileNetBlock)          [1, 2, 20, 40]       [1, 4, 20, 40]       --                   True
├─Sequential (block)                     [1, 2, 20, 40]       [1, 4, 20, 40]       --                   True
│    └─Sequential (0)                    [1, 2, 20, 40]       [1, 4, 20, 40]       --                   True
│    │    └─Conv2dNormActivation (0)     [1, 2, 20, 40]       [1, 8, 20, 40]       32                   True
│    │    └─Conv2dNormActivation (1)     [1, 8, 20, 40]       [1, 8, 20, 40]       592                  True
│    │    └─Conv2dNormActivation (2)     [1, 8, 20, 40]       [1, 4, 20, 40]       40                   True
├─Identity (act)                         [1, 4, 20, 40]       [1, 4, 20, 40]       --                   --
Total params: 664
Trainable params: 664
Non-trainable params: 0
Total mult-adds (M): 0.50
Input size (MB): 0.01
Forward/backw

## 1.5. MBConv

MobileNetV2

In [48]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class MBConv(nn.Module):
    def __init__(
        self,
        in_channels, out_channels,
        bottleneck=4, kernel_size=3, stride=1, padding='same',
        **kwargs
    ):
        """MBConv is a MobileNetBlock with Depthwise Convolution.
        It replaces ReLU with ReLU6

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottleneck (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
        """
        super().__init__()
        bottleneck_size = int(in_channels*bottleneck)

        residualOrNot = ResidualAdd if in_channels == out_channels else nn.Sequential
        self.block = (
            residualOrNot(
                nn.Sequential(
                    # (..., in_channels, ...) -> (..., bottleneck_size, ...)
                    ops.Conv2dNormActivation(
                        in_channels, bottleneck_size,
                        kernel_size=1, stride=1, padding='same',
                        activation_layer=nn.ReLU6,
                        **kwargs
                    ),

                    # (..., bottleneck_size, ...) -> (..., bottleneck_size, ...)
                    ops.Conv2dNormActivation(
                        bottleneck_size, bottleneck_size,
                        kernel_size=kernel_size, stride=stride, padding=padding,
                        groups=bottleneck_size, # Depthwise Convolution
                        activation_layer=nn.ReLU6,
                        **kwargs
                    ),

                    # (..., bottleneck_size, ...) -> (..., out_channels, ...)
                    ops.Conv2dNormActivation(
                        bottleneck_size, out_channels,
                        kernel_size=1, stride=1, padding='same',
                        activation_layer=nn.Identity,
                        **kwargs
                    ),
                ),
            )
        )
    
    def forward(self, x):
        return self.block(x)

# in_channels, out_channels = 2, 4
# kernel_size = 3
# H, W = 20, 40
# batch_size = 1
# l = MBConv(
#     in_channels, out_channels,
#     kernel_size=kernel_size,
# )
# x = torch.randn(batch_size, in_channels, H, W)
# out = l(x)
# x.shape, out.shape

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
MBConv (MBConv)                          [1, 2, 20, 40]       [1, 4, 20, 40]       --                   True
├─Sequential (block)                     [1, 2, 20, 40]       [1, 4, 20, 40]       --                   True
│    └─Sequential (0)                    [1, 2, 20, 40]       [1, 4, 20, 40]       --                   True
│    │    └─Conv2dNormActivation (0)     [1, 2, 20, 40]       [1, 8, 20, 40]       32                   True
│    │    └─Conv2dNormActivation (1)     [1, 8, 20, 40]       [1, 8, 20, 40]       88                   True
│    │    └─Conv2dNormActivation (2)     [1, 8, 20, 40]       [1, 4, 20, 40]       40                   True
Total params: 160
Trainable params: 160
Non-trainable params: 0
Total mult-adds (M): 0.10
Input size (MB): 0.01
Forward/backward pass size (MB): 0.26
Params size (MB): 0.00
Estimated Total Size (MB): 0.26


## 1.6. FusedMBConv

EfficientNetV2

In [47]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class FusedMBConv(nn.Module):
    def __init__(
        self,
        in_channels, out_channels,
        bottleneck=4, kernel_size=3, stride=1, padding='same',
        **kwargs
    ):
        """FusedMBConv is a MBConv block with fused 1st and 2nd convs.

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottleneck (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
        """
        super().__init__()
        bottleneck_size = int(in_channels*bottleneck)

        residualOrNot = ResidualAdd if in_channels == out_channels else nn.Sequential
        self.block = (
            residualOrNot(
                nn.Sequential(
                    # (..., in_channels, ...) -> (..., bottleneck_size, ...)
                    ops.Conv2dNormActivation(
                        in_channels, bottleneck_size,
                        kernel_size=kernel_size, stride=stride, padding=padding,
                        activation_layer=nn.ReLU6,
                        **kwargs
                    ),

                    # (..., bottleneck_size, ...) -> (..., out_channels, ...)
                    ops.Conv2dNormActivation(
                        bottleneck_size, out_channels,
                        kernel_size=1, stride=1, padding='same',
                        activation_layer=nn.Identity,
                        **kwargs
                    ),
                ),
            )
        )
    
    def forward(self, x):
        return self.block(x)

# in_channels, out_channels = 2, 4
# kernel_size = 3
# H, W = 20, 40
# batch_size = 1
# l = FusedMBConv(
#     in_channels, out_channels,
#     kernel_size=kernel_size,
# )
# x = torch.randn(batch_size, in_channels, H, W)
# out = l(x)
# x.shape, out.shape

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
FusedMBConv (FusedMBConv)                [1, 2, 20, 40]       [1, 4, 20, 40]       --                   True
├─Sequential (block)                     [1, 2, 20, 40]       [1, 4, 20, 40]       --                   True
│    └─Sequential (0)                    [1, 2, 20, 40]       [1, 4, 20, 40]       --                   True
│    │    └─Conv2dNormActivation (0)     [1, 2, 20, 40]       [1, 8, 20, 40]       160                  True
│    │    └─Conv2dNormActivation (1)     [1, 8, 20, 40]       [1, 4, 20, 40]       40                   True
Total params: 200
Trainable params: 200
Non-trainable params: 0
Total mult-adds (M): 0.14
Input size (MB): 0.01
Forward/backward pass size (MB): 0.15
Params size (MB): 0.00
Estimated Total Size (MB): 0.16


## 1.7. SEBlock

In [173]:
import torch
import torch.nn as nn

from torchinfo import summary

class SEBlock(nn.Module):
    def __init__(self, in_channels, squeeze_channels, activation_layer=nn.ReLU):
        """SEBlock is a FusedMBConv block with Squeeze-And-Excitation.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, squeeze_channels, 1)
        self.activation = activation_layer()
        self.conv2 = nn.Conv2d(squeeze_channels, in_channels, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        inp = x

        # 1. (...) -> (..., in_channels, 1, 1)
        x = self.pool(x)

        # 2. -> (..., squeeze_channels, 1, 1)
        x = self.conv1(x)
        x = self.activation(x)

        # 3. -> (..., in_channels, 1, 1)
        x = self.conv2(x)

        # 4. Scale
        x = self.sigmoid(x)
        x = inp * x

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = SEBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
SEBlock (SEBlock)                        [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
├─AdaptiveAvgPool2d (pool)               [1, 144, 10, 10]     [1, 144, 1, 1]       --                   --
├─Conv2d (conv1)                         [1, 144, 1, 1]       [1, 6, 1, 1]         870                  True
├─ReLU (activation)                      [1, 6, 1, 1]         [1, 6, 1, 1]         --                   --
├─Conv2d (conv2)                         [1, 6, 1, 1]         [1, 144, 1, 1]       1,008                True
├─Sigmoid (sigmoid)                      [1, 144, 1, 1]       [1, 144, 1, 1]       --                   --
Total params: 1,878
Trainable params: 1,878
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.06
Forward/backward pass size (MB): 0.00
Params size (MB): 0.01
Estimated Total Size (MB): 0.07


## 1.7. EfficientNetBlock

EfficientNetV2

In [174]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class EfficientNetBlock(nn.Module):
    def __init__(
        self,
        in_channels, out_channels,
        bottleneck=4, kernel_size=3, stride=1, padding='same', squeeze_ratio=4,
        activation_layer=nn.SiLU,
        **kwargs
    ):
        """EfficientNetBlock is a MBConv block with SEBlock.

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottleneck (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
        """
        super().__init__()
        self.residual = (in_channels == out_channels and stride == 1)

        modules = nn.ModuleList()
        bottleneck_size = int(in_channels*bottleneck)

        # 1. (..., in_channels, ...) -> (..., bottleneck_size, ...)
        if in_channels != bottleneck_size:
            modules.append(
                ops.Conv2dNormActivation(
                    in_channels, bottleneck_size,
                    kernel_size=1, stride=1, padding='same',
                    activation_layer=activation_layer,
                    **kwargs
                )
            )
        
        # 2. (..., bottleneck_size, ...) -> (..., bottleneck_size, ...)
        modules.append(
            ops.Conv2dNormActivation(
                bottleneck_size, bottleneck_size,
                kernel_size=kernel_size, stride=stride, padding=padding,
                groups=bottleneck_size, # Depthwise Convolution
                activation_layer=activation_layer,
                **kwargs
            )
        )

        # 3. Squeeze and excitation block
        squeeze_channels = max(1, in_channels // squeeze_ratio)
        modules.append(
            SEBlock(bottleneck_size, squeeze_channels, activation_layer=activation_layer)
        )

        # 4. (..., bottleneck_size, ...) -> (..., out_channels, ...)
        modules.append(
            ops.Conv2dNormActivation(
                bottleneck_size, out_channels,
                kernel_size=1, stride=1, padding='same',
                activation_layer=nn.Identity,
                **kwargs
            )
        )
        
        self.block = nn.Sequential(*modules)
    
    def forward(self, x):
        inp = x
        x = self.block(x)

        if self.residual:
            x = x + inp

        return x

# in_channels, out_channels = 32, 4
# kernel_size = 3
# H, W = 20, 40
# batch_size = 1
# l = EfficientNetBlock(
#     in_channels, out_channels,
#     kernel_size=kernel_size, bottleneck=1,
# )
# x = torch.randn(batch_size, in_channels, H, W)
# out = l(x)
# x.shape, out.shape

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
EfficientNetBlock (EfficientNetBlock)    [1, 32, 20, 40]      [1, 4, 20, 40]       --                   True
├─Sequential (block)                     [1, 32, 20, 40]      [1, 4, 20, 40]       --                   True
│    └─Conv2dNormActivation (0)          [1, 32, 20, 40]      [1, 32, 20, 40]      --                   True
│    │    └─Conv2d (0)                   [1, 32, 20, 40]      [1, 32, 20, 40]      288                  True
│    │    └─BatchNorm2d (1)              [1, 32, 20, 40]      [1, 32, 20, 40]      64                   True
│    │    └─SiLU (2)                     [1, 32, 20, 40]      [1, 32, 20, 40]      --                   --
│    └─SEBlock (1)                       [1, 32, 20, 40]      [1, 32, 20, 40]      --                   True
│    │    └─AdaptiveAvgPool2d (pool)     [1, 32, 20, 40]      [1, 32, 1, 1]        --                   --
│    │    └─Conv2d

# 2. Model

* ReLU6 and SiLU usage

## 2.1. EfficientNetConfig

#### TODO
- [x] Pipe activation_layer all the way through.
- [x] Residual connection in SEBlock is broken for stride != 1.
- [x] Round channels to multiple of 8.
- [-] Pipe dropout.
- [-] Configure the minimum number of channels in EfficientNetConfig.
- [x] Adjust variable blocks to have same number of input and output channels.

In [175]:
import math
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class EfficientNetConfig(object):
    BaseResolution = 224

    def __init__(
        self,
        # B4 configuration
        width_mult=1.4, depth_mult=1.8, dropout=0.4, last_channels=1280,

        # Block configuration
        kernel_1=3, kernel_2=5,
        padding_1=1, padding_2=2,
    ):
        self.width_mult = width_mult
        self.depth_mult =depth_mult

        self.dropout = dropout
        self.last_channels = last_channels

        # Conv type 1
        self.kernel_1 = kernel_1
        self.padding_1 = padding_1

        # Conv type 2
        self.kernel_2 = kernel_2
        self.padding_2 = padding_2

        # MBConv configs
        self.block_configs = [
            # (in_channels, out_channels, bottleneck, kernel, padding, stride, layers)
            (32, 16, 1, kernel_1, 'same', 1, 1),
            (16, 24, 6, kernel_1, padding_1, 2, 2),
            (24, 40, 6, kernel_2, padding_2, 2, 2),
            (40, 80, 6, kernel_1, padding_1, 2, 3),
            (80, 112, 6, kernel_2, 'same', 1, 3),
            (112, 192, 6, kernel_2, padding_2, 2, 4),
            (192, 320, 6, kernel_1, 'same', 1, 1),
        ]

    def adjust_channels(self, channels):
        return self.round_to(channels*self.width_mult)
    
    def adjust_depth(self, num_layers):
        return int(math.ceil(num_layers*self.depth_mult))
    
    @staticmethod
    def round_to(v, multiple=8):
        return int(multiple * round(v / multiple))
    
    def _block(self, args, **kwargs):
        in_channels, out_channels, bottleneck, kernel, padding, stride, layers = args

        # 1. Update in_channels and out_channels based on the width_mult
        in_channels = self.adjust_channels(in_channels)
        out_channels = self.adjust_channels(out_channels)

        # 2. Update layers based on depth_mult
        layers = self.adjust_depth(layers)
        
        block = nn.Sequential(
            EfficientNetBlock(
                in_channels, out_channels,
                bottleneck=bottleneck,
                kernel_size=kernel, stride=stride, padding=padding,
                **kwargs
            ),
            *map(
                lambda _: EfficientNetBlock(
                    out_channels, out_channels,
                    bottleneck=bottleneck,
                    kernel_size=kernel, stride=1, padding='same',
                    **kwargs
                ),
                range(layers - 1)
            )
        )
        return block
    
    def make_blocks(self, **kwargs):
        modules = nn.Sequential(
            *map(
                lambda b_config: self._block(b_config,  **kwargs),
                self.block_configs
            )
        )
        return modules

# eff_config = EfficientNetConfig()
# block_config = eff_config.block_configs[0]
# block = eff_config.make_blocks()

# x = torch.randn(1, 48, 224, 224)
# print(summary(
#     model=block, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"],
#     # depth=8,
# ))

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
Sequential (Sequential)                            [1, 48, 224, 224]    [1, 448, 14, 14]     --                   True
├─Sequential (0)                                   [1, 48, 224, 224]    [1, 24, 224, 224]    --                   True
│    └─EfficientNetBlock (0)                       [1, 48, 224, 224]    [1, 24, 224, 224]    --                   True
│    │    └─Sequential (block)                     [1, 48, 224, 224]    [1, 24, 224, 224]    2,940                True
│    └─EfficientNetBlock (1)                       [1, 24, 224, 224]    [1, 24, 224, 224]    --                   True
│    │    └─Sequential (block)                     [1, 24, 224, 224]    [1, 24, 224, 224]    1,206                True
├─Sequential (1)                                   [1, 24, 224, 224]    [1, 32, 112, 112]    --                   True
│    └─EfficientNetBlock (0)               

In [177]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class EfficientNet(nn.Module):
    def __init__(
        self,
        in_channels, out_channels, config,
        activation_layer=nn.SiLU,
        **kwargs
    ):
        """EfficientNet

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottle_factor (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
            phi (float, optional): The compound scaling coefficient. Defaults to 1.
        """
        super().__init__()
        self.config = config
        in_mb_channels = config.adjust_channels(32)
        out_mb_channels = config.adjust_channels(320)

        self.model = nn.Sequential(
            ops.Conv2dNormActivation(
                in_channels, in_mb_channels,
                kernel_size=config.kernel_1, padding=config.padding_1,
                stride=2,
                activation_layer=activation_layer,
                **kwargs,
            ),
            self.config.make_blocks(activation_layer=activation_layer, **kwargs),
            ops.Conv2dNormActivation(
                out_mb_channels, config.last_channels,
                kernel_size=1, activation_layer=activation_layer,
                **kwargs,
            ),
            nn.AdaptiveAvgPool2d(1),
        )
        self.classifier = nn.Sequential(
            # nn.Dropout(p=config.dropout, inplace=True),
            nn.Linear(config.last_channels, out_channels),
        )

    
    def forward(self, x):
        x = self.model(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

# # in_channels, out_channels = 4, 6
# in_channels, out_channels = 3, 6
# eff_config = EfficientNetConfig()
# model = EfficientNet(
#     in_channels,
#     out_channels,
#     eff_config,
# )

# x = torch.randn(1, in_channels, 256, 256)
# print(summary(
#     model=model, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"],
# ))

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
EfficientNet (EfficientNet)                                  [1, 3, 256, 256]     --                   --                   True
├─Sequential (model)                                         [1, 3, 256, 256]     [1, 1280, 1, 1]      --                   True
│    └─Conv2dNormActivation (0)                              [1, 3, 256, 256]     [1, 48, 128, 128]    --                   True
│    │    └─Conv2d (0)                                       [1, 3, 256, 256]     [1, 48, 128, 128]    1,296                True
│    │    └─BatchNorm2d (1)                                  [1, 48, 128, 128]    [1, 48, 128, 128]    96                   True
│    │    └─SiLU (2)                                         [1, 48, 128, 128]    [1, 48, 128, 128]    --                   --
│    └─Sequential (1)                                        [1, 48, 128, 128]    [1, 448, 8, 