# 手写self-attention四重境界

### step1:了解公式


$$
\begin{align}
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V. \tag{1}
\end{align}
$$

### step2:开始写代码

In [2]:
### 第一重境界： 简化版本

import math
import torch
import torch.nn as nn

class SelfAttentionv1(nn.Module):
    def __init__(self, hidden_dim:int  =728):
        # 注意初始化
        super().__init__()
        self.hidden_dim = hidden_dim

        # Q K V
        self.Q_proj = nn.Linear(hidden_dim,hidden_dim)
        self.K_proj = nn.Linear(hidden_dim,hidden_dim)
        self.V_proj = nn.Linear(hidden_dim,hidden_dim)
    
    def forward(self,X):
        # X shape: [batch_size,seq_len,hidden_dim]

        # Q K V [batch_size,seq_len,hidden_dim]
        Q = self.Q_proj(X)
        K = self.K_proj(X)
        V = self.V_proj(X)
        
        # 注意力分数 
        # attention_scores shape: [batch_size,seq_len,seq_len]
        attention_scores = torch.matmul(Q,K.transpose(-2,-1)) / math.sqrt(self.hidden_dim)
        
        # 注意力权重
        attention_weights = torch.softmax(attention_scores,dim=-1)

        #输出 
        # output shape: [batch_size,seq_len,hidden_dim]
        output = torch.matmul(attention_weights,V)
        return output

X = torch.randn(3,2,4)
net = SelfAttentionv1(hidden_dim=4)
res = net(X)
print(res.shape)

torch.Size([3, 2, 4])


### 第二重： 效率优化

In [None]:
class SelfAttentionv2(nn.Module):
    def __init__(self,dim):
        super().__init__()

        self.dim = dim

        self.proj = nn.Linear(dim,dim * 3)
    
    def forward(self,x):
        """
        x: [batch, seq_len, dim]
        qkv: [batch, seq_len, dim * 3]
        """
        qkv = self.proj(x)
        Q,K,V = torch.split(qkv,self.dim,dim=-1)

        output = torch.softmax(Q @ K.transpose(-2,-1)/math.sqrt(self.dim),dim=-1)@V
        print(output)

        return output


### 第3重: 加入一些细节

In [13]:
# 1. dropout 位置
# 2. attention_mask 部分
# 3. output 矩阵映射(可选)

class SelfAttentionv3(nn.Module):
    def __init__(self, dim, dropout_rate=0.1):
        super().__init__()
        self.dim = dim 
        self.proj = nn.Linear(dim, dim * 3)

        self.attention_dropout = nn.Dropout(dropout_rate)

        # 可选的 output 矩阵映射
        self.output_proj = nn.Linear(dim, dim)

    def forward(self,x,attention_mask=None):
        # x shape [batch_size,seq_len,dim]
        qkv = self.proj(x)

        Q,K,V = torch.split(qkv,self.dim,dim=-1)

        attention_weight = Q @ K.transpose(-1,-2) / math.sqrt(self.dim)
        if attention_mask is not None:
           attention_weight = attention_weight.masked_fill(attention_mask == 0, float("-1e20"))

        attention_weight = torch.softmax(attention_weight,dim = -1)
        print("attention_weight shape:",attention_weight.shape)
        print("attention_weight :",attention_weight)
        # 最后对attention_weight进行dropout
        attention_weight = self.attention_dropout(attention_weight)
        # 增加了output_proj映射。
        output = self.output_proj(torch.matmul(attention_weight,V))

        return output

# shape [batch_size,seq_len,dim]
x = torch.randn(3,4,2)

# mask矩阵应该和 attention_weight 形状相同 [batch_size,seq_len,seq_len]
mask = torch.tensor(
    [
        [1,1,1,0],
        [1,1,0,0],
        [1,0,0,0]
    ]
)
# print("mask shape:",mask.shape)
# mask.unsqueeze(1)是把mask从[batch_size,seq_len]变成[batch_size,1,seq_len]
# 然后重复seq_len=4次，变成[batch_size,seq_len,seq_len]
mask = mask.unsqueeze(1).repeat(1,4,1)
# print("repeat mask shape:",mask.shape)
# print("repeat mask :",mask)
# mask shape: torch.Size([3, 4])
# repeat mask shape: torch.Size([3, 4, 4])


net =SelfAttentionv3(2)
net(x,mask)

        

attention_weight shape: torch.Size([3, 4, 4])
attention_weight : tensor([[[0.3425, 0.3293, 0.3282, 0.0000],
         [0.3375, 0.3199, 0.3426, 0.0000],
         [0.3619, 0.3724, 0.2658, 0.0000],
         [0.3705, 0.3969, 0.2326, 0.0000]],

        [[0.3423, 0.6577, 0.0000, 0.0000],
         [0.3595, 0.6405, 0.0000, 0.0000],
         [0.4004, 0.5996, 0.0000, 0.0000],
         [0.4761, 0.5239, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)


tensor([[[0.2458, 0.7761],
         [0.2472, 0.7549],
         [0.2281, 0.7387],
         [0.2187, 0.7188]],

        [[0.5555, 1.4392],
         [0.5634, 1.4581],
         [0.5824, 1.5031],
         [0.6175, 1.5865]],

        [[0.1588, 0.5139],
         [0.5770, 1.5795],
         [0.5770, 1.5795],
         [0.5770, 1.5795]]], grad_fn=<ViewBackward0>)

### 第4重: 面试写法

In [20]:
class SelfAttentionInterview(nn.Module):
    def __init__(self,dim:int ,dropout_rate: float = 0.1) -> None:
        super().__init__()
        self.dim = dim

        self.query = nn.Linear(dim,dim)
        self.key = nn.Linear(dim,dim)
        self.value = nn.Linear(dim,dim)

        self.attention_dropout = nn.Dropout(dropout_rate)
    
    def forward(self, x, attention_mask=None):
        # x [batch_size,seq_len,dim]
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # [bacth_size,seq_len,seq_len]
        attention_weight = Q @ K.transpose(-2,-1) / math.sqrt(self.dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float('-1e20')
            )
        
        attention_weight = torch.softmax(attention_weight,dim=-1)
        print(attention_weight)
        attention_weight = self.attention_dropout(attention_weight)
        # [batch_size,seq_len,dim]
        output = attention_weight @ V
        return output

x = torch.randn(3,3,4)

mask = torch.tensor(
    [
        [1,0,1],
        [1,0,0],
        [0,0,1]
    ]
)

mask = mask.unsqueeze(1).repeat(1,3,1)
net = SelfAttentionInterview(4)
output = net(x,mask)
print(output)
        

tensor([[[0.5791, 0.0000, 0.4209],
         [0.4932, 0.0000, 0.5068],
         [0.5426, 0.0000, 0.4574]],

        [[1.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 1.0000]]], grad_fn=<SoftmaxBackward0>)
tensor([[[-0.1486,  0.1214, -0.6503, -0.2755],
         [-0.0382,  0.0863, -0.6434, -0.2326],
         [-0.3743,  0.1591, -0.3713, -0.2634]],

        [[ 0.8472, -0.2117, -0.3991,  0.0823],
         [ 0.8472, -0.2117, -0.3991,  0.0823],
         [ 0.8472, -0.2117, -0.3991,  0.0823]],

        [[-0.0645,  0.2409, -0.2858, -0.4543],
         [-0.0645,  0.2409, -0.2858, -0.4543],
         [-0.0645,  0.2409, -0.2858, -0.4543]]], grad_fn=<UnsafeViewBackward0>)
