In [3]:
import torch
import torch.nn.functional as F
# 一个形状为(batch_size, seq_len, feature_dim)的张量
x = torch.randn(2, 3, 4)  # 形状为(batch_size, seq_len, feature_dim)
# 定义头数和每个头的维度
num_heads = 2
head_dim = 2
# feature_dim必须是num_heads * head_dim的倍数
assert x.size(-1) == num_heads * head_dim
# 定义线性层用于将x转换为Q, K, V向量
linear_q = torch.nn.Linear(4, 4)
linear_k = torch.nn.Linear(4, 4)
linear_v = torch.nn.Linear(4, 4)
# 将x转换为Q, K, V向量
Q = linear_q(x) # 形状为(batch_size, seq_len, feature_dim)
K = linear_k(x) # 形状为(batch_size, seq_len, feature_dim)
V = linear_v(x) # 形状为(batch_size, seq_len, feature_dim)
# 将Q, K, V向量分割为num_heads个头
def split_heads(tensor, num_heads):
    batch_size, seq_len, feature_dim = tensor.size()
    head_dim = feature_dim // num_heads
    output = tensor.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
    return output # 形状为(batch_size, num_heads, seq_len, head_dim)
Q = split_heads(Q, num_heads) # 形状为(batch_size, num_heads, seq_len, head_dim)
K = split_heads(K, num_heads) # 形状为(batch_size, num_heads, seq_len, head_dim)
V = split_heads(V, num_heads) # 形状为(batch_size, num_heads, seq_len, head_dim)
# 计算Q和K的点积，作为相似度分数，也就是自注意力原始权重
raw_weights = torch.matmul(Q, K.transpose(-1, -2)) # 形状为(batch_size, num_heads, seq_len, seq_len)
# 对自注意力原始权重进行缩放
scaled_weights = raw_weights / (Q.size(-1) ** 0.5)
# 对权重进行softmax操作
attn_weights = F.softmax(scaled_weights, dim=-1) # 形状为(batch_size, num_heads, seq_len, seq_len)
# 使用注意力权重对V进行加权平均
attention_outputs = torch.matmul(attn_weights, V) # 形状为(batch_size, num_heads, seq_len, head_dim)
# 将多头注意力的输出重组为原始形状
def combine_heads(tensor):
    batch_size, num_heads, seq_len, head_dim = tensor.size()
    output = tensor.transpose(1, 2).contiguous().view(batch_size, seq_len, num_heads * head_dim)
    return output # 形状为(batch_size, seq_len, feature_dim)
attn_outputs = combine_heads(attention_outputs) # 形状为(batch_size, seq_len, feature_dim)
# 对多头注意力的输出进行线性变换
linear = torch.nn.Linear(4, 4)
attn_outputs = linear(attn_outputs) # 形状为(batch_size, seq_len, feature_dim)
print("加权信息：", attn_outputs)

加权信息： tensor([[[ 0.0143, -0.3493, -0.5976,  0.2996],
         [-0.1211, -0.4357, -0.6027,  0.3403],
         [ 0.0272, -0.3429, -0.6007,  0.3015]],

        [[-0.3424, -0.3070, -0.0536,  0.0419],
         [-0.3738, -0.3230, -0.0479,  0.0585],
         [-0.3385, -0.3043, -0.0551,  0.0508]]], grad_fn=<ViewBackward0>)
