In [0]:
from mxnet import gluon, nd
from mxnet.gluon import nn
import mxnet as mx

class Residual(nn.Block):
  '''
  simple residual module
  '''
  def __init__(self, num_channels, change_shape=False, **kwargs):
    super(Residual, self).__init__(**kwargs)
    self.change_shape = change_shape
    strides = 1 if not change_shape else 2
    
    # conv use 3*3 filter, use strides change shape if necessery
    self.conv1 = nn.Conv2D(num_channels, kernel_size=3, strides=strides, padding=1)
    self.bn1 = nn.BatchNorm()
    
    # conv use 3*3 filter, only get more features
    self.conv2 = nn.Conv2D(num_channels, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm()
    
    # in order to plus input and output, change input shape if output shape has changed 
    if change_shape:
      self.conv3 = nn.Conv2D(num_channels, kernel_size=1, strides=strides)
    # SAM1
    self.sam1 = SAM()
    # SAM2
    self.sam2 = SAM()
    # CAM
    self.cam = CAM(num_channels, 4)
      
  def forward(self, X):
    Y_1 = nd.relu(self.bn1(self.conv1(X)))
    Y_2 = self.bn2(self.conv2(Y_1))
    if self.change_shape:
      X = self.conv3(X)
    sam_conv1 = self.sam1(Y_1)
    sam_conv2 = self.sam2(Y_2)
    attention_weight = self.cam((sam_conv1 + sam_conv2))
    X = attention_weight * X
    Y = nd.relu(Y_2 + X)
    return Y


class CAM(nn.Block):
  def __init__(self, num_channels, ratio, **kwargs):
    super(CAM, self).__init__(**kwargs)
    self.avg_pool = nn.GlobalAvgPool2D()
    self.max_pool = nn.GlobalMaxPool2D()
    self.conv1 = nn.Conv2D(num_channels // ratio, 1)
    self.conv2 = nn.Conv2D(num_channels, 1)

  def forward(self, X):
    X_avg = self.conv2(
        nd.relu(self.conv1(self.avg_pool(X))))
    X_max = self.conv2(
        nd.relu(self.conv1(self.max_pool(X))))
    Y = nd.sigmoid(X_avg + X_max)
    return Y


class SAM(nn.Block):
  def __init__(self, kernel_size=7, **kwargs):
    super(SAM, self).__init__(**kwargs)
    self.kernel_size = kernel_size
    assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
    self.padding = 3 if self.kernel_size == 7 else 1
    self.avg_pool = nn.AvgPool2D(pool_size=3, strides=2, padding=1)
    self.max_pool = nn.MaxPool2D(pool_size=3, strides=2, padding=1)
    self.conv = nn.Conv2D(1, kernel_size=self.kernel_size, strides=2, padding=self.padding)

  def forward(self, X):
    X_avg = self.avg_pool(X)
    X_max = self.max_pool(X)
    Y = nd.concat(X_avg, X_max, dim=1)
    Y = nd.sigmoid(self.conv(Y))
    return Y
  
    
def resnet_block(num_channels, num_residuals, first_block=False):
  '''
  define resnet block
  '''
  blk = nn.Sequential()
  for i in range(num_residuals):
    if i == 0 and not first_block:
      # if first conv in block, need change shape
      blk.add(Residual(num_channels, change_shape=True))
    else:
      blk.add(Residual(num_channels))
  return blk




def Resnet_18():
  '''
  define resnet 18
  '''
  print('loading model resnet_18_attention')
  net = nn.Sequential()
  net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
          nn.BatchNorm(), nn.Activation('relu'),
          nn.MaxPool2D(pool_size=3, strides=2, padding=1))
  
  net.add(resnet_block(64, 2, first_block=True),
          resnet_block(128, 2),
          resnet_block(256, 2),
          resnet_block(512, 2))
  
  net.add(nn.GlobalAvgPool2D(), nn.Dense(10))
  # net.add(nn.Conv2D(num_classes, kernel_size=1),
  #       nn.Conv2DTranspose(num_classes, kernel_size=64, padding=16,
  #                          strides=32))
  
  return net


def Resnet_34():
  '''
  define resnet 34
  '''
  print('loading model resnet_34_attention')
  net = nn.Sequential()
  net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),
          nn.BatchNorm(), nn.Activation('relu'),
          nn.MaxPool2D(pool_size=3, strides=2, padding=1))
  
  net.add(resnet_block(64, 3, first_block=True),
          resnet_block(128, 4),
          resnet_block(256, 6),
          resnet_block(512, 3))
  
  net.add(nn.GlobalAvgPool2D(), nn.Dense(10))
  
  return net