In [4]:
import math
import torch
import torch.nn as nn

In [5]:
# omit attention_mask, attention_dropout
class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_dim, head_num, nums_key_value_head):
        super().__init__()
        assert hidden_dim%head_num == 0
        assert head_num%nums_key_value_head == 0 # n of query head as a group, nums_key_value_head is num of groups

        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim//head_num
        self.nums_key_value_head = nums_key_value_head 
        
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        
        self.k_proj = nn.Linear(hidden_dim, nums_key_value_head*self.head_dim)
        self.v_proj = nn.Linear(hidden_dim, nums_key_value_head*self.head_dim)

        self.o_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X, attention_mask = None):
        # X shape (batch, seq, hidden_dim)
        batch_size, seq, _ = X.size()

        #q, k, v projection
        q = self.q_proj(X)
        k = self.k_proj(X)
        v = self.v_proj(X)

        # attention_weight target: shape (batch, head_num, seq, seq)

        q = q.view(batch_size, seq, self.head_num, self.head_dim)
        k = k.view(batch_size, seq, self.nums_key_value_head, self.head_dim)
        v = v.view(batch_size, seq, self.nums_key_value_head, self.head_dim)

        # reshape to (batch, head_num, seq, head_dim)
        q = q.transpose(1,2)
        k = v.transpose(1,2)
        v = v.transpose(1,2)

        # augement k&v dimension by repeat
        k = k.repeat_interleave(self.head_num // self.nums_key_value_head, dim = 1) # repeat on dimension of nums_key_value_head
        v = v.repeat_interleave(self.head_num // self.nums_key_value_head, dim = 1)

        attention_score = (q @ k.transpose(2, 3)) / math.sqrt(self.head_dim) #cancel out head_dim seq

        attention_weight = torch.softmax(attention_score, dim = -1)

        # shape (batch, head_num, seq, seq) @ (batch, head_num, seq, head_dim) = (batch, head_num, seq, head_dim)
        output = attention_weight @ v

        # output projection to shape(batch, seq, hidden_dim)

        output = output.transpose(1,2).contigous() # menmeroy continuous 
        output = output.view(batch_size, seq, -1)
        final_output = self.o_proj(output)

        return final_output


        

