# Vision Transformer(ViT)模型原理及PyTorch逐行实现

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_28_Vision Transformer(ViT)模型原理及PyTorch逐行实现：

https://www.bilibili.com/video/BV1cS4y1M7wo/?spm_id_from=pageDriver&vd_source=18e91d849da09d846f771c89a366ed40

***论文***

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale：

https://arxiv.org/pdf/2010.11929.pdf

 ## Transformer 模型特点
 
 * 无先验假设 （例如：局部关联性、有序建模性）
 * 核心计算在于自注意力机制，平方复杂度
 * 数据量的要求与归纳偏置的引入成反比
 
 归纳偏置：人类通过归纳法所总结的经验，把经验带入到设计模型的过程之中

## Transformer 使用类型

* Encoder only： BERT、分类任务、非流式任务、**Vision Transformer（ViT）**
* Decoder only： GPT系列、语言建模、自回归生成任务、流式任务
* Encoder-Decoder： 机器翻译、语音识别

## Vision Transformer（ViT）

* DNN perspective
  * image2patch
  * patch2embedding
* CNN perspective
  * 2D convolution over image
  * flatten the output feature map
* class token embedding
* position embedding
  * interpolation when inference
* Transformer Encoder
* classification head

## 代码实现

***相关资料***

torch.nn.Unfold 官方文档：https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html

torch.tile 官方文档：https://pytorch.org/docs/stable/generated/torch.tile.html

torch.nn.TransformerEncoder 官方文档：https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html

## step1 convert image to embedding vector sequence

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def image2emb_naive(image, patch_size, weight):
    # image shape: bs*channel*h*w
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
    patch_embedding = patch @ weight
    return patch_embedding
    
def image2emb_conv(image, kernel, stride):
    conv_output = F.conv2d(image, kernel, stride=stride)  # bs*oc*oh*ow
    bs, oc, oh, ow = conv_output.shape
    patch_embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(-1, -2)
    return patch_embedding

# test code for image2emb
bs, ic, image_h, image_w = 1, 3, 8, 8
patch_size = 4
model_dim = 8
patch_depth = patch_size*patch_size*ic
image = torch.randn(bs, ic, image_h, image_w)
weight = torch.randn(patch_depth, model_dim)  # model_dim是输出通道数目，patch_depth是卷积核的面积乘以输入通道数

patch_embedding_naive = image2emb_naive(image, patch_size, weight)  # 分块方法得到embedding
kernel = weight.transpose(0, 1).reshape((-1, ic, patch_size, patch_size))  # oc*ic*kh*kw

patch_embedding_conv = image2emb_conv(image, kernel, patch_size)  # 二维卷积方法得到embedding
print(patch_embedding_naive)
print(patch_embedding_conv)

tensor([[[-5.3705,  5.9496, -3.7997,  1.7626, -6.0725,  6.5959,  3.7233,
          -3.2709],
         [ 9.6118,  4.9762, -0.4314,  4.9131,  1.5909,  1.2257,  8.3413,
          -7.1560],
         [10.8133, 11.2609,  3.9394,  0.3752, 10.3314,  3.6254, -6.2153,
           0.1659],
         [ 1.4727, -5.2586,  4.5698, -3.8365,  5.6321, -4.3481, -8.4019,
           4.7010]]])
tensor([[[-5.3705,  5.9496, -3.7997,  1.7626, -6.0725,  6.5959,  3.7233,
          -3.2709],
         [ 9.6118,  4.9762, -0.4314,  4.9131,  1.5909,  1.2257,  8.3413,
          -7.1560],
         [10.8133, 11.2609,  3.9394,  0.3752, 10.3314,  3.6254, -6.2153,
           0.1659],
         [ 1.4727, -5.2586,  4.5698, -3.8365,  5.6321, -4.3481, -8.4019,
           4.7010]]])


## step2 prepare CLS token embedding

In [11]:
cls_token_embedding = torch.randn(bs, 1, model_dim, requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)
print(token_embedding.shape)

torch.Size([1, 5, 8])


## step3 add position embedding

In [13]:
max_num_token = 16

position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)
seq_len = token_embedding.shape[1]
# ...
# 忽略mask
# ...
print(position_embedding_table[:seq_len].shape)
position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])
print(position_embedding.shape)

token_embedding += position_embedding

torch.Size([5, 8])
torch.Size([1, 5, 8])
