# Attention核心思想
这里尝试使用简单的单头self_attention来解释attention的思想。

In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as f

attention的核心思想，简单的说就是，先用q和k对比，求一个相似性，然后根据这个相似性，按照比例提取v，组成新的embedding。

In [62]:
class SelfAttention(nn.Module):
    def __init__(self, vec_size):
        super().__init__()
        # （Q，K，V）3个w矩阵
        self.QueryW = nn.Linear(vec_size, vec_size)
        self.KeyW = nn.Linear(vec_size, vec_size)
        self.ValueW = nn.Linear(vec_size, vec_size)
        self.Score2Prob = nn.Softmax(dim=-1)
            
    def forward(self, sents):
        '''
        sents: [batch_size, seq_len, vec_size]
        '''
        print('sents', sents.shape)
        # 先计算q，k，v向量
        Q = self.QueryW(sents)  #[batch_size, seq_len, vec_size]
        K = self.KeyW(sents)    #[batch_size, seq_len, vec_size]
        V = self.ValueW(sents)  #[batch_size, seq_len, vec_size]
        # 再计算q和k的相似值，并转成概率分布
        Scores = torch.matmul(Q, K.transpose(1,2).contiguous())  #[batch_size, seq_len, seq_len]
        print(Scores.shape)
        print(Scores)
        Scores = self.Score2Prob(Scores)  #[batch_size, seq_len, seq_len]
        print(Scores)
        print(Scores.shape)
        # 再输出新的embedding，即：相似值*V
        Out = torch.matmul(Scores,V)  #[batch_size, seq_len, vec_size]
        print('Out', Out.shape)
        return Out

In [60]:
# 先假设一些参数
vec_size = 10
batch_size = 2
seq_len = 4
# 初始化一些数据
sents = torch.rand(batch_size,seq_len,vec_size)
# 这里我们都使用单位向量，这样结果比较明显
sents = f.normalize(sents, p=2, dim=2)
print(sents)

tensor([[[0.1696, 0.0125, 0.1379, 0.3478, 0.3790, 0.2276, 0.4528, 0.2270,
          0.4609, 0.4082],
         [0.0771, 0.2745, 0.1747, 0.4133, 0.1001, 0.1052, 0.4344, 0.0607,
          0.4888, 0.5148],
         [0.3007, 0.0279, 0.0811, 0.3120, 0.4506, 0.3386, 0.4638, 0.2450,
          0.0686, 0.4554],
         [0.4082, 0.3927, 0.4035, 0.1330, 0.3023, 0.3567, 0.0814, 0.0258,
          0.1728, 0.4929]],

        [[0.3552, 0.3260, 0.0076, 0.2821, 0.1292, 0.4619, 0.3386, 0.4913,
          0.2997, 0.1095],
         [0.4076, 0.1469, 0.2115, 0.5363, 0.1666, 0.5930, 0.0947, 0.2785,
          0.0444, 0.1096],
         [0.0576, 0.4700, 0.0063, 0.1463, 0.3560, 0.2613, 0.4805, 0.0120,
          0.3639, 0.4427],
         [0.4564, 0.5436, 0.0129, 0.1896, 0.2353, 0.3753, 0.2794, 0.4229,
          0.0830, 0.0102]]])


In [63]:
model = SelfAttention(vec_size=vec_size)
print(model)
model(sents)

SelfAttention(
  (QueryW): Linear(in_features=10, out_features=10, bias=True)
  (KeyW): Linear(in_features=10, out_features=10, bias=True)
  (ValueW): Linear(in_features=10, out_features=10, bias=True)
  (Score2Prob): Softmax(dim=-1)
)
sents torch.Size([2, 4, 10])
torch.Size([2, 4, 4])
tensor([[[1.0000, 0.8921, 0.9020, 0.6954],
         [0.8921, 1.0000, 0.7390, 0.7076],
         [0.9020, 0.7390, 1.0000, 0.7453],
         [0.6954, 0.7076, 0.7453, 1.0000]],

        [[1.0000, 0.8352, 0.7078, 0.9250],
         [0.8352, 1.0000, 0.5001, 0.7811],
         [0.7078, 0.5001, 1.0000, 0.6655],
         [0.9250, 0.7811, 0.6655, 1.0000]]])
tensor([[[0.2823, 0.2535, 0.2560, 0.2082],
         [0.2629, 0.2929, 0.2256, 0.2186],
         [0.2626, 0.2231, 0.2897, 0.2245],
         [0.2263, 0.2290, 0.2378, 0.3068]],

        [[0.2839, 0.2408, 0.2120, 0.2634],
         [0.2603, 0.3069, 0.1862, 0.2466],
         [0.2433, 0.1977, 0.3259, 0.2332],
         [0.2692, 0.2331, 0.2076, 0.2901]]])
torch.Size([2, 4,