# Vision Transformer(ViT)模型原理及PyTorch逐行实现

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_28_Vision Transformer(ViT)模型原理及PyTorch逐行实现：

https://www.bilibili.com/video/BV1cS4y1M7wo/?spm_id_from=pageDriver&vd_source=18e91d849da09d846f771c89a366ed40

***论文***

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale：

https://arxiv.org/pdf/2010.11929.pdf

 ## Transformer 模型特点
 
 * 无先验假设 （例如：局部关联性、有序建模性）
 * 核心计算在于自注意力机制，平方复杂度
 * 数据量的要求与归纳偏置的引入成反比
 
 归纳偏置：人类通过归纳法所总结的经验，把经验带入到设计模型的过程之中

## Transformer 使用类型

* Encoder only： BERT、分类任务、非流式任务、**Vision Transformer（ViT）**
* Decoder only： GPT系列、语言建模、自回归生成任务、流式任务
* Encoder-Decoder： 机器翻译、语音识别

## Vision Transformer（ViT）

* DNN perspective
  * image2patch
  * patch2embedding
* CNN perspective
  * 2D convolution over image
  * flatten the output feature map
* class token embedding
* position embedding
  * interpolation when inference
* Transformer Encoder
* classification head

## 代码实现

***相关资料***

torch.nn.Unfold 官方文档：https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html

torch.tile 官方文档：https://pytorch.org/docs/stable/generated/torch.tile.html

torch.nn.TransformerEncoder 官方文档：https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html

## step1 convert image to embedding vector sequence

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def image2emb_naive(image, patch_size, weight):
    # image shape: bs*channel*h*w
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
    patch_embedding = patch @ weight
    return patch_embedding
    
def image2emb_conv(image, kernel, stride):
    conv_output = F.conv2d(image, kernel, stride=stride)  # bs*oc*oh*ow
    bs, oc, oh, ow = conv_output.shape
    patch_embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(-1, -2)
    return patch_embedding

# test code for image2emb
bs, ic, image_h, image_w = 1, 3, 8, 8
patch_size = 4
model_dim = 8
patch_depth = patch_size*patch_size*ic
image = torch.randn(bs, ic, image_h, image_w)
weight = torch.randn(patch_depth, model_dim)  # model_dim是输出通道数目，patch_depth是卷积核的面积乘以输入通道数

patch_embedding_naive = image2emb_naive(image, patch_size, weight)  # 分块方法得到embedding
kernel = weight.transpose(0, 1).reshape((-1, ic, patch_size, patch_size))  # oc*ic*kh*kw

patch_embedding_conv = image2emb_conv(image, kernel, patch_size)  # 二维卷积方法得到embedding
print(patch_embedding_naive)
print(patch_embedding_conv)

tensor([[[ -7.2940,  11.0761,  -8.8822,  -5.9179,  -1.0398,   0.8863,  -9.5755,
            8.2882],
         [ -4.0738,   9.6549,  -5.3832,  -4.6684, -16.4855,  -2.5323, -17.7172,
            0.0493],
         [ -7.7384,   7.3792,   4.9279,   2.4082,   2.2541,   5.4289,   4.0024,
           -7.1342],
         [  3.4040,  -0.5489,  11.0364,  -8.9861,   8.5651,   3.1418,   3.6907,
            5.6323]]])
tensor([[[ -7.2940,  11.0760,  -8.8822,  -5.9179,  -1.0398,   0.8863,  -9.5755,
            8.2882],
         [ -4.0738,   9.6549,  -5.3832,  -4.6684, -16.4855,  -2.5323, -17.7172,
            0.0493],
         [ -7.7384,   7.3792,   4.9279,   2.4082,   2.2541,   5.4289,   4.0024,
           -7.1342],
         [  3.4040,  -0.5489,  11.0364,  -8.9861,   8.5651,   3.1418,   3.6907,
            5.6323]]])


## step2 prepare CLS token embedding

In [3]:
cls_token_embedding = torch.randn(bs, 1, model_dim, requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)
print(token_embedding.shape)

torch.Size([1, 5, 8])


## step3 add position embedding

In [4]:
max_num_token = 16

position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)
seq_len = token_embedding.shape[1]
# ...
# 忽略mask
# ...
print(position_embedding_table[:seq_len].shape)
position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])
print(position_embedding.shape)

token_embedding += position_embedding

torch.Size([5, 8])
torch.Size([1, 5, 8])


## step4 pass embedding to Transformer Encoder

In [5]:
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
encoder_output = transformer_encoder(token_embedding)

## step5 do clssification

In [6]:
num_classes = 10
label = torch.randint(10, (bs,))

cls_token_output = encoder_output[:, 0, :]
linear_layer = nn.Linear(model_dim, num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, label)
print(loss)

tensor(3.0347, grad_fn=<NllLossBackward0>)
