In [12]:
import torch
import torch.nn.functional as F

# 1.创建两个张量x1和x2
x1 = torch.randn(2, 3, 4)  # 形状(batch_size, seq_len1, feature_dim)
x2 = torch.randn(2, 5, 4)  # 形状(batch_size, seq_len2, feature_dim)

# 2.计算原始权重
raw_weights = torch.bmm(x1, x2.transpose(1, 2))  # 形状为(batch_size, seq_len1, seq_len2)

# 2.5.对原始权重进行缩放
d = x1.size(-1)  # feature_dim
scaled_weights = raw_weights / (d ** 0.5)

# 3.用softmax函数对原始权重进行归一化
attn_weights = F.softmax(scaled_weights, dim=2)  # 形状为(batch_size, seq_len1, seq_len2)

# 4.将注意力权重与x2相乘，得到加权和
attn_output = torch.bmm(attn_weights, x2)  # 形状为(batch_size, seq_len1, feature_dim)

print("x1:", x1[0])
print("x2:", x2[0])

x1: tensor([[-0.1168,  1.2077,  1.6158, -0.1350],
        [-0.2722,  1.2552,  0.0547,  1.9958],
        [-0.2110,  0.8115,  0.3279, -0.0909]])
x2: tensor([[-2.1248,  0.1422, -1.3766, -0.1223],
        [-0.1991, -0.7151, -1.1571,  0.1581],
        [-0.2669, -0.4177, -0.7891, -0.4896],
        [-1.4461, -0.8789, -0.2778, -0.7230],
        [-0.1207, -0.0939,  0.6572, -0.5005]])


In [27]:
# 分析注意力分数
# 点积越大，两个向量的相似度越高
print("raw_weight:", raw_weights[0])
_, max_idx = torch.max(raw_weights[0], 1)
print("max_idx:", max_idx)
print("第一批次相似的向量")
for i, idx in enumerate(max_idx,):
    print("x1:", x1[0][i], "x2;", x2[0][idx])
    
# 使用余弦相似度来查看向量的相似度
cos_sim_matrix = F.cosine_similarity(x1[0].unsqueeze(1), x2[0].unsqueeze(0), dim=2)
print("cos_sim_matrix:", cos_sim_matrix)
_, max_idx = torch.max(raw_weights[0], 1)
print("max_idx:", max_idx)

raw_weight: tensor([[-1.7878, -2.7313, -1.6822, -1.2437,  1.0302],
        [ 0.4375, -0.5911, -1.4719, -2.1677, -1.0480],
        [ 0.1234, -0.9321, -0.4969, -0.4335,  0.2103]])
max_idx: tensor([4, 0, 4])
第一批次相似的向量
x1: tensor([-0.1168,  1.2077,  1.6158, -0.1350]) x2; tensor([-0.1207, -0.0939,  0.6572, -0.5005])
x1: tensor([-0.2722,  1.2552,  0.0547,  1.9958]) x2; tensor([-2.1248,  0.1422, -1.3766, -0.1223])
x1: tensor([-0.2110,  0.8115,  0.3279, -0.0909]) x2; tensor([-0.1207, -0.0939,  0.6572, -0.5005])
cos_sim_matrix: tensor([[-0.3478, -0.9747, -0.7891, -0.3300,  0.6055],
        [ 0.0726, -0.1800, -0.5890, -0.4906, -0.5255],
        [ 0.0537, -0.7444, -0.5216, -0.2574,  0.2766]])
max_idx: tensor([4, 0, 4])


In [28]:
# 分析归一化后的权重
print("attn_weights:", attn_weights[0])

attn_weights: tensor([[0.1237, 0.0772, 0.1304, 0.1624, 0.5063],
        [0.3662, 0.2190, 0.1410, 0.0996, 0.1743],
        [0.2424, 0.1430, 0.1778, 0.1835, 0.2532]])


In [29]:
# 分析加权和
print("attn_output:", attn_output[0])

attn_output: tensor([[-0.6090, -0.2823, -0.0749, -0.4376],
        [-1.0244, -0.2672, -0.7819, -0.2384],
        [-0.8870, -0.3271, -0.5241, -0.3535]])
