In [1]:

import torch
from torch import nn
import torch.nn.functional as f
import math

In [5]:
x = torch.rand(128, 32, 512)
d_model = 512
n_head = 8

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.d_model = d_model

        # TODO self.w_q = nn.Linear(d_model, d_model)
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, q, k, v, mask=None):
        batch, time, dimension = q.shape
        n_d = self.d_model//self.n_head
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)

        q = q.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        k = k.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        v = v.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)

        score = q @ k.transpose(2,3)/math.sqrt(n_d)

        # 是否使用掩码
        if mask is not None:
            score = score.masked_fill(mask==0, -10000)
        score = self.softmax(score)@v
        # 将加权值重新排列
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, time, dimension)
        output = self.w_combine(score)
        return output

In [14]:
# 实际应用
attention = MultiHeadAttention(d_model, n_head)

In [15]:
out = attention(x, x, x)
print(out)

tensor([[[-5.7451e-01, -5.1952e-02, -4.4980e-02,  ..., -1.2093e+00,
           5.3894e-01, -1.5803e+00],
         [-5.1599e-01, -1.0537e-01, -1.0344e-01,  ..., -1.2566e+00,
           4.8568e-01, -1.5943e+00],
         [-6.2014e-01, -5.2003e-02, -8.1265e-02,  ..., -1.2228e+00,
           5.4199e-01, -1.5605e+00],
         ...,
         [-6.3748e-01, -8.8462e-02, -1.0768e-01,  ..., -1.2628e+00,
           6.0105e-01, -1.5460e+00],
         [-5.8561e-01, -4.9528e-02, -3.4395e-02,  ..., -1.1627e+00,
           6.0745e-01, -1.5965e+00],
         [-5.5127e-01, -6.0567e-03,  4.3767e-02,  ..., -1.1436e+00,
           5.2303e-01, -1.5795e+00]],

        [[-5.2106e-01, -3.0776e-03, -4.0464e-01,  ..., -1.2116e+00,
           4.3165e-01, -1.6808e+00],
         [-5.6568e-01,  3.1164e-02, -3.6846e-01,  ..., -1.1646e+00,
           4.6545e-01, -1.5898e+00],
         [-5.5567e-01,  1.8965e-02, -3.3534e-01,  ..., -1.1327e+00,
           5.3946e-01, -1.6865e+00],
         ...,
         [-5.9939e-01,  3