In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
class Swish(nn.Module):
    def forward(self, x):
        return x*torch.sigmoid(x)

In [9]:
class SqueezeExcitation(nn.Module):
    def __init__(self, in_channels, reduced_dim):
        super(SqueezeExcitation, self).__init__()
        self.fc1= nn.Conv2d(in_channels, reduced_dim, kernel_size= 1)
        self.fc2= nn.Conv2d(reduced_dim, in_channels, kernel_size= 1)
        
    def forward(self, x):
        scale= F.adaptive_avg_pool2d(x, 1) #Squeeze
        scale= F.relu(self.fc1(scale))
        scale= torch.sigmoid(self.fc2(scale)) #Excitation
        return x* scale # Recalibration

In [12]:
class MBConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expand_ratio, stride, reduction= 4, survival_prob=0.8):
        super(MBConvBlock, self).__init__()
        self.stride= stride
        self.survival_prob= survival_prob
        hidden_dim= in_channels * expand_ratio
    
        if expand_ratio !=1:
            self.expand_conv= nn.Conv2d(in_channels, hidden_dim, kernel_size= 1, bias= False)
            self.bn0= nn.BatchNorm2d(hidden_dim)
    
        self.depthwise_conv= nn.Conv2d(hidden_dim, hidden_dim, kernel_size= 3, stride= stride, padding= 1, groups= hidden_dim, bias= False)
        self.bn1= nn.BatchNorm2d(hidden_dim)
        self.se = SqueezeExcitation(hidden_dim, in_channels // reduction)
        self.project_conv = nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=False)
        self.bn2= nn.BatchNorm2d(out_channels)
        self.swish= Swish()
    
    def forward(self, x):
        identity= x
        if hasattr(self,'expand_conv'):
            x= self.swish(self.bn0(self.expand_conv(x)))
        
        x= self.swish(self.bn1(self.depthwise_conv(x)))
        x= self.se(x)
        x= self.bn2(self.project_conv(x))
        
        if self.stride==1 and x.size()== identity.size():
            if self.training and self.survival_prob < 1:
                bernoulli= torch.rand(x.size(0), 1, 1, 1, device= x.device)
                bernoulli= bernoulli.lt(self.survival_prob).float()
                x= (x * bernoulli)/ self.survival_prob
            x+= identity
        return x

In [14]:
class EfficientNet(nn.Module):
    def __init__(self, num_classes= 1000):
        super(EfficientNet, self).__init__()
        settings= [
            [32, 16, 1, 1, 1],
            [16, 24, 2, 6, 2],
            [24, 40, 2, 6, 2],
            [40, 80, 3, 6, 1],
            [80, 112, 3, 6, 1],
            [112, 192, 4, 6, 2],
            [192, 320, 1, 6, 1],
        ]
        self.stem= nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias= False),
            nn.BatchNorm2d(32),
            Swish()
        )
        
        self.blocks= self._make_blocks(settings)
        self.head= nn.Sequential(
            nn.Conv2d(320, 1280, kernel_size=1, bias= False),
            nn.BatchNorm2d(1280),
            Swish(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1280, num_classes)
        )
        
        self._initialize_weights()
    
    def _make_blocks(self, settings):
        blocks= []
        for in_channels, out_channels, num_blocks, expand_ratio, stride in settings:
            for i in range(num_blocks):
                if i==0:
                    blocks.append(MBConvBlock(in_channels, out_channels, expand_ratio, stride))
                else:
                    blocks.append(MBConvBlock(out_channels, out_channels, expand_ratio, 1))
                in_channels= out_channels
        return nn.Sequential(*blocks)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode= 'fan_out', nonlinearity= 'relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x= self.stem(x)
        x= self.blocks(x)
        x= self.head(x)
        return x

    
model= EfficientNet(num_classes= 10)
x= torch.randn(1, 3, 224, 224)
print(model(x).shape)
print(model)   

torch.Size([1, 10])
EfficientNet(
  (stem): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Swish()
  )
  (blocks): Sequential(
    (0): MBConvBlock(
      (depthwise_conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (se): SqueezeExcitation(
        (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
        (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (project_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (swish): Swish()
    )
    (1): MBConvBlock(
      (expand_conv): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn0): Batch