In [None]:
import torch
import torch.nn as nn

class InvertedResidual(nn.Module):
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def __init__(self, in_channels, out_channels, stride, expansion_factor):
        super().__init__()
        hidden_dim = in_channels * expansion_factor
        self.use_res_connect = (stride == 1 and in_channels == out_channels)

        layers = []
        # Expansion phase
        if expansion_factor != 1:
            layers.append(nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False))
            layers.append(nn.BatchNorm2d(hidden_dim))
            layers.append(nn.ReLU6(inplace=True))
            # layers.append(nn.SiLU(inplace=True))
        
        # Depthwise separable convolution
        layers.append(nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False))
        layers.append(nn.BatchNorm2d(hidden_dim))
        layers.append(nn.ReLU6(inplace=True))
        # layers.append(nn.SiLU(inplace=True))

        # Projection phase (linear bottleneck)
        layers.append(nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=False))
        layers.append(nn.BatchNorm2d(out_channels))
        
        self.conv = nn.Sequential(*layers)

        self._initialize_weights()

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


In [7]:

# Test run
x = torch.randn(1, 32, 112, 112)
# block = InvertedResidual(32, 16, stride=1, expansion_factor=6)
block = InvertedResidual(32, 16, stride=2, expansion_factor=6)
print(block(x).shape)


torch.Size([1, 16, 56, 56])
