In [2]:
import torch
import torch.nn as nn
from torchvision.ops import StochasticDepth

In [None]:
class SqueezeExcitation(nn.Module):
    def __init__(self, in_dim, reduction_ratio=16, use_residual=False) -> None:
        super(SqueezeExcitation, self).__init__()

        self.use_residual = use_residual

        self.squeeze = nn.AdaptiveAvgPool2d(1)

        self.excitation = nn.Sequential(
            nn.Conv2d(in_channels=in_dim, out_channels=in_dim//reduction_ratio, kernel_size=1, stride=1),
            nn.SiLU(), # nn.ReLU()
            nn.Conv2d(in_channels=in_dim//reduction_ratio, out_channels=in_dim, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

    def forward(self, x):

        se_out = self.squeeze(x)
        se_out = self.excitation(se_out)
        se_out = se_out * x 

        if self.use_residual:
            se_out += x
        
        return se_out

In [None]:
class MBConv(nn.Module):
    def __init__(self, in_dim, out_dim, expasion_ratio=6, kernel_size=3, stride=1) -> None:
        super(MBConv, self).__init__()
        self.use_residual = in_dim == out_dim and stride == 1
        hidden_dim = int ( in_dim * expasion_ratio)
        padding = kernel_size // 2  # if kernel size 3 -> padding 1, kerenl size 5 -> padding 2

        self.expasion = nn.Sequential(
            nn.Conv2d(in_channels=in_dim, out_channels=hidden_dim, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU()
        )

        self.dwise = nn.Sequential(
            nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU()
        )

        self.se = SqueezeExcitation(in_dim,hidden_dim, reduction_ratio=16)

        self.projection = nn.Sequential(
            nn.Conv2d(in_channels=hidden_dim, out_channels=out_dim, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_dim)
        )

    def forward(self, x):

        h = self.expasion(x)
        h = self.dwise(h)
        h = self.se(h)
        h = self.projection(h)

        if self.use_residual:
            h = h + x 

        return h
        

In [None]:
# efficientNetV1_b0 model 
class efficientNetV1(nn.Module):
    def __init__(self, hidden_dim=32, num_classes=1000) -> None:
        super(efficientNetV1, self).__init__()

        self.init_conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU()
        )
        
        # c 32 -> 16
        self.mbconv1 = MBConv(in_dim=hidden_dim, out_dim=hidden_dim//2, expasion_ratio=1, kernel_size=3, stride=1)

        hidden_dim = hidden_dim*3//4 # 32 -> 24
        self.mbconv2 = nn.Sequential(
            *([MBConv(in_dim=hidden_dim*2//3, out_dim=hidden_dim, expasion_ratio=6, kernel_size=3, stride=2)]
            + [MBConv(in_dim=hidden_dim, out_dim=hidden_dim, expasion_ratio=6, kernel_size=3, stride=1) for _ in range(1)])
        )

        hidden_dim = hidden_dim*5//3 # 24 -> 40 
        self.mbconv3 = nn.Sequential(
            *([MBConv(in_dim=hidden_dim*3//5, out_dim=hidden_dim, expasion_ratio=6, kernel_size=5, stride=2)]
            + [MBConv(in_dim=hidden_dim, out_dim=hidden_dim, expasion_ratio=6, kernel_size=5, stride=1) for _ in range(1)])
        )

        hidden_dim = hidden_dim*2 # 40 -> 80
        self.mbconv4 = nn.Sequential(
            *([MBConv(in_dim=hidden_dim//2, out_dim=hidden_dim, expasion_ratio=6, kernel_size=3, stride=2)]
            + [MBConv(in_dim=hidden_dim, out_dim=hidden_dim, expasion_ratio=6, kernel_size=3, stride=1) for _ in range(2)])
        )

        hidden_dim = hidden_dim*7//5 # 80 -> 112
        self.mbconv5 = nn.Sequential(
            *([MBConv(in_dim=hidden_dim*5//7, out_dim=hidden_dim, expasion_ratio=6, kernel_size=5, stride=1)]
            + [MBConv(in_dim=hidden_dim, out_dim=hidden_dim, expasion_ratio=6, kernel_size=5, stride=1) for _ in range(2)])
        )
        
        hidden_dim = hidden_dim*12//7
        self.mbconv6 = nn.Sequential(
            *([MBConv(in_dim=hidden_dim*7//12, out_dim=hidden_dim, expasion_ratio=6, kernel_size=5, stride=2)]
              +[MBConv(in_dim=hidden_dim, out_dim=hidden_dim, expasion_ratio=6, kernel_size=5, stride=1) for _ in range(3)])
        )

        hidden_dim = hidden_dim*5 // 3
        self.mbconv7 = MBConv(in_dim=hidden_dim*3//5, out_dim=hidden_dim, expasion_ratio=6, kernel_size=3, stride=1)

        hidden_dim = hidden_dim*4
        self.last_conv = nn.Sequential(
            nn.Conv2d(in_channels=hidden_dim//4, out_channels=hidden_dim, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU()
        )

        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(in_channels=hidden_dim, out_channels=num_classes, kernel_size=1, stride=1)
    
    def forward(self, x):

        h = self.init_conv(x)

        h = self.mbconv1(h)
        h = self.mbconv2(h)
        h = self.mbconv3(h)
        h = self.mbconv4(h)
        h = self.mbconv5(h)
        h = self.mbconv6(h)
        h = self.mbconv7(h)

        h = self.last_conv(h)

        p = self.pooling(h)

        out = self.fc(p)

        return out





In [3]:
class SqueezeExcitation3D(nn.Module):
    def __init__(self, in_dim, sqz_dim) -> None:
        super(SqueezeExcitation3D, self).__init__()

        self.pool = nn.AdaptiveAvgPool3d(output_size=1)
        self.fc1 = nn.Conv3d(in_dim, sqz_dim, kernel_size=1, stride=1)
        self.fc2 = nn.Conv3d(sqz_dim, in_dim, kernel_size=1, stride=1)
        self.act = nn.SiLU()
        self.scale_act = nn.Sigmoid()

    
    def forward(self, x):

        squeezed = self.pool(x)

        e = self.fc1(squeezed)
        e = self.act(e)
        e = self.fc2(e)
        e = self.scale_act(e)

        out = x * e
        
        return out



In [None]:
class MBConv3D(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, kernel_size, stride, padding, scale=True) -> None:
        super(MBConv3D, self).__init__()

        self.scale = scale

        if self.scale:
            self.bottleneck = nn.Sequential(
                nn.Conv3d(in_dim, hidden_dim, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm3d(hidden_dim),
                nn.SiLU()
            )

            self.conv1 = nn.Sequential(
                nn.Conv3d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding, groups=hidden_dim, bias=False),
                nn.BatchNorm3d(hidden_dim),
                nn.SiLU()
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_dim, hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_dim, bias=False),
                nn.BatchNorm3d(hidden_dim),
                nn.SiLU()
            )
        
        self.SqueezeExcitation = SqueezeExcitation3D(hidden_dim, 8 if hidden_dim == 32 else hidden_dim//24)

        self.conv2 = nn.Sequential(
            nn.Conv3d(hidden_dim, out_dim, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm3d(out_dim)
        )
    
    def forward(self, x):

        if self.scale:
            x = self.bottleneck(x)
            
        h = self.conv1(x)
        h = self.SqueezeExcitation(h)
        h = self.conv2(h)

        return h



In [None]:
class efficientNet3D(nn.Module):
    def __init__(self, num_classes) -> None:
        super(efficientNet3D, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv3d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm3d(32),
            nn.SiLU()
        )

        self.conv2 = nn.Sequential(
            MBConv3D(in_dim=32, hidden_dim=32, out_dim=16, kernel_size=3, stride=1, padding=1, scale=False),
            StochasticDepth(p=0.0, mode='row')
        )

        self.conv3 = nn.Sequential(
            MBConv3D(in_dim=16, hidden_dim=96, out_dim=24, kernel_size=3, stride=2, padding=1, scale=True),
            StochasticDepth(p=0.0125, mode='row'),
            MBConv3D(in_dim=24, hidden_dim=144, out_dim=24, kernel_size=3, stride=1, padding=1, scale=True),
            StochasticDepth(p=0.025, mode='row')
        )

        self.conv4 = nn.Sequential(
            MBConv3D(in_dim=24, hidden_dim=144, out_dim=40, kernel_size=5, stride=2, padding=2, scale=True),
            StochasticDepth(p=0.0375, mode='row'),
            MBConv3D(in_dim=40, hidden_dim=240, out_dim=40, kernel_size=5, stride=1, padding=2, scale=True),
            StochasticDepth(p=0.05, mode='row'),
        )

        self.conv5 = nn.Sequential(
            MBConv3D(in_dim=40, hidden_dim=240, out_dim=80, kernel_size=3, stride=2, padding=1, scale=True),
            StochasticDepth(p=0.0625, mode='row'),
            MBConv3D(in_dim=80, hidden_dim=480, out_dim=80, kernel_size=3, stride=1, padding=1, scale=True),
            StochasticDepth(p=0.075, mode='row'),
            MBConv3D(in_dim=80, hidden_dim=480, out_dim=80, kernel_size=3, stride=1, padding=1, scale=True),
            StochasticDepth(p=0.0875, mode='row'),
        )

        self.conv6 = nn.Sequential(
            MBConv3D(in_dim=80, hidden_dim=480, out_dim=112, kernel_size=5, stride=1, padding=2, scale=True),
            StochasticDepth(p=0.1, mode='row'),
            MBConv3D(in_dim=112, hidden_dim=672, out_dim=112, kernel_size=5, stride=1, padding=2, scale=True),
            StochasticDepth(p=0.1125, mode='row'),
            MBConv3D(in_dim=112, hidden_dim=672, out_dim=112, kernel_size=5, stride=1, padding=2, scale=True),
            StochasticDepth(p=0.125, mode='row'),
        )

        self.conv7 = nn.Sequential(
            MBConv3D(in_dim=112, hidden_dim=672, out_dim=192, kernel_size=5, stride=2, padding=2, scale=True),
            StochasticDepth(p=0.1375, mode='row'),
            MBConv3D(in_dim=192, hidden_dim=1152, out_dim=192, kernel_size=5, stride=1, padding=2, scale=True),
            StochasticDepth(p=0.15, mode='row'),
            MBConv3D(in_dim=192, hidden_dim=1152, out_dim=192, kernel_size=5, stride=1, padding=2, scale=True),
            StochasticDepth(p=0.1625, mode='row'),
            MBConv3D(in_dim=192, hidden_dim=1152, out_dim=192, kernel_size=5, stride=1, padding=2, scale=True),
            StochasticDepth(p=0.175, mode='row'),
        )

        self.conv8 = nn.Sequential(
            MBConv3D(in_dim=192, hidden_dim=1152, out_dim=320, kernel_size=3, stride=1, padding=1, scale=False),
            StochasticDepth(p=0.1875, mode='row')
        )

        self.conv9 = nn.Sequential(
            nn.Conv3d(320, 1280, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm3d(1280),
            nn.SiLU()
        )

        self.pool = nn.AdaptiveAvgPool3d(output_size=1)
        self.drop = nn.Dropout(p=0.2)
        self.clf  = nn.Linear(1280, out_features=num_classes)
        
    def forward(self, x):

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.conv9(x)

        p = self.pool(x)
        p = self.drop(p)

        out = self.clf(p)

        return out

        