# Attention核心思想

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import numpy as np

首先我们准备一些数据。

In [3]:
# 先假设一些参数
vec_size = 10
batch_size = 2
seq_len = 4
# 初始化一些数据
sents = torch.rand(batch_size,seq_len,vec_size)
# 这里我们都使用单位向量，这样结果比较明显
sents = f.normalize(sents, p=2, dim=2)
print(sents)

tensor([[[0.1126, 0.0947, 0.2096, 0.3071, 0.4642, 0.0592, 0.2444, 0.3948,
          0.4472, 0.4533],
         [0.0581, 0.4792, 0.1842, 0.1569, 0.1049, 0.3199, 0.4926, 0.2041,
          0.2906, 0.4759],
         [0.0434, 0.3689, 0.5747, 0.2188, 0.0666, 0.0442, 0.3289, 0.4632,
          0.3774, 0.1112],
         [0.5800, 0.2434, 0.2349, 0.2005, 0.3672, 0.0239, 0.2227, 0.0588,
          0.3605, 0.4366]],

        [[0.2391, 0.3953, 0.0980, 0.2839, 0.1098, 0.3597, 0.3866, 0.3801,
          0.4072, 0.3085],
         [0.0019, 0.1059, 0.1363, 0.2362, 0.3724, 0.2176, 0.3953, 0.2270,
          0.7115, 0.1198],
         [0.3452, 0.4973, 0.2242, 0.4729, 0.2477, 0.2316, 0.4679, 0.0854,
          0.0620, 0.1210],
         [0.0453, 0.2274, 0.1182, 0.3428, 0.3818, 0.2016, 0.2029, 0.4490,
          0.4221, 0.4554]]])


## 1.这里先尝试使用简单的单头self_attention来解释attention的思想。
attention的核心思想，简单的说就是，先用q和k对比，求一个相似性，然后根据这个相似性，按照比例提取v，组成新的embedding。

In [6]:
class SelfAttention(nn.Module):
    def __init__(self, vec_size):
        super().__init__()
        # （Q，K，V）3个w矩阵
        self.QueryW = nn.Linear(vec_size, vec_size)
        self.KeyW = nn.Linear(vec_size, vec_size)
        self.ValueW = nn.Linear(vec_size, vec_size)
        self.Score2Prob = nn.Softmax(dim=-1)
            
    def forward(self, sents):
        '''
        sents: [batch_size, seq_len, vec_size]
        '''
        print('sents.shape: ', sents.shape)
        # 先计算q，k，v向量
        Q = self.QueryW(sents)  #[batch_size, seq_len, vec_size]
        K = self.KeyW(sents)    #[batch_size, seq_len, vec_size]
        V = self.ValueW(sents)  #[batch_size, seq_len, vec_size]
        # 再计算q和k的相似值，并转成概率分布
        Scores = torch.matmul(Q, K.transpose(1,2).contiguous())  #[batch_size, seq_len, seq_len]
        Scores = self.Score2Prob(Scores)  #[batch_size, seq_len, seq_len]
        # 再输出新的embedding，即：相似值*V
        Out = torch.matmul(Scores,V)  #[batch_size, seq_len, vec_size]
        return Out

In [7]:
model = SelfAttention(vec_size=vec_size)
model(sents)

sents.shape:  torch.Size([2, 4, 10])


tensor([[[-0.0399,  0.0413,  0.2227,  0.0572,  0.1893, -0.1025, -0.2460,
           0.0029, -0.2069, -0.5122],
         [-0.0416,  0.0439,  0.2222,  0.0568,  0.1875, -0.1060, -0.2449,
           0.0014, -0.2077, -0.5153],
         [-0.0428,  0.0429,  0.2235,  0.0567,  0.1883, -0.1074, -0.2472,
           0.0032, -0.2050, -0.5153],
         [-0.0387,  0.0411,  0.2214,  0.0565,  0.1897, -0.1001, -0.2447,
           0.0017, -0.2089, -0.5110]],

        [[-0.1623,  0.0629,  0.1961, -0.0138,  0.2406, -0.1183, -0.2788,
          -0.0576, -0.1563, -0.4873],
         [-0.1616,  0.0633,  0.1948, -0.0140,  0.2412, -0.1179, -0.2779,
          -0.0574, -0.1575, -0.4875],
         [-0.1626,  0.0630,  0.1964, -0.0141,  0.2405, -0.1184, -0.2791,
          -0.0575, -0.1563, -0.4873],
         [-0.1623,  0.0631,  0.1963, -0.0133,  0.2408, -0.1190, -0.2786,
          -0.0577, -0.1558, -0.4876]]], grad_fn=<UnsafeViewBackward0>)

## 2.接下来我们尝试Multi-head Attention。
和单头相比，多头就是多个单头。

In [25]:
class MultiHeadAttention(nn.Module):
    def __init__(self, vec_size, head_num):
        super().__init__()
        self.head_num = head_num
        self.vec_size = vec_size
        self.QueryW = nn.Linear(vec_size, vec_size*head_num)
        self.KeyW = nn.Linear(vec_size, vec_size*head_num)
        self.ValueW = nn.Linear(vec_size, vec_size*head_num)
        self.Score2Prob = nn.Softmax(dim=2)
        self.OutW = nn.Linear(vec_size*head_num, vec_size)
        
    def forward(self, x):
        print('sents.shape: ', sents.shape)
        batch_size, seq_len, _ = x.shape
        vec_size = self.vec_size
        head_num = self.head_num
        # 同样先计算q，k，v向量，但这次是多头。
        Q = self.QueryW(x).view(batch_size, seq_len, head_num, vec_size)
        K = self.KeyW(x).view(batch_size, seq_len, head_num, vec_size)
        V = self.ValueW(x).view(batch_size, seq_len, head_num, vec_size)
        # 再计算q和k的相似值，并转成概率分布
        Q = Q.transpose(1,2).contiguous().view(batch_size*head_num, seq_len, vec_size)  # [batch_size*head_num, seq_len, vec_size]
        K = K.transpose(1,2).contiguous().view(batch_size*head_num, seq_len, vec_size)
        V = V.transpose(1,2).contiguous().view(batch_size*head_num, seq_len, vec_size)
        Score = torch.matmul(Q, K.transpose(1,2).contiguous())
        Score = Score/np.sqrt(vec_size)
        Score = self.Score2Prob(Score)  # [batch_size*head_num, seq_len, seq_len]
        # 再输出新的embedding，即：相似值*V
        Out = torch.matmul(Score,V)  # [batch_size*head_num, seq_len, vec_size]
        Out = Out.view(batch_size, head_num, seq_len, vec_size)  # [batch_size, head_num, seq_len, vec_size]
        # 将多头分别得到的新embed，转为一个embed
        Out = Out.transpose(1,2).contiguous().view(batch_size, seq_len, head_num*vec_size)  # [batch_size, seq_len, head_num*vec_size]
        Out = self.OutW(Out)
        return Out

In [26]:
model = MultiHeadAttention(vec_size=vec_size, head_num=3)
model(sents)

sents.shape:  torch.Size([2, 4, 10])


tensor([[[ 0.1886, -0.3139,  0.0119, -0.4636, -0.1918,  0.0862,  0.0148,
          -0.3379, -0.0632,  0.2887],
         [ 0.1890, -0.3137,  0.0117, -0.4632, -0.1927,  0.0860,  0.0143,
          -0.3378, -0.0633,  0.2893],
         [ 0.1888, -0.3136,  0.0120, -0.4637, -0.1925,  0.0863,  0.0144,
          -0.3378, -0.0633,  0.2892],
         [ 0.1890, -0.3136,  0.0120, -0.4639, -0.1925,  0.0863,  0.0143,
          -0.3378, -0.0628,  0.2891]],

        [[ 0.1472, -0.3800, -0.0130, -0.4557, -0.2222,  0.0819, -0.0080,
          -0.3445, -0.0653,  0.2745],
         [ 0.1473, -0.3795, -0.0122, -0.4558, -0.2221,  0.0819, -0.0080,
          -0.3437, -0.0651,  0.2745],
         [ 0.1473, -0.3800, -0.0133, -0.4558, -0.2224,  0.0822, -0.0079,
          -0.3446, -0.0656,  0.2748],
         [ 0.1472, -0.3798, -0.0129, -0.4557, -0.2221,  0.0818, -0.0079,
          -0.3443, -0.0651,  0.2745]]], grad_fn=<ViewBackward0>)