In [0]:
# ! pip install mxnet-cu100

In [0]:
from mxnet import gluon, nd, init
from mxnet.gluon import nn, model_zoo
import mxnet as mx, numpy as np

class CAM(nn.HybridBlock):
  def __init__(self, num_channels, ratio, **kwargs):
    super(CAM, self).__init__(**kwargs)
    with self.name_scope():
      self.avg_pool = nn.GlobalAvgPool2D()
      self.max_pool = nn.GlobalMaxPool2D()
      self.conv1 = nn.Conv2D(num_channels // ratio, 1, use_bias=False)
      self.conv2 = nn.Conv2D(num_channels, 1, use_bias=False)

  def hybrid_forward(self, F, X):
    X_avg = self.avg_pool(X)
    X_avg = self.conv1(X_avg)
    X_avg = F.relu(X_avg)
    X_avg = self.conv2(X_avg)

    X_max = self.max_pool(X)
    X_max = self.conv1(X_max)
    X_max = F.relu(X_max)
    X_max = self.conv2(X_max)

    Y = X_avg + X_max
    Y = F.sigmoid(Y)
    return Y


class SAM(nn.HybridBlock):
  def __init__(self, kernel_size=7, **kwargs):
    super(SAM, self).__init__(**kwargs)
    with self.name_scope():
      self.kernel_size = kernel_size
      assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
      self.padding = 3 if self.kernel_size == 7 else 1

      self.conv = nn.Conv2D(1, kernel_size=self.kernel_size, padding=self.padding, use_bias=False)

  def hybrid_forward(self, F, X):
    X_avg = F.mean(X, axis=1, keepdims=True)
    X_max = F.max(X, axis=1, keepdims=True)
    Y = F.concat(X_avg, X_max, dim=1)
    Y = self.conv(Y)
    Y = F.sigmoid(Y)
    return Y


class BCAM(nn.HybridBlock):
  def __init__(self, num_channels, ratio, **kwargs):
    super(BCAM, self).__init__(**kwargs)
    with self.name_scope():
      self.num_channels = num_channels
      self.ratio = ratio
      self.conv1 = nn.Conv2D(self.num_channels, kernel_size=3, strides=1, padding=1, use_bias=False)
      self.bn1 = nn.BatchNorm()
      # self.conv2 = nn.Conv2D(self.num_channels, kernel_size=3, strides=1, padding=1, use_bias=False)
      # self.bn2 = nn.BatchNorm()
      self.cam = CAM(self.num_channels, self.ratio)
      self.sam = SAM()

  def hybrid_forward(self, F, X):
    # residual = X
    Y = F.relu(self.bn1(self.conv1(X)))
    Y = F.broadcast_mul(self.cam(Y), Y)
    Y = F.broadcast_mul(self.sam(Y), Y)
    # Y = Y + residual
    Y = F.relu(Y)
    return Y

class BCAM1(nn.HybridBlock):
  def __init__(self, num_channels, ratio, **kwargs):
    super(BCAM1, self).__init__(**kwargs)
    with self.name_scope():
      self.num_channels = num_channels
      self.ratio = ratio
      # self.conv1 = nn.Conv2D(self.num_channels, kernel_size=3, strides=1, padding=1, use_bias=False)
      # self.bn1 = nn.BatchNorm()
      # self.conv2 = nn.Conv2D(self.num_channels, kernel_size=3, strides=1, padding=1, use_bias=False)
      # self.bn2 = nn.BatchNorm()
      self.cam = CAM(self.num_channels, self.ratio)
      self.sam = SAM()

  def hybrid_forward(self, F, X):
    # residual = X
    # Y = F.relu(self.bn1(self.conv1(X)))
    Y = F.broadcast_mul(self.cam(X), X)
    Y = F.broadcast_mul(self.sam(Y), Y)
    # Y = Y + residual
    Y = F.relu(Y)
    return Y

class Parallel(nn.HybridBlock):
  def __init__(self, ptype, num_channels, extend=True, **kwargs):
    super(Parallel, self).__init__(**kwargs)
    assert ptype in ('left', 'right'), 'wrong type!'
    self.extend = extend
    with self.name_scope():
      if ptype == 'left':
        self.conv1 = nn.Conv2D(num_channels[0], 1)
        self.bn1 = nn.BatchNorm()
        self.conv2 = nn.Conv2D(num_channels[0], kernel_size=3, strides=1, padding=1)
        self.bn2 = nn.BatchNorm()
        self.conv3 = nn.Conv2D(num_channels[1], kernel_size=3, strides=1, padding=1)
        self.bn3 = nn.BatchNorm()
        self.conv4 = nn.Conv2D(num_channels[2], kernel_size=3, strides=1, padding=1)
        self.bn4 = nn.BatchNorm()
        self.conv5 = nn.Conv2D(num_channels[2], kernel_size=5, strides=1, padding=2)
        self.bn5 = nn.BatchNorm()
        if extend:
          self.maxpool = nn.MaxPool2D()
          self.conv6 = nn.Conv2D(num_channels[2], kernel_size=1, strides=1)
          self.bn6 = nn.BatchNorm()
      elif ptype == 'right':
        self.conv1 = nn.Conv2D(num_channels[0], kernel_size=3, strides=1, padding=1)
        self.bn1 = nn.BatchNorm()
        self.conv2 = nn.Conv2D(num_channels[0], kernel_size=5, strides=1, padding=2)
        self.bn2 = nn.BatchNorm()
        self.conv3 = nn.Conv2D(num_channels[1], kernel_size=5, strides=1, padding=2)
        self.bn3 = nn.BatchNorm()
        self.conv4 = nn.Conv2D(num_channels[2], kernel_size=5, strides=1, padding=2)
        self.bn4 = nn.BatchNorm()
        self.conv5 = nn.Conv2D(num_channels[0], 1)
        self.bn5 = nn.BatchNorm()
        if extend:
          self.maxpool = nn.MaxPool2D()
          self.conv6 = nn.Conv2D(num_channels[0], kernel_size=1, strides=1)
          self.bn6 = nn.BatchNorm()
      
  def hybrid_forward(self, F, X):
    if self.extend:
      X = self.bn6(self.maxpool(self.conv6(X)))
    Y = F.relu(self.bn1(self.conv1(X)))
    Y = F.relu(self.bn2(self.conv2(Y)))
    Y = F.relu(self.bn3(self.conv3(Y)))
    Y = F.relu(self.bn4(self.conv4(Y)))
    Y = self.bn5(self.conv5(Y))
    
    Y = Y + X
    Y = F.relu(Y)
    return Y



In [0]:
def bilinear_kernel(in_channels, out_channels, kernel_size):
  factor = (kernel_size + 1) // 2
  if kernel_size % 2 == 1:
      center = factor - 1
  else:
      center = factor - 0.5
  og = np.ogrid[:kernel_size, :kernel_size]
  filt = (1 - abs(og[0] - center) / factor) * \
          (1 - abs(og[1] - center) / factor)
  weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
                    dtype='float32')
  weight[range(in_channels), range(out_channels), :, :] = filt
  return nd.array(weight)

def res50_bcam():
  pretrained = model_zoo.vision.resnet50_v2(pretrained=True).features
  net = nn.HybridSequential()
  net.add(pretrained[:5])
  net.add(pretrained[5], BCAM(256, 8))
  # net.add(pretrained[5])
  # net.add(pretrained[6], BCAM(512, 16))
  net.add(pretrained[6])
  net.add(pretrained[7])
  # net.add(pretrained[7], BCAM(1024, 32))
  # net.add(pretrained[8:11], BCAM(2048, 64))
  net.add(pretrained[8:11])
  net.add(nn.Conv2D(2, kernel_size=1), nn.BatchNorm(), nn.Activation('relu'))
  # net.add(nn.Conv2D(64, kernel_size=1), nn.BatchNorm(), nn.Activation('relu'))
  net.add(nn.Conv2DTranspose(2, kernel_size=64, padding=16, strides=32))
  # net.add(nn.Conv2D(16, kernel_size=1), nn.BatchNorm(), nn.Activation('relu'))
  # net.add(nn.Conv2D(1, kernel_size=1))

  # for i in [2,4,6,8,9,10,12,13]:
  for i in [2,6,7]:
    # net[i].initialize(init=init.Xavier())
    net[i].initialize(init=init.Normal(sigma=0.01), force_reinit=True)
  net[9].initialize(init=init.Constant(bilinear_kernel(2, 2, 64)))
  print('res50_bcam loaded...')

  return net

class res34_bcam(nn.HybridSequential):
  def __init__(self, **kwargs):
    super(res34_bcam, self).__init__(**kwargs)
    res34 = model_zoo.vision.resnet34_v2(pretrained=True).features
    with self.name_scope():
      self.preoperation = res34[0:5]
      self.residual1 = res34[5]
      self.bcam1 = BCAM1(64, 8)
      self.residual2 = res34[6]
      self.bcam2 = BCAM1(128, 8)
      self.residual3 = res34[7]
      self.bcam3 = BCAM1(256, 8)
      self.residual4 = res34[8]
      self.bcam4 = BCAM1(512, 8)
      self.bn1 = res34[9]
      self.conv1 = nn.Conv2D(64, kernel_size=1)
      # self.bn2 = nn.BatchNorm()
      self.conv2 = nn.Conv2D(2, kernel_size=1)
      # self.bn3 = nn.BatchNorm()
      self.upsample = nn.Conv2DTranspose(2, kernel_size=64, padding=16, strides=32)
    
  def hybrid_forward(self, F, X):
    Y = self.preoperation(X)
    Y = self.residual1(Y)
    Y = self.bcam1(Y)
    Y = self.residual2(Y)
    Y = self.bcam2(Y)
    Y = self.residual3(Y)
    Y = self.bcam3(Y)
    Y = self.residual4(Y)
    Y = self.bcam4(Y)
    Y = F.relu(self.bn1(Y))
    Y = F.relu((self.conv1(Y)))
    Y = F.relu((self.conv2(Y)))
    Y = self.upsample(Y)
    return Y


class res34_bcam_parallel(nn.HybridSequential):
  def __init__(self, **kwargs):
    super(res34_bcam_parallel, self).__init__(**kwargs)
    res34 = model_zoo.vision.resnet34_v2(pretrained=True).features
    with self.name_scope():
      self.preoperation = res34[0:5]
      self.residual1 = res34[5]
      self.parallel1_0 = Parallel('left', (256, 128, 64), False)
      self.parallel1_1 = Parallel('right', (64, 128, 256), False)
      self.bcam1 = BCAM(64, 8)

      self.conv1 = nn.Conv2D(128, kernel_size=3, strides=2, padding=1)
      self.residual2 = res34[6]
      self.parallel2_0 = Parallel('left', (512, 256, 128))
      self.parallel2_1 = Parallel('right', (128, 256, 512))
      self.bcam2 = BCAM(128, 8)

      self.conv2 = nn.Conv2D(256, kernel_size=3, strides=2, padding=1)
      self.residual3 = res34[7]
      self.parallel3_0 = Parallel('left', (1024, 512, 256))
      self.parallel3_1 = Parallel('right', (256, 512, 1024))
      self.bcam3 = BCAM(256, 8)

      self.conv3 = nn.Conv2D(512, kernel_size=3, strides=2, padding=1)
      self.residual4 = res34[8]
      self.parallel4_0 = Parallel('left', (2048, 1024, 512))
      self.parallel4_1 = Parallel('right', (512, 1024, 2048))
      self.bcam4 = BCAM(512, 8)

      self.bn1 = res34[9]
      self.conv4 = nn.Conv2D(64, kernel_size=1)
      # self.bn2 = nn.BatchNorm()
      self.conv5 = nn.Conv2D(2, kernel_size=1)
      self.bn3 = nn.BatchNorm()
      self.upsample = nn.Conv2DTranspose(2, kernel_size=64, padding=16, strides=32)
    
  def hybrid_forward(self, F, X):
    Y = self.preoperation(X)
    out0 = Y
    concat0 = Y
    out1_0 = self.parallel1_0(out0)
    out1_1 = self.residual1(out0)
    out1_2 = self.parallel1_1(out0)
    out1 = F.concat(out1_0, out1_1, out1_2, dim=1)
    out1 = self.bcam1(F.concat(out1, concat0, dim=1))

    concat1 = self.conv1(out1)
    out2_0 = self.parallel2_0(out1)
    out2_1 = self.residual2(out1)
    out2_2 = self.parallel2_1(out1)
    out2 = F.concat(out2_0, out2_1, out2_2, dim=1)
    out2 = self.bcam2(F.concat(out2, concat1, dim=1))

    concat2 = self.conv2(out2)
    out3_0 = self.parallel3_0(out2)
    out3_1 = self.residual3(out2)
    out3_2 = self.parallel3_1(out2)
    out3 = F.concat(out3_0, out3_1, out3_2, dim=1)
    out3 = self.bcam3(F.concat(out3, concat2, dim=1))

    concat3 = self.conv3(out3)
    out4_0 = self.parallel4_0(out3)
    out4_1 = self.residual4(out3)
    out4_2 = self.parallel4_1(out3)
    out4 = F.concat(out4_0, out4_1, out4_2, dim=1)
    out4 = self.bcam4(F.concat(out4, concat3, dim=1))

    Y = F.relu(self.bn1(out4))
    Y = F.relu((self.conv4(Y)))
    Y = F.relu((self.bn3(self.conv5(Y))))
    Y = self.upsample(Y)
    return Y

In [0]:
def get_net(model='res34_bcam'):
  a = nd.random.uniform(shape=(1,3,64,64))
  if model == 'res34_bcam':
    net = nn.HybridSequential()
    net = res34_bcam()
    
  elif model == 'res34_bcam_parallel':
    net = nn.HybridSequential()
    net = res34_bcam_parallel()
  
  print(model, ' loading')
  for layer in net:
    if net.prefix in layer.name and 'resnet' not in layer.name and layer != net[-1]:
      layer.initialize(init=init.Xavier(), force_reinit=True)
    elif layer == net[-1]:
      net[-1].initialize(init.Constant(bilinear_kernel(2, 2, 64)), force_reinit=True)
    # a = layer(a)
    
  print(model, ' initialize finished')

  return net