In [1]:
# 参考资料: https://www.jiqizhixin.com/articles/2024-02-16
# 一、文本输入转为向量嵌入
# 分词、转为token下标
import torch
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split(' ')))}
s_index = [dc[s] for s in sentence.replace(',', '').split(' ')]
s_index = torch.tensor(s_index)
s_index

tensor([0, 4, 5, 2, 1, 3])

In [2]:
# 对下标进行embed
vocab_size = 50_000
torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embeded_sentence = embed(s_index).detach()
embeded_sentence

tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])

In [5]:
# 二、自注意力机制
# 初始化q、k、v的权重矩阵
torch.manual_seed(123)
d = embeded_sentence.shape[1]
d_q, d_k, d_v = 2, 2, 4

w_q = torch.nn.Parameter(torch.rand(d, d_q))
w_k = torch.nn.Parameter(torch.rand(d, d_k))
w_v = torch.nn.Parameter(torch.rand(d, d_v))
print(embeded_sentence @ w_q)

tensor([[ 0.0327, -0.2112],
        [ 0.5667,  1.8269],
        [-0.0152, -0.7982],
        [-0.1037,  0.2902],
        [-0.0375,  0.5085],
        [-0.2816, -1.3567]], grad_fn=<MmBackward0>)


In [16]:
# 计算q、k、v
x_2 = embeded_sentence[2]
q_2 = x_2 @ w_q
k_2 = x_2 @ w_k
v_2 = x_2 @ w_v
print(q_2)
querys = embeded_sentence @ w_q
keys = embeded_sentence @ w_k
values = embeded_sentence @ w_v
print(querys)
print(keys)
print(values)

tensor([-0.0152, -0.7982], grad_fn=<SqueezeBackward4>)
tensor([[ 0.0327, -0.2112],
        [ 0.5667,  1.8269],
        [-0.0152, -0.7982],
        [-0.1037,  0.2902],
        [-0.0375,  0.5085],
        [-0.2816, -1.3567]], grad_fn=<MmBackward0>)
tensor([[-0.0823, -0.3031],
        [ 0.5295,  1.7355],
        [-0.2991, -0.7295],
        [ 0.1420,  0.2291],
        [ 0.1920,  0.6467],
        [-0.4788, -0.5835]], grad_fn=<MmBackward0>)
tensor([[-0.2546, -0.2608, -0.1544, -0.2801],
        [ 0.6612,  1.8972,  1.0963,  1.8106],
        [-0.8598, -0.6161, -0.5940, -0.9455],
        [ 0.5932,  0.0981,  0.2741,  0.4151],
        [ 0.5605,  0.5645,  0.3676,  0.6429],
        [-1.2107, -0.4929, -1.0081, -1.4031]], grad_fn=<MmBackward0>)


In [10]:
# 注意力权重ω(i,j) 是查询和键序列之间的点积 ω(i,j) = q⁽ⁱ⁾ k⁽ʲ⁾
omega_24 = q_2.dot(keys[4])
print(omega_24)

tensor(-0.5191, grad_fn=<DotBackward0>)


In [19]:
# 例：计算第三个词对整个序列的注意力权重 w, omega
omega_2 = q_2 @ keys.T
omega_2

tensor([ 0.2432, -1.3934,  0.5869, -0.1851, -0.5191,  0.4730],
       grad_fn=<SqueezeBackward4>)

In [20]:
# 归一化
import torch.nn.functional as F
attention_w_2 = F.softmax(omega_2/d_k ** 0.5, dim=0)
print(attention_w_2)

tensor([0.1965, 0.0618, 0.2506, 0.1452, 0.1146, 0.2312],
       grad_fn=<SoftmaxBackward0>)


In [21]:
# 使用归一化后的注意力，计算上下文向量嵌入
context_vec_2 = attention_w_2 @ values
context_vec_2

tensor([-0.3542, -0.1234, -0.2627, -0.3706], grad_fn=<SqueezeBackward4>)

In [25]:
# 将自注意力融合为一个类
import torch.nn as nn
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.w_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.w_key = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.w_value = nn.Parameter(torch.rand(d_in, d_out_v))
    
    def forward(self, x):
        keys = x @ self.w_key
        queries = x @ self.w_query
        values = x @ self.w_value
        attn_scores = queries @ keys.T
        # 得到归一化的，每个token彼此之间的注意力值，seq_length * seq_length
        attn_weights = torch.softmax(attn_scores/self.d_out_kq ** 0.5, dim=-1)
        # 得到在每一个value维度上，每个token使用自己与其他token的注意力 @ 该维度的value , seq_length * d_v
        context_vec = attn_weights @ values
        return context_vec

In [26]:
# 测试，结果中的第三行与上文计算的上下文嵌入一致
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3, 2, 4

sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embeded_sentence))

tensor([[-0.1564,  0.1028, -0.0763, -0.0764],
        [ 0.5313,  1.3607,  0.7891,  1.3110],
        [-0.3542, -0.1234, -0.2627, -0.3706],
        [ 0.0071,  0.3345,  0.0969,  0.1998],
        [ 0.1008,  0.4780,  0.2021,  0.3674],
        [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)


In [31]:
# 三、多头注意力
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(d_in, d_out_kq, d_out_v) for _ in range(num_heads)])
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [33]:
# 举例测试多头注意力机制
torch.manual_seed(123)

# 单头注意力，输出维度为seq_len * d_v
d_in, d_out_kq, d_out_v = 3, 2, 1
sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embeded_sentence))

# 多头注意力，在最后一个维度拼接，输出维度为 seq_len * (d_v * num_heads)
torch.manual_seed(123)
mha = MultiHeadAttentionWrapper(d_in, d_out_kq, d_out_v, num_heads=4)
context_vecs = mha(embeded_sentence)
print(context_vecs)

tensor([[-0.0185],
        [ 0.4003],
        [-0.1103],
        [ 0.0668],
        [ 0.1180],
        [-0.1827]], grad_fn=<MmBackward0>)
tensor([[-0.0185,  0.0170,  0.1999, -0.0860],
        [ 0.4003,  1.7137,  1.3981,  1.0497],
        [-0.1103, -0.1609,  0.0079, -0.2416],
        [ 0.0668,  0.3534,  0.2322,  0.1008],
        [ 0.1180,  0.6949,  0.3157,  0.2807],
        [-0.1827, -0.2060, -0.2393, -0.3167]], grad_fn=<CatBackward0>)


In [35]:
# 四、交叉注意力(从selfattention的基础上改)
import torch.nn as nn
class CrossAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.w_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.w_key = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.w_value = nn.Parameter(torch.rand(d_in, d_out_v))
    
    def forward(self, x1, x2):
        queries = x2 @ self.w_query
        
        keys = x1 @ self.w_key
        values = x1 @ self.w_value
        attn_scores = queries @ keys.T
        # 得到归一化的，每个x2的token对每个x1的token之间的注意力值，输出维度x2_seq_length * x1_seq_length
        attn_weights = torch.softmax(attn_scores/self.d_out_kq ** 0.5, dim=-1)
        # 得到在每一个value维度上，每个x2的token对每个x1的token之间的注意力值 @ x1的该token在该维度的value 
        # 输出维度为x2_seq_length * d_v，代表从value的不同维度上（角度/语境）对应不同注意力值，计算上下文的嵌入
        context_vec = attn_weights @ values
        return context_vec
        

In [36]:
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 4
cat = CrossAttention(d_in, d_out_kq, d_out_v)

x1 = embeded_sentence
x2 = torch.rand(8, d_in)
print(x1)
print(x2)

tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])
tensor([[0.2745, 0.6584, 0.2775],
        [0.8573, 0.8993, 0.0390],
        [0.9268, 0.7388, 0.7179],
        [0.7058, 0.9156, 0.4340],
        [0.0772, 0.3565, 0.1479],
        [0.5331, 0.4066, 0.2318],
        [0.4545, 0.9737, 0.4606],
        [0.5159, 0.4220, 0.5786]])


In [37]:
context_vecs = cat(x1, x2)
print(context_vecs)
print(context_vecs.shape)

tensor([[0.2628, 0.7515, 0.3963, 0.6775],
        [0.3689, 0.9600, 0.5367, 0.9030],
        [0.4914, 1.2517, 0.7219, 1.2023],
        [0.4381, 1.1187, 0.6384, 1.0672],
        [0.0906, 0.4545, 0.1880, 0.3441],
        [0.2374, 0.7029, 0.3635, 0.6248],
        [0.4167, 1.0701, 0.6070, 1.0166],
        [0.3376, 0.8998, 0.4955, 0.8371]], grad_fn=<MmBackward0>)
torch.Size([8, 4])


In [38]:
# 五、掩码自注意力（因果自注意力，Causal self-attention）
# 回顾自注意力
torch.manual_seed(123)

d_in, d_out_kq, d_out_v = 3, 2, 4

w_q = torch.nn.Parameter(torch.rand(d_in, d_out_kq))
w_k = torch.nn.Parameter(torch.rand(d_in, d_out_kq))
w_v = torch.nn.Parameter(torch.rand(d_in, d_out_v))

x = embeded_sentence

q = x @ w_q
k = x @ w_k
atten_scores = q @ k.T
print(atten_scores)

tensor([[ 0.0613, -0.3491,  0.1443, -0.0437, -0.1303,  0.1076],
        [-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374],
        [ 0.2432, -1.3934,  0.5869, -0.1851, -0.5191,  0.4730],
        [-0.0794,  0.4487, -0.1807,  0.0518,  0.1677, -0.1197],
        [-0.1510,  0.8626, -0.3597,  0.1112,  0.3216, -0.2787],
        [ 0.4344, -2.5037,  1.0740, -0.3509, -0.9315,  0.9265]],
       grad_fn=<MmBackward0>)


In [40]:
atten_weights = torch.softmax(atten_scores/d_out_kq**0.5, dim=1)
atten_weights

tensor([[0.1772, 0.1326, 0.1879, 0.1645, 0.1547, 0.1831],
        [0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],
        [0.1965, 0.0618, 0.2506, 0.1452, 0.1146, 0.2312],
        [0.1505, 0.2187, 0.1401, 0.1651, 0.1793, 0.1463],
        [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.1231],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<SoftmaxBackward0>)

In [41]:
# 使用torch.tril构建简单掩码矩阵
block_size = atten_scores.shape[0]
mask_simple = torch.tril(torch.ones(block_size, block_size))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [42]:
# 使用*构建掩码注意力
masked_atten = atten_weights * mask_simple
masked_atten

tensor([[0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0386, 0.6870, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1965, 0.0618, 0.2506, 0.0000, 0.0000, 0.0000],
        [0.1505, 0.2187, 0.1401, 0.1651, 0.0000, 0.0000],
        [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.0000],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<MulBackward0>)

In [48]:
# 再度进行归一化，逐行、注意力保持和为1
row_sums = masked_atten.sum(dim=1, keepdim=True)
masked_atten_norm = masked_atten / row_sums
print(masked_atten_norm)

# 使用掩码注意力计算上下文嵌入
v = x @ w_v
print(v)
masked_context_vec = masked_atten_norm @ v
print(masked_context_vec)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
        [0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
        [0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<DivBackward0>)
tensor([[-0.2546, -0.2608, -0.1544, -0.2801],
        [ 0.6612,  1.8972,  1.0963,  1.8106],
        [-0.8598, -0.6161, -0.5940, -0.9455],
        [ 0.5932,  0.0981,  0.2741,  0.4151],
        [ 0.5605,  0.5645,  0.3676,  0.6429],
        [-1.2107, -0.4929, -1.0081, -1.4031]], grad_fn=<MmBackward0>)
tensor([[-0.2546, -0.2608, -0.1544, -0.2801],
        [ 0.6124,  1.7823,  1.0298,  1.6994],
        [-0.4415, -0.1738, -0.2191, -0.3539],
        [ 0.1242,  0.4529,  0.2647,  0.4297],
        [ 0.2848,  0.6142,  0.3719,  0.6158],
        [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)


In [55]:
# 六、掩码自注意力更高效的实现方法
# 将上述 【计算注意力分数-》softmax权重-》掩码注意力-》归一化】 的过程使用 【负无穷掩码-》softmax】的方法实现
mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
print(mask.bool())
masked_atten = atten_scores.masked_fill(mask.bool(), -torch.inf)
print(masked_atten)
masked_atten_soft = torch.softmax(masked_atten, dim=1)
print(masked_atten_soft)

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])
tensor([[ 0.0613,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.6004,  3.4707,    -inf,    -inf,    -inf,    -inf],
        [ 0.2432, -1.3934,  0.5869,    -inf,    -inf,    -inf],
        [-0.0794,  0.4487, -0.1807,  0.0518,    -inf,    -inf],
        [-0.1510,  0.8626, -0.3597,  0.1112,  0.3216,    -inf],
        [ 0.4344, -2.5037,  1.0740, -0.3509, -0.9315,  0.9265]],
       grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0168, 0.9832, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3839, 0.0747, 0.5414, 0.0000, 0.0000, 0.0000],
        [0.2110, 0.3578, 0.1907, 0.2406, 0.0000, 0.0000],
        [0.1338, 0.3688, 0.1086,