In [None]:
# MHA multiheadattention

import torch
import torch.nn as nn
import torch.nn.functional as F

class MHA(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.h = num_heads
        self.dk = d_model // num_heads
        
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, D = x.shape
        
        Q = self.Wq(x).reshape(B, N, self.h, self.dk).transpose(1, 2)
        K = self.Wk(x).reshape(B, N, self.h, self.dk).transpose(1, 2)
        V = self.Wv(x).reshape(B, N, self.h, self.dk).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / (self.dk ** 0.5)
        A = F.softmax(scores, dim=-1)

        out = A @ V  # (B, h, N, dk)
        out = out.transpose(1, 2).reshape(B, N, D)
        
        return self.Wo(out)


In [None]:
#MQA multiqueryattention
import torch
import torch.nn as nn
import torch.nn.functional as F

class MQA(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.h = num_heads
        self.dk = d_model // num_heads

        self.Wq = nn.Linear(d_model, d_model)   # many Q
        self.Wk = nn.Linear(d_model, self.dk)   # one K
        self.Wv = nn.Linear(d_model, self.dk)   # one V

        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, D = x.shape
        
        Q = self.Wq(x).reshape(B, N, self.h, self.dk).transpose(1, 2)
        
        K = self.Wk(x).unsqueeze(1)  # (B, 1, N, dk)
        V = self.Wv(x).unsqueeze(1)  # (B, 1, N, dk)

        scores = (Q @ K.transpose(-2, -1)) / (self.dk ** 0.5)
        A = F.softmax(scores, dim=-1)

        out = A @ V  # (B, h, N, dk)
        out = out.transpose(1, 2).reshape(B, N, D)

        return self.Wo(out)


In [None]:
#GQA groupedqueryattention
import torch
import torch.nn as nn
import torch.nn.functional as F

class MQA(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.h = num_heads
        self.dk = d_model // num_heads

        self.Wq = nn.Linear(d_model, d_model)   # many Q
        self.Wk = nn.Linear(d_model, self.dk)   # one K
        self.Wv = nn.Linear(d_model, self.dk)   # one V

        self.Wo = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, N, D = x.shape
        
        Q = self.Wq(x).reshape(B, N, self.h, self.dk).transpose(1, 2)
        
        K = self.Wk(x).unsqueeze(1)  # (B, 1, N, dk)
        V = self.Wv(x).unsqueeze(1)  # (B, 1, N, dk)

        scores = (Q @ K.transpose(-2, -1)) / (self.dk ** 0.5)
        A = F.softmax(scores, dim=-1)

        out = A @ V  # (B, h, N, dk)
        out = out.transpose(1, 2).reshape(B, N, D)

        return self.Wo(out)
