In [1]:
import torch
import torch.nn as nn

In [None]:
class self_attention(nn.Module):
    def __init__(self, input_dim, n_heads=8, qkv_bias=False, qk_scale=None, attn_dropout=0., proj_dropout=0.):
        super().__init__(self, self_attention)
        self.n_heads = n_heads
        self.head_dim = input_dim // n_heads
        self.qkv = nn.Linear(input_dim, input * 3, bias=qkv_bias) # Calculate q, k, v using a single FC.
        self.attn_drop = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(input_dim, input_dim)
        self.proj_dropout = nn.Dropout(proj_dropout)
        self.qk_scale = qk_scale or self.head_dim ** (-0.5)

    def forward(self, x):
        B, N, L = x.shape # (batch_size, number of sequence, sequence length)
    
        ## (B, N, 3, n_heads, head_dim) -> (3, B, n_heads, N, head_dim)
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2] # (B, n_heads, N, head_dim)
    
        ## (B, n_heads, N, head_dim) * (B, n_heads, head_dim, N) -> (B, n_heads, N, N)
        attn_score = (q @ k.transpose(-2, -1)).softmax(axis=-1) * self.qk_scale
    
        ## (B, n_heads, N, N) * (B, n_heads, N, head_dim) -> (B, N, n_heads, head_dim) -> (B, N, L)
        weighted_sum = (attn_score @ v).transpose(1, 2).reshape(B, N, L)
        weighted_sum = self.attn_dropout(weighted_sum)
        proj = self.proj(weighted_sum)
        proj = self.proj_dropout(proj)
    
        return proj

In [5]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads # 384/6=64 每个head中的token的维度
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5 # 计算Q和K的相似度时分母用到的数值=1/sqrt(64)=0.125

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 一次FC同时得到Q，K以及V三个矩阵
        self.attn_drop = nn.Dropout(attn_drop) # dropout:0-0.2, 12个等差数列
        self.proj = nn.Linear(dim, dim) # 多个head的输出进行concat后，再做一次矩阵变换得到multi-head attention的结果
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape  # [batch_size, num_patches+1(class token), total_embed_dim]

        # qkv(x): [batch_size, num_patches+1, 3*total_embed_dim] = [batchsize, 197, 3*384]
        # reshape() -> permute: [batchsize, num_patches+1, 3, 6, 384/6] -> [3, batchsize, 6, num_patches+1, 64]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # 获取q，k以及v矩阵，[batchsize, 6, 197, 64]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        # 将key矩阵的最后两个维度进行转置，高维矩阵乘法转换成两个维度的矩阵乘法 [batchsize, 6, 197, 64] * [batchsize, 6, 64, 197]
        attn = (q @ k.transpose(-2, -1)) * self.scale # [batchsize, 6, 197, 197]
        attn = attn.softmax(dim=-1) # 在最后一个维度上进行softmax也就是针对每一行进行softmax
        attn = self.attn_drop(attn)
        # attention * v：[batchsize, 6, 197, 64] -> [batchsize, 197, 6, 64] -> [batchsize, 197, 384]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)      # 进行一个线性变换得到multi-head attention的输出 [batch, 197, 384]
        x = self.proj_drop(x)
        return x

a = torch.rand((4, 196, 384))
t = Attention(384, 4, False)
t1 = nn.Linear(384, 196)
t1(a).shape

torch.Size([4, 196, 196])