<a href="https://colab.research.google.com/github/Quillbolt/colabnotebook/blob/main/mxnet_compare.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision


class SqueezeAndExcite(nn.Module):
    def __init__(self, in_channels, out_channels, divide=4):
        super(SqueezeAndExcite, self).__init__()
        mid_channels = in_channels // divide
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.SEblock = nn.Sequential(
            nn.Linear(in_features=in_channels, out_features=mid_channels),
            nn.ReLU6(inplace=True),
            nn.Linear(in_features=mid_channels, out_features=out_channels),
            nn.ReLU6(inplace=True),
        )

    def forward(self, x):
        b, c, h, w = x.size()
        out = self.pool(x)
        out = out.view(b, -1)
        out = self.SEblock(out)
        out = out.view(b, c, 1, 1)
        return out * x


def Conv1x1BNReLU(in_channels,out_channels):
    return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU6(inplace=True)
        )

def Conv3x3BNReLU(in_channels,out_channels,stride):
    return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU6(inplace=True)
        )


def VarGConv(in_channels,out_channels,kernel_size,stride,S):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=kernel_size // 2, groups=in_channels // S,
                  bias=False),
        nn.BatchNorm2d(out_channels),
        nn.PReLU(),
    )

def VarGPointConv(in_channels, out_channels,stride,S,isRelu):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 1, stride, padding=0, groups=in_channels // S,
                  bias=False),
        nn.BatchNorm2d(out_channels),
        nn.PReLU() if isRelu else nn.Sequential(),
    )

class VarGBlock_S1(nn.Module):
    def __init__(self, in_plances,kernel_size, stride=1, S=8):
        super(VarGBlock_S1, self).__init__()
        plances = 2 * in_plances
        self.varGConv1 = VarGConv(in_plances, plances, kernel_size, stride, S)
        self.varGPointConv1 = VarGPointConv(plances, in_plances, stride, S, isRelu=True)
        self.varGConv2 = VarGConv(in_plances, plances, kernel_size, stride, S)
        self.varGPointConv2 = VarGPointConv(plances, in_plances, stride, S, isRelu=False)
        self.se =  SqueezeAndExcite(in_plances,in_plances)
        self.prelu = nn.PReLU()

    def forward(self, x):
        out = x
        x = self.varGPointConv1(self.varGConv1(x))
        x = self.varGPointConv2(self.varGConv2(x))
        x = self.se(x)
        out += x
        return self.prelu(out)

class VarGBlock_S2(nn.Module):
    def __init__(self, in_plances,kernel_size, stride=2, S=8):
        super(VarGBlock_S2, self).__init__()
        plances = 2 * in_plances

        self.varGConvBlock_branch1 = nn.Sequential(
            VarGConv(in_plances, plances, kernel_size, stride, S),
            VarGPointConv(plances, plances, 1, S, isRelu=True),
        )
        self.varGConvBlock_branch2 = nn.Sequential(
            VarGConv(in_plances, plances, kernel_size, stride, S),
            VarGPointConv(plances, plances, 1, S, isRelu=True),
        )

        self.varGConvBlock_3 = nn.Sequential(
            VarGConv(plances, plances*2, kernel_size, 1, S),
            VarGPointConv(plances*2, plances, 1, S, isRelu=False),
        )
        self.shortcut = nn.Sequential(
            VarGConv(in_plances, plances, kernel_size, stride, S),
            VarGPointConv(plances, plances, 1, S, isRelu=False),
        )
        self.prelu = nn.PReLU()

    def forward(self, x):
        out = self.shortcut(x)
        x1 = x2 = x
        x1= self.varGConvBlock_branch1(x1)
        x2 = self.varGConvBlock_branch2(x2)
        x_new = x1 + x2
        x_new = self.varGConvBlock_3(x_new)
        out += x_new
        return self.prelu(out)


class HeadBlock(nn.Module):
    def __init__(self, in_plances, kernel_size, S=8):
        super(HeadBlock, self).__init__()

        self.varGConvBlock = nn.Sequential(
            VarGConv(in_plances, in_plances, kernel_size, 2, S),
            VarGPointConv(in_plances, in_plances, 1, S, isRelu=True),
            VarGConv(in_plances, in_plances, kernel_size, 1, S),
            VarGPointConv(in_plances, in_plances, 1, S, isRelu=False),
         )

        self.shortcut = nn.Sequential(
            VarGConv(in_plances, in_plances, kernel_size, 2, S),
            VarGPointConv(in_plances, in_plances, 1, S, isRelu=False),
        )

    def forward(self, x):
        out = self.shortcut(x)
        x = self.varGConvBlock(x)
        out += x
        return out


class TailEmbedding(nn.Module):
    def __init__(self, in_plances, plances=512, S=8):
        super(TailEmbedding, self).__init__()
        self.embedding = nn.Sequential(
            Conv1x1BNReLU(in_plances, 1024),
            nn.Conv2d(1024, 1024, 7, 1, padding=0, groups=1024 // S,
                      bias=False),
            nn.Conv2d(1024, 512, 1, 1, padding=0, groups=512, bias=False),
        )

        self.fc = nn.Linear(in_features=512,out_features=plances)

    def forward(self, x):
        x = self.embedding(x)
        x = x.view(x.size(0),-1)
        out = self.fc(x)
        return out


class VarGFaceNet(nn.Module):
    def __init__(self, num_classes=512):
        super(VarGFaceNet, self).__init__()
        S = 8

        self.conv1 = Conv3x3BNReLU(3, 40, 1)
        self.head = HeadBlock(40,3)
        self.stage2 = nn.Sequential(
            VarGBlock_S2(40,3,2),
            VarGBlock_S1(80, 3, 1),
            VarGBlock_S1(80, 3, 1),
        )
        self.stage3 = nn.Sequential(
            VarGBlock_S2(80, 3, 2),
            VarGBlock_S1(160, 3, 1),
            VarGBlock_S1(160, 3, 1),
            VarGBlock_S1(160, 3, 1),
            VarGBlock_S1(160, 3, 1),
            VarGBlock_S1(160, 3, 1),
            VarGBlock_S1(160, 3, 1),
        )
        self.stage4 = nn.Sequential(
            VarGBlock_S2(160, 3, 2),
            VarGBlock_S1(320, 3, 1),
            VarGBlock_S1(320, 3, 1),
            VarGBlock_S1(320, 3, 1),
        )

        self.tail = TailEmbedding(320,num_classes)

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.head(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        out= self.tail(x)
        return out

In [None]:
    model = VarGFaceNet()
    print(model)

    input = torch.randn(1, 3, 112, 112)
    out = model(input)
    print(out.shape)

VarGFaceNet(
  (conv1): Sequential(
    (0): Conv2d(3, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
  )
  (head): HeadBlock(
    (varGConvBlock): Sequential(
      (0): Sequential(
        (0): Conv2d(40, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=5, bias=False)
        (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
      )
      (1): Sequential(
        (0): Conv2d(40, 40, kernel_size=(1, 1), stride=(1, 1), groups=5, bias=False)
        (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
      )
      (2): Sequential(
        (0): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=5, bias=False)
        (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, tra

In [None]:
!pip install mxnet

Collecting mxnet
[?25l  Downloading https://files.pythonhosted.org/packages/29/bb/54cbabe428351c06d10903c658878d29ee7026efbe45133fd133598d6eb6/mxnet-1.7.0.post1-py2.py3-none-manylinux2014_x86_64.whl (55.0MB)
[K     |████████████████████████████████| 55.0MB 77kB/s 
[?25hCollecting graphviz<0.9.0,>=0.8.1
  Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl
Installing collected packages: graphviz, mxnet
  Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.8.4 mxnet-1.7.0.post1


In [None]:
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

'''
Author: Horizon Robotics Inc.
The company is committed to be the global leader of edge AI platform.
The model implemented in this scripts runs ~200fps on the Sunrise 2.
Sunrise 2 is the second generation of an embedded AI chip designed by Horizon Robotics,
targeting to empower AIoT devices by AI.

Implemented the following paper:
Mengjia Yan, Mengao Zhao, Zining Xu, Qian Zhang, Guoli Wang, Zhizhong Su. "VarGFaceNet: An Efficient Variable Group Convolutional Neural Network for Lightweight Face Recognition" (https://arxiv.org/abs/1910.04985)

'''

import mxnet as mx


def Act(data, act_type, name):
    if act_type == 'prelu':
        body = mx.sym.LeakyReLU(data=data, act_type='prelu', name=name)
    else:
        body = mx.symbol.Activation(data=data, act_type=act_type, name=name)
    return body


def get_setting_params(**kwargs):
    # bn_params
    bn_mom = kwargs.get('bn_mom', 0.9)
    bn_eps = kwargs.get('bn_eps', 2e-5)
    fix_gamma = kwargs.get('fix_gamma', False)
    use_global_stats = kwargs.get('use_global_stats', False)
    # net_setting param
    workspace = kwargs.get('workspace', 512)
    act_type = kwargs.get('act_type', 'prelu')
    use_se = kwargs.get('use_se', True)
    se_ratio = kwargs.get('se_ratio', 4)
    group_base = kwargs.get('group_base', 8)

    setting_params = {}
    setting_params['bn_mom'] = bn_mom
    setting_params['bn_eps'] = bn_eps
    setting_params['fix_gamma'] = fix_gamma
    setting_params['use_global_stats'] = use_global_stats
    setting_params['workspace'] = workspace
    setting_params['act_type'] = act_type
    setting_params['use_se'] = use_se
    setting_params['se_ratio'] = se_ratio
    setting_params['group_base'] = group_base

    return setting_params


def se_block(data, num_filter, setting_params, name):
    se_ratio = setting_params['se_ratio']
    act_type = setting_params['act_type']

    pool1 = mx.sym.Pooling(data=data,
                           global_pool=True,
                           pool_type='avg',
                           name=name + '_se_pool1')
    conv1 = mx.sym.Convolution(data=pool1,
                               num_filter=num_filter // se_ratio,
                               kernel=(1, 1),
                               stride=(1, 1),
                               pad=(0, 0),
                               name=name + "_se_conv1")
    act1 = Act(data=conv1, act_type=act_type, name=name + '_se_act1')

    conv2 = mx.sym.Convolution(data=act1,
                               num_filter=num_filter,
                               kernel=(1, 1),
                               stride=(1, 1),
                               pad=(0, 0),
                               name=name + "_se_conv2")
    act2 = mx.symbol.Activation(data=conv2,
                                act_type='sigmoid',
                                name=name + "_se_sigmoid")
    out_data = mx.symbol.broadcast_mul(data, act2)
    return out_data


def separable_conv2d(data,
                     in_channels,
                     out_channels,
                     kernel,
                     pad,
                     setting_params,
                     stride=(1, 1),
                     factor=1,
                     bias=False,
                     bn_dw_out=True,
                     act_dw_out=True,
                     bn_pw_out=True,
                     act_pw_out=True,
                     dilate=1,
                     name=None):
    bn_mom = setting_params['bn_mom']
    bn_eps = setting_params['bn_eps']
    fix_gamma = setting_params['fix_gamma']
    use_global_stats = setting_params['use_global_stats']
    workspace = setting_params['workspace']
    group_base = setting_params['group_base']
    act_type = setting_params['act_type']
    assert in_channels % group_base == 0

    # depthwise
    dw_out = mx.sym.Convolution(data=data,
                                num_filter=int(in_channels * factor),
                                kernel=kernel,
                                pad=pad,
                                stride=stride,
                                no_bias=False if bias else True,
                                num_group=int(in_channels / group_base),
                                dilate=(dilate, dilate),
                                workspace=workspace,
                                name=name + '_conv2d_depthwise')
    if bn_dw_out:
        dw_out = mx.sym.BatchNorm(data=dw_out,
                                  fix_gamma=fix_gamma,
                                  eps=bn_eps,
                                  momentum=bn_mom,
                                  use_global_stats=use_global_stats,
                                  name=name + '_conv2d_depthwise_bn')
    if act_dw_out:
        dw_out = Act(data=dw_out,
                     act_type=act_type,
                     name=name + '_conv2d_depthwise_act')
    # pointwise
    pw_out = mx.sym.Convolution(data=dw_out,
                                num_filter=out_channels,
                                kernel=(1, 1),
                                stride=(1, 1),
                                pad=(0, 0),
                                num_group=1,
                                no_bias=False if bias else True,
                                workspace=workspace,
                                name=name + '_conv2d_pointwise')
    if bn_pw_out:
        pw_out = mx.sym.BatchNorm(data=pw_out,
                                  fix_gamma=fix_gamma,
                                  eps=bn_eps,
                                  momentum=bn_mom,
                                  use_global_stats=use_global_stats,
                                  name=name + '_conv2d_pointwise_bn')
    if act_pw_out:
        pw_out = Act(data=pw_out,
                     act_type=act_type,
                     name=name + '_conv2d_pointwise_act')
    return pw_out


def vargnet_block(data,
                  n_out_ch1,
                  n_out_ch2,
                  n_out_ch3,
                  setting_params,
                  factor=2,
                  dim_match=True,
                  multiplier=1,
                  kernel=(3, 3),
                  stride=(1, 1),
                  dilate=1,
                  with_dilate=False,
                  name=None):
    use_se = setting_params['use_se']
    act_type = setting_params['act_type']

    out_channels_1 = int(n_out_ch1 * multiplier)
    out_channels_2 = int(n_out_ch2 * multiplier)
    out_channels_3 = int(n_out_ch3 * multiplier)

    pad = (((kernel[0] - 1) * dilate + 1) // 2,
           ((kernel[1] - 1) * dilate + 1) // 2)

    if with_dilate:
        stride = (1, 1)
    if dim_match:
        short_cut = data
    else:
        short_cut = separable_conv2d(data=data,
                                     in_channels=out_channels_1,
                                     out_channels=out_channels_3,
                                     kernel=kernel,
                                     pad=pad,
                                     setting_params=setting_params,
                                     stride=stride,
                                     factor=factor,
                                     bias=False,
                                     act_pw_out=False,
                                     dilate=dilate,
                                     name=name + '_shortcut')
    sep1_data = separable_conv2d(data=data,
                                 in_channels=out_channels_1,
                                 out_channels=out_channels_2,
                                 kernel=kernel,
                                 pad=pad,
                                 setting_params=setting_params,
                                 stride=stride,
                                 factor=factor,
                                 bias=False,
                                 dilate=dilate,
                                 name=name + '_sep1_data')
    sep2_data = separable_conv2d(data=sep1_data,
                                 in_channels=out_channels_2,
                                 out_channels=out_channels_3,
                                 kernel=kernel,
                                 pad=pad,
                                 setting_params=setting_params,
                                 stride=(1, 1),
                                 factor=factor,
                                 bias=False,
                                 dilate=dilate,
                                 act_pw_out=False,
                                 name=name + '_sep2_data')

    if use_se:
        sep2_data = se_block(data=sep2_data,
                             num_filter=out_channels_3,
                             setting_params=setting_params,
                             name=name)

    out_data = sep2_data + short_cut
    out_data = Act(data=out_data, act_type=act_type, name=name + '_out_data_act')
    return out_data


def vargnet_branch_merge_block(data,
                               n_out_ch1,
                               n_out_ch2,
                               n_out_ch3,
                               setting_params,
                               factor=2,
                               dim_match=False,
                               multiplier=1,
                               kernel=(3, 3),
                               stride=(2, 2),
                               dilate=1,
                               with_dilate=False,
                               name=None):
    act_type = setting_params['act_type']

    out_channels_1 = int(n_out_ch1 * multiplier)
    out_channels_2 = int(n_out_ch2 * multiplier)
    out_channels_3 = int(n_out_ch3 * multiplier)

    pad = (((kernel[0] - 1) * dilate + 1) // 2,
           ((kernel[1] - 1) * dilate + 1) // 2)

    if with_dilate:
        stride = (1, 1)
    if dim_match:
        short_cut = data
    else:
        short_cut = separable_conv2d(data=data,
                                     in_channels=out_channels_1,
                                     out_channels=out_channels_3,
                                     kernel=kernel,
                                     pad=pad,
                                     setting_params=setting_params,
                                     stride=stride,
                                     factor=factor,
                                     bias=False,
                                     act_pw_out=False,
                                     dilate=dilate,
                                     name=name + '_shortcut')
    sep1_data_brach1 = separable_conv2d(data=data,
                                        in_channels=out_channels_1,
                                        out_channels=out_channels_2,
                                        kernel=kernel,
                                        pad=pad,
                                        setting_params=setting_params,
                                        stride=stride,
                                        factor=factor,
                                        bias=False,
                                        dilate=dilate,
                                        act_pw_out=False,
                                        name=name + '_sep1_data_branch')
    sep1_data_brach2 = separable_conv2d(data=data,
                                        in_channels=out_channels_1,
                                        out_channels=out_channels_2,
                                        kernel=kernel,
                                        pad=pad,
                                        setting_params=setting_params,
                                        stride=stride,
                                        factor=factor,
                                        bias=False,
                                        dilate=dilate,
                                        act_pw_out=False,
                                        name=name + '_sep2_data_branch')
    sep1_data = sep1_data_brach1 + sep1_data_brach2
    sep1_data = Act(data=sep1_data, act_type=act_type, name=name + '_sep1_data_act')
    sep2_data = separable_conv2d(data=sep1_data,
                                 in_channels=out_channels_2,
                                 out_channels=out_channels_3,
                                 kernel=kernel,
                                 pad=pad,
                                 setting_params=setting_params,
                                 stride=(1, 1),
                                 factor=factor,
                                 bias=False,
                                 dilate=dilate,
                                 act_pw_out=False,
                                 name=name + '_sep2_data')
    out_data = sep2_data + short_cut
    out_data = Act(data=out_data, act_type=act_type, name=name + '_out_data_act')
    return out_data


def add_vargnet_conv_block(data,
                           stage,
                           units,
                           in_channels,
                           out_channels,
                           setting_params,
                           kernel=(3, 3),
                           stride=(2, 2),
                           multiplier=1,
                           factor=2,
                           dilate=1,
                           with_dilate=False,
                           name=None):
    assert stage >= 2, 'stage is {}, stage must be set >=2'.format(stage)
    data = vargnet_branch_merge_block(data=data,
                                      n_out_ch1=in_channels,
                                      n_out_ch2=out_channels,
                                      n_out_ch3=out_channels,
                                      setting_params=setting_params,
                                      factor=factor,
                                      dim_match=False,
                                      multiplier=multiplier,
                                      kernel=kernel,
                                      stride=stride,
                                      dilate=dilate,
                                      with_dilate=with_dilate,
                                      name=name + '_stage_{}_unit_1'.format(stage))
    for i in range(units - 1):
        data = vargnet_block(data=data,
                             n_out_ch1=out_channels,
                             n_out_ch2=out_channels,
                             n_out_ch3=out_channels,
                             setting_params=setting_params,
                             factor=factor,
                             dim_match=True,
                             multiplier=multiplier,
                             kernel=kernel,
                             stride=(1, 1),
                             dilate=dilate,
                             with_dilate=with_dilate,
                             name=name + '_stage_{}_unit_{}'.format(stage, i + 2))
    return data


def add_head_block(data,
                   num_filter,
                   setting_params,
                   multiplier,
                   head_pooling=False,
                   kernel=(3, 3),
                   stride=(2, 2),
                   pad=(1, 1),
                   name=None):
    bn_mom = setting_params['bn_mom']
    bn_eps = setting_params['bn_eps']
    fix_gamma = setting_params['fix_gamma']
    use_global_stats = setting_params['use_global_stats']
    workspace = setting_params['workspace']
    act_type = setting_params['act_type']
    channels = int(num_filter * multiplier)

    conv1 = mx.sym.Convolution(data=data,
                               num_filter=channels,
                               kernel=kernel,
                               pad=pad,
                               stride=stride,
                               no_bias=True,
                               num_group=1,
                               workspace=workspace,
                               name=name + '_conv1')
    bn1 = mx.sym.BatchNorm(data=conv1,
                           fix_gamma=fix_gamma,
                           eps=bn_eps,
                           momentum=bn_mom,
                           use_global_stats=use_global_stats,
                           name=name + '_conv1_bn')

    act1 = Act(data=bn1, act_type=act_type, name=name + '_conv1_act')

    if head_pooling:
        head_data = mx.symbol.Pooling(data=act1,
                                      kernel=(3, 3),
                                      stride=(2, 2),
                                      pad=(1, 1),
                                      pool_type='max',
                                      name=name + '_max_pooling')
    else:
        head_data = vargnet_block(data=act1,
                                  n_out_ch1=num_filter,
                                  n_out_ch2=num_filter,
                                  n_out_ch3=num_filter,
                                  setting_params=setting_params,
                                  factor=1,
                                  dim_match=False,
                                  multiplier=multiplier,
                                  kernel=kernel,
                                  stride=(2, 2),
                                  dilate=1,
                                  with_dilate=False,
                                  name=name + '_head_pooling')
    return head_data


def add_emb_block(data,
                  input_channels,
                  last_channels,
                  emb_size,
                  setting_params,
                  bias=False,
                  name=None):
    bn_mom = setting_params['bn_mom']
    bn_eps = setting_params['bn_eps']
    fix_gamma = setting_params['fix_gamma']
    use_global_stats = setting_params['use_global_stats']
    workspace = setting_params['workspace']
    act_type = setting_params['act_type']
    group_base = setting_params['group_base']
    # last channels
    if input_channels != last_channels:
        data = mx.sym.Convolution(data=data,
                                  num_filter=last_channels,
                                  kernel=(1, 1),
                                  pad=(0, 0),
                                  stride=(1, 1),
                                  no_bias=False if bias else True,
                                  workspace=workspace,
                                  name=name + '_convx')
        data = mx.sym.BatchNorm(data=data,
                                fix_gamma=fix_gamma,
                                eps=bn_eps,
                                momentum=bn_mom,
                                use_global_stats=use_global_stats,
                                name=name + '_convx_bn')
        data = Act(data=data, act_type=act_type, name=name + '_convx_act')
    # depthwise
    convx_depthwise = mx.sym.Convolution(data=data,
                                         num_filter=last_channels,
                                         num_group=int(last_channels / group_base),
                                         kernel=(7, 7),
                                         pad=(0, 0),
                                         stride=(1, 1),
                                         no_bias=False if bias else True,
                                         workspace=workspace,
                                         name=name + '_convx_depthwise')
    convx_depthwise = mx.sym.BatchNorm(data=convx_depthwise,
                                       fix_gamma=fix_gamma,
                                       eps=bn_eps,
                                       momentum=bn_mom,
                                       use_global_stats=use_global_stats,
                                       name=name + '_convx_depthwise_bn')
    # pointwise
    convx_pointwise = mx.sym.Convolution(data=convx_depthwise,
                                         num_filter=last_channels // 2,
                                         kernel=(1, 1),
                                         pad=(0, 0),
                                         stride=(1, 1),
                                         no_bias=False if bias else True,
                                         workspace=workspace,
                                         name=name + '_convx_pointwise')
    convx_pointwise = mx.sym.BatchNorm(data=convx_pointwise,
                                       fix_gamma=fix_gamma,
                                       eps=bn_eps,
                                       momentum=bn_mom,
                                       use_global_stats=use_global_stats,
                                       name=name + '_convx_pointwise_bn')
    convx_pointwise = Act(data=convx_pointwise,
                          act_type=act_type,
                          name=name + '_convx_pointwise_act')
    emb_feat = mx.sym.FullyConnected(data=convx_pointwise, num_hidden=emb_size, name='pre_fc1')
    emb_feat = mx.sym.BatchNorm(data=emb_feat,
                                fix_gamma=fix_gamma,
                                eps=bn_eps,
                                momentum=bn_mom,
                                use_global_stats=use_global_stats,
                                name='fc1')
    return emb_feat


def get_symbol(**kwargs):
    setting_params = get_setting_params(**kwargs)
    multiplier = kwargs.get('multiplier', 1.25)
    emb_size = kwargs.get('emb_size', 512)
    factor = kwargs.get('factor', 2)
    head_pooling = kwargs.get('head_pooling', False)
    num_stage = 3
    stage_list = [2, 3, 4]
    units = [3, 7, 4]
    filter_list = [32, 64, 128, 256]
    last_channels = 1024
    dilate_list = [1, 1, 1]
    with_dilate_list = [False, False, False]

    data = mx.sym.Variable(name='data')
    data = mx.sym.identity(data=data, name='id')
    data = data - 127.5
    data = data * 0.0078125

    body = add_head_block(data=data,
                          num_filter=filter_list[0],
                          setting_params=setting_params,
                          multiplier=multiplier,
                          head_pooling=head_pooling,
                          kernel=(3, 3),
                          stride=(1, 1),
                          pad=(1, 1),
                          name="vargface_head")

    for i in range(num_stage):
        body = add_vargnet_conv_block(data=body,
                                      stage=stage_list[i],
                                      units=units[i],
                                      in_channels=filter_list[i],
                                      out_channels=filter_list[i + 1],
                                      setting_params=setting_params,
                                      kernel=(3, 3),
                                      stride=(2, 2),
                                      multiplier=multiplier,
                                      factor=factor,
                                      dilate=dilate_list[i],
                                      with_dilate=with_dilate_list[i],
                                      name="vargface")
    emb_feat = add_emb_block(data=body,
                             input_channels=filter_list[3],
                             last_channels=last_channels,
                             emb_size=emb_size,
                             setting_params=setting_params,
                             bias=False,
                             name='embed')
    return emb_feat


if __name__ == '__main__':
    get_symbol()

In [None]:
    net = get_symbol()
    digraph = mx.viz.plot_network(net, shape={'data': (1, 3, 112, 112)},
                                  node_attrs={'fixedsize': 'false'})
    digraph.view()

'plot.gv.pdf'

In [None]:
import torch
import torch.nn as nn

In [None]:

class MLP(nn.Module):
    # layer_sizes[0] is the dimension of the input
    # layer_sizes[-1] is the dimension of the output
    def __init__(self, layer_sizes, final_relu=False):
        super().__init__()
        layer_list = []
        layer_sizes = [int(x) for x in layer_sizes]
        num_layers = len(layer_sizes) - 1
        final_relu_layer = num_layers if final_relu else num_layers - 1
        for i in range(len(layer_sizes) - 1):
            input_size = layer_sizes[i]
            curr_size = layer_sizes[i + 1]
            if i < final_relu_layer:
                layer_list.append(nn.ReLU(inplace=True))
            layer_list.append(nn.Linear(input_size, curr_size))
        self.net = nn.Sequential(*layer_list)
        self.last_linear = self.net[-1]

    def forward(self, x):
        return self.net(x)

In [None]:
net = MLP([128, 64],final_relu=True)

In [None]:
net

MLP(
  (net): Sequential(
    (0): ReLU(inplace=True)
    (1): Linear(in_features=128, out_features=64, bias=True)
  )
  (last_linear): Linear(in_features=128, out_features=64, bias=True)
)