In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert self.hidden_dim % self.num_heads == 0, "隐层维度必须能被头数整除."
        self.head_dim = self.hidden_dim // self.num_heads

        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim)

        self.fc_out = nn.Linear(self.hidden_dim, self.hidden_dim)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        
        query = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.einsum("bnih,bnjh->bnij", query, key) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        weights = torch.softmax(scores, dim=-1)
        output = torch.einsum("bnij,bnjh->bnih", weights, value)

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.fc_out(output)


def test_mha():
    batch_size = 8
    seq_length = 512
    hidden_dim = 128
    x = torch.randn((batch_size, seq_length, hidden_dim))
    mha = MultiHeadAttention(128, 8)
    out = mha(x)
    print(out.shape)


test_mha()


torch.Size([8, 512, 128])


In [15]:
# Multi-Qurey Attention
class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert self.hidden_dim % self.num_heads == 0, "隐层维度必须能被头数整除"
        self.head_dim = self.hidden_dim // self.num_heads

        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.k_proj = nn.Linear(self.hidden_dim, self.head_dim)
        self.v_proj = nn.Linear(self.hidden_dim, self.head_dim)

        self.fc_out = nn.Linear(self.hidden_dim, self.hidden_dim)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        
        query = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.k_proj(x).view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
        value = self.v_proj(x).view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)

        scores = torch.einsum("bnih,bnjh->bnij", query, key) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = torch.masked_fill(mask == 0, float('-inf'))
        weights = torch.softmax(scores, dim=-1)
        output = torch.einsum("bnij,bnjh->bnih", weights, value)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)

        return self.fc_out(output)


def test_mqa():
    batch_size = 8
    seq_len = 512
    hidden_dim = 128
    x = torch.randn((batch_size, seq_len, hidden_dim))
    mqa = MultiQueryAttention(128, 16)
    out = mqa(x)
    print(out.shape)


test_mqa()

torch.Size([8, 512, 128])


In [8]:
# Group-Query Attention
class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_groups):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        assert self.hidden_dim % self.num_heads == 0, "隐层维度必须被头数整除"
        self.head_dim = self.hidden_dim // self.num_heads
        self.num_groups = num_groups
        assert self.num_heads % self.num_groups == 0, "头数必须被组数整除"
        self.head_per_group = self.num_heads // self.num_groups

        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.k_proj = nn.Linear(self.hidden_dim, self.head_per_group * self.head_dim)
        self.v_proj = nn.Linear(self.hidden_dim, self.head_per_group * self.head_dim)

        self.fc_out = nn.Linear(self.hidden_dim, self.hidden_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        # Multi-Head Query
        query = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # Grouped Key and Value
        key = self.k_proj(x).view(batch_size, seq_len, self.head_per_group, self.head_dim).transpose(1, 2)
        value = self.v_proj(x).view(batch_size, seq_len, self.head_per_group, self.head_dim).transpose(1, 2)

        # Split Queries into Groups and Compute Attention within each group
        grouped_out = []
        for group_idx in range(self.num_groups):
            q_group = query[:, group_idx * self.head_per_group: (group_idx + 1) * self.head_per_group]
            k_group = key[:, group_idx].unsqueeze(1)
            v_group = value[:, group_idx].unsqueeze(1)

            scores = torch.einsum("bnih,bnjh->bnij", q_group, k_group) / math.sqrt(self.head_dim)
            if mask is not None:
                scores = torch.masked_fill(mask == 0, float("-inf"))
            weights = torch.softmax(scores, dim=-1)

            output = torch.einsum("bnij,bnjh->bnih", weights, v_group)
            grouped_out.append(output)

        grouped_out = torch.cat(grouped_out, dim=-1).view(batch_size, seq_len, -1)
        
        return self.fc_out(grouped_out)


def test_gqa():
    batch_size = 8
    seq_len = 1024
    hidden_dim = 512
    x = torch.randn((batch_size, seq_len, hidden_dim))
    gqa = GroupQueryAttention(512, 32, 4)
    output = gqa(x)
    print(output.shape)


test_gqa()


torch.Size([8, 1024, 512])
