# 实现多头注意力

⚠️核心：多组查询 $W_q$ ，键 $W_k$，值 $W_v$ 权重矩阵。

## 方案一 - 顺序处理
堆叠多个CausalAttention模块来构建Multi-head Attention，

### 前置

In [1]:
# 引入 CausalAttention

import torch
import torch.nn as nn

class CausalAttention(nn.Module):
    
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # register_buffer 不需要手动确保这些张量与模型参数在同一设备上，从而避免设备不匹配错误
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # 将keys的第二维和第三维进行交换
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = attn_weights @ values
        return context_vec

### 代码 - MultiHeadAttentionWrapper

In [2]:
# 实现 MultiHeadAttentionWrapper 类

class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, 
                dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
             for _ in range(num_heads)]
        )
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)
        

#### 测试

In [3]:
# 测试：

from sympy import content


inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts   (x^3)
    [0.22, 0.58, 0.33], # with     (x^4)
    [0.77, 0.25, 0.10], # one      (x^5)
    [0.05, 0.80, 0.55]] # step     (x^6)   
)
batch = torch.stack((inputs, inputs), dim=0)

torch.manual_seed(123)
context_length = batch.shape[1]

d_in, d_out = 3, 1
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
content_vecs = mha(batch)

print(content_vecs)


tensor([[[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]],

        [[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)


## 练习3.2返回二维嵌入向量

### 题目

更改 MultiHeadAttentionWrapper (..., num_heads=2) 调用的输入参数，使输出上下文向量为二维而不是四维，同时保持 num_heads=2 的设置。提示：你不需要修改类的实现、你只需要更改其他一个输入参数。

### 解决

In [4]:
# 只需将上面代码块3修改
d_in, d_out = 3, 2
# 修改为：
d_in, d_out = 3, 1

## 方案二 - 并行处理

将 MutilHeadAttentionWrapper 和 CausalAttention合并成一个类

### 代码 - MutilHeadAttention

In [5]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads" # A
        self.d_out = d_out
        self.num_heads = num_heads # A
        self.head_dim = d_out // num_heads # A
        
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # B
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # 将[(b, num_tokens, d_out)] 重塑为 [(b, num_tokens, num_heads, num_dim)]
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        
        # 将[(b, num_tokens, num_heads, num_dim)] 转置为 [(b, num_heads, num_tokens, num_dim)]
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # [(b, num_heads, num_tokens, num_dim)] @ [(b, num_heads, num_dim, num_tokens)]
        # => attn_scores: [(b, num_heads, num_tokens, num_tokens)]
        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # [(b, num_heads, num_tokens, num_tokens)] @ [(b, num_heads, num_tokens, num_dim)]
        # =>[(b, num_heads, num_tokens, num_dim)] 再转置
        # => context_vec: [(b, num_tokens, num_heads, num_dim)]
        context_vec = (attn_weights @ values).transpose(1, 2)
        # [(b, num_tokens, num_heads, num_dim)] 重塑为 [(b, num_tokens, d_out)]
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # F
        context_vec = self.out_proj(context_vec) # F
        return context_vec

#### exp

* 上面关键关键是 d_out 维度分割 为 head_dim 和 num_heads

说明attn_scores = queries @ keys.transpose(2, 3)及操作：

In [6]:
# 例子：[(b, num_heads, num_tokens, num_dim)]

a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573], #A
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],
 
                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

print(a.shape)

torch.Size([1, 2, 3, 4])


In [7]:
print(a @ a.transpose(2, 3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


In [8]:
first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("first: ", first_head.shape)
print(first_res)

second_head = a[0, 0, :, :]
second_res = second_head @ second_head.T
print("second: ", second_head.shape)
print(second_res)

first:  torch.Size([3, 4])
tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])
second:  torch.Size([3, 4])
tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])


#### 测试 

In [9]:
torch.manual_seed(123)
bath_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)


## 练习3.3初始化GPT-2规模的MultiHeadAttention

最小 GPT-2 模型：12 个注意力头，输入和输出均768 维，1024 个 Token 的上下文长度

### 解决

In [10]:
mha = MultiHeadAttention(d_in=768, d_out=768, context_length=1024, dropout=0.0, num_heads=12)