In [1]:
import torch
import torch.nn as nn

  cpu = _conversion_method_template(device=torch.device("cpu"))


### Grouped Query Attention (GQA)

A more compute & parameter **efficient** implementation of attention. The idea is to reduce the number of final query groups by reducing the number of KV projections.


**Shared Buffers**

- class that is used alongside rope to ruse the attention mask, sin and cos computations each subsequen prediction that improves effeciency


In [2]:
from rope import precompute_rope_params_llama3 as precompute_rope_params
class SharedBuffers:
    _buffers = {}

    @staticmethod
    def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):
        # if rope config is not none, get the (mask, cos, sin) config values, otherwise pass none
        key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)

        if key not in SharedBuffers._buffers:
            # Create or fetch the buffers
            mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
            cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)
            if dtype is not None:
                cos = cos.to(dtype)
                sin = sin.to(dtype)
            SharedBuffers._buffers[key] = (mask, cos, sin)

        return SharedBuffers._buffers[key]

In [3]:
class GroupedQueryAttention(nn.Module):
    def __init__(
            self, d_in, d_out, context_length, num_heads,
            num_kv_groups,       # NEW
            rope_base=10_000,    # NEW
            rope_config=None,    # NEW
            dtype=None
        ):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        # Set the dimensions of the q, k, v queries
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # Create the k and v weight matricies. 
        # Traditionally, the second dim is d_out. If num_kv_groups=1 we have Multi-Query, if num_kv_groups=num_head we have Multi-head attention
        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        # query weights are the same as MHA
        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
        
        # Fetch buffers using Shared buffers class
        mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)

        self.register_buffer("mask", mask)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)

#### Stepping through the forward pass


create fake input tokens


In [4]:
# Settings
batch_size = 2
context_length = 3000
max_context_len=8192
embed_dim  =4096
num_heads = 32
x_batch = torch.randn((batch_size, context_length, embed_dim))
# Create the batch inputs 
print(f"Batch Shape: {x_batch.shape}")
print(f"There are {x_batch.shape[0]} exmaples of size {x_batch.shape[1]}x{x_batch.shape[2]} in the batch")

Batch Shape: torch.Size([2, 3000, 4096])
There are 2 exmaples of size 3000x4096 in the batch


Lets initialize all of the layers using the GroupedQueryAttention class we created


In [5]:
num_kv_groups = 8 # can be tuned, but make sure it evenly divides the number of heads
gqa = GroupedQueryAttention(
    d_in=embed_dim, 
    d_out=embed_dim, 
    context_length=context_length,
    num_heads=num_heads,
    num_kv_groups=num_kv_groups,
    rope_base=500_000
    )

Now that we have all of our layers, lets step through the forward pass to see what is happening


In [6]:
# First get the dimensions of the input, in this case it is x_batch
b, num_tokens, d_in = x_batch.shape
x_batch.shape

torch.Size([2, 3000, 4096])

**Create the Q, K, and V weight matricies**


In [7]:
queries = gqa.W_query(x_batch)
keys = gqa.W_key(x_batch)
values = gqa.W_value(x_batch)

In [8]:
print(f"Input Shape: {x_batch.shape}")
print(f"Output Shape: {queries.shape}")

Input Shape: torch.Size([2, 3000, 4096])
Output Shape: torch.Size([2, 3000, 4096])


**reshape the queries, keys and values matricies**


In [9]:
queries = queries.view(b, num_tokens, gqa.num_heads, gqa.head_dim)

In [10]:
print(f"Output Shape: {queries.shape}")
print(f"See how we split the embedding dimension 4096 into two 32 x 128")

Output Shape: torch.Size([2, 3000, 32, 128])
See how we split the embedding dimension 4096 into two 32 x 128


Now lets do the same for the keys and the values. One thing to notice here is that we are using the number of key value groups here instead of the number of heads. This compresses the latent space of the key and value vectors, which reduces the load on the KV cache, making the attention head much more efficient. There is a trade off between the number of groups.
If the number of groups is equivalent to the number of heads we don't see any efficiency gains, but the attention mechanism is more accurate. This is the same as Multiheaded attention [paper]()

If the number of groups is 1, we get much better efficiency but an accuracy loss is often present. This is the same as Multi Query attention [paper]()


In [11]:
keys = keys.view(b, num_tokens, gqa.num_kv_groups, gqa.head_dim)
values = values.view(b, num_tokens, gqa.num_kv_groups, gqa.head_dim)

**Transposing Tensors to split the combined heads into their individual heads**


In [12]:
keys = keys.transpose(1, 2) # (batch, num_kv_groups, num_tokens, head_dim)
values = values.transpose(1, 2) # (batch, num_kv_groups, num_tokens, head_dim)
queries = queries.transpose(1, 2) # (batch, num_heads, num_tokens, head_dim)

**Apply RoPE**


In [13]:
from rope import compute_rope
keys = compute_rope(keys, gqa.cos, gqa.sin)
queries = compute_rope(queries, gqa.cos, gqa.sin)

**Expand the compressed key and value groups to match the number of heads**

This is what is happening here:

Before repeat_interleave along dim=1 (query groups):

- [K1, K2]

After repeat_interleave (each query group is repeated group_size times):

- [K1, K1, K2, K2]

If we used regular repeat instead of repeat_interleave, we'd get:

- [K1, K2, K1, K2]


In [14]:
keys = keys.repeat_interleave(gqa.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)
values = values.repeat_interleave(gqa.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)

**Self attention with causal mask**


In [15]:
keys.shape

torch.Size([2, 32, 3000, 128])

In [16]:
attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

In [18]:
attn_scores[0,0,:,:]

tensor([[ 4.3431, -3.0201, -3.2463,  ...,  0.4197,  4.4118, -4.0584],
        [ 0.8311,  1.7444,  1.4355,  ...,  0.3652, -3.0394, -0.8232],
        [ 0.9977, -1.2324, -3.9204,  ...,  3.1144, -2.0908,  4.2852],
        ...,
        [-1.2618,  4.5581, -3.4784,  ...,  5.6611, -2.2241, -0.0371],
        [ 3.5333, -1.4920, -5.0984,  ..., -1.1776, -0.7741,  2.7010],
        [ 1.8170, -0.0663,  0.4966,  ..., -1.7697,  5.8875, -0.6811]],
       grad_fn=<SliceBackward0>)

**Apply attention mask**


In [19]:
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = gqa.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

tensor([[[[  4.3431,     -inf,     -inf,  ...,     -inf,     -inf,     -inf],
          [  0.8311,   1.7444,     -inf,  ...,     -inf,     -inf,     -inf],
          [  0.9977,  -1.2324,  -3.9204,  ...,     -inf,     -inf,     -inf],
          ...,
          [ -1.2618,   4.5581,  -3.4784,  ...,   5.6611,     -inf,     -inf],
          [  3.5333,  -1.4920,  -5.0984,  ...,  -1.1776,  -0.7741,     -inf],
          [  1.8170,  -0.0663,   0.4966,  ...,  -1.7697,   5.8875,  -0.6811]],

         [[ -5.8828,     -inf,     -inf,  ...,     -inf,     -inf,     -inf],
          [ -4.4849,  -1.9317,     -inf,  ...,     -inf,     -inf,     -inf],
          [ -0.7706,   0.5707,   5.5048,  ...,     -inf,     -inf,     -inf],
          ...,
          [  3.9763,   0.8077,   1.6192,  ...,  -1.0558,     -inf,     -inf],
          [  3.5011,  -2.4071, -11.2116,  ...,  -1.1011,   7.8819,     -inf],
          [  0.8048,  -0.5970,  -3.8474,  ...,   5.7352,  -3.4377,   0.6053]],

         [[ -5.6287,     -inf,

**Normalize weights**


In [20]:
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
assert keys.shape[-1] == gqa.head_dim