In [1]:
####

In [2]:
import torch
from torch import nn

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:

class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_grps = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        
        groups = out_channels if use_grps else 1
        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding , 
                               groups = groups)
        if self.use_norm:
            self.norm = nn.BatchNorm2d(out_channels)
        if self.use_activation:
            self.relu = nn.ReLU6()

    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.relu(x)
        return x

In [40]:
class SqueezeExcitation(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels):
        super(SqueezeExcitation , self).__init__()

        self.adp_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = Conv(in_channels , out_channels , kernel_size=(1 , 1) , padding=0)
        self.silu = nn.SiLU()
        self.conv2 = Conv(out_channels , in_channels , kernel_size=(1 , 1) , padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self , x):
        x_ = x.clone()
        x = self.adp_avg_pool(x)
        x = self.conv1(x)
        x = self.silu(x)
        x = self.conv2(x)
        x = self.sigmoid(x)
        x = x * x_
        return x

In [41]:
class Inverted_Res_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 t , 
                 stride = (1 , 1) , 
                 reduction = 0.4):
        super(Inverted_Res_Block , self).__init__()
        
        self.use_residual = in_channels == out_channels and stride == 1
        hidden_dim = in_channels * t
        self.conv1 = Conv(in_channels , hidden_dim , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0)
        self.conv2 = Conv(hidden_dim , hidden_dim , stride=stride , use_grps=True)
        self.conv3 = Conv(hidden_dim , out_channels , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0)

        reduced_dim = int(in_channels / reduction)
        self.squeeze_excitation = SqueezeExcitation(out_channels , reduced_dim)

    def forward(self , x):
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.squeeze_excitation(x)
        if self.use_residual:
            return x + x_
        else :
            return x

In [42]:

config = [
    # expand_ratio, channels, repeats, stride, kernel_size
    [1, 16, 1, 1, 3],
    [6, 24, 2, 2, 3],
    [6, 40, 2, 2, 5],
    [6, 80, 3, 2, 3],
    [6, 112, 3, 1, 5],
    [6, 192, 4, 2, 5],
    [6, 320, 1, 1, 3],
]

phi_values = {
    # tuple of: (phi_value, resolution, drop_rate)
    "b0": (0, 224, 0.2),  
    "b1": (0.5, 240, 0.2),
    "b2": (1, 260, 0.3),
    "b3": (2, 300, 0.3),
    "b4": (3, 380, 0.4),
    "b5": (4, 456, 0.4),
    "b6": (5, 528, 0.5),
    "b7": (6, 600, 0.5),
}

In [51]:
class Efficient_Net(nn.Module):
    def __init__(self , 
                 in_channels , 
                 version , 
                 out_channels_model , 
                 config=config, 
                 phi_values=phi_values):
        super(Efficient_Net , self).__init__()

        self.config=config
        self.phi_values = phi_values
        depth_factor , width_factor = self._get_scale_params(version)
        out_channels = int(1280 * width_factor)

        self.layers = self._get_layers(depth_factor , width_factor , out_channels)

        self.cls = nn.Linear(out_channels , out_channels_model)
        self.adp_avg_pool = nn.AdaptiveAvgPool2d(1)

    def _get_layers(self , depth_factor , width_factor , out_channels_model):
        layers = nn.ModuleList()
        
        channels = int(32 * width_factor)
        layers.append(Conv(3 , channels , stride=(2 , 2)))
        in_channels = channels
        for layer in config:
            t , out_channels , repeats , stride , kernel_size = layer
            out_channels = 4 * int(int(out_channels * width_factor)/4)
            repeats = int(repeats * depth_factor)

            for _ in range(repeats):
                layers.append(
                    Inverted_Res_Block(
                        in_channels , 
                        out_channels , 
                        t , 
                        stride
                    )
                )
                in_channels = out_channels
        layers.append(Conv(in_channels , out_channels_model , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0))
        return layers

    def _get_scale_params(self , version , alpha=1.2 , beta=1.1):
        phi_value , resolution , drop_rate = self.phi_values[version]
        depth_factor = alpha ** phi_value
        width_factor = beta ** phi_value
        return depth_factor , width_factor

    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        x = self.adp_avg_pool(x)
        x = self.cls(x.squeeze(-1).squeeze(-1))
        return x

In [55]:
def test():
    version = "b7"
    phi, res, drop_rate = phi_values[version]
    num_examples, num_classes = 4, 10
    x = torch.randn((num_examples, 3, res, res)).to(device)
    model = Efficient_Net(
        version=version,
        out_channels_model=num_classes,
        in_channels = 3
    ).to(device)

    print(model(x).shape)

test()

torch.Size([4, 10])
