In [1]:
import os, sys
import torch
import torchvision
import torch.nn as nn

### Setting block

In [110]:
#Reference: https://github.com/lukemelas/EfficientNet-PyTorch
class depthwise_conv(nn.Module):
    def __init__(self, in_channels, kernels_per_layer, out_channels, groups, **kwarg):
        super(depthwise_conv, self).__init__()
        
        assert in_channels%groups == 0, "Groups Error: in_channels should be divisible by groups"
            
        self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=kernels_per_layer*in_channels,
                                        groups=in_channels,
                                        **kwarg)
        
    def forward(self, x):
        out = self.depthwise_conv(x)
        
        return out
    
class Basic_conv2d_ReLU6(nn.Module):
    def __init__(self, in_channels, out_channels, inplace, **kwarg):
        super(Basic_conv2d_ReLU6, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, **kwarg)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu6 = nn.ReLU6(inplace=inplace)
        
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu6(out)
        
        return out
        
class MBConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expand_ratio, se_ratio, strides, bias, inplace, use_squeeze_excitation):
        super(MBConvBlock, self).__init__()
        
        assert 0 < se_ratio <= 1, 'Squeeze ratio for squeeze and excitation should be within (0,1]'
        
        self.strides = strides
        self.use_residual = strides == 1 or strides == (1,1)
        self.use_squeeze_excitation = use_squeeze_excitation
        hidden_dim = round(in_channels * expand_ratio)
        
        self.inverted_residual_block = nn.Sequential(
            #pw: expansion
            nn.Conv2d(in_channels=in_channels, out_channels=hidden_dim, kernel_size=(1,1), stride=1, padding=0, bias=bias),
            nn.BatchNorm2d(num_features=hidden_dim),
            nn.ReLU6(inplace=inplace),
            
            #dw
            nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(3,3),
                      stride=strides, padding=1, bias=bias, groups=hidden_dim),
            nn.BatchNorm2d(num_features=hidden_dim),
            nn.ReLU6(inplace=inplace),
            
            #pw: linear, compression
            nn.Conv2d(in_channels=hidden_dim, out_channels=out_channels, kernel_size=(1,1), 
                      stride=1, padding=0, bias=bias),
            nn.BatchNorm2d(num_features=out_channels)      
        )
        
        if in_channels != out_channels:
            self.match_conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), stride=1, padding=0)
        
        if self.use_squeeze_excitation:
            squeeze_channels = int(out_channels * se_ratio)
            
            self.global_avgPooling = nn.AdaptiveAvgPool2d((1,1))
            
            self.se_net = nn.Sequential(
                nn.Linear(in_features=out_channels, out_features=squeeze_channels),
                nn.ReLU(inplace=inplace),
                nn.Linear(in_features=squeeze_channels, out_features=out_channels),
                nn.Sigmoid()
            )
            
            
    def forward(self, x):
        out = self.inverted_residual_block(x)
        if self.use_squeeze_excitation:
            out_se = self.global_avgPooling(out)
            out_se = torch.squeeze(out_se)
            out_se = self.se_net(out_se)
            out_se = out_se.view(out_se.size()[0], out_se.size()[1], 1, 1)
            out = out * out_se
        
        if self.use_residual:
            if out.size()[1] != x.size()[1]:
                out = self.match_conv(x) + out
            else:
                out = x + out
                
        return out

### Testing on Mnist

In [111]:
mnist_trainset = torchvision.datasets.MNIST(root='./data',
                                            train=True,
                                            download=True,
                                            transform=torchvision.transforms.ToTensor())
mnist_loader = torch.utils.data.DataLoader(dataset=mnist_trainset, batch_size=64)
x, y = mnist_loader.__iter__().__next__()

In [116]:
model = MBConvBlock(in_channels=1, out_channels=64, expand_ratio=4, se_ratio=0.25,
                    strides=(1,1), bias=False, inplace=True,
                    use_squeeze_excitation=True)

model

MBConvBlock(
  (inverted_residual_block): Sequential(
    (0): Conv2d(1, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace)
    (3): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)
    (4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU6(inplace)
    (6): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (match_conv): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
  (global_avgPooling): AdaptiveAvgPool2d(output_size=(1, 1))
  (se_net): Sequential(
    (0): Linear(in_features=64, out_features=16, bias=True)
    (1): ReLU(inplace)
    (2): Linear(in_features=16, out_features=64, bias=True)
    (3): Sigmoid()
  )
)

In [115]:
out = model(x)
out.size()

torch.Size([64, 64, 1, 1])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64, 1, 1])


torch.Size([64, 64, 14, 14])

### Reference:

1. [EfficientNet: Paper](https://ai.googleblog.com/2019/05/efficientnet-improving-accuracy-and.html?fbclid=IwAR3ZzFTybRmUruCHvPi8MyTr02v6AuSHjzK9NeSccUrRZ1XQq8eMAwsgIjo)
2. [EfficientNet: Pytorch Tutorial](https://github.com/lukemelas/EfficientNet-PyTorch)
3. [EfficientNet: Review Article](https://mc.ai/%E8%AB%96%E6%96%87%E7%AD%86%E8%A8%98-ef%EF%AC%81cient-net-rethinking/)
4. [EfficientNet: Discussion Forum](https://forums.fast.ai/t/efficientnet/46978)
5. [SENet: Paper](https://arxiv.org/pdf/1709.01507.pdf)
6. [SENet: Review Article](https://blog.csdn.net/evan123mg/article/details/80058077)