## 注意力机制Attention:

### 简介:

这一部分是 Transformer 模型的核心部分,以下部分逐步给出实现过程中可能用到的一些矩阵运算的原理，以下代码均不需要大家实现，希望大家阅读代码以及下列文档中的信息:

https://arxiv.org/abs/1706.03762

https://jalammar.github.io/illustrated-transformer/

理解 Attention 的运行机制以及实现过程的数学技巧，完成最后的主文件中的 HeadAttention(), MultiHeadAttention() 部分。

我们虚构一组输入数据的 Embedding 用于这部分讲解：

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

B, T, C = (
    1,
    8,
    16,
)  ## B: batch size 一次训练的数据量, T: context length 前文 token 数, C: embedding length 隐变量长度
inputData = torch.rand(size=(B, T, C))

for i in range(T):
    print(f"Embedding of {i}th position:\n {inputData[0,i]}")

Embedding of 0th position:
 tensor([0.0655, 0.2146, 0.0163, 0.3274, 0.7019, 0.7841, 0.5895, 0.9965, 0.9581,
        0.3506, 0.0609, 0.5496, 0.2838, 0.9680, 0.0972, 0.4693])
Embedding of 1th position:
 tensor([0.2820, 0.8082, 0.6518, 0.1029, 0.0392, 0.7936, 0.0991, 0.6869, 0.3117,
        0.4507, 0.4244, 0.6839, 0.4689, 0.1584, 0.1111, 0.4141])
Embedding of 2th position:
 tensor([0.9979, 0.0802, 0.0767, 0.3044, 0.4179, 0.7948, 0.1601, 0.4493, 0.1846,
        0.0096, 0.0117, 0.1438, 0.1879, 0.3382, 0.5580, 0.6590])
Embedding of 3th position:
 tensor([0.4356, 0.5603, 0.8388, 0.7373, 0.7282, 0.4072, 0.1054, 0.7673, 0.1046,
        0.9179, 0.4355, 0.4042, 0.3301, 0.5042, 0.0901, 0.0131])
Embedding of 4th position:
 tensor([0.7821, 0.6227, 0.4416, 0.5655, 0.7603, 0.3025, 0.7934, 0.2387, 0.9791,
        0.4235, 0.2423, 0.2274, 0.6070, 0.3381, 0.7119, 0.4682])
Embedding of 5th position:
 tensor([0.5184, 0.2964, 0.6380, 0.2912, 0.9893, 0.5858, 0.3418, 0.2993, 0.0957,
        0.9476, 0.5580, 0.6

Attention 从直观上可以理解为对前文各个位置信息的融合以获得当前语境所需的信息。一个最简单的融合方式为对前文 Embedding 加权求和作为当前位置的信息。

我们计算第 i 个位置的融合后的 embedding:

假设前 i 个位置的 embedding 的权重相同，均为 1/i，即更新后第 i 个位置 embedding 为前文所有位置 embedding 的平均值：

In [2]:
def Attention_version1(contextEmbeddings):
    for i in range(T):
        context_embeddings = contextEmbeddings[0, : i + 1, :]  ## shape [i+1, C]
        new_embedding_for_i = torch.mean(context_embeddings, dim=0)
        contextEmbeddings[0, i] = new_embedding_for_i
    return contextEmbeddings


print(
    "Embedding of Data after aggregate context embedding:\n",
    Attention_version1(inputData),
)

Embedding of Data after aggregate context embedding:
 tensor([[[0.0655, 0.2146, 0.0163, 0.3274, 0.7019, 0.7841, 0.5895, 0.9965,
          0.9581, 0.3506, 0.0609, 0.5496, 0.2838, 0.9680, 0.0972, 0.4693],
         [0.1738, 0.5114, 0.3341, 0.2151, 0.3706, 0.7888, 0.3443, 0.8417,
          0.6349, 0.4006, 0.2427, 0.6168, 0.3763, 0.5632, 0.1042, 0.4417],
         [0.4124, 0.2688, 0.1424, 0.2823, 0.4968, 0.7893, 0.3646, 0.7625,
          0.5925, 0.2536, 0.1051, 0.4367, 0.2827, 0.6231, 0.2531, 0.5233],
         [0.2718, 0.3888, 0.3329, 0.3905, 0.5744, 0.6923, 0.3510, 0.8420,
          0.5726, 0.4807, 0.2110, 0.5018, 0.3182, 0.6647, 0.1362, 0.3619],
         [0.3411, 0.4013, 0.2535, 0.3562, 0.5808, 0.6714, 0.4885, 0.7363,
          0.7475, 0.3818, 0.1724, 0.4665, 0.3736, 0.6314, 0.2605, 0.4529],
         [0.2972, 0.3469, 0.2862, 0.3105, 0.6189, 0.7186, 0.4133, 0.7464,
          0.6002, 0.4692, 0.2250, 0.5353, 0.3899, 0.6461, 0.2498, 0.5065],
         [0.2874, 0.3724, 0.2035, 0.2853, 0.5497, 0.

我们将上述的 mean 操作换为等价的矩阵运算，以 i=3 为例：

new_embedding_for_3 = torch.mean(contextEmbeddings[0,:3+1],dim=0)

等价于(@ 是矩阵乘法):

new_embedding_for_3 = contextEmbeddings[0] @ torch.tensor([1/4,1/4,1/4,1/4,0,0,0,0])

In [3]:
def Attention_version2(contextEmbeddings):
    for i in range(T):
        weight = torch.cat(
            (torch.ones(i + 1) / (i + 1), torch.zeros(T - i - 1, dtype=torch.float)),
            dim=0,
        )
        contextEmbeddings[0, i] = weight @ contextEmbeddings[0]
    return contextEmbeddings


print(
    "Attention_version1 equivalent to Attention_version2: ",
    torch.all(Attention_version1(inputData) == Attention_version2(inputData)).item(),
)

Attention_version1 equivalent to Attention_version2:  True


接下来我们用矩阵运算进一步简化上述运算，移除其中的 for 循环:

其中 weight = torch.tril(torch.ones(T,T)) 得到:

[[1., 0., 0., 0., 0., 0., 0., 0.],

 [1., 1., 0., 0., 0., 0., 0., 0.],

 [1., 1., 1., 0., 0., 0., 0., 0.],

 [1., 1., 1., 1., 0., 0., 0., 0.],

 [1., 1., 1., 1., 1., 0., 0., 0.],

 [1., 1., 1., 1., 1., 1., 0., 0.],

 [1., 1., 1., 1., 1., 1., 1., 0.],

 [1., 1., 1., 1., 1., 1., 1., 1.]]

表示前文的求和权重相同都为一。

weight = weight.masked_fill(weight==0,float("-inf"))

weight = F.softmax(weight)

这两行用于归一化 weight，即每一次加权求和的权重和为 1，具体详见 Softmax 公式，我们可得到：

[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],

[0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],

[0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],

[0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],

[0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],

[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],

[0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],

[0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]]


In [4]:
def Attention_version3(contextEmbeddings):
    B, T, C = contextEmbeddings.shape
    weight = torch.tril(torch.ones(T, T))
    print("weight of context embeddings:\n", weight)
    weight = weight.masked_fill(weight == 0, float("-inf"))
    weight = F.softmax(weight, dim=1)
    print("weight of context embeddings after regularization:\n", weight)
    contextEmbeddings[0] = weight @ contextEmbeddings[0]
    return contextEmbeddings


print(
    "Attention_version1 equivalent to Attention_version3: ",
    torch.all(Attention_version1(inputData) == Attention_version3(inputData)).item(),
)

weight of context embeddings:
 tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
weight of context embeddings after regularization:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.125

最后，我们确定计算 weight 的方法，上述三个版本的 weight 都是假定所有前文信息的重要程度相同，在大语言模型中，我们希望有一个灵活的方式计算前文信息对应当前语境的重要程度，为此 Transformer 引入了 Query，Key，Value:

其中 Query 可以理解为当前语境对于前文信息的需求，Key 可以理解为前文包含信息的索引，Value 为前文所包含的信息。

Query 和 Key 用来计算信息融合的 weight.

如何计算 Query 和 Key，并用他们计算 weight 对 Value 加权求和是这次实验的重点内容，这里不能给出大家具体代码，希望大家参见 Attention is All you need 原论文以及助教提供的文档最后的参考链接学习这部分。

利于 Query 和 Key 得出的是信息相关性，我们需要遮盖住下文的信息 (生成第 i 个 token 时，只可以使用 0 到 i-1 处的信息)，并且要对相关性归一化使之可以作为 weight。这里利于 Attension_version3() 中的结论给出如何对计算出来的相关性加掩码和归一化:

In [5]:
def weight_mask_and_normalization(weight):
    tril = torch.tril(torch.ones_like(weight))
    weight = weight.masked_fill(tril == 0, float("-inf"))
    weight = F.softmax(weight, dim=-1)
    return weight


weight = torch.rand(T, T)
print("weight before mask and normalize:\n", weight)
print("weight after mask and normalize:\n", weight_mask_and_normalization(weight))

weight before mask and normalize:
 tensor([[0.2582, 0.1002, 0.8570, 0.1501, 0.4511, 0.2914, 0.3238, 0.6036],
        [0.7463, 0.0219, 0.5533, 0.7125, 0.5138, 0.1145, 0.0069, 0.8642],
        [0.0614, 0.6251, 0.5986, 0.7518, 0.6477, 0.6475, 0.9583, 0.1678],
        [0.6510, 0.4634, 0.1915, 0.0537, 0.3881, 0.1949, 0.8625, 0.9028],
        [0.8041, 0.2990, 0.1533, 0.0212, 0.7180, 0.6434, 0.6996, 0.6003],
        [0.9653, 0.7097, 0.1601, 0.9392, 0.7792, 0.0873, 0.8680, 0.8393],
        [0.3106, 0.9449, 0.7335, 0.0592, 0.8797, 0.6614, 0.1452, 0.6447],
        [0.1547, 0.2637, 0.7468, 0.5537, 0.4040, 0.7567, 0.7452, 0.3971]])
weight after mask and normalize:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6736, 0.3264, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2238, 0.3932, 0.3829, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3321, 0.2753, 0.2098, 0.1828, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2857, 0.1724, 0.1491, 0.1306, 0