## 1.DotProductAttention  
$$ \text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$  
其中Q为(bs,n,d),K为(bs,m,d),V为(bs,m,v)
要确保的是Q与K具有相同的长度(d)

In [6]:
import torch
from torch import nn
import math


In [7]:
# 前提是q和key需要有相同的长度d
class DotProducAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProducAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, keys, values):
        d = query.shape[-1]
        scores = torch.bmm(query, keys.transpose(1, 2)) / math.sqrt(d)
        attention_weight = nn.functional.softmax(scores, dim=-1)
        return torch.bmm(self.dropout(attention_weight), values)



In [8]:

query, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)  # 2,10,4
attention = DotProducAttention(dropout=0.1)
attention.eval()
print(attention(query, keys, values))

tensor([[[18., 19., 20., 21.]],

        [[18., 19., 20., 21.]]])


## 2.MSHA

In [9]:
def transpose_qkv(X, num_heads):
    # 输入的形状都为(bs,kv,hiddens)
    X = X.reshape(
        X.shape[0], X.shape[1], num_heads, -1
    )  # (bs,kv, heads, hiddens/heads)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])  # (bs*heads, kv, hiddens/heads)

def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])  # (bs, heads, kv, hiddens/heads)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class MultiHeadAttention(nn.Module):
    def __init__(
        self, key_size, query_size, value_size, num_hiddens, num_heads, dropout
    ):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.attention = DotProducAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=False)
        self.W_o = nn.Linear(
            num_hiddens, num_hiddens, bias=False
        )  # 看书295页，最后还会再接一个Linear的

    def forward(self, queries, keys, values):
        # qkv的形状都为(bs,kv,hiddens)
        # 经过变换后，qkv的形状为(bs*num_heads, kv, hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        output = self.attention(queries, keys, values)
        # output(bs*num_heads, kv, hiddens/num_heads)
        output_cat = transpose_output(output, self.num_heads)
        return self.W_o(output_cat)


In [10]:

num_hiddens, num_heads = 100, 5

attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.2)

attention.eval()

batch_size, num_queries = 2, 4
num_kv = 6
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kv, num_hiddens))
result = attention(X, Y, Y)
print(result.shape)

torch.Size([2, 4, 100])
