# 1. Blocks

## 1.1. Conv1dNormActivation

In [2]:
import torch.nn as nn

class Conv1dNormActivation(nn.Module):
    def __init__(
        self,
        in_channels, out_channels, kernel_size,
        norm_layer=nn.BatchNorm1d,
        activation_layer=nn.ReLU,
        **kwargs
    ):
        super().__init__()

        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, **kwargs)
        self.norm = norm_layer(out_channels) if norm_layer else None
        self.activation = activation_layer() if activation_layer else None
    
    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x) if self.norm else x
        x = self.activation(x) if self.activation else x
        return x

# in_channels, out_channels = 144, 64
# x = torch.randn(1, in_channels, 10)
# l = Conv1dNormActivation(
#     in_channels, out_channels,
#     kernel_size=3, padding='same',
# )

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

## 1.2. ChannelSeBlock

In [30]:
import torch
import torch.nn as nn

from torchinfo import summary

class ChannelSeBlock(nn.Module):
    def __init__(
        self,
        in_channels, squeeze_channels,
        activation_layer=nn.ReLU,
        conv_block=ops.Conv2dNormActivation,
        pool_block=nn.AdaptiveAvgPool2d,
    ):
        """ChannelSeBlock is a Squeeze-And-Excitation block.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.pool = pool_block(1)
        self.conv1 = conv_block(
            in_channels, squeeze_channels, 1,
            activation_layer=activation_layer,
            norm_layer=None, bias=False,
        )
        self.conv2 = conv_block(
            squeeze_channels, in_channels, 1,
            norm_layer=None, activation_layer=None, bias=False,
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        inp = x

        # 1. (...) -> (..., in_channels, 1, 1)
        x = self.pool(x)

        # 2. -> (..., squeeze_channels, 1, 1)
        x = self.conv1(x)

        # 3. -> (..., in_channels, 1, 1)
        x = self.conv2(x)

        # 4. Scale
        x = self.sigmoid(x)
        x = inp * x

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = ChannelSeBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
ChannelSeBlock (ChannelSeBlock)          [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
├─AdaptiveAvgPool2d (pool)               [1, 144, 10, 10]     [1, 144, 1, 1]       --                   --
├─Conv2dNormActivation (conv1)           [1, 144, 1, 1]       [1, 6, 1, 1]         --                   True
│    └─Conv2d (0)                        [1, 144, 1, 1]       [1, 6, 1, 1]         870                  True
│    └─ReLU (1)                          [1, 6, 1, 1]         [1, 6, 1, 1]         --                   --
├─Conv2dNormActivation (conv2)           [1, 6, 1, 1]         [1, 144, 1, 1]       --                   True
│    └─Conv2d (0)                        [1, 6, 1, 1]         [1, 144, 1, 1]       864                  True
├─Sigmoid (sigmoid)                      [1, 144, 1, 1]       [1, 144, 1, 1]       --                   --
Total params: 1,734


## 1.3. SpatialSeBlock

In [31]:
import torch
import torch.nn as nn

from torchinfo import summary

class SpatialSeBlock(nn.Module):
    def __init__(
        self,
        in_channels, squeeze_channels,
        conv_block=ops.Conv2dNormActivation,
        **kwargs,
    ):
        """SpatialSeBlock is a spatial Squeeze-And-Excitation.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.conv = conv_block(
            in_channels, 1, 1,
            norm_layer=None, activation_layer=None, bias=None,
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 1. Save input -> (..., in_channels, H, W)
        inp = x

        # 2. -> (..., 1, H, W)
        x = self.conv(x)

        # 3. Scale
        x = self.sigmoid(x)
        x = inp * x

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = SpatialSeBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
SpatialSeBlock (SpatialSeBlock)          [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
├─Conv2dNormActivation (conv)            [1, 144, 10, 10]     [1, 1, 10, 10]       --                   True
│    └─Conv2d (0)                        [1, 144, 10, 10]     [1, 1, 10, 10]       145                  True
├─Sigmoid (sigmoid)                      [1, 1, 10, 10]       [1, 1, 10, 10]       --                   --
Total params: 145
Trainable params: 145
Non-trainable params: 0
Total mult-adds (M): 0.01
Input size (MB): 0.06
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.06


## 1.4. SpatialChannelSeBlock

In [32]:
import torch
import torch.nn as nn

from torchinfo import summary

class SpatialChannelSeBlock(nn.Module):
    def __init__(self, in_channels, squeeze_channels, **kwargs):
        """SpatialChannelSeBlock is a Squeeze-And-Excitation block
        which combines spatial and channel squeeze-and-excitation.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.spatial_se = SpatialSeBlock(
            in_channels=in_channels, squeeze_channels=squeeze_channels,
            **kwargs,
        )
        self.channel_se = ChannelSeBlock(
            in_channels=in_channels, squeeze_channels=squeeze_channels,
            **kwargs,
        )

    def forward(self, x):
        # 1. Spatial scaling
        spatial_x = self.spatial_se(x)

        # 2. Channel scaling
        channel_x = self.channel_se(x)

        # 3. Combine spatial and channel scaling results
        x = spatial_x + channel_x

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = SpatialChannelSeBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
SpatialChannelSeBlock (SpatialChannelSeBlock)      [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
├─SpatialSeBlock (spatial_se)                      [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
│    └─Conv2dNormActivation (conv)                 [1, 144, 10, 10]     [1, 1, 10, 10]       --                   True
│    │    └─Conv2d (0)                             [1, 144, 10, 10]     [1, 1, 10, 10]       145                  True
│    └─Sigmoid (sigmoid)                           [1, 1, 10, 10]       [1, 1, 10, 10]       --                   --
├─ChannelSeBlock (channel_se)                      [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
│    └─AdaptiveAvgPool2d (pool)                    [1, 144, 10, 10]     [1, 144, 1, 1]       --                   --
│    └─Conv2dNormActivation (conv1)            

## 1.5. HSeBlock

In [35]:
import torch
import torch.nn as nn

from torchinfo import summary

class HSeBlock(nn.Module):
    def __init__(
        self,
        in_channels, squeeze_channels,
        activation_layer=nn.ReLU,
        conv_block=ops.Conv2dNormActivation,
        pool_block=nn.AdaptiveAvgPool2d,
    ):
        """HSeBlock is a Squeeze-And-Excitation block.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.conv = conv_block(
            in_channels, 1, 1,
            norm_layer=None, activation_layer=None, bias=None,
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        inp = x

        # 1. -> (..., in_channels, H, 1)
        x = x.mean(dim=-1, keepdim=True)

        # 2. -> (..., 1, H, 1)
        x = self.conv(x)

        # 3. Scale
        x = self.sigmoid(x)
        x = inp * x

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = HSeBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
HSeBlock (HSeBlock)                      [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
├─Conv2dNormActivation (conv)            [1, 144, 10, 1]      [1, 1, 10, 1]        --                   True
│    └─Conv2d (0)                        [1, 144, 10, 1]      [1, 1, 10, 1]        145                  True
├─Sigmoid (sigmoid)                      [1, 1, 10, 1]        [1, 1, 10, 1]        --                   --
Total params: 145
Trainable params: 145
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.06
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.06


In [44]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class HBADeSeBlock(nn.Module):
    def __init__(
        self,
        in_channels, squeeze_channels,
        activation_layer=nn.ReLU,
        conv_block=ops.Conv2dNormActivation,
        pool_block=nn.AdaptiveAvgPool2d,
    ):
        """HBADeSeBlock is a FusedMBConv block with Squeeze-And-Excitation.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.deconv = nn.ConvTranspose2d(
            in_channels, squeeze_channels, (3, 1),
            stride=(2, 1), padding=(1, 0), output_padding=(1, 0),
        )
        self.activation = activation_layer()
        self.conv = conv_block(
            squeeze_channels, in_channels, (3, 1),
            stride=(2, 1), padding=(1, 0),
            norm_layer=None, activation_layer=None,
        )

    def forward(self, x):
        inp = x

        # 1. (...) -> (..., in_channels, H, 1)
        x = x.mean(dim=-1, keepdim=True)

        # 2. -> (..., squeeze_channels, 2H, 1)
        x = self.deconv(x)
        x = self.activation(x)

        # 3. -> (..., in_channels, H, 1)
        x = self.conv(x)

        # 4. Scale
        x = inp * x.sigmoid()

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = HBADeSeBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
HBADeSeBlock (HBADeSeBlock)              [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
├─ConvTranspose2d (deconv)               [1, 144, 10, 1]      [1, 6, 20, 1]        2,598                True
├─ReLU (activation)                      [1, 6, 20, 1]        [1, 6, 20, 1]        --                   --
├─Conv2dNormActivation (conv)            [1, 6, 20, 1]        [1, 144, 10, 1]      --                   True
│    └─Conv2d (0)                        [1, 6, 20, 1]        [1, 144, 10, 1]      2,736                True
Total params: 5,334
Trainable params: 5,334
Non-trainable params: 0
Total mult-adds (M): 0.08
Input size (MB): 0.06
Forward/backward pass size (MB): 0.01
Params size (MB): 0.02
Estimated Total Size (MB): 0.09


In [36]:
import torch
import torch.nn as nn

from torchinfo import summary

class HBAChannelHSeBlock(nn.Module):
    def __init__(
        self,
        in_channels, squeeze_channels,
        **kwargs
    ):
        """HBAChannelHSeBlock is a FusedMBConv block with Squeeze-And-Excitation.

        Args:
            in_channels (int): The number of input channels.
            squeeze_channels (int): The number of channels to squeeze to.
        """
        super().__init__()

        self.channel_se = ChannelSeBlock(
            in_channels=in_channels, squeeze_channels=squeeze_channels,
            **kwargs,
        )
        self.h_se = HSeBlock(
            in_channels=in_channels, squeeze_channels=squeeze_channels,
            **kwargs,
        )

    def forward(self, x):
        # 1. Channel scaling
        channel_x = self.channel_se(x)

        # 2. H scaling
        h_x = self.h_se(x)

        # 3. Combine spatial and channel scaling results
        x = channel_x + h_x

        return x

# in_channels = 144
# x = torch.randn(1, in_channels, 10, 10)
# l = HBAChannelHSeBlock(in_channels, squeeze_channels=6)

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
HBAChannelHSeBlock (HBAChannelHSeBlock)  [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
├─ChannelSeBlock (channel_se)            [1, 144, 10, 10]     [1, 144, 10, 10]     --                   True
│    └─AdaptiveAvgPool2d (pool)          [1, 144, 10, 10]     [1, 144, 1, 1]       --                   --
│    └─Conv2dNormActivation (conv1)      [1, 144, 1, 1]       [1, 6, 1, 1]         --                   True
│    │    └─Conv2d (0)                   [1, 144, 1, 1]       [1, 6, 1, 1]         870                  True
│    │    └─ReLU (1)                     [1, 6, 1, 1]         [1, 6, 1, 1]         --                   --
│    └─Conv2dNormActivation (conv2)      [1, 6, 1, 1]         [1, 144, 1, 1]       --                   True
│    │    └─Conv2d (0)                   [1, 6, 1, 1]         [1, 144, 1, 1]       864                  True
│    └─Sigmoid (si

## 1.7. MBConv

EfficientNetV2

In [19]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class MBConv(nn.Module):
    def __init__(
        self,
        in_channels, out_channels,
        bottleneck=4, kernel_size=3, stride=1, padding='same', squeeze_ratio=4,
        activation_layer=nn.SiLU,
        se_block=ChannelSeBlock,
        conv_block=ops.Conv2dNormActivation,
        pool_block=nn.AdaptiveAvgPool2d,
        **kwargs
    ):
        """EfficientNetBlock is a MBConv block with SEBlock.

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottleneck (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
        """
        super().__init__()
        self.residual = (in_channels == out_channels and stride == 1)

        modules = nn.ModuleList()
        bottleneck_size = int(in_channels*bottleneck)

        # 1. (..., in_channels, ...) -> (..., bottleneck_size, ...)
        if in_channels != bottleneck_size:
            modules.append(
                conv_block(
                    in_channels, bottleneck_size,
                    kernel_size=1, stride=1, padding='same',
                    activation_layer=activation_layer,
                    **kwargs
                )
            )
        
        # 2. (..., bottleneck_size, ...) -> (..., bottleneck_size, ...)
        modules.append(
            conv_block(
                bottleneck_size, bottleneck_size,
                kernel_size=kernel_size, stride=stride, padding=padding,
                groups=bottleneck_size, # Depthwise Convolution
                activation_layer=activation_layer,
                **kwargs
            )
        )

        # 3. Squeeze and excitation block
        squeeze_channels = max(1, in_channels // squeeze_ratio)
        modules.append(
            se_block(
                bottleneck_size, squeeze_channels,
                activation_layer=activation_layer,
                conv_block=conv_block,
                pool_block=pool_block,
            )
        )

        # 4. (..., bottleneck_size, ...) -> (..., out_channels, ...)
        modules.append(
            conv_block(
                bottleneck_size, out_channels,
                kernel_size=1, stride=1, padding='same',
                activation_layer=nn.Identity,
                **kwargs
            )
        )
        
        self.block = nn.Sequential(*modules)
    
    def forward(self, x):
        inp = x
        x = self.block(x)

        if self.residual:
            x = x + inp

        return x

# in_channels, out_channels = 32, 4
# kernel_size = 3
# batch_size = 1
# l = MBConv(
#     in_channels, out_channels,
#     kernel_size=kernel_size, bottleneck=1,
#     squeeze_ratio=in_channels//6
# )
# x = torch.randn(batch_size, in_channels, 20, 20)
# out = l(x)
# x.shape, out.shape

# print(summary(
#     model=l, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

# 2. Model

* ReLU6 and SiLU usage

## 2.1. EfficientNetConfig

#### TODO
- [x] Pipe activation_layer all the way through.
- [x] Residual connection in SEBlock is broken for stride != 1.
- [x] Round channels to multiple of 8.
- [-] Pipe dropout.
- [-] Configure the minimum number of channels in EfficientNetConfig.
- [x] Adjust variable blocks to have same number of input and output channels.

In [25]:
import math
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class EfficientNetConfig(object):
    def __init__(
        self,
        # B4 configuration
        width_mult=1.4, depth_mult=1.8, dropout=0.4, last_channels=1280,

        # Block configuration
        kernel_1=3, kernel_2=5,

        # Blocks
        se_block=ChannelSeBlock,
        conv_block=ops.Conv2dNormActivation,
        pool_block=nn.AdaptiveAvgPool2d,
    ):
        self.width_mult = width_mult
        self.depth_mult =depth_mult

        self.dropout = dropout
        self.last_channels = last_channels

        # Conv type 1
        self.kernel_1 = kernel_1

        # Conv type 2
        self.kernel_2 = kernel_2

        # Blocks
        self.se_block = se_block
        self.conv_block = conv_block
        self.pool_block = pool_block

        # MBConv configs
        self.block_configs = [
            # (in_channels, out_channels, bottleneck, squeeze_channels, kernel, padding, stride, layers)
            (32, 16, 1, None, kernel_1, 'same', 1, 1),
            (16, 24, 6, None, 3, 1, 2, 2),
            (24, 40, 6, None, 5, 2, 2, 2),
            (40, 80, 6, None, 3, 1, 2, 3),
            (80, 112, 6, None, kernel_2, 'same', 1, 3),
            (112, 192, 6, 6, 5, 2, 2, 4),
            (192, 320, 6, 6, kernel_1, 'same', 1, 1),
        ]

    def adjust_channels(self, channels):
        return self.round_to(channels*self.width_mult)
    
    def adjust_depth(self, num_layers):
        return int(math.ceil(num_layers*self.depth_mult))
    
    @staticmethod
    def round_to(v, multiple=8):
        return int(multiple * round(v / multiple))
    
    @staticmethod
    def get_squeeze_ratio(channels, squeeze_channels, default=4):
        return channels//squeeze_channels if squeeze_channels else default
    
    def _block(self, args, **kwargs):
        in_channels, out_channels, bottleneck, squeeze_channels, kernel, padding, stride, layers = args

        # 1. Update in_channels and out_channels based on the width_mult
        in_channels = self.adjust_channels(in_channels)
        out_channels = self.adjust_channels(out_channels)

        # 2. Update layers based on depth_mult
        layers = self.adjust_depth(layers)
        
        block = nn.Sequential(
            MBConv(
                in_channels, out_channels,
                bottleneck=bottleneck, squeeze_ratio=self.get_squeeze_ratio(in_channels, squeeze_channels),
                kernel_size=kernel, stride=stride, padding=padding,
                se_block=self.se_block,
                conv_block=self.conv_block,
                pool_block=self.pool_block,
                **kwargs
            ),
            *map(
                lambda _: MBConv(
                    out_channels, out_channels,
                    bottleneck=bottleneck, squeeze_ratio=self.get_squeeze_ratio(out_channels, squeeze_channels),
                    kernel_size=kernel, stride=1, padding='same',
                    se_block=self.se_block,
                    conv_block=self.conv_block,
                    pool_block=self.pool_block,
                    **kwargs
                ),
                range(layers - 1)
            )
        )
        return block
    
    def make_blocks(self, **kwargs):
        modules = nn.Sequential(
            *map(
                lambda b_config: self._block(b_config,  **kwargs),
                self.block_configs
            )
        )
        return modules

eff_config = EfficientNetConfig(
    # se_block=SEBlock,
    # conv_block=Conv1dNormActivation,
    # pool_block=nn.AdaptiveAvgPool1d
    
)
block_config = eff_config.block_configs[0]
block = eff_config.make_blocks()

x = torch.randn(1, 48, 224, 224)
# x = torch.randn(1, 48, 224)
print(summary(
    model=block, 
    input_data=x,
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
    depth=5,
))

Layer (type (var_name))                                 Input Shape          Output Shape         Param #              Trainable
Sequential (Sequential)                                 [1, 48, 224, 224]    [1, 448, 14, 14]     --                   True
├─Sequential (0)                                        [1, 48, 224, 224]    [1, 24, 224, 224]    --                   True
│    └─MBConv (0)                                       [1, 48, 224, 224]    [1, 24, 224, 224]    --                   True
│    │    └─Sequential (block)                          [1, 48, 224, 224]    [1, 24, 224, 224]    --                   True
│    │    │    └─Conv2dNormActivation (0)               [1, 48, 224, 224]    [1, 48, 224, 224]    --                   True
│    │    │    │    └─Conv2d (0)                        [1, 48, 224, 224]    [1, 48, 224, 224]    432                  True
│    │    │    │    └─BatchNorm2d (1)                   [1, 48, 224, 224]    [1, 48, 224, 224]    96                   True
│  

In [6]:
import torch
import torch.nn as nn
import torchvision.ops as ops

from torchinfo import summary

class EfficientNet(nn.Module):
    def __init__(
        self,
        in_channels, out_channels, config,
        activation_layer=nn.SiLU,
        **kwargs
    ):
        """EfficientNet

        Args:
            in_channels (int): The number of input channels.
            out_channels (int): The number of output channels.
            bottle_factor (int, optional): The size of bottle neck. Defaults to 4.
            kernel_size (int, optional): The kernel for the middle convolution. Defaults to 3.
            stride (int, optional): The stride for the middle convolution and the shortcut. Defaults to 1.
            padding (str, optional): The padding for the middle convolution. Defaults to 'same'.
            phi (float, optional): The compound scaling coefficient. Defaults to 1.
        """
        super().__init__()
        self.config = config
        in_mb_channels = config.adjust_channels(32)
        out_mb_channels = config.adjust_channels(320)

        self.model = nn.Sequential(
            config.conv_block(
                in_channels, in_mb_channels,
                kernel_size=3, padding=1, stride=2,
                activation_layer=activation_layer,
                **kwargs,
            ),
            self.config.make_blocks(activation_layer=activation_layer, **kwargs),
            config.conv_block(
                out_mb_channels, config.last_channels,
                kernel_size=1, activation_layer=activation_layer,
                **kwargs,
            ),
            config.pool_block(1),
        )
        self.classifier = nn.Sequential(
            # nn.Dropout(p=config.dropout, inplace=True),
            nn.Linear(config.last_channels, out_channels),
        )
    
    def forward(self, x):
        x = self.model(x)
        x = x.flatten(1)
        x = self.classifier(x)
        return x

# # in_channels, out_channels = 4, 6
# in_channels, out_channels = 3, 6
# eff_config = EfficientNetConfig()
# model = EfficientNet(
#     in_channels,
#     out_channels,
#     eff_config,
# )

# x = torch.randn(1, in_channels, 128, 128)
# print(summary(
#     model=model, 
#     input_data=x,
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"],
#     depth=6
# ))