In [9]:
import numpy as np
from numpy.random import randn

n = 32 # 32个序列
d = 256 # 256维度
x = randn(d, n)
x.shape


(256, 32)

In [11]:
# Q, K, V参数矩阵, (256, 256)
wq, wk, wv = randn(d, d), randn(d, d), randn(d, d)
wq.shape, wk.shape, wv.shape

((256, 256), (256, 256), (256, 256))

In [13]:
# q, k, v (256, 32)
q, k, v = wq@x, wk@x, wv@x
q.shape, k.shape, v.shape

((256, 32), (256, 32), (256, 32))

In [16]:
# 点乘法相似度计算
A = k.T @ q
A.shape,v.shape, A

((32, 32),
 (256, 32),
 array([[ -732.66891071, -8472.4176827 ,  3711.77805307, ...,
         10197.57032438,  5222.92549552,  6951.3394169 ],
        [-1308.87817759,  7367.19530841,  1674.67180354, ...,
         -3569.54412008,  3847.23156104,  2188.1039832 ],
        [ 3839.50251184,  -363.10734661, -3172.05425493, ...,
          8739.95859637,  2033.54640443,  6338.81982138],
        ...,
        [ -242.12397722, -3290.2433743 ,  3510.43187706, ...,
         -1255.62576572, -2406.5724773 ,  2572.68847053],
        [ 4978.2472669 ,  1629.91830642, -4672.44925723, ...,
          1690.32594984,   595.42400959, -6710.26318199],
        [-5299.44241512,  2422.6729047 , -8947.90701009, ...,
          2428.86382326, -2648.97409254,  1740.87328814]]))

计算Q与K之间的点乘，然后为了防止其结果过大，会除以一个尺度标度 
$\sqrt{d_{k}}$, $d_{k}$为一个query和key向量的维度

In [17]:
A /= np.sqrt(d)
A

array([[ -45.79180692, -529.52610517,  231.98612832, ...,  637.34814527,
         326.43284347,  434.45871356],
       [ -81.8048861 ,  460.44970678,  104.66698772, ..., -223.09650751,
         240.45197256,  136.75649895],
       [ 239.96890699,  -22.69420916, -198.25339093, ...,  546.24741227,
         127.09665028,  396.17623884],
       ...,
       [ -15.13274858, -205.64021089,  219.40199232, ...,  -78.47661036,
        -150.41077983,  160.79302941],
       [ 311.14045418,  101.86989415, -292.02807858, ...,  105.64537187,
          37.2140006 , -419.39144887],
       [-331.21515095,  151.41705654, -559.24418813, ...,  151.80398895,
        -165.56088078,  108.80458051]])

In [18]:
def softmax(x):
    e_x = np.exp(x - np.max(x))  # 防溢出
    return e_x / e_x.sum(axis=0)


In [21]:
A_hat = softmax(A)
A_hat

array([[0.00000000e+000, 0.00000000e+000, 2.57991337e-171, ...,
        1.00000000e+000, 6.50832245e-067, 1.00000000e+000],
       [0.00000000e+000, 1.00000000e+000, 1.31100673e-226, ...,
        0.00000000e+000, 2.96791377e-104, 5.12354962e-130],
       [1.82280441e-126, 0.00000000e+000, 0.00000000e+000, ...,
        2.72555130e-040, 1.74928396e-153, 2.36664150e-017],
       ...,
       [0.00000000e+000, 0.00000000e+000, 8.83861666e-177, ...,
        0.00000000e+000, 0.00000000e+000, 1.40767853e-119],
       [1.47962031e-095, 1.86537458e-156, 0.00000000e+000, ...,
        1.21456146e-231, 0.00000000e+000, 0.00000000e+000],
       [0.00000000e+000, 6.14929427e-135, 0.00000000e+000, ...,
        1.35161247e-211, 0.00000000e+000, 3.71712365e-142]])

In [23]:
output = v @ A_hat
output, output.shape

(array([[ 27.58908341,  18.75863463, -25.88234845, ...,   2.6919639 ,
         -44.38314028,   2.6919639 ],
        [ -7.98291035,  21.47241536,  26.31114057, ..., -11.05100869,
          11.36196103, -11.05100869],
        [ 19.50827166,  15.27661364,   5.86700709, ...,   0.40453767,
           8.61082295,   0.40453767],
        ...,
        [ 12.94247054,   2.67552343,  36.90241521, ...,  -4.05490688,
         -15.62833438,  -4.05490688],
        [ 17.37422847, -27.89108878,  -6.05616558, ...,  -8.31898505,
          18.16583656,  -8.31898505],
        [ -5.33219367,  22.46682346, -14.92224288, ...,  -8.46626958,
          38.38477197,  -8.46626958]]),
 (256, 32))

In [30]:
from math import sqrt
import torch
import torch.nn as nn

In [86]:
class Self_Attention(nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self,input_dim,dim_k,dim_v):
        super(Self_Attention,self).__init__()
        self.q = nn.Linear(input_dim,dim_k)
        self.k = nn.Linear(input_dim,dim_k)
        self.v = nn.Linear(input_dim,dim_v)
        self._norm_fact = 1 / sqrt(dim_k)
        
    
    def forward(self,x):
        print(self.q.weight.shape)
        print(self.k.weight.shape)
        print(self.v.weight.shape)
        Q = self.q(x) # Q: batch_size * seq_len * dim_k
        K = self.k(x) # K: batch_size * seq_len * dim_k
        V = self.v(x) # V: batch_size * seq_len * dim_v
        print(Q.shape, K.shape , V.shape)
         
        # K.permute(0,2,1)# 将K的维度索引1和维度索引2交换位置, 也就是转置
        # torch.bmm# 两个tensor的矩阵乘法        
        atten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))) * self._norm_fact # Q * K.T() # batch_size * seq_len * seq_len
        
        output = torch.bmm(atten,V) # Q * K.T() * V # batch_size * seq_len * dim_v
        
        return output


注意：此处的randn 第一个为样本的个数，第二个为序列的个数，第三个为单个序列的维度也就是input_dim

In [87]:
x = torch.randn(1, 32, 64)
print(x.shape)

torch.Size([1, 32, 64])


In [90]:
self_attention = Self_Attention(input_dim=64,dim_k=d,dim_v=d * 2)
o = self_attention(x)
o.shape

torch.Size([256, 64])
torch.Size([256, 64])
torch.Size([512, 64])
torch.Size([1, 32, 256]) torch.Size([1, 32, 256]) torch.Size([1, 32, 512])


torch.Size([1, 32, 512])