# 1. 基础概念

在nlp里面，cls是对于句子首位置的一个特殊token,在bert里用来做句子的分类信息，在此也用作图像的分类信息。
![img](../pic/VIT.png)  
  
VIT本质是给输入的图像块做encoder  
  
  <center>
  
![img](../pic/VIT2.png)
  </center>

## 1.1 整体

In [13]:
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

def pair(t):
    return t if isinstance(t, tuple) else (t, t)
    
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)  # 224*224
        patch_height, patch_width = pair(patch_size)  # 16 * 16

        num_patches = (image_height // patch_height) * (image_width // patch_width)  
        # 224//16 **2 =14*14=196
        patch_dim = channels * patch_height * patch_width  
        # patch的w乘h乘通道数c,将三维16*16*3=768拉平到二维

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            # (h p1)是将原本的224切分成了(224//16=14, 16);h,w是指高宽方向patch的个数。
            # 从(1,3,224,224)-->(1,196,768);;224的图片可分14*14个16尺寸的patch，
            # 一张图片共切14*14=196个patch。768=16*16*3
            nn.Linear(patch_dim, dim)  # 从768--1024
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  
        # 生成位置编码，cls和token
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 初始化cls
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)  # img (1, 3, 224, 224)  输出形状x : (1, 196, 1024)
        # patch形状为16*16，拉平为3*16*16=768，再通过Linear转为dim为1024（人为设置dim，是emb维度）则每个patch图像块对应的维度为1024.
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)  
        # 复制batch份，每个batch都要加一个cls---(bs,1,dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (1,197,1024) 每张图就一个cls
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)  # 输入输出维度不变！都是(1,197,1024)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]  # 提取第一个cls

        x = self.to_latent(x)
        return self.mlp_head(x)
        
v = ViT(
    image_size=224,  # 输入图像wh
    patch_size=16,
    num_classes=1000,
    dim=1024,  # dim是emb维度
    depth=6,  # encorder的个数
    heads=16,  # 多头注意力有几个头
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1
)

img = torch.randn(1, 3, 224, 224)

preds = v(img)

- 明确一点，在送完emb+pos进去之后，输出的形状没有变，送进去是(1,197,1024),输出也是(1,197,1024),可以用来接不同任务的下游。
- 由于原vit是分类任务，所以cls的操作是沿用bert的，每个输出前面设置一个cls来做整个图片的特征，在这里对所有的做mean(1)也可以。得到的都是(1,1024)  
  
**描述概括黑箱:**
- 对输入的图片先Linear扩增三分kqv，再经过nheads头的dotattention，最终再将nheads的输出拼接起来做Linear，完成MHA后(与上面的MHA一摸一样),再接一个MLP(输入输出维度不变)。
- kqv的形状为(bs,heads,n,emb_dim/heads),其中n为一张图片切出的块数(再加cls),进行完QK^T之后得到的是(bs,heads,n,n),这其中就表示了n个块之间的全局关系，在经过一个softmax后与V进行matmul即得到out。

## 1.2拆黑箱  
首先看Transformer整体，首先对每一次输入做一个LayerNorm，输入一个self-Attention和一个mlp

In [12]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x
        
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.norm(x))

### 1.2.1 attention

1. 注意，这里送进去attention的只有一个x，哪怕是nlp里面，kqv也是要送进去三个的(见书305页，或上面的MHA例子送进去的XYY)，所以这里要先将单个的x形成三分kqv，再送进去，这与nheads的观念并无冲突！！！
2. 送进去后，利用chunk分块，再使用map这行命令，将(bs,n,1024)分成了(bs,heads,n,1024/heads),这与上面的MHA也对上了！！！
3. 经过上面的操作，kqv都是(bs,heads,n,1024/heads),此时进行attention操作，注意这里的torch.matmul只对后两维进行矩阵乘法，进行完softmax(QK^T/d^0.5)V之后，得到out尺度仍为(bs,heads,n,1024/heads)
4. 使用rearrange(out, 'b h n d -> b n (h d)')将heads个头的维度进行合并！
5. 再进行MHA的Linear与Dropout

In [11]:
class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.1):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)  # 从1024Linear成1024*3

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # chunk是分块 x(1,197,1024)--(1,197,3072).chunk
        # 对tensor张量分块 x :(1, 197(196+1), 1024)   qkv 最后是一个元祖, tuple，长度是3，每个元素形状：(1, 197, 1024)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        # q(1,16,197,64),64=1024/16;;heads=16
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # 效果与bmm一致，更通用。
        # 得到了(1,16,197,197),表示的是197个块之间相互之间的关系
        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

### 1.2.2 再拆FF  
很常规的MLP

In [10]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)
