In [1]:
import torch
from torch import nn

In [2]:
class InvertedResidual(nn.Module):
    def __init__(self, inp, outp, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        hid_dim = round(inp * expand_ratio)
        self.conv = nn.Sequential(
            # 1x1 升维
            nn.Conv2d(inp, hid_dim, 1, 1, 0, bias=False),
            nn.BatchNorm2d(hid_dim),
            nn.ReLU6(hid_dim),
            # dw
            nn.Conv2d(hid_dim, hid_dim, 3, stride, 1, groups=hid_dim, bias=False),
            nn.BatchNorm2d(hid_dim),
            nn.ReLU6(hid_dim),
            # 1x1 降维
            nn.Conv2d(hid_dim, outp, 1, 1, 0, bias=False),
            nn.BatchNorm2d(outp)
        )
        
    def forward(self, x):
        return x + self.conv(x)

In [5]:
mobileNet_v2_block = InvertedResidual(24, 24, 1, 6)

In [6]:
mobileNet_v2_block

InvertedResidual(
  (conv): Sequential(
    (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
    (3): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
    (4): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU6(inplace=True)
    (6): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (7): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [7]:
data = torch.randn(1, 24, 56, 56)
out = mobileNet_v2_block(data)
out.shape

torch.Size([1, 24, 56, 56])