In [1]:
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from functools import partial
from typing import Callable, List, Optional
# from pytorch_model_summary import summary
from torchsummary import summary


def _make_divisible(ch, divisor=8, min_ch=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_ch is None:
        min_ch = divisor
    new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_ch < 0.9 * ch:
        new_ch += divisor
    return new_ch


class SqueezeExcitation(nn.Module):
    def __init__(self, input_c: int, squeeze_factor: int = 4):
        super(SqueezeExcitation, self).__init__()
        squeeze_c = _make_divisible(input_c // squeeze_factor, 8)
        self.fc1 = nn.Conv1d(input_c, squeeze_c, 1)
        self.fc2 = nn.Conv1d(squeeze_c, input_c, 1)

    def forward(self, x: Tensor) -> Tensor:
        scale = F.adaptive_avg_pool1d(x, output_size=1)
        scale = self.fc1(scale)
        scale = F.relu(scale, inplace=True)
        scale = self.fc2(scale)
        scale = F.hardsigmoid(scale, inplace=True)
        return scale * x


class ConvBNActivation(nn.Sequential):
    def __init__(self,
                 in_planes: int,
                 out_planes: int,
                 kernel_size: int = 3,
                 stride: int = 1,
                 groups: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None):
        padding = (kernel_size - 1) // 2
        if norm_layer is None:
            norm_layer = nn.BatchNorm1d
        if activation_layer is None:
            activation_layer = nn.ReLU6
        super(ConvBNActivation, self).__init__(nn.Conv1d(in_channels=in_planes,
                                                         out_channels=out_planes,
                                                         kernel_size=kernel_size,
                                                         stride=stride,
                                                         padding=padding,
                                                         groups=groups,
                                                         bias=False),
                                               norm_layer(out_planes),
                                               activation_layer(inplace=True))


class InvertedResidual(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, exp_size, stride, use_se, use_hs,
                 norm_layer: Callable[..., nn.Module]):
        super(InvertedResidual, self).__init__()

        if stride not in [1, 2]:
            raise ValueError("illegal stride value.")

        self.use_res_connect = (stride == 1 and in_planes == out_planes)

        layers: List[nn.Module] = []
        activation_layer = nn.Hardswish if use_hs else nn.ReLU

        # expand
        if exp_size != in_planes:
            layers.append(ConvBNActivation(in_planes,
                                           exp_size,
                                           kernel_size=1,
                                           norm_layer=norm_layer,
                                           activation_layer=activation_layer))

        # depthwise
        layers.append(ConvBNActivation(exp_size,
                                       exp_size,
                                       kernel_size=kernel_size,
                                       stride=stride,
                                       groups=exp_size,
                                       norm_layer=norm_layer,
                                       activation_layer=activation_layer))

        if use_se:
            layers.append(SqueezeExcitation(exp_size))

        # project
        layers.append(ConvBNActivation(exp_size,
                                       out_planes,
                                       kernel_size=1,
                                       norm_layer=norm_layer,
                                       activation_layer=nn.Identity))

        self.block = nn.Sequential(*layers)
        self.out_channels = out_planes
        self.is_strided = stride > 1

    def forward(self, x: Tensor) -> Tensor:
        result = self.block(x)
        if self.use_res_connect:
            result += x

        return result


class modile_net_v3_small(nn.Module):
    def __init__(self, out_planes):
        super(modile_net_v3_small, self).__init__()
        in_planes = 4
        norm_layer = partial(nn.BatchNorm1d, eps=0.001, momentum=0.01)
        # 第一层
        self.in_1 = nn.Sequential(ConvBNActivation(in_planes,
                                                   16,
                                                   kernel_size=3,
                                                   stride=2,
                                                   norm_layer=norm_layer,
                                                   activation_layer=nn.Hardswish)
                                  )
        # 第二层
        self.in_2 = InvertedResidual(in_planes=16, out_planes=16, kernel_size=3,
                                     exp_size=16, stride=2, use_se=1, use_hs=0, norm_layer=norm_layer)
        self.in_3 = InvertedResidual(in_planes=16, out_planes=24, kernel_size=3,
                                     exp_size=72, stride=2, use_se=0, use_hs=0, norm_layer=norm_layer)
        self.in_4 = InvertedResidual(in_planes=24, out_planes=24, kernel_size=3,
                                     exp_size=88, stride=1, use_se=0, use_hs=0, norm_layer=norm_layer)
        self.in_5 = InvertedResidual(in_planes=24, out_planes=40, kernel_size=5,
                                     exp_size=96, stride=2, use_se=1, use_hs=1, norm_layer=norm_layer)
        self.in_6 = InvertedResidual(in_planes=40, out_planes=40, kernel_size=5,
                                     exp_size=240, stride=1, use_se=1, use_hs=1, norm_layer=norm_layer)
        self.in_7 = InvertedResidual(in_planes=40, out_planes=40, kernel_size=5,
                                     exp_size=240, stride=1, use_se=1, use_hs=1, norm_layer=norm_layer)
        self.in_8 = InvertedResidual(in_planes=40, out_planes=48, kernel_size=5,
                                     exp_size=120, stride=1, use_se=1, use_hs=1, norm_layer=norm_layer)
        self.in_9 = InvertedResidual(in_planes=48, out_planes=48, kernel_size=5,
                                     exp_size=144, stride=1, use_se=1, use_hs=1, norm_layer=norm_layer)
        self.in_10 = InvertedResidual(in_planes=48, out_planes=96, kernel_size=5,
                                      exp_size=288, stride=2, use_se=1, use_hs=1, norm_layer=norm_layer)
        self.in_11 = InvertedResidual(in_planes=96, out_planes=96, kernel_size=5,
                                      exp_size=576, stride=2, use_se=1, use_hs=1, norm_layer=norm_layer)
        self.in_12 = InvertedResidual(in_planes=96, out_planes=96, kernel_size=5,
                                      exp_size=576, stride=2, use_se=1, use_hs=1, norm_layer=norm_layer)
        self.in_13 = ConvBNActivation(96,
                                      576,
                                      kernel_size=1,
                                      norm_layer=norm_layer,
                                      activation_layer=nn.Hardswish)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        last_channel = _make_divisible(1280)
        print(last_channel)
        last_channel = 1024
        self.classifier = nn.Sequential(nn.Linear(576, last_channel),
                                        nn.Hardswish(inplace=True),
                                        nn.Dropout(p=0.2, inplace=True),
                                        nn.Linear(last_channel, out_planes))

    def forward(self, x):
        x1 = self.in_1(x)
        x2 = self.in_2(x1)
        x3 = self.in_3(x2)
        x4 = self.in_4(x3)
        x5 = self.in_5(x4)
        x6 = self.in_6(x5)
        x7 = self.in_7(x6)
        x8 = self.in_8(x7)
        x9 = self.in_9(x8)
        x10 = self.in_10(x9)
        x11 = self.in_11(x10)
        x12 = self.in_12(x11)
        x13 = self.in_13(x12)
        x14 = self.avgpool(x13)
        x14 = torch.flatten(x14, 1)
        x15 = self.classifier(x14)
        x15 = x15.unsqueeze(dim=1)
        return x15


if __name__ == "__main__":
    a = torch.randn(1, 4, 625)
    model = modile_net_v3_small(625)
    print("input shape: ", a.shape)
    print("output shape: ", model(a).shape)
    summary(model, (4,625), device="cpu")

1280
input shape:  torch.Size([1, 4, 625])
output shape:  torch.Size([1, 1, 625])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1              [-1, 16, 313]             192
       BatchNorm1d-2              [-1, 16, 313]              32
         Hardswish-3              [-1, 16, 313]               0
            Conv1d-4              [-1, 16, 157]              48
       BatchNorm1d-5              [-1, 16, 157]              32
              ReLU-6              [-1, 16, 157]               0
            Conv1d-7                 [-1, 8, 1]             136
            Conv1d-8                [-1, 16, 1]             144
 SqueezeExcitation-9              [-1, 16, 157]               0
           Conv1d-10              [-1, 16, 157]             256
      BatchNorm1d-11              [-1, 16, 157]              32
         Identity-12              [-1, 16, 157]               0
 InvertedResidual-13 