In [6]:
import torch
from sympy.codegen.fnodes import dimension
from torch import nn

In [22]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.concat = nn.Linear(d_model, d_model, bias=False)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        # encodings_for_q: (batch, seq_len, d_model)
        # encodings_for_k: (batch, seq_len, d_model)
        # encodings_for_v: (batch, seq_len, d_model)
        batch, seq_len, d_model = encodings_for_q.size()
        n_d = self.d_model // self.n_head
        Q = self.W_q(encodings_for_q)
        K = self.W_k(encodings_for_k)
        V = self.W_v(encodings_for_v)
        Q = Q.view(batch, seq_len, self.n_head, n_d).permute(0, 2, 1, 3) # (batch, n_head, seq_len, d_model)
        K = K.view(batch, seq_len, self.n_head, n_d).permute(0, 2, 1, 3) # (batch, n_head, seq_len, d_model)
        V = V.view(batch, seq_len, self.n_head, n_d).permute(0, 2, 1, 3) # (batch, n_head, seq_len, d_model)
        scaled_sims = Q@K.transpose(2, 3) / torch.sqrt(torch.tensor(n_d))
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask == 0, -1e9)
        attention_percent = self.softmax(scaled_sims)
        attention_scores = attention_percent@V # (batch, n_head, seq_len, d_model)
        attention_scores = attention_scores.permute(0, 2, 1, 3).contiguous().view(batch, seq_len, d_model) # (batch, seq_len, d_model)

        return self.concat(attention_scores)


In [27]:
torch.manual_seed(42)
x = torch.rand(128, 32, 512) # (batch_size, seq_len, d_model）
multiHeadAttention = MultiHeadAttention(d_model=512, n_head=8)
print("多头注意力的输出大小：", mutiHeadAttention(x, x, x).shape)

多头注意力的输出大小： torch.Size([128, 32, 512])
