In [10]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert self.hidden_dim % self.num_heads == 0, "隐层维度必须能被头数整除."
        self.head_dim = self.hidden_dim // self.num_heads

        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim)

        self.fc_out = nn.Linear(self.hidden_dim, self.hidden_dim)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        
        query = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.einsum("bnih,bnjh->bnij", query, key) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        weights = torch.softmax(scores, dim=-1)
        output = torch.einsum("bnij,bnjh->bnih", weights, value)

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.fc_out(output)


def test_mha():
    batch_size = 8
    seq_length = 512
    hidden_dim = 128
    x = torch.randn((batch_size, seq_length, hidden_dim))
    mha = MultiHeadAttention(128, 8)
    out = mha(x)
    print(out.shape)


test_mha()


torch.Size([8, 512, 128])


In [15]:
# Multi-Qurey Attention
class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert self.hidden_dim % self.num_heads == 0, "隐层维度必须能被头数整除"
        self.head_dim = self.hidden_dim // self.num_heads

        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.k_proj = nn.Linear(self.hidden_dim, self.head_dim)
        self.v_proj = nn.Linear(self.hidden_dim, self.head_dim)

        self.fc_out = nn.Linear(self.hidden_dim, self.hidden_dim)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        
        query = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.k_proj(x).view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
        value = self.v_proj(x).view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)

        scores = torch.einsum("bnih,bnjh->bnij", query, key) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = torch.masked_fill(mask == 0, float('-inf'))
        weights = torch.softmax(scores, dim=-1)
        output = torch.einsum("bnij,bnjh->bnih", weights, value)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

        return self.fc_out(output)


def test_mqa():
    batch_size = 8
    seq_len = 512
    hidden_dim = 128
    x = torch.randn((batch_size, seq_len, hidden_dim))
    mqa = MultiQueryAttention(128, 16)
    out = mqa(x)
    print(out.shape)


test_mqa()

torch.Size([8, 512, 128])
