In [48]:
import torch
from torch import nn

In [49]:
#Ensure all layers has channel number divisible by 8
def make_divisible(v , divisor , min_value = None):
    if min_value == None:
       min_value = divisor
    new_v = max(min_value , int(v+divisor/2)//divisor*divisor)
    if new_v < 0.9*v:
        new_v += divisor 
    return new_v

#SiLU (Swish) activation function
class SiLU(nn.Module):
    def forward(self , x):
        return x*torch.sigmoid(x)

class SE_Layer(nn.Module):
    def __init__(self , in_channels , out_channels , reduction = 4):
        super(SE_Layer , self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(out_channels , make_divisible(in_channels // reduction , 8)),
            SiLU(),
            nn.Linear(make_divisible(in_channels // reduction , 8), out_channels),
            nn.Sigmoid()
        )
    def forward(self , x):
        b , c , _ , _ = x.size() #-> [batch_size , color_channels , height , width]
        y = self.avgpool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x*y

#Convolution layers
def conv3x3(in_channels , out_channels , stride):
        return nn.Sequential(
        nn.Conv2d(in_channels , out_channels , 3 , stride , 1 , bias = False),
        nn.BatchNorm2d(out_channels),
        SiLU()
    )
def conv1x1(in_channels , out_channels):
        return nn.Sequential(
        nn.Conv2d(in_channels , out_channels , 1 , 1 , 0 , bias = False), #-> Kernel_size , stride , padding
        nn.BatchNorm2d(out_channels),
        SiLU()
    )

class MBConv(nn.Module):
    def __init__(self , in_channels , out_channels , stride , expand_ratio , se):
        super(MBConv, self).__init__()
        hidden_dim = round(in_channels*expand_ratio)
        self.residual = stride == 1 and in_channels == out_channels
        if se:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels , hidden_dim , 1 , 1 , 0 , bias = False),
                nn.BatchNorm2d(hidden_dim),
                SiLU(),
                nn.Conv2d(hidden_dim , hidden_dim , 3 , stride , 1 , groups = hidden_dim , bias = False),
                nn.BatchNorm2d(hidden_dim),
                SiLU(),
                SE_Layer(in_channels , hidden_dim),
                nn.Conv2d(hidden_dim , out_channels , 1 , 1 , 0 , bias = False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.conv = nn.Sequential( #fused-MBConv
                nn.Conv2d(in_channels , hidden_dim , 3 , stride , 1 , bias = False),
                nn.BatchNorm2d(hidden_dim),
                SiLU(),
                nn.Conv2d(hidden_dim , out_channels , 1 , 1 , 0 , bias = False),
                nn.BatchNorm2d(out_channels)
            )
    def forward(self , x):
        if self.residual:
            return x+self.conv(x)
        else:
            return self.conv(x)

In [50]:
class EffNetV2(nn.Module):
    def __init__(self , cfgs , num_classes = 1000 , width_mult = 1.0):
        super(EffNetV2, self).__init__()
        self.cfgs = cfgs
        input_channel = make_divisible(24*width_mult,8)
        layers = [conv3x3(3 , input_channel , 2)]
        #Residual Block
        block = MBConv
        for t , c , n , s , se in self.cfgs:
            output_channel = make_divisible(c*width_mult , 8)
            for i in range(n):
                layers.append(block(input_channel , output_channel , s if i == 0 else 1 , t , se))
                input_channel = output_channel
        self.features = nn.Sequential(*layers)
        output_channel = make_divisible(1792*width_mult , 8) if width_mult > 1.0 else 1792
        self.conv = conv1x1(input_channel , output_channel)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.classifier = nn.Linear(output_channel , num_classes)
    def forward(self , x):
        x = self.features(x)
        x = self.conv(x)
        x = self.avgpool(x)
        x = x.view(x.size(0) , -1)
        x = self.classifier(x)
        return x

In [51]:
def effnetv2_m(**kwargs):
    """
    EfficientNetV2-M model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  24,  3, 1, 0],
        [4,  48,  5, 2, 0],
        [4,  80,  5, 2, 0],
        [4, 160,  7, 2, 1],
        [6, 176, 14, 1, 1],
        [6, 304, 18, 2, 1],
        [6, 512,  5, 1, 1],
    ]
    return EffNetV2(cfgs, **kwargs)


def effnetv2_l(**kwargs):
    """
    EfficientNetV2-L model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  32,  4, 1, 0],
        [4,  64,  7, 2, 0],
        [4,  96,  7, 2, 0],
        [4, 192, 10, 2, 1],
        [6, 224, 19, 1, 1],
        [6, 384, 25, 2, 1],
        [6, 640,  7, 1, 1],
    ]
    return EffNetV2(cfgs, **kwargs)


def effnetv2_xl(**kwargs):
    """
    EfficientNetV2-XL model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  32,  4, 1, 0],
        [4,  64,  8, 2, 0],
        [4,  96,  8, 2, 0],
        [4, 192, 16, 2, 1],
        [6, 256, 24, 1, 1],
        [6, 512, 32, 2, 1],
        [6, 640,  8, 1, 1],
    ]
    return EffNetV2(cfgs, **kwargs)