搭建vit模型，相关代码

In [None]:
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import math

from os.path import join as pjoin
import platform

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage

import models.configs as configs

from .modeling_resnet import ResNetV2


logger = logging.getLogger(__name__)

In [None]:
ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2" 

In [None]:
# numpy转torch
def np2th(weights, conv=False):
    # 如果conv，HWIO -> OIHW
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)

# 激活函数swish
def swish(x):
    return x * torch.sigmoid(x)

# 激活函数字典
ACT2FN = {
    "gelu": torch.nn.functional.gelu,
    "relu": torch.nn.functional.relu,
    "swish": swish,
}

通过模型原理介绍可知，VIT模型以transformer为基础，因此需要先搭建ffn，注意力机制等组件，再将其与图像预处理，编码嵌入层等拼接起来得到一个完整的vit模型

对于图像编码，本例中为VIT-B_16
卷积核大小16*16，步长为16

In [None]:
# Embeddings构建patch embedding和position embedding
class Embeddings(nn.Module):
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        # patch_size大小和n_patches个数
        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (
                img_size[0] // 16 // grid_size[0],
                img_size[1] // 16 // grid_size[1],
            )
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        #混合模型
        if self.hybrid:
            self.hybrid_model = ResNetV2(
                block_units=config.resnet.num_layers,
                width_factor=config.resnet.width_factor,
            )
            in_channels = self.hybrid_model.width * 16
        # 初始化patch embedding
        self.patch_embeddings = Conv2d(
            in_channels=in_channels,
            out_channels=config.hidden_size,
            kernel_size=patch_size,
            stride=patch_size,
        )
        # 初始化位置编码position_embeddings
        self.position_embeddings = nn.Parameter(
            torch.zeros(1, n_patches + 1, config.hidden_size)
        )
        # 第0个位置patch，表示分类特征
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.dropout = Dropout(config.transformer["dropout_rate"]) # dropout

    def forward(self, x):
        B = x.shape[0]
        # 拓展cls_token维度，16*1*768
        cls_tokens = self.cls_token.expand(B, -1, -1)

        if self.hybrid:
            x = self.hybrid_model(x)
        # patch embedding，16*768*14*14
        x = self.patch_embeddings(x)
        x = x.flatten(2)
        # 变换维度，16*196*768
        x = x.transpose(-1, -2)
        # 拼接cls_token，加入分类特征
        x = torch.cat((cls_tokens, x), dim=1)

        # 加入位置编码
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

In [None]:
# 采用多头注意力机制，对于采用的ViT-B_16，为12个
# 首先构建q,k,v三个辅助向量，将q,k,v维度(16, 197, 768)转换成(16, 12, 197, 64)
# 然后获得q,k的相似性qk，因为获得的是两两之间的关系，所以维度为(16, 12, 197, 197)
# 经过softmax后，消除量纲，得到提取到的特征向量qkv，维度为(16, 12, 197, 64)，再还原成(16, 197, 768)
class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis #是否可视化

        self.num_attention_heads = config.transformer["num_heads"] # 多头注意力机制的头数
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads) # 每个头的维度
        self.all_head_size = self.num_attention_heads * self.attention_head_size # q，k，a的总维度

        # 定义Wq，Wk，Wv，通过全连接网络生成
        self.query = Linear(config.hidden_size, self.all_head_size) 
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size) # 定义W0，通过全连接网络生成
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) # dropout
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) # dropout

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        # 将q，k，v的维度变换为(batch_size, num_attention_heads, seq_length, attention_head_size)
        # 16，197，768 -> 16，197，12，64
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.view(*new_x_shape)
        # 16，197，12，64 -> 16，12，197，64
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        # 构建q，k，v，维度为16, 197, 768
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        # 将q，k，v的维度变换为(batch_size, num_attention_heads, seq_length, attention_head_size)
        # 16，197，768 -> 16，12，197，64
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # 获得q，k的相似性，维度为16，12，197，197
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # 16，12，197，64 -> 16，197，768
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights


In [None]:
# MLP用于分类的层结构，由全连接+GELU激活函数+Dropout组成
class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [None]:
# Transformer的Block结构，由Attention，MLP，LayerNorm组成
class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)

    def forward(self, x):
        h = x
        # 层归一化
        x = self.attention_norm(x)
        # 多头注意力机制
        x, weights = self.attn(x)
        # 残差连接
        x = x + h

        h = x
        # 层归一化
        x = self.ffn_norm(x)
        # MLP
        x = self.ffn(x)
        # 残差连接
        x = x + h
        return x, weights

In [None]:
# Transformer的Encoder结构，由多个Block组成
# Block的数量由num_layers指定
class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights


In [None]:
# Transformer，由Embedding和Encoder组成
class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)
        return encoded, attn_weights


In [None]:
# VisionTransformer，由Transformer和Linear组成
# 对于任意x，进行patch embedding和positional embedding后，输入Encoder中，经过L层的编码，取出第一个token的输出，输入到分类层中，得到分类结果
class VisionTransformer(nn.Module):
    def __init__(
        self, config, img_size=224, num_classes=21843, zero_head=False, vis=False
    ):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier

        self.transformer = Transformer(config, img_size, vis)
        self.head = Linear(config.hidden_size, num_classes)

    def forward(self, x, labels=None):
        x, attn_weights = self.transformer(x)
        logits = self.head(x[:, 0])

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
            return loss
        else:
            return logits, attn_weights

    def load_from(self, weights):
        with torch.no_grad():
            if self.zero_head:
                nn.init.zeros_(self.head.weight)
                nn.init.zeros_(self.head.bias)
            else:
                self.head.weight.copy_(np2th(weights["head/kernel"]).t())
                self.head.bias.copy_(np2th(weights["head/bias"]).t())

            self.transformer.embeddings.patch_embeddings.weight.copy_(
                np2th(weights["embedding/kernel"], conv=True)
            )
            self.transformer.embeddings.patch_embeddings.bias.copy_(
                np2th(weights["embedding/bias"])
            )
            self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
            self.transformer.encoder.encoder_norm.weight.copy_(
                np2th(weights["Transformer/encoder_norm/scale"])
            )
            self.transformer.encoder.encoder_norm.bias.copy_(
                np2th(weights["Transformer/encoder_norm/bias"])
            )

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                logger.info(
                    "load_pretrained: resized variant: %s to %s"
                    % (posemb.size(), posemb_new.size())
                )
                ntok_new = posemb_new.size(1)

                if self.classifier == "token":
                    posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
                    ntok_new -= 1
                else:
                    posemb_tok, posemb_grid = posemb[:, :0], posemb[0]

                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print("load_pretrained: grid-size from %s to %s" % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            for bname, block in self.transformer.encoder.named_children():
                for uname, unit in block.named_children():
                    unit.load_from(weights, n_block=uname)

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(
                    np2th(weights["conv_root/kernel"], conv=True)
                )
                gn_weight = np2th(weights["gn_root/scale"]).view(-1)
                gn_bias = np2th(weights["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for (
                    bname,
                    block,
                ) in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(weights, n_block=bname, n_unit=uname)


In [None]:
# 配置文件
CONFIGS = {
    "ViT-B_16": configs.get_b16_config(),
    "ViT-B_32": configs.get_b32_config(),
    "ViT-L_16": configs.get_l16_config(),
    "ViT-L_32": configs.get_l32_config(),
    "ViT-H_14": configs.get_h14_config(),
    "R50-ViT-B_16": configs.get_r50_b16_config(),
    "testing": configs.get_testing(),
}
# config属性
# patches：patch的大小
# hidden_size：隐藏层大小
# transformer：transformer的配置
# transformer.mlp_dim：mlp的大小
# transformer.num_heads：头的数量
# transformer.num_layers：层数
# transformer.attention_dropout_rate：attention的dropout
# transformer.dropout_rate：dropout
# classifier：分类器
# representation_size：representation的大小
#
# vis：可视化