## Multi Head Attention
The original attention mechanism as introduced in the seminal paper [Vaswani et al. (2017), "Attention Is All You Need"](https://arxiv.org/abs/1706.03762). It splits the latent dimension in many heads that learns unique patterns by creating its own queries, keys and values. It is used in transformer models like the Bert and GPT.
![Multi Head Attention](./assets/mha.jpg)



In [1]:
import torch
import torch.nn as nn

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        hidden_dim,
        num_heads,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // num_heads
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.keys = nn.Linear(hidden_dim, hidden_dim)
        self.values = nn.Linear(hidden_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        b, s, d = x.shape
        query = self.query(x)
        keys = self.keys(x)
        values = self.values(x)
        
        query = query.view(b, s, self.num_heads, self.head_dim)
        keys = keys.view(b, s, self.num_heads, self.head_dim)
        values = values.view(b, s, self.num_heads, self.head_dim)

        query = query.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        
        atten_score = torch.softmax((query@keys.transpose(2,3))/(self.head_dim**0.5), dim=-1)
        
        out = atten_score@values
        out = out.transpose(1,2)
        out = out.contiguous().view(b,s,d)
        return self.output(out)
        

In [8]:
layer = MultiHeadAttention(256, 8)
x = torch.randn((128, 64, 256))
layer(x).shape

torch.Size([128, 64, 256])

## Group Query Attention
It is much more efficient than that of Multi Head Attention in terms of compute as in the paper [Ainslie et al. (2023), "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"](https://arxiv.org/abs/2305.13245).
It is used in efficient large models like the llama models.
![Group Query Attention](./assets/GQA.jpg)

In [9]:
import torch
import torch.nn as nn

In [52]:
class GroupQueryAttention(nn.Module):
    def __init__(
        self,
        hidden_dim,
        query_heads,
        kv_heads
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // query_heads
        self.q_heads = query_heads
        self.kv_heads = kv_heads
        self.query = nn.Linear(hidden_dim, self.q_heads*self.head_dim)
        self.keys = nn.Linear(hidden_dim, self.kv_heads*self.head_dim)
        self.values = nn.Linear(hidden_dim, self.kv_heads*self.head_dim)
        self.output = nn.Linear(self.kv_heads*self.head_dim, hidden_dim)
        
    def forward(self, x):
        b, s, d = x.shape
        query = self.query(x)
        keys = self.keys(x)
        values = self.values(x)

        query = query.view(b, s, self.q_heads, self.head_dim)
        keys = keys.view(b, s, self.kv_heads, self.head_dim)
        values = values.view(b, s, self.kv_heads, self.head_dim)
        
        query = query.transpose(1, 2)
        query = query.view(b, self.q_heads//self.kv_heads, self.kv_heads, s, self.head_dim)
        keys = keys.transpose(1, 2).unsqueeze(1)
        values = values.transpose(1, 2)
        
        atten_score = torch.sum(query@keys.transpose(3,4), dim=1)
        out = atten_score@values
        out = out.transpose(1,2).contiguous().view(b, s, self.kv_heads*self.head_dim)
        return self.output(out)
        


In [53]:
model = GroupQueryAttention(256, 8, 2)
x = torch.randn((16, 128, 256))
out = model(x)
out.shape

torch.Size([16, 128, 256])

## Multi Query Attention
It is the fastest among the above attention mechanisms but comes with a trade off of quality. It was introduced in the paper [Shazeer (2019), "Fast Transformer Decoding: One Write-Head is All You Need"](https://arxiv.org/abs/1911.02150)
It is a special case of Group Query Attention where the number of keys heads and value heads are equal to 1.
![Multi Query Attention](./assets/MQA.jpg)

In [54]:
model = GroupQueryAttention(256, 8, 1)
x = torch.randn(16, 128, 256)
out = model(x)
out.shape

torch.Size([16, 128, 256])