In [1]:
from imports import *
from attention import BaseMultiHeadedAttention

In [6]:
class BasePreTrainedGroupedQueryAttention(BaseMultiHeadedAttention):
    """
    Modifies *trained* Multi-Headed Attention to Grouped Query Attention as done in https://arxiv.org/pdf/2305.13245v3.pdf
    """
    def __init__(self,embed_dim,num_heads,num_groups):
        super().__init__(embed_dim,num_heads)
        self.num_groups = num_groups
        assert num_heads%num_groups==0, "num_heads must be divisible by num_groups"

    def _group(self,x):
        """
        Grouping with mean pooling as suggested by https://arxiv.org/pdf/2305.13245v3.pdf

        key,value shape: B,H,S,E => B,G,H//G,S,E ===mean pooling===> B,G,S,E
        To ensure order is correct, we permute to B,S,E,H then group to B,S,E,G,H//G
        Then mean pool B,S,E,G,H//G => B,S,E,G ===permute===> B,G,S,E
        Then Interleave repeat to B,H,S,E
        """

        B,H,S,E = x.shape
        G = self.num_groups

        x = x.permute([0,2,3,1]) # B,S,E,H
        x = x.reshape(B,S,E,G,H//G) # B,S,E,G,H//G
        x = x.mean(dim=-1) # B,S,E,G
        x = x.permute([0,3,1,2]) # B,G,S,E
        x = torch.repeat_interleave(x,H//G,dim=1)
        return x

    def construct_query_key_value(self, x):
        query,key,value =  super().construct_query_key_value(x)

        B,H,S,E = key.shape
        G = self.num_groups

        key = self._group(key)
        value = self._group(value)

        return query,key,value

In [8]:
from attention.test import AttentionTestCase

testing = AttentionTestCase(BasePreTrainedGroupedQueryAttention,B=2,S=3,E=32
                            ,model_kwargs={'embed_dim':32,'num_groups':2,'num_heads':8}).run()

Forward test passed!
Attention mask test passed!
KV cache test passed!


In [9]:
# Query: B,H,S,E
# Key: B,G,S,E

In [12]:
class BaseGroupedQueryAttention(BaseMultiHeadedAttention):
    """
    Implementation for training form scratch, 
    directly project to grouped query instead of mean pooling like done in Mistral
    """
    def __init__(self,embed_dim,num_heads,num_groups):
        self.num_groups = num_groups
        super().__init__(embed_dim,num_heads)
        assert num_heads%num_groups==0, "num_heads must be divisible by num_groups"

    def init_qkvo_proj(self):
        kv_head_embed_dim = self.num_groups * (self.embed_dim//self.num_heads)
        self.query_proj = nn.Linear(self.embed_dim,self.embed_dim)

        self.key_proj = nn.Linear(self.embed_dim,kv_head_embed_dim)
        self.value_proj = nn.Linear(self.embed_dim,kv_head_embed_dim)

        self.output_proj = nn.Linear(self.embed_dim,self.embed_dim)


    def construct_query_key_value(self, x):
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
     
        query = self._split_head(query,self.num_heads)
        key = self._split_head(key,self.num_groups)
        value = self._split_head(value,self.num_groups)

        key = torch.repeat_interleave(key,self.num_heads//self.num_groups,dim=1)
        value = torch.repeat_interleave(value,self.num_heads//self.num_groups,dim=1)

        return query,key,value

In [18]:
from attention.test import AttentionTestCase

# Same as multi-query
testing = AttentionTestCase(BaseGroupedQueryAttention,B=2,S=3,E=32
                            ,model_kwargs={'embed_dim':32,'num_groups':1,'num_heads':8}).run()

print()
# Same as grouped query
testing = AttentionTestCase(BaseGroupedQueryAttention,B=2,S=3,E=32
                            ,model_kwargs={'embed_dim':32,'num_groups':4,'num_heads':8}).run()

Forward test passed!
Attention mask test passed!
KV cache test passed!

Forward test passed!
Attention mask test passed!
KV cache test passed!
