# ResNet50-NAM: 一种新的注意力计算方式复现

论文地址：https://arxiv.org/abs/2111.12419
## 简介
注意力机制在近年来大热，注意力机制可以帮助神经网络抑制通道中或者是空间中不太显著的特征。之前的很多的研究聚焦于如何通过注意力算子来获取显著性的特征。这些方法成功的发现了特征的不同维度之间的互信息量。但是，缺乏对权值的贡献因子的考虑，而这个贡献因子可以进一步的抑制不显著的特征。因此，我们瞄准了**利用权值的贡献因子来提升注意力**的效果。我们使用了**Batch Normalization**的缩放因子来表示权值的重要程度。这样可以**避免如SE，BAM和CBAM**一样增加全连接层和卷积层。这样，我们提出了一个新的注意力方式：基于归一化的注意力（**NAM**）。

## 方法
我们提出的NAM是一种轻量级的高效的注意力机制，我们采用了CBAM的模块集成方式，重新设计了通道注意力和空间注意力子模块，这样，NAM可以嵌入到每个网络block的最后。对于残差网络，可以嵌入到残差结构的最后。对于通道注意力子模块，我们使用了Batch Normalization中的缩放因子，如式子（1），缩放因子反映出各个通道的变化的大小，也表示了该通道的重要性。为什么这么说呢，可以这样理解，缩放因子即BN中的方差，方差越大表示该通道变化的越厉害，那么该通道中包含的信息会越丰富，重要性也越大，而那些变化不大的通道，信息单一，重要性小。
![](https://ai-studio-static-online.cdn.bcebos.com/b397bce30f4e4662bb68b7c5a975393773e5df9b1f404b22840330e22c82fae8)
![](https://ai-studio-static-online.cdn.bcebos.com/fbbd2c9dfc864689a2fb3a073eaa21bdcb9953ce89544162927185217e0aa2d8)

其中$\mu_B 和 \sigma_B$为均值，$B$为标准差，$\gamma 和 \beta$是可训练的仿射变换参数（尺度和位移）[参考Batch Normalization](https://arxiv.org/pdf/1502.03167.pdf).通道注意力子模块如图(1)和式(2)所示：
![](https://ai-studio-static-online.cdn.bcebos.com/86b757ab40f3460ba3430d7f7f71622d5d2c29622a5c4420921bfe033700b22b)
其中$M_c$表示最后得到的输出特征，$\gamma$是每个通道的缩放因子，因此，每个通道的权值可以通过 $W_\gamma =\gamma_i/\sum_{j=0}\gamma_j$ 得到。我们也使用一个缩放因子 $BN$ 来计算注意力权重，称为**像素归一化**。像素注意力如图(2)和式(3)所示：
![](https://ai-studio-static-online.cdn.bcebos.com/2b24da52d2614ebca2c387564fd8920d93289777228d460b9a81b0a8c4df152a)

为了抑制不重要的特征，作者在损失函数中加入了一个正则化项，如式(4)所示。




## 数据集介绍：Cifar100
链接：http://www.cs.toronto.edu/~kriz/cifar.html

![](https://ai-studio-static-online.cdn.bcebos.com/7baff8d84907478294ca778385fca688741d14db13d841c39f0125c41f0a0ac4)


CIFAR100数据集有100个类。每个类有600张大小为32 × 32 32\times 3232×32的彩色图像，其中500张作为训练集，100张作为测试集。

## 代码复现

### 1.引入依赖包

In [48]:
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn
from paddle.nn import functional as F

from paddle.utils.download import get_weights_path_from_url
import pickle
import numpy as np
from paddle import callbacks
from paddle.vision.transforms import (
    ToTensor, RandomHorizontalFlip, RandomResizedCrop, SaturationTransform, Compose,
    HueTransform, BrightnessTransform, ContrastTransform, RandomCrop, Normalize, RandomRotation
)
from paddle.vision.datasets import Cifar100
from paddle.io import DataLoader
from paddle.optimizer.lr import CosineAnnealingDecay, MultiStepDecay, LinearWarmup
import random

### 2.定义NAM注意力机制

In [49]:
class Channel_Att(nn.Layer):
    def __init__(self, channels=3, t=16):
        super(Channel_Att, self).__init__()
        self.channels = channels
        self.bn2 = nn.BatchNorm2D(self.channels)

    def forward(self, x):
        residual = x
        x = self.bn2(x)
        weight_bn = self.bn2.weight.abs() / paddle.sum(self.bn2.weight.abs())
        x = x.transpose([0, 2, 3, 1])
        x = paddle.multiply(weight_bn, x)
        x = x.transpose([0, 3, 1, 2])
        x = F.sigmoid(x) * residual #
        
        return x

class Att(nn.Layer):
    def __init__(self, channels=3, out_channels=None, no_spatial=True):
        super(Att, self).__init__()
        self.Channel_Att = Channel_Att(channels)
  
    def forward(self, x):
        x_out1=self.Channel_Att(x)

        return x_out1  

### 3.定义ResNet网络，加入NAM注意力机制
本代码参考[Paddleclas](https://github.com/PaddlePaddle/PaddleClas)实现，代码中将分类类别设定为100类

In [50]:
__all__ = []
model_urls = {
    'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
                 'cf548f46534aa3560945be4b95cd11c4'),
    'resnet34': ('https://paddle-hapi.bj.bcebos.com/models/resnet34.pdparams',
                 '8d2275cf8706028345f78ac0e1d31969'),
    'resnet50': ('https://paddle-hapi.bj.bcebos.com/models/resnet50.pdparams',
                 'ca6f485ee1ab0492d38f323885b0ad80'),
    'resnet101': ('https://paddle-hapi.bj.bcebos.com/models/resnet101.pdparams',
                  '02f35f034ca3858e1e54d4036443c92d'),
    'resnet152': ('https://paddle-hapi.bj.bcebos.com/models/resnet152.pdparams',
                  '7ad16a2f1e7333859ff986138630fd7a'),
}

class BasicBlock(nn.Layer):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2D

        if dilation > 1:
            raise NotImplementedError(
                "Dilation > 1 not supported in BasicBlock")

        self.conv1 = nn.Conv2D(
            inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
        self.nam = Att(planes)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.nam(out)
        out += identity
        out = self.relu(out)

        return out

class BottleneckBlock(nn.Layer):

    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None):
        super(BottleneckBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2D
        width = int(planes * (base_width / 64.)) * groups
        self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
        self.bn1 = norm_layer(width)
        self.conv2 = nn.Conv2D(
            width,
            width,
            3,
            padding=dilation,
            stride=stride,
            groups=groups,
            dilation=dilation,
            bias_attr=False)
        self.bn2 = norm_layer(width)
        self.conv3 = nn.Conv2D(
            width, planes * self.expansion, 1, bias_attr=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride
        self.nam = Att(planes*4)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        out = self.nam(out)
        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Layer):
    """ResNet model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        Block (BasicBlock|BottleneckBlock): block module of model.
        depth (int): layers of resnet, default: 50.
        num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
                            will not be defined. Default: 1000.
        with_pool (bool): use pool before the last fc layer or not. Default: True.

    Examples:
        .. code-block:: python

            from paddle.vision.models import ResNet
            from paddle.vision.models.resnet import BottleneckBlock, BasicBlock

            resnet50 = ResNet(BottleneckBlock, 50)

            resnet18 = ResNet(BasicBlock, 18)

    """

    def __init__(self, block, depth, num_classes=100, with_pool=True):
        super(ResNet, self).__init__()
        layer_cfg = {
            18: [2, 2, 2, 2],
            34: [3, 4, 6, 3],
            50: [3, 4, 6, 3],
            101: [3, 4, 23, 3],
            152: [3, 8, 36, 3]
        }
        layers = layer_cfg[depth]
        self.num_classes = num_classes
        self.with_pool = with_pool
        self._norm_layer = nn.BatchNorm2D

        self.inplanes = 64
        self.dilation = 1

        self.conv1 = nn.Conv2D(
            3,
            self.inplanes,
            kernel_size=7,
            stride=2,
            padding=3,
            bias_attr=False)
        self.bn1 = self._norm_layer(self.inplanes)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        if with_pool:
            self.avgpool = nn.AdaptiveAvgPool2D((1, 1))

        if num_classes > 0:
            self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2D(
                    self.inplanes,
                    planes * block.expansion,
                    1,
                    stride=stride,
                    bias_attr=False),
                norm_layer(planes * block.expansion), )

        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, 1, 64,
                  previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.with_pool:
            x = self.avgpool(x)

        if self.num_classes > 0:
            x = paddle.flatten(x, 1)
            x = self.fc(x)

        return x

def _resnet(arch, Block, depth, pretrained, **kwargs):
    model = ResNet(Block, depth, **kwargs)
    if pretrained:
        assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
            arch)
        weight_path = get_weights_path_from_url(model_urls[arch][0],
                                                model_urls[arch][1])

        param = paddle.load(weight_path)
        model.set_dict(param)

    return model

def resnet50(pretrained=False, **kwargs):
    """ResNet 50-layer model

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet

    Examples:
        .. code-block:: python

            from paddle.vision.models import resnet50

            # build model
            model = resnet50()

            # build model and load imagenet pretrained weight
            # model = resnet50(pretrained=True)
    """
    return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)

In [69]:
net = resnet50()
paddle.summary(net, (1,3,224,224))

--------------------------------------------------------------------------------
    Layer (type)         Input Shape          Output Shape         Param #    
    Conv2D-1962       [[1, 3, 224, 224]]   [1, 64, 112, 112]        9,408     
  BatchNorm2D-2458   [[1, 64, 112, 112]]   [1, 64, 112, 112]         256      
      ReLU-630       [[1, 64, 112, 112]]   [1, 64, 112, 112]          0       
    MaxPool2D-38     [[1, 64, 112, 112]]    [1, 64, 56, 56]           0       
    Conv2D-1964       [[1, 64, 56, 56]]     [1, 64, 56, 56]         4,096     
  BatchNorm2D-2460    [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-631        [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
    Conv2D-1965       [[1, 64, 56, 56]]     [1, 64, 56, 56]        36,864     
  BatchNorm2D-2461    [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
    Conv2D-1966       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-2462    [[1, 256, 56, 56]]    [1, 25

{'total_params': 23826468, 'trainable_params': 23659812}

#### 4.自定义数据集处理方式

In [51]:
class ToArray(object):
    def __call__(self, img):
        img = np.array(img)
        img = np.transpose(img, [2, 0, 1])
        img = img / 255.
        return img.astype('float32')

class RandomApply(object):
    def __init__(self, transform, p=0.5):
        super().__init__()
        self.p = p
        self.transform = transform
        

    def __call__(self, img):
        if self.p < random.random():
            return img
        img = self.transform(img)
        return img
                                                                                                                    
class LRSchedulerM(callbacks.LRScheduler):                                                                                                           
    def __init__(self, by_step=False, by_epoch=True, warm_up=True):                                                                                                
        super().__init__(by_step, by_epoch)                                                                                                                          
        assert by_step ^ warm_up
        self.warm_up = warm_up
        
    def on_epoch_end(self, epoch, logs=None):
        if self.by_epoch and not self.warm_up:
            if self.model._optimizer and hasattr(
                self.model._optimizer, '_learning_rate') and isinstance(
                    self.model._optimizer._learning_rate, paddle.optimizer.lr.LRScheduler):                                                                                         
                self.model._optimizer._learning_rate.step()                                                                                          
                                                                                                                                                     
    def on_train_batch_end(self, step, logs=None):                                                                                                   
        if self.by_step or self.warm_up:                                                                                                                             
            if self.model._optimizer and hasattr(
                self.model._optimizer, '_learning_rate') and isinstance(
                    self.model._optimizer._learning_rate, paddle.optimizer.lr.LRScheduler):                                                                                         
                self.model._optimizer._learning_rate.step()
            if self.model._optimizer._learning_rate.last_epoch >= self.model._optimizer._learning_rate.warmup_steps:
                self.warm_up = False

def _on_train_batch_end(self, step, logs=None):
    logs = logs or {}
    logs['lr'] = self.model._optimizer.get_lr()
    self.train_step += 1
    if self._is_write():
        self._updates(logs, 'train')

def _on_train_begin(self, logs=None):
    self.epochs = self.params['epochs']
    assert self.epochs
    self.train_metrics = self.params['metrics'] + ['lr']
    assert self.train_metrics
    self._is_fit = True
    self.train_step = 0

callbacks.VisualDL.on_train_batch_end = _on_train_batch_end
callbacks.VisualDL.on_train_begin = _on_train_begin

### 5.在Cifar100数据集上训练模型
使用Paddle自带的Cifar100数据集API加载

In [None]:
model = paddle.Model(resnet50())
# 加载checkpoint
# model.load('output/ResNet50-NAM/299.pdparams')
MAX_EPOCH = 300
LR = 0.001
WEIGHT_DECAY = 5e-4
MOMENTUM = 0.9
BATCH_SIZE = 128
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
CIFAR_STD = [0.1942, 0.1918, 0.1958]
DATA_FILE = './data/data76994/cifar-100-python.tar.gz'

model.prepare(
    paddle.optimizer.Momentum(
        learning_rate=LinearWarmup(CosineAnnealingDecay(LR, MAX_EPOCH), 2000, 0., LR),
        momentum=MOMENTUM,
        parameters=model.parameters(),
        weight_decay=WEIGHT_DECAY),
    paddle.nn.CrossEntropyLoss(),
    paddle.metric.Accuracy(topk=(1,5)))

# 定义数据集增强方式
transforms = Compose([
    RandomCrop(32, padding=4),
    RandomApply(BrightnessTransform(0.1)),
    RandomApply(ContrastTransform(0.1)),
    RandomHorizontalFlip(),
    RandomRotation(15),
    ToArray(),
    Normalize(CIFAR_MEAN, CIFAR_STD),
])
val_transforms = Compose([ToArray(), Normalize(CIFAR_MEAN, CIFAR_STD)])

# 加载训练和测试数据集
train_set = Cifar100(DATA_FILE, mode='train', transform=transforms)
test_set = Cifar100(DATA_FILE, mode='test', transform=val_transforms)

# 定义保存方式和训练可视化
checkpoint_callback = paddle.callbacks.ModelCheckpoint(save_freq=1, save_dir='output/ResNet50-NAM')
callbacks = [LRSchedulerM(),checkpoint_callback, callbacks.VisualDL('vis_logs/resnet50_nam.log')]

# 训练模型
model.fit(
    train_set,
    test_set,
    epochs=MAX_EPOCH, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    verbose=1, 
    callbacks=callbacks,
)

### 6.使用训练后的模型进行预测

In [None]:
models = paddle.Model(resnet50())
models.load('output/ResNet50-NAM/1.pdparams')
models.prepare()

result = models.evaluate(test_set, verbose=1)

print(result)

###  总结
本次复现任务是一个很简单的小项目，是百度AI达人项目的一个课题，仅需要复现NAM的结构并对ResNet50进行微小的修改，这也是本人第一次对论文进行复现。本次复现主要参考了github上他人的Pytorch复现代码，并改写为PaddlePaddle框架代码，特别感谢百度李老师的指导！