In [1]:
# MQA

import torch
import torch.nn as nn
import math


In [32]:
class MQA(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        
        assert d_model%num_heads==0, "d_model must divisible by num_heads"
        self.head_dim = d_model//num_heads
        
        self.qeury = nn.Linear(d_model, self.d_model)
        self.key = nn.Linear(d_model, self.head_dim) 
        self.val = nn.Linear(d_model, self.head_dim)
        
        self.linear = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        query = self.qeury(x) 
        key = self.key(x).unsqueeze(1) # batch_size, 1,seq_len, head_dim
        val = self.val(x).unsqueeze(1) #batch_size, 1, seq_len, head_dim
        
        Q = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)  # [batch_size, num_heads, seq_length, head_dim]
        print(Q.size())
        print(key.transpose(-2,-1).size())
        # attention
        scores = torch.matmul(Q, key.transpose(-2,-1))/ math.sqrt(d_model)
        print(scores.size())
        
        scores = torch.softmax(scores,dim=-1) # batch_size, seq_len, seq_len
        attn_output = torch.matmul(scores, val) #batch_size, num_heads, seq_len, head_dim
        
        attn_output = attn_output.transpose(1,2).contiguous().view(batch_size, seq_len, self.num_heads*self.head_dim)
        
        attn_output = self.linear(attn_output)
        
        return attn_output, scores

In [33]:
batch_size, seq_len, d_model = 16, 10, 768

mqa = MQA(d_model, 12)

x = torch.randn(batch_size, seq_len, d_model)

output, _ = mqa(x)

print(f"output is {output.size()}")

torch.Size([16, 12, 10, 64])
torch.Size([16, 1, 64, 10])
torch.Size([16, 12, 10, 10])
output is torch.Size([16, 10, 768])


In [47]:
class GQA(nn.Module):
    def __init__(self, d_model, head_dim, num_q_heads, num_kv_groups=None):
        super().__init__()
        self.d_model = d_model
        self.head_dim = head_dim
        self.num_kv_groups = num_kv_groups
        self.num_q_heads = num_q_heads
        
        assert num_q_heads%num_kv_groups==0, "num_q_heads must be divisible by num_kv_groups"
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, num_kv_groups*head_dim)
        self.val = nn.Linear(d_model, num_kv_groups*head_dim)
        
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        Q = self.query(x) # batch_szie, seq_len, d_model
        K = self.key(x)   # batch_size, seq_len, num_kv_groups*head_dim
        V = self.val(x)   # batch_szie, seq_len, num_kv_groups*head_dim

        # Q*KT
        # Q ï¼šbatch_size, num_q_heads, seq_len, head_dim
        # K: batch_size, num_kv_groups, seq_len, head_dim
        Q = Q.view(batch_size, seq_len, self.num_q_heads, head_dim).transpose(1,2)
        K = K.view(batch_size, seq_len, self.num_kv_groups, self.head_dim).transpose(1,2)
        K = torch.repeat_interleave(K, self.num_q_heads//self.num_kv_groups, 1) # batch_size, num_q_heads, seq_len, head_dim
        
        scores = torch.matmul(Q, K.transpose(-2,-1))/math.sqrt(head_dim)
        # batch_size, num_q_heads, seq_len, seq_len
        # V batch_szie, seq_len, num_kv_groups*head_dim
        
        V = V.view(batch_size, seq_len, num_kv_groups, head_dim).transpose(1,2)
        V = torch.repeat_interleave(V, self.num_q_heads//self.num_kv_groups, 1) #batch_size, num_q_heads, seq_len, head_dim
        scores = torch.softmax(scores, dim = -1)
        
        attn_out = torch.matmul(scores, V)
        attn_out = attn_out.transpose(1,2).contiguous().view(batch_size, seq_len, num_q_heads*head_dim)
        
        output = self.out_proj(attn_out)
        
        return output, scores

In [48]:
batch_size, num_kv_groups, seq_len, head_dim = 1,2,5,8
num_q_heads = 4
# batch_size, num_kv_groups, seq_len, head_dim ->
#batch_size, num_q_heads, seq_len, head_dim
x = torch.randn(batch_size, num_kv_groups, seq_len, head_dim)
x[0]

tensor([[[ 1.7242, -0.8726,  1.7843, -0.0254, -0.1278,  0.4328,  0.7598,
          -0.7234],
         [-0.6868, -0.5239,  0.2407,  0.0415, -0.2173, -0.5755, -2.3127,
          -1.5684],
         [-1.2079,  0.7664,  0.7790, -0.9413,  0.3961, -1.3437, -0.6679,
          -0.3673],
         [-1.4134, -1.2172,  0.4095,  0.1899,  0.6765,  0.1250,  1.7811,
           1.0072],
         [ 0.6723,  2.2872, -0.6784, -0.1038, -0.7065, -0.1860,  0.9742,
           0.5806]],

        [[-0.0546, -0.0164, -0.1714,  0.3457,  1.0806,  0.8956,  0.6793,
           1.0495],
         [-0.1510,  0.5340,  0.1760,  1.1213, -1.0191,  0.9075, -0.7738,
          -0.0395],
         [-1.0886, -0.4489, -0.8932, -1.8782,  0.9766, -0.4804, -0.3618,
           0.6617],
         [ 0.7188, -0.2659,  0.0598, -0.6073,  1.0099, -2.1783,  2.0350,
          -1.0818],
         [-0.5871, -1.5166, -0.5091,  0.4130,  0.6491, -0.0533,  2.3273,
          -1.1284]]])

In [49]:
torch.repeat_interleave(x, 2, 1).size()

torch.Size([1, 4, 5, 8])

In [50]:
batch_size, seq_len, d_model = 16, 10, 768
gqa = GQA(d_model, 64 ,12, 4)

x = torch.randn(batch_size, seq_len, d_model)

output, _ = mqa(x)

print(f"output is {output.size()}")

torch.Size([16, 12, 10, 64])
torch.Size([16, 1, 64, 10])
torch.Size([16, 12, 10, 10])
output is torch.Size([16, 10, 768])


In [45]:
768//12

64