In [1]:
import torch
from torch import nn

class SE(nn.Module):
    def __init__(self, c, r=4):
        super().__init__()
        hidden = c // r
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(c, hidden, 1),
            nn.SiLU(),
            nn.Conv2d(hidden, c, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        scale = self.fc(x)
        return x * scale

In [2]:
class FusedMBConv(nn.Module):
    def __init__(self, in_c, out_c, expand_ratio, stride):
        super().__init__()
        hidden = in_c * expand_ratio

        if expand_ratio != 1:
            # fused: expansion happens in the 3x3 conv
            self.expand = nn.Sequential(
                nn.Conv2d(in_c, hidden, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(hidden),
                nn.SiLU()
            )
        else:
            self.expand = nn.Identity()

        self.project = nn.Sequential(
            nn.Conv2d(hidden, out_c, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_c)
        )

        self.use_res = (stride == 1 and in_c == out_c)

    def forward(self, x):
        out = self.expand(x)
        out = self.project(out)
        if self.use_res:
            out = out + x
        return out


In [3]:
class MBConv(nn.Module):
    def __init__(self, in_c, out_c, expand_ratio, stride):
        super().__init__()
        hidden = in_c * expand_ratio

        self.pre = nn.Sequential(
            nn.Conv2d(in_c, hidden, 1, bias=False),
            nn.BatchNorm2d(hidden),
            nn.SiLU()
        )

        self.dw = nn.Sequential(
            nn.Conv2d(hidden, hidden, 3, stride=stride, padding=1, groups=hidden, bias=False),
            nn.BatchNorm2d(hidden),
            nn.SiLU()
        )

        self.se = SE(hidden)

        self.project = nn.Sequential(
            nn.Conv2d(hidden, out_c, 1, bias=False),
            nn.BatchNorm2d(out_c)
        )

        self.use_res = (stride == 1 and in_c == out_c)

    def forward(self, x):
        out = self.pre(x)
        out = self.dw(out)
        out = self.se(out)
        out = self.project(out)
        if self.use_res:
            out = out + x
        return out


In [4]:
class EfficientNetV2S(nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()

        # 1. Stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(24),
            nn.SiLU()
        )

        # Block configurations: (type, in, out, expand, stride, repeats)
        cfg = [
            ("fused", 24, 24, 1, 1, 2),
            ("fused", 24, 48, 4, 2, 4),
            ("fused", 48, 64, 4, 2, 4),
            ("mbconv", 64, 128, 4, 2, 6),
            ("mbconv", 128, 160, 6, 1, 9),
            ("mbconv", 160, 256, 6, 2, 15),
        ]

        layers = []
        for block_type, in_c, out_c, exp, stride, reps in cfg:
            for i in range(reps):
                s = stride if i == 0 else 1
                if block_type == "fused":
                    layers.append(FusedMBConv(in_c, out_c, exp, s))
                else:
                    layers.append(MBConv(in_c, out_c, exp, s))
                in_c = out_c

        self.blocks = nn.Sequential(*layers)

        # 3. Head
        self.head = nn.Sequential(
            nn.Conv2d(256, 1280, 1, bias=False),
            nn.BatchNorm2d(1280),
            nn.SiLU()
        )

        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1280, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        x = self.head(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
