在自注意力机制中，我们的目标是计算输入序列中每个元素$x^{(i)}$的上下文向量

假设有如下输入句子（已嵌入为 3 维向量）

In [1]:
import torch
inputs = torch.tensor(
	[[0.43, 0.15, 0.89], # Your (x^1)
	 [0.55, 0.87, 0.66], # journey (x^2)
	 [0.57, 0.85, 0.64], # starts (x^3)
	 [0.22, 0.58, 0.33], # with (x^4)
	 [0.77, 0.25, 0.10], # one (x^5)
	 [0.05, 0.80, 0.55]] # step (x^6)
)

1. 计算注意力分数$\omega$（attention score）：

In [2]:
query = inputs[1]
attn_scores_1 = inputs @ query
print(attn_scores_1)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


2. 计算注意力权重$\alpha$（attention weights）：

In [3]:
attn_weights_1 = torch.softmax(attn_scores_1, dim=0)
print(attn_weights_1)

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


3. 计算上下文向量$z$（context vector）：

In [4]:
context_vec_1 = attn_weights_1 @ inputs
print(context_vec_1)

tensor([0.4419, 0.6515, 0.5683])


计算所有输入词元的上下文向量

In [5]:
attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=-1)
context_vecs = attn_weights @ inputs
print(context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


现在我们需要引入可训练权重（trainable weights），先定义三个可训练的矩阵：

In [None]:
d_in, d_out = inputs.shape[1], 2
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
W_value = torch.nn.Parameter(torch.rand(d_in, d_out))

计算 query：

In [7]:
x_1 = inputs[1]
query_1 = x_1 @ W_query
print(query_1)

tensor([0.4306, 1.4551], grad_fn=<SqueezeBackward4>)


计算所有输入词元的 value 和 key：

In [8]:
keys = inputs @ W_key
values = inputs @ W_value
print(keys, values)

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]], grad_fn=<MmBackward0>) tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]], grad_fn=<MmBackward0>)


计算注意力分数：

In [9]:
attn_scores_1 = query_1 @ keys.T
print(attn_scores_1)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
       grad_fn=<SqueezeBackward4>)


计算注意力权重，通过将注意力分数除以 key 的维度的平方根来实现缩放：

In [10]:
d_key = keys.shape[-1] # 2
attn_weights_1 = torch.softmax(attn_scores_1 / d_key ** 0.5, dim=-1)
print(attn_weights_1)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
       grad_fn=<SoftmaxBackward0>)


计算上下文向量：

In [11]:
context_vec_1 = attn_weights_1 @ values
print(context_vec_1)

tensor([0.3061, 0.8210], grad_fn=<SqueezeBackward4>)


使用线性层对自注意力机制进行封装

In [12]:
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        d_key = keys.shape[-1]
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / d_key ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        
        return context_vec

将包含了 6 个词元的输入传入该自注意力层中：

In [13]:
torch.manual_seed(123)
sa = SelfAttention(d_in, d_out)
print(sa(inputs))

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


下面对标准自注意力机制进行改进，构建因果注意力机制

第一种方式是在标准注意力机制的基础上，使用掩码将注意力权重右上部分全部变为 0，然后再重新进行一次归一化。

In [14]:
queries = sa.W_query(inputs)
keys = sa.W_key(inputs)
attn_scores = queries @ keys.T
d_key = keys.shape[-1]
attn_weights = torch.softmax(attn_scores / d_key ** 0.5, dim=-1)

sequence_length = attn_scores.shape[0]
mask = torch.tril(torch.ones(sequence_length, sequence_length))
masked_attn_scores = attn_weights * mask

row_sums = masked_attn_scores.sum(dim=-1, keepdim=True)
masked_attn_weights = masked_attn_scores / row_sums
print(masked_attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)


借助 softmax 函数的数学特性，我们能够以更精简的步骤高效计算掩码注意力权重：

In [15]:
mask = torch.triu(torch.ones(sequence_length, sequence_length), diagonal=1)
masked_attn_scores = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked_attn_weights = torch.softmax(masked_attn_scores / d_key ** 0.5, dim=-1)
print(masked_attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


在计算完注意力权重后施加 Dropout 掩码

In [16]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
print(dropout(masked_attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6380, 0.6816, 0.6804, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5090, 0.5085, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4120, 0.0000, 0.3869, 0.0000, 0.0000],
        [0.0000, 0.3418, 0.3413, 0.3308, 0.3249, 0.0000]],
       grad_fn=<MulBackward0>)


综合以上内容，实现一个带有 Dropout 对因果注意力模块：

In [17]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        tmp = torch.ones(context_length, context_length)
        self.register_buffer('mask', torch.triu(tmp, diagonal=1))
        
    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        d_key = keys.shape[-1]
        
        attn_scores = queries @ keys.transpose(1, 2)
        mask = self.mask.bool()[:num_tokens, :num_tokens] # type: ignore
        attn_scores.masked_fill_(mask, -torch.inf)
        attn_weights = torch.softmax(attn_scores / d_key ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vecs = attn_weights @ values
        return context_vecs

In [19]:
torch.manual_seed(123)
batch = torch.stack((inputs, inputs), dim=0)
sequence_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, sequence_length, 0.0)
context_vecs = ca(batch)
print("batch.shape:", batch.shape, "\ncontext_vecs.shape:", context_vecs.shape)

batch.shape: torch.Size([2, 6, 3]) 
context_vecs.shape: torch.Size([2, 6, 2])


下面将把因果注意力模块扩展至多头结构，即多头注意力

In [18]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, sequence_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, sequence_length, dropout, qkv_bias)
            for _ in range(num_heads)]
        )
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [20]:
torch.manual_seed(123)
sequence_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, sequence_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs.shape)

torch.Size([2, 6, 4])


可以通过并行处理的多头注意力来优化实现方案

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.d_out = d_out

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        tmp = torch.ones(context_length, context_length)
        self.register_buffer('mask', torch.triu(tmp, diagonal=1))
        
    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        d_key = keys.shape[-1]
        
        attn_scores = queries @ keys.transpose(2, 3)
        mask = self.mask.bool()[:num_tokens, :num_tokens] # type: ignore
        attn_scores.masked_fill_(mask, -torch.inf)
        attn_weights = torch.softmax(attn_scores / d_key ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vecs = (attn_weights @ values).transpose(1, 2)
        context_vecs = context_vecs.contiguous().view(batch_size, num_tokens, self.d_out)
        context_vecs = self.out_proj(context_vecs)
        return context_vecs