# Attention核心思想

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
import numpy as np

首先我们准备一些数据。

In [2]:
# 先假设一些参数
vec_size = 10
batch_size = 2
seq_len = 4
# 初始化一些数据
sents = torch.rand(batch_size,seq_len,vec_size)
# 这里我们都使用单位向量，这样结果比较明显
sents = f.normalize(sents, p=2, dim=2)
print(sents)

tensor([[[0.2328, 0.1955, 0.4514, 0.3852, 0.4502, 0.1747, 0.1463, 0.0992,
          0.4605, 0.2810],
         [0.3783, 0.4232, 0.1628, 0.2301, 0.2317, 0.3193, 0.4110, 0.1351,
          0.1864, 0.4699],
         [0.4324, 0.4321, 0.1258, 0.3348, 0.3773, 0.2643, 0.0188, 0.4271,
          0.3183, 0.0460],
         [0.0199, 0.4182, 0.1126, 0.6190, 0.1190, 0.0155, 0.3584, 0.3132,
          0.4240, 0.0905]],

        [[0.2952, 0.0060, 0.4770, 0.0710, 0.1280, 0.2444, 0.2332, 0.1486,
          0.7071, 0.1662],
         [0.0251, 0.1355, 0.2461, 0.2914, 0.1460, 0.1670, 0.6920, 0.3603,
          0.3921, 0.1545],
         [0.1063, 0.2384, 0.7215, 0.1637, 0.1687, 0.2931, 0.3835, 0.0703,
          0.1135, 0.3243],
         [0.4006, 0.4957, 0.0089, 0.1899, 0.1000, 0.1672, 0.1317, 0.4992,
          0.0395, 0.5016]]])


## 1.这里先尝试使用简单的单头self_attention来解释attention的思想。
attention的核心思想，简单的说就是，先用q和k对比，求一个相似性，然后根据这个相似性，按照比例提取v，组成新的embedding。

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, vec_size):
        super().__init__()
        # （Q，K，V）3个w矩阵
        self.QueryW = nn.Linear(vec_size, vec_size)
        self.KeyW = nn.Linear(vec_size, vec_size)
        self.ValueW = nn.Linear(vec_size, vec_size)
        self.Score2Prob = nn.Softmax(dim=-1)
            
    def forward(self, sents):
        '''
        sents: [batch_size, seq_len, vec_size]
        '''
        print('sents.shape: ', sents.shape)
        # 先计算q，k，v向量
        Q = self.QueryW(sents)  #[batch_size, seq_len, vec_size]
        K = self.KeyW(sents)    #[batch_size, seq_len, vec_size]
        V = self.ValueW(sents)  #[batch_size, seq_len, vec_size]
        # 再计算q和k的相似值，并转成概率分布
        Scores = torch.matmul(Q, K.transpose(1,2).contiguous())  #[batch_size, seq_len, seq_len]
        Scores = self.Score2Prob(Scores)  #[batch_size, seq_len, seq_len]
        # 再输出新的embedding，即：相似值*V
        Out = torch.matmul(Scores,V)  #[batch_size, seq_len, vec_size]
        return Out

In [4]:
model = SelfAttention(vec_size=vec_size)
model(sents)

sents.shape:  torch.Size([2, 4, 10])


tensor([[[ 0.0109,  0.1471,  0.1141, -0.3939, -0.1368, -0.3375, -0.3938,
          -0.1718, -0.0318, -0.1074],
         [ 0.0109,  0.1472,  0.1141, -0.3939, -0.1368, -0.3374, -0.3938,
          -0.1718, -0.0318, -0.1073],
         [ 0.0095,  0.1473,  0.1144, -0.3940, -0.1366, -0.3370, -0.3942,
          -0.1720, -0.0334, -0.1068],
         [ 0.0105,  0.1461,  0.1141, -0.3941, -0.1360, -0.3375, -0.3934,
          -0.1725, -0.0324, -0.1070]],

        [[-0.0121,  0.1051,  0.1865, -0.3757, -0.1369, -0.2970, -0.3399,
          -0.2815, -0.1429, -0.2604],
         [-0.0124,  0.1052,  0.1873, -0.3763, -0.1371, -0.2966, -0.3395,
          -0.2797, -0.1408, -0.2595],
         [-0.0176,  0.1046,  0.1850, -0.3746, -0.1357, -0.2978, -0.3419,
          -0.2856, -0.1517, -0.2622],
         [-0.0201,  0.1046,  0.1852, -0.3750, -0.1354, -0.2980, -0.3426,
          -0.2861, -0.1540, -0.2625]]], grad_fn=<UnsafeViewBackward0>)

## 2.接下来我们尝试自己实现一下Multi-head Attention。
和单头相比，多头就是多个单头。

In [5]:
class MyMultiHeadAttention(nn.Module):
    def __init__(self, vec_size, head_num):
        super().__init__()
        self.head_num = head_num
        self.vec_size = vec_size
        self.QueryW = nn.Linear(vec_size, vec_size*head_num)
        self.KeyW = nn.Linear(vec_size, vec_size*head_num)
        self.ValueW = nn.Linear(vec_size, vec_size*head_num)
        self.Score2Prob = nn.Softmax(dim=2)
        self.OutW = nn.Linear(vec_size*head_num, vec_size)
        
    def forward(self, x):
        print('sents.shape: ', sents.shape)
        batch_size, seq_len, _ = x.shape
        vec_size = self.vec_size
        head_num = self.head_num
        # 同样先计算q，k，v向量，但这次是多头。
        Q = self.QueryW(x).view(batch_size, seq_len, head_num, vec_size)
        K = self.KeyW(x).view(batch_size, seq_len, head_num, vec_size)
        V = self.ValueW(x).view(batch_size, seq_len, head_num, vec_size)
        # 再计算q和k的相似值，并转成概率分布
        Q = Q.transpose(1,2).contiguous().view(batch_size*head_num, seq_len, vec_size)  # [batch_size*head_num, seq_len, vec_size]
        K = K.transpose(1,2).contiguous().view(batch_size*head_num, seq_len, vec_size)
        V = V.transpose(1,2).contiguous().view(batch_size*head_num, seq_len, vec_size)
        Score = torch.matmul(Q, K.transpose(1,2).contiguous())
        Score = Score/np.sqrt(vec_size)
        Score = self.Score2Prob(Score)  # [batch_size*head_num, seq_len, seq_len]
        # 再输出新的embedding，即：相似值*V
        Out = torch.matmul(Score,V)  # [batch_size*head_num, seq_len, vec_size]
        Out = Out.view(batch_size, head_num, seq_len, vec_size)  # [batch_size, head_num, seq_len, vec_size]
        # 将多头分别得到的新embed，转为一个embed
        Out = Out.transpose(1,2).contiguous().view(batch_size, seq_len, head_num*vec_size)  # [batch_size, seq_len, head_num*vec_size]
        Out = self.OutW(Out)
        return Out

In [6]:
model = MyMultiHeadAttention(vec_size=vec_size, head_num=3)
model(sents)

sents.shape:  torch.Size([2, 4, 10])


tensor([[[-0.1450,  0.2247, -0.0268, -0.0157, -0.0866, -0.3739,  0.2823,
          -0.1348,  0.0882,  0.0622],
         [-0.1447,  0.2250, -0.0263, -0.0158, -0.0873, -0.3738,  0.2825,
          -0.1345,  0.0883,  0.0619],
         [-0.1448,  0.2250, -0.0264, -0.0158, -0.0871, -0.3737,  0.2823,
          -0.1346,  0.0884,  0.0620],
         [-0.1447,  0.2248, -0.0263, -0.0156, -0.0872, -0.3739,  0.2824,
          -0.1351,  0.0881,  0.0623]],

        [[-0.1431,  0.2354,  0.0481,  0.0087, -0.1031, -0.3085,  0.2405,
          -0.1824,  0.1316,  0.0426],
         [-0.1426,  0.2348,  0.0483,  0.0087, -0.1037, -0.3087,  0.2404,
          -0.1829,  0.1314,  0.0421],
         [-0.1423,  0.2355,  0.0494,  0.0084, -0.1030, -0.3088,  0.2410,
          -0.1816,  0.1320,  0.0417],
         [-0.1413,  0.2345,  0.0498,  0.0087, -0.1033, -0.3098,  0.2417,
          -0.1814,  0.1326,  0.0405]]], grad_fn=<ViewBackward0>)

## 3.接下来我们尝试使用pytorch的Multi-head Attention。

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, vec_size, head_num):
        super().__init__()
        self.head_num = head_num
        self.vec_size = vec_size
        self.QueryW = nn.Linear(vec_size, vec_size*head_num)
        self.KeyW = nn.Linear(vec_size, vec_size*head_num)
        self.ValueW = nn.Linear(vec_size, vec_size*head_num)
        self.Multihead_attn = nn.MultiheadAttention(vec_size*head_num, head_num, batch_first=True)
        self.OutW = nn.Linear(vec_size*head_num, vec_size)
    
    def forward(self, x):
        print('sents.shape: ', sents.shape)
        batch_size, seq_len, _ = x.shape
        vec_size = self.vec_size
        head_num = self.head_num
        # 同样先计算q，k，v向量，但这次是多头。
        Q = self.QueryW(x)  #[batch_size, seq_len, vec_size]
        K = self.KeyW(x)
        V = self.ValueW(x)
        # mutihead attention
        Out, Weights = self.Multihead_attn(Q, K, V)
        # 将多头分别得到的新embed，转为一个embed
        Out = self.OutW(Out)
        print(Out.shape)
        return Out

In [8]:
model = MultiHeadAttention(vec_size=vec_size, head_num=3)
model(sents)

sents.shape:  torch.Size([2, 4, 10])
torch.Size([2, 4, 10])


tensor([[[ 0.1454, -0.0547,  0.0199, -0.2036,  0.2947,  0.0182,  0.0481,
          -0.0444,  0.0834, -0.2219],
         [ 0.1454, -0.0547,  0.0198, -0.2037,  0.2947,  0.0182,  0.0481,
          -0.0444,  0.0834, -0.2218],
         [ 0.1454, -0.0547,  0.0199, -0.2036,  0.2947,  0.0182,  0.0481,
          -0.0444,  0.0833, -0.2219],
         [ 0.1454, -0.0545,  0.0199, -0.2037,  0.2946,  0.0184,  0.0480,
          -0.0444,  0.0834, -0.2217]],

        [[ 0.1523, -0.0366, -0.0088, -0.1952,  0.2638,  0.0124,  0.0537,
          -0.0112,  0.1072, -0.2007],
         [ 0.1523, -0.0361, -0.0088, -0.1954,  0.2636,  0.0130,  0.0536,
          -0.0112,  0.1068, -0.2004],
         [ 0.1522, -0.0364, -0.0087, -0.1953,  0.2638,  0.0126,  0.0537,
          -0.0112,  0.1071, -0.2005],
         [ 0.1523, -0.0362, -0.0087, -0.1954,  0.2637,  0.0129,  0.0537,
          -0.0113,  0.1068, -0.2005]]], grad_fn=<AddBackward0>)