# Vision Transformer
> 尝试自己实现 Transformer 的架构，并提升自己的代码能力  
> 参考本仓库：D，当然也可以自己先试一下，有不明白的地方再去研究  
> 不对，有一个问题，那为什么我不直接去尝试实现 ViT 呢？都是 Transfomer 架构，并且还更贴近 CV，是我所需要的  
> 好吧，转换目标，手撕 ViT

![ViT](./ViT.jpg)


# 1. 预处理部分
分为 patch 划分，线性映射以及位置嵌入

In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


class pre_proces(nn.Module):
    def __init__(self, image_size, patch_size, patch_dim, dim):
        super().__init__()
        self.patch_size = patch_size
        self.dim = dim
        self.patch_num = (image_size//patch_size)**2
        self.linear_embedding = nn.Linear(patch_dim, dim)
        self.position_embedding = nn.Parameter(torch.randn(1, self.patch_num+1, self.dim))  # 使用广播
        self.CLS_token = nn.Parameter(torch.randn(1, 1, self.dim))  # 别忘了维度要和 (B,L,C) 对齐

    def forward(self, x):
        x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)  # (B,L,C)
        x = self.linear_embedding(x)
        b, l, c = x.shape   # 获取 token 的形状 (B,L,c)
        CLS_token = repeat(self.CLS_token, '1 1 d -> b 1 d', b=b)  # 位置编码复制 B 份
        x = torch.concat((CLS_token, x), dim=1)
        x = x+self.position_embedding
        return x


## 1.1 验证

In [18]:
p = pre_proces(128, 16, 768, 768)
input = torch.randn(1, 3, 128, 128)
p(input)


tensor([[[ 0.7700,  0.6970, -0.8516,  ...,  1.6258,  1.0387,  1.2238],
         [-1.7675, -0.4237,  0.3249,  ..., -0.3300,  1.7388, -0.6130],
         [ 1.3180,  1.3085,  0.8253,  ..., -0.0889,  0.6063,  1.9750],
         ...,
         [-1.1643,  0.4448,  0.3943,  ...,  0.7199, -0.1004, -1.0573],
         [-0.3817, -1.6991, -2.5210,  ..., -0.1832,  1.2500,  0.4725],
         [ 1.2402,  1.4359,  0.1627,  ..., -0.4160, -0.3571, -0.1215]]],
       grad_fn=<AddBackward0>)

# 2. Transformer Block
接下来要构建每一个 Transformer block 了，可以先从一个小块开始

## 2.1 MultiHead self Attention
构建自注意力层

In [54]:
class Multihead_self_attention(nn.Module):
    def __init__(self, heads, head_dim, dim):
        super().__init__()
        self.head_dim = head_dim    # 每一个注意力头的维度
        self.heads = heads  # 注意力头个数
        self.inner_dim = self.heads*self.head_dim  # 多头自注意力最后的输出维度
        self.scale = self.head_dim**-0.5   # 正则化系数
        self.to_qkv = nn.Linear(dim, self.inner_dim*3)  # 生成 qkv，每一个矩阵的维度和由自注意力头的维度以及头的个数决定
        self.to_output = nn.Linear(self.inner_dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.norm(x)    # PreNorm
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 划分 QKV，返回一个列表，其中就包含了 QKV
        Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)
        K_T = K.transpose(-1, -2)
        att_score = Q@K_T*self.scale
        att = self.softmax(att_score)
        out = att@V   # (B,H,L,dim)
        out = rearrange(out, 'b h l dim -> b l (h dim)')  # 拼接
        output = self.to_output(out)
        return output


### 2.1.1 测试自注意力层

In [53]:
MHA = Multihead_self_attention(heads=8, head_dim=64, dim=768)
input = torch.randn(1, 4, 768)
MHA(input)


tensor([[[ 0.1356,  0.0105,  0.5154,  ...,  0.1114, -0.0470,  0.3496],
         [ 0.0756, -0.1498,  0.5160,  ...,  0.0313, -0.0332,  0.2123],
         [ 0.1368, -0.0630,  0.4745,  ...,  0.0991, -0.0957,  0.3188],
         [ 0.1227, -0.0846,  0.5227,  ...,  0.0340, -0.0341,  0.2845]]],
       grad_fn=<ViewBackward0>)

## 2.2 MLP
构建自注意力层后面的 FeedForward 模块

In [55]:
class FeedForward(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x


### 2.2.1 测试 MLP 

In [43]:
ff = FeedForward(768, 1024)
x = torch.randn(1, 4, 768)
ff(x).shape


torch.Size([1, 4, 768])

# 2.3 Transformer block
建立残差连接，构建 Transformer

In [56]:
class Transformer_block(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim):
        super().__init__()
        self.MHA = Multihead_self_attention(heads=heads, head_dim=head_dim, dim=dim)
        self.FeedForward = FeedForward(dim=dim, mlp_dim=mlp_dim)

    def forward(self, x):
        x = self.MHA(x)+x
        x = self.FeedForward(x)+x
        return x


### 2.3.1 测试 Transformer Block

In [58]:
transformer_block = Transformer_block(768,8,64,1024)
x=torch.randn(1,4,768)
x=transformer_block(x)
x.shape

torch.Size([1, 4, 768])

# 3. 组装 ViT
开始组装 ViT，将上面的各个模块进行整合

## 3.1 Transformer
组成 ViT 的主体部分，也就是 Transformer


In [59]:
class Transformer(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, depth):
        super().__init__()
        self.layers=nn.ModuleList([])
        for i in range(depth):
            self.layers.append(Transformer_block(dim=dim,heads=heads,head_dim=head_dim,mlp_dim=mlp_dim))
    def forward(self,x):
        for layer in self.layers:
            x=layer(x)
        return x



## 3.1.1 验证组合完成的 Transformer

In [65]:
transformer = Transformer(768,8,64,1024,6)
transformer(x)

tensor([[[ 0.9251,  0.8963,  0.5717,  ..., -1.4688, -0.6899,  2.5642],
         [-0.7377,  1.8619,  1.9917,  ..., -0.4586, -0.7757,  0.4690],
         [-0.9505, -1.5973,  1.3902,  ..., -0.8011,  0.7190, -0.2873],
         [ 1.0755,  1.0339,  2.0910,  ..., -1.0413, -2.8090,  1.7156]]],
       grad_fn=<AddBackward0>)

## 3.2 ViT
构建 ViT

In [68]:
class ViT(nn.Module):
    def __init__(self, image_size, channels, patch_size, dim, heads, head_dim, mlp_dim, depth, num_class):
        super().__init__()
        self.to_patch_embedding = pre_proces(image_size=image_size, patch_size=patch_size, patch_dim=channels*patch_size**2, dim=dim)
        self.transformer = Transformer(dim=dim, heads=heads, head_dim=head_dim, mlp_dim=mlp_dim, depth=depth)
        self.MLP_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_class)
        )
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        token = self.to_patch_embedding(x)
        output = self.transformer(token)
        CLS_token = output[:, 0, :]
        out = self.softmax(self.MLP_head(CLS_token))
        return out


### 3.2.1 测试 ViT

In [70]:
vit = ViT(image_size=64,channels=3,patch_size=16,dim=768,heads=8,head_dim=64,mlp_dim=1024,depth=6,num_class=4)
x=torch.randn(1,3,64,64)
vit(x)

tensor([[0.1439, 0.2126, 0.4134, 0.2301]], grad_fn=<SoftmaxBackward0>)