# MetaFormer

论文地址：[MetaFormer is Actually What You Need for Vision](https://arxiv.org/abs/2111.11418)

## 简介
Transformer已经证明在计算机视觉任务中有非常大的潜力，一种普遍的看法是基于attention的**token mixer**模块使transformer具有竞争力。但是将attention用**spatial MLP**替代后，模型仍然具有非常好的效果。那么是不是**transformer的结构**而不是attention使其有效呢？作者使用**池化层**代替transformer中的**attention**，构建了**PoolFormer**模型，取得了非常好的效果，ImageNet-1k准确率达到82.1%。证明了Transformer结构的有效性，而非attention。

本文提出**MetaFormer**：一种从Transformer中抽象出来的**通用架构**，没有指定token mixer，并提出PoolFormer基线在分类、检测和分割任务上进行验证。本次复现在**分类任务**上进行验证实验。各种模型的对比如下图：

![](https://ai-studio-static-online.cdn.bcebos.com/e084fdecc43c4b989783b938f85fe76163d7a9aa69304ef2b6a197a4d6adb43c)

PoolFormer的网络结构非常简单，只需要把Transformer的Attention模块换成Pooling就可以：

![](https://ai-studio-static-online.cdn.bcebos.com/e8eb1384155f42fab8642779150f8fc67d7b2413673e4e2481fe358690a17d69)

不同的Pooling模块可以有不同的配置：

![](https://ai-studio-static-online.cdn.bcebos.com/4a3ef9fc51a844f8a0485da4808ac1b5cf199f56caf4420b87953a6052ad523e)

针对PoolFormer的复现在AiStudio中已经存在，本复现针对**MetaFormer**，并完整复现不同大小网络的MetaFormer

## Cifar10数据集

链接：http://www.cs.toronto.edu/~kriz/cifar.html

![](https://ai-studio-static-online.cdn.bcebos.com/15a8790a113d41418d6fc8563aeb4acd10da73b4b8c6488599fa9e7a01cc0833)

**CIFAR-10**是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB彩色图片：**飞机**(airplane)、**汽车**(automobile)、**鸟类**(bird)、**猫**(cat)、**鹿**(deer)、**狗**(dog)、**蛙类**(frog)、**马**(horse)、**船**(ship)和**卡车**(truck).

每个图片的尺寸为 $32\times 32$，每个类别有**6000**个图像，数据集中一共有**50000**张训练图片和**10000**张测试图片。


## 代码复现

### 1.引入依赖包

In [None]:
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn
from paddle.nn import functional as F
from paddle.utils.download import get_weights_path_from_url
from paddle import callbacks
from paddle.vision.datasets import Cifar10
from paddle.io import DataLoader
from paddle.optimizer.lr import CosineAnnealingDecay, MultiStepDecay, LinearWarmup
from paddle.vision.transforms import (
    ToTensor, RandomHorizontalFlip, RandomResizedCrop, SaturationTransform, Compose,
    HueTransform, BrightnessTransform, ContrastTransform, RandomCrop, Normalize, RandomRotation, Resize
)
import pickle
import numpy as np
import random
from functools import partial, reduce
from typing import Sequence

trunc_normal_ = nn.initializer.TruncatedNormal(std=0.02)
zeros_ = nn.initializer.Constant(value=0.0)
ones_ = nn.initializer.Constant(value=1.0)

def drop_path(x, drop_prob=0.0, training=False):

    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = paddle.to_tensor(keep_prob) + paddle.rand(shape)
    random_tensor = paddle.floor(random_tensor)
    output = x.divide(keep_prob) * random_tensor
    return output

class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

### 2.定义两种Emb方式

In [None]:
class AddPositionEmb(nn.Layer):
    """Module to add position embedding to input features
    """
    def __init__(self, dim=384, 
                spatial_shape=[14, 14]):
        super().__init__()
        if isinstance(spatial_shape, int):
            spatial_shape = [spatial_shape]
        assert isinstance(spatial_shape, Sequence), \
            f'"spatial_shape" must by a sequence or int, ' \
            f'get {type(spatial_shape)} instead.'
        if len(spatial_shape) == 1:
            embed_shape = list(spatial_shape) + [dim]
        else:
            embed_shape = [dim] + list(spatial_shape)
        self.pos_embed = nn.Parameter(torch.zeros(1, *embed_shape))
        paddle.create_parameter(
            shape=[embed_shape],
            dtype='float32',
            default_initializer=ones_)
    def forward(self, x):
        return x+self.pos_embed

class PatchEmbed(nn.Layer):
    """
    Patch Embedding that is implemented by a layer of conv. 
    Input: tensor in shape [B, C, H, W]
    Output: tensor in shape [B, C, H/stride, W/stride]
    """
    def __init__(self, patch_size=16, stride=16, padding=0, 
                 in_chans=3, embed_dim=768, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        stride = (stride, stride)
        padding = (padding, padding)
        self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, 
                              stride=stride, padding=padding)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        x = self.proj(x)
        x = self.norm(x)
        return x

### 3.定义Pooling
这是本次复现的**核心**，但是它的实现代码**相当简单**

In [None]:
class Pooling(nn.Layer):
    """
    Implementation of pooling for PoolFormer
    --pool_size: pooling size
    """
    def __init__(self, kernel_size=3):
        super().__init__()
        self.pool = nn.AvgPool2D(
            kernel_size, stride=1, padding=kernel_size//2, exclusive=True)

    def forward(self, x):
        return self.pool(x) - x

### 4.定义Attention机制

In [None]:
class Attention(nn.Layer):
    """Attention module that can take tensor with [B, N, C] or [B, C, H, W] as input.
    Modified from: 
    https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
    """
    def __init__(self, dim, head_dim=32, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % head_dim == 0, 'dim should be divisible by head_dim'
        self.head_dim = head_dim
        self.num_heads = dim // head_dim
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        shape = x.shape
        if len(shape) == 4:
            B, C, H, W = shape
            N = H * W
            x = paddle.flatten(x, start_axis=2).transpose(-2, -1) # (B, N, C)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        # trick here to make q@k.t more stable
        attn = (q * self.scale) @ k.transpose(-2, -1)
        # attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        if len(shape) == 4:
            x = x.transpose(-2, -1).reshape(B, C, H, W)

        return x


### 5.定义LayerNorm和GN

In [None]:
class LayerNormChannel(nn.Layer):
    """
    LayerNorm only for Channel Dimension.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, eps=1e-05):
        super().__init__()
        self.weight = paddle.create_parameter(
            shape=[num_channels],
            dtype='float32',
            default_initializer=ones_)
        self.bias = paddle.create_parameter(
            shape=[num_channels],
            dtype='float32',
            default_initializer=zeros_)
        self.epsilon = eps

    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / paddle.sqrt(s + self.eps)
        x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
            + self.bias.unsqueeze(-1).unsqueeze(-1)
        return x


class GroupNorm(nn.GroupNorm):
    """
    Group Normalization with 1 group.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, **kwargs):
        super().__init__(1, num_channels, **kwargs)

### 6.定义SpatialFc

In [None]:
class SpatialFc(nn.Layer):
    """SpatialFc module that take features with shape of (B,C,*) as input.
    """
    def __init__(
        self, spatial_shape=[14, 14], **kwargs, 
        ):
        super().__init__()
        if isinstance(spatial_shape, int):
            spatial_shape = [spatial_shape]
        assert isinstance(spatial_shape, Sequence), \
            f'"spatial_shape" must by a sequence or int, ' \
            f'get {type(spatial_shape)} instead.'
        N = reduce(lambda x, y: x * y, spatial_shape)
        self.fc = nn.Linear(N, N, bias_attr=False)

    def forward(self, x):
        # input shape like [B, C, H, W]
        shape = x.shape
        x = paddle.flatten(x, start_axis=2) # [B, C, H*W]
        x = self.fc(x) # [B, C, H*W]
        x = paddle.reshape(x, shape) # [B, C, H, W]
        return x

### 7.定义MLP
**注：与ViT版本有所区别**

In [None]:
class Mlp(nn.Layer):
    """
    Implementation of MLP with 1*1 convolutions.
    Input: tensor with shape [B, C, H, W]
    """
    def __init__(self, in_features, hidden_features=None, 
                 out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2D(in_features, hidden_features, 1)
        self.act = act_layer()
        self.fc2 = nn.Conv2D(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2D):
            trunc_normal_(m.weight)
            if m.bias is not None:
                zeros_(m.bias)

    def forward(self, x):
        x = self.fc1(x)     # (B, C, H, W) --> (B, C, H, W)
        x = self.act(x)     
        x = self.drop(x)
        x = self.fc2(x)     # (B, C, H, W) --> (B, C, H, W)
        x = self.drop(x)
        return x            

### 8.开始组装，定义MetaFormerBlock

In [None]:
class MetaFormerBlock(nn.Layer):
    """
    Implementation of one MetaFormer block.
    --dim: embedding dim
    --token_mixer: token mixer module
    --mlp_ratio: mlp expansion ratio
    --act_layer: activation
    --norm_layer: normalization
    --drop: dropout rate
    --drop path: Stochastic Depth, 
        refer to https://arxiv.org/abs/1603.09382
    --use_layer_scale, --layer_scale_init_value: LayerScale, 
        refer to https://arxiv.org/abs/2103.17239
    """
    def __init__(self, dim, 
                 token_mixer=nn.Identity(), 
                 mlp_ratio=4., 
                 act_layer=nn.GELU, norm_layer=LayerNormChannel, 
                 drop=0., drop_path=0., 
                 use_layer_scale=True, layer_scale_init_value=1e-5):

        super().__init__()

        self.norm1 = norm_layer(dim)
        self.token_mixer = token_mixer()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 
                       act_layer=act_layer, drop=drop)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.drop_path = DropPath(drop_path) if drop_path > 0. \
            else nn.Identity()
        self.use_layer_scale = use_layer_scale
        if use_layer_scale:
            # self.layer_scale_1 = nn.Parameter(
            #     layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.layer_scale_1 = paddle.create_parameter(
                                    shape=[dim],
                                    dtype='float32',
                                    default_initializer=nn.initializer.Constant(value=layer_scale_init_value))
            # self.layer_scale_2 = nn.Parameter(
            #     layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            self.layer_scale_2 = paddle.create_parameter(
                                    shape=[dim],
                                    dtype='float32',
                                    default_initializer=nn.initializer.Constant(value=layer_scale_init_value))

        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()


    def forward(self, x):
        if self.use_layer_scale:
            x = x + self.drop_path(
                self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
                * self.token_mixer(self.norm1(x)))
            x = x + self.drop_path(
                self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
                * self.mlp(self.norm2(x)))
        else:
            x = x + self.drop_path(self.token_mixer(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


### 9.组合basic block

In [None]:
def basic_blocks(dim, index, layers, token_mixer=nn.Identity(), 
                 mlp_ratio=4., 
                 act_layer=nn.GELU, norm_layer=LayerNormChannel, 
                 drop_rate=.0, drop_path_rate=0., 
                 use_layer_scale=True, layer_scale_init_value=1e-5):
    """
    generate PoolFormer blocks for a stage
    return: PoolFormer blocks 
    """
    blocks = []
    for block_idx in range(layers[index]):
        block_dpr = drop_path_rate * (
            block_idx + sum(layers[:index])) / (sum(layers) - 1)
        blocks.append(MetaFormerBlock(
            dim, token_mixer=token_mixer, mlp_ratio=mlp_ratio, 
            act_layer=act_layer, norm_layer=norm_layer, 
            drop=drop_rate, drop_path=block_dpr, 
            use_layer_scale=use_layer_scale, 
            layer_scale_init_value=layer_scale_init_value, 
            ))
    blocks = nn.Sequential(*blocks)

    return blocks

### 9.Meta横空出世

In [None]:
class MetaFormer(nn.Layer):
    """
    MetaFormer, the main class of our model
    --layers: [x,x,x,x], number of blocks for the 4 stages
    --embed_dims, --mlp_ratios: the embedding dims and mlp ratios for the 4 stages
    --token_mixers: token mixers of different stages
    --norm_layer, --act_layer: define the types of normalization and activation
    --num_classes: number of classes for the image classification
    --in_patch_size, --in_stride, --in_pad: specify the patch embedding
        for the input image
    --down_patch_size --down_stride --down_pad: 
        specify the downsample (patch embed.)
    --add_pos_embs: position embedding modules of different stages
    """
    def __init__(self, layers, embed_dims=None, 
                 token_mixers=None, mlp_ratios=None, 
                 norm_layer=LayerNormChannel, act_layer=nn.GELU, 
                 num_classes=1000,
                 in_patch_size=7, in_stride=4, in_pad=2, 
                 downsamples=None, down_patch_size=3, down_stride=2, down_pad=1, 
                 add_pos_embs=None, 
                 drop_rate=0., drop_path_rate=0.,
                 use_layer_scale=True, layer_scale_init_value=1e-5, 
                 **kwargs):

        super().__init__()


        self.num_classes = num_classes

        self.patch_embed = PatchEmbed(
            patch_size=in_patch_size, stride=in_stride, padding=in_pad, 
            in_chans=3, embed_dim=embed_dims[0])
        if add_pos_embs is None:
            add_pos_embs = [None] * len(layers)
        if token_mixers is None:
            token_mixers = [nn.Identity()] * len(layers)
        # set the main block in network
        network = []
        for i in range(len(layers)):
            if add_pos_embs[i] is not None:
                network.append(add_pos_embs[i](embed_dims[i]))
            stage = basic_blocks(embed_dims[i], i, layers, 
                                 token_mixer=token_mixers[i], mlp_ratio=mlp_ratios[i],
                                 act_layer=act_layer, norm_layer=norm_layer, 
                                 drop_rate=drop_rate, 
                                 drop_path_rate=drop_path_rate,
                                 use_layer_scale=use_layer_scale, 
                                 layer_scale_init_value=layer_scale_init_value)
            network.append(stage)
            if i >= len(layers) - 1:
                break
            if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
                # downsampling between two stages
                network.append(
                    PatchEmbed(
                        patch_size=down_patch_size, stride=down_stride, 
                        padding=down_pad, 
                        in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
                        )
                    )

        self.network = nn.LayerList(network)
        self.norm = norm_layer(embed_dims[-1])
        self.head = nn.Linear(
            embed_dims[-1], num_classes) if num_classes > 0 \
            else nn.Identity()

        self.apply(self.cls_init_weights)

    # init for classification
    # def cls_init_weights(self, m):
    #     if isinstance(m, nn.Linear):
    #         trunc_normal_(m.weight)
    #         if isinstance(m, nn.Linear) and m.bias is not None:
    #             nn.init.constant_(m.bias, 0)

    def cls_init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes):
        self.num_classes = num_classes
        self.head = nn.Linear(
            self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_embeddings(self, x):
        x = self.patch_embed(x)
        return x

    def forward_tokens(self, x):
        for idx, block in enumerate(self.network):
            x = block(x)
        return x

    def forward(self, x):
        # input embedding
        x = self.forward_embeddings(x)
        # through backbone
        x = self.forward_tokens(x)
        x = self.norm(x)
        # for image classification
        cls_out = self.head(x.mean([-2, -1]))
        return cls_out

### 10.几种不同大小的MetaFormer

In [None]:
def metaformer_id_s12(pretrained=False, **kwargs):
    layers = [2, 2, 6, 2]
    embed_dims = [64, 128, 320, 512]
    token_mixers = [nn.Identity()] * len(layers)
    mlp_ratios = [4, 4, 4, 4]
    downsamples = [True, True, True, True]
    model = MetaFormer(
        layers, embed_dims=embed_dims,
        token_mixers=token_mixers,
        mlp_ratios=mlp_ratios,
        norm_layer=GroupNorm,
        downsamples=downsamples,
        **kwargs)
    # model.default_cfg = _cfg(crop_pct=0.9)
    return model

def metaformer_pppa_s12_224(pretrained=False, **kwargs):
    layers = [2, 2, 6, 2]
    embed_dims = [64, 128, 320, 512]
    add_pos_embs = [None, None, None,
        partial(AddPositionEmb, spatial_shape=[7, 7])]
    token_mixers = [Pooling, Pooling, Pooling, Attention]
    mlp_ratios = [4, 4, 4, 4]
    downsamples = [True, True, True, True]
    model = MetaFormer(
        layers, embed_dims=embed_dims,
        token_mixers=token_mixers,
        mlp_ratios=mlp_ratios,
        downsamples=downsamples,
        add_pos_embs=add_pos_embs,
        **kwargs)
    # model.default_cfg = _cfg()
    return model

def metaformer_ppaa_s12_224(pretrained=False, **kwargs):
    layers = [2, 2, 6, 2]
    embed_dims = [64, 128, 320, 512]
    add_pos_embs = [None, None, 
        partial(AddPositionEmb, spatial_shape=[14, 14]), None]
    token_mixers = [Pooling, Pooling, Attention, Attention]
    mlp_ratios = [4, 4, 4, 4]
    downsamples = [True, True, True, True]
    model = MetaFormer(
        layers, embed_dims=embed_dims,
        token_mixers=token_mixers,
        mlp_ratios=mlp_ratios,
        downsamples=downsamples,
        add_pos_embs=add_pos_embs,
        **kwargs)
    # model.default_cfg = _cfg()
    return model

def metaformer_pppf_s12_224(pretrained=False, **kwargs):
    layers = [2, 2, 6, 2]
    embed_dims = [64, 128, 320, 512]
    token_mixers = [Pooling, Pooling, Pooling,
        partial(SpatialFc, spatial_shape=[7, 7]),
        ]
    mlp_ratios = [4, 4, 4, 4]
    downsamples = [True, True, True, True]
    model = MetaFormer(
        layers, embed_dims=embed_dims,
        token_mixers=token_mixers,
        mlp_ratios=mlp_ratios,
        norm_layer=GroupNorm,
        downsamples=downsamples,
        **kwargs)
    # model.default_cfg = _cfg(crop_pct=0.9)
    return model

def metaformer_ppff_s12_224(pretrained=False, **kwargs):
    layers = [2, 2, 6, 2]
    embed_dims = [64, 128, 320, 512]
    token_mixers = [Pooling, Pooling, 
        partial(SpatialFc, spatial_shape=[14, 14]), 
        partial(SpatialFc, spatial_shape=[7, 7]),
        ]
    mlp_ratios = [4, 4, 4, 4]
    downsamples = [True, True, True, True]
    model = MetaFormer(
        layers, embed_dims=embed_dims,
        token_mixers=token_mixers,
        mlp_ratios=mlp_ratios,
        norm_layer=GroupNorm,
        downsamples=downsamples,
        **kwargs)
    # model.default_cfg = _cfg()
    return model

In [13]:
net = metaformer_pppf_s12_224()
paddle.summary(net, (1,3,224,224))

------------------------------------------------------------------------------
   Layer (type)        Input Shape          Output Shape         Param #    
    Conv2D-29       [[1, 3, 224, 224]]    [1, 64, 56, 56]         9,472     
   Identity-43      [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
   PatchEmbed-5     [[1, 3, 224, 224]]    [1, 64, 56, 56]           0       
   GroupNorm-26     [[1, 64, 56, 56]]     [1, 64, 56, 56]          128      
   AvgPool2D-11     [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
    Pooling-11      [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
   Identity-45      [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
   GroupNorm-27     [[1, 64, 56, 56]]     [1, 64, 56, 56]          128      
    Conv2D-30       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,640     
     GELU-13        [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
    Dropout-13      [[1, 64, 56, 56]]     [1, 64, 56, 56]           0     

{'total_params': 11919978, 'trainable_params': 11919978}

### 11.定义数据集处理

In [None]:
class ToArray(object):
    def __call__(self, img):
        img = np.array(img)
        img = np.transpose(img, [2, 0, 1])
        img = img / 255.
        return img.astype('float32')

class RandomApply(object):
    def __init__(self, transform, p=0.5):
        super().__init__()
        self.p = p
        self.transform = transform
        

    def __call__(self, img):
        if self.p < random.random():
            return img
        img = self.transform(img)
        return img
                                                                                                                    
class LRSchedulerM(callbacks.LRScheduler):                                                                                                           
    def __init__(self, by_step=False, by_epoch=True, warm_up=True):                                                                                                
        super().__init__(by_step, by_epoch)                                                                                                                          
        assert by_step ^ warm_up
        self.warm_up = warm_up
        
    def on_epoch_end(self, epoch, logs=None):
        if self.by_epoch and not self.warm_up:
            if self.model._optimizer and hasattr(
                self.model._optimizer, '_learning_rate') and isinstance(
                    self.model._optimizer._learning_rate, paddle.optimizer.lr.LRScheduler):                                                                                         
                self.model._optimizer._learning_rate.step()                                                                                          
                                                                                                                                                     
    def on_train_batch_end(self, step, logs=None):                                                                                                   
        if self.by_step or self.warm_up:                                                                                                                             
            if self.model._optimizer and hasattr(
                self.model._optimizer, '_learning_rate') and isinstance(
                    self.model._optimizer._learning_rate, paddle.optimizer.lr.LRScheduler):                                                                                         
                self.model._optimizer._learning_rate.step()
            if self.model._optimizer._learning_rate.last_epoch >= self.model._optimizer._learning_rate.warmup_steps:
                self.warm_up = False

def _on_train_batch_end(self, step, logs=None):
    logs = logs or {}
    logs['lr'] = self.model._optimizer.get_lr()
    self.train_step += 1
    if self._is_write():
        self._updates(logs, 'train')

def _on_train_begin(self, logs=None):
    self.epochs = self.params['epochs']
    assert self.epochs
    self.train_metrics = self.params['metrics'] + ['lr']
    assert self.train_metrics
    self._is_fit = True
    self.train_step = 0

callbacks.VisualDL.on_train_batch_end = _on_train_batch_end
callbacks.VisualDL.on_train_begin = _on_train_begin

### 12.训练模型

In [None]:
model = paddle.Model(metaformer_pppf_s12_224(num_classes=10))
# 加载checkpoint
# model.load('output/metaformer_pppf_s12_224/299.pdparams')
MAX_EPOCH = 300
LR = 0.001
WEIGHT_DECAY = 5e-4
MOMENTUM = 0.9
BATCH_SIZE = 12
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
CIFAR_STD = [0.1942, 0.1918, 0.1958]
DATA_FILE = None

model.prepare(
    paddle.optimizer.Momentum(
        learning_rate=LinearWarmup(CosineAnnealingDecay(LR, MAX_EPOCH), 2000, 0., LR),
        momentum=MOMENTUM,
        parameters=model.parameters(),
        weight_decay=WEIGHT_DECAY),
    paddle.nn.CrossEntropyLoss(),
    paddle.metric.Accuracy(topk=(1,5)))

# 定义数据集增强方式
transforms = Compose([
    Resize(size=224),
    ToArray(),
    Normalize(CIFAR_MEAN, CIFAR_STD),
])
val_transforms = Compose([Resize(size=224), ToArray(), Normalize(CIFAR_MEAN, CIFAR_STD)])

# 加载训练和测试数据集
train_set = Cifar10(DATA_FILE, mode='train', transform=transforms)
test_set = Cifar10(DATA_FILE, mode='test', transform=val_transforms)

# 定义保存方式和训练可视化
checkpoint_callback = paddle.callbacks.ModelCheckpoint(save_freq=1, save_dir='output/metaformer_pppf_s12_224')
callbacks = [LRSchedulerM(),checkpoint_callback, callbacks.VisualDL('vis_logs/metaformer_pppf_s12_224.log')]

# 训练模型
model.fit(
    train_set,
    test_set,
    epochs=MAX_EPOCH, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    verbose=1, 
    callbacks=callbacks,
)

## 总结
最近复现了几篇Transformer领域的论文，非常感谢百度李老师的指导，复现过程比自己以为的要更加轻松，通过复现这种相对比较简单的网络结构，让我能够深入理解网络的运行过程，对Transformer的理解更加深入。

请点击[此处](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576)查看本环境基本用法.  <br>
Please click [here ](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576) for more detailed instructions. 