# Grouped Query Attention

### In this notebook we will compare Grouped Query Attetion with two othe forms of handling attention heads.

Authored by: Michael Haidar


### Grouped-Querry Attention provides three major advantages over other architectures such as Multi-Headed Attetion and Multi-Querry Attetion:
  ##### - Efficiency: 
    - GQA strikes a balance between flexibility and efficiency. By reducing the number of key-value pairs, GQA      reduces memory consumption and computation cost, which is crucial for large-scale models.
   #####  - Parameter Sharing:
    - Shared key-value pairs within groups still allow some diversity in the attention patterns, especially since the query projections remain unique for each head. This means that GQA can capture useful relationships without needing as many parameters as MHA.
   #####  - Scalability:
    - For very large models, where the number of heads is high (e.g., 64 or 128 heads in transformer-based large language models), GQA provides a scalable solution that reduces memory and computational demands while still maintaining adequate model performance.

In [6]:
import time
import tensorflow as tf
from tensorflow.keras import layers, Model
import tracemalloc
class style():
  RED = '\033[31m'
  GREEN = '\033[32m'
  BLUE = '\033[34m'
  RESET = '\033[0m'

In [29]:
# Define model parameters
batch_size = 8
seq_len = 5
d_model = 16  # Hidden size of model
num_heads = 4  # Number of attention heads
d_k = d_model // num_heads  # Dimensionality per head
group_size = 2  # Number of heads in each group (for GQA)

In [38]:
# Scaled dot-product attention
def scaled_dot_product_attention(q, k, v):
    matmul_qk = tf.matmul(q, k, transpose_b=True)  # [batch_size, num_heads, seq_len_q, seq_len_k]
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
    
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # [batch_size, num_heads, seq_len_q, seq_len_k]
    output = tf.matmul(attention_weights, v)  # [batch_size, num_heads, seq_len_q, depth_v]
    return output

## Multi-Headed Attention (MHA)
- In MHA, each head has its own independent query, key, and value projections. 
    - This allows each head to attend to the input sequence in its own unique way, giving the model more flexibility to learn different types of relationships between tokens.
- Since every head can form its own independent attention scores and output, the manifold of possible representations is higher-dimensional. 
    - This is because the model can learn different attention patterns for each head, covering a wide variety of input-output mappings.

In [None]:

# Multi-Head Attention (MHA): Independent key-value pairs per head
class MHA(layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MHA, self).__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Key, value, and query layers (unique per head)
        self.w_q = layers.Dense(d_model)
        self.w_k = layers.Dense(d_model)
        self.w_v = layers.Dense(d_model)
        
        # Output linear layer
        self.dense = layers.Dense(d_model)
    
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth)."""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_k))
        return tf.transpose(x, perm=[0, 2, 1, 3])  # [batch_size, num_heads, seq_len, d_k]
    
    def call(self, query, key, value):
        print('=================MHA==================')
        batch_size = tf.shape(query)[0]
        
        # Compute query, key, value projections independently for each head
        Q = self.split_heads(self.w_q(query), batch_size)
        K = self.split_heads(self.w_k(key), batch_size)
        V = self.split_heads(self.w_v(value), batch_size)
        
        # Print shapes (Optional, for checking differences)
        print(f"MHA Query Shape (Q): {Q.shape}")
        print(f"{style.RED}MHA Key Shape (K): {K.shape}")
        print(f"MHA Value Shape (V): {V.shape}"+ style.RESET)
        
        attention_output = scaled_dot_product_attention(Q, K, V)
        
        # Concatenate attention output
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention_output, (batch_size, -1, self.num_heads * self.d_k))
        
        # Final linear layer
        output = self.dense(concat_attention)
        return output
 

## Group-Querry Attention (GQA)
- In GQA, multiple heads share key-value pairs within groups, meaning that the attention mechanism is constrained by these shared parameters. 
    - The heads within a group will compute attention scores based on the same key-value projections, although each head can still have its own query projections.
- This reduces the degrees of freedom for the model because the number of independent key-value pairs is smaller than in MHA. 
    - As a result, the possible space of attention outputs (the manifold) is lower-dimensional compared to MHA, because fewer independent attention computations are possible.


In [None]:

# Grouped-Query Attention (GQA)
class GQA(layers.Layer):
    def __init__(self, d_model, num_heads, group_size):
        super(GQA, self).__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.group_size = group_size
        self.num_groups = num_heads // group_size
        
        # Shared key and value layers per group
        self.w_k = layers.Dense(d_model)
        self.w_v = layers.Dense(d_model)
        
        # Query layer (unique per head)
        self.w_q = layers.Dense(d_model)
        
        # Output linear layer
        self.dense = layers.Dense(d_model)
    
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth)."""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_k))
        return tf.transpose(x, perm=[0, 2, 1, 3])  # [batch_size, num_heads, seq_len, d_k]
    
    def call(self, query, key, value):
        print('=================GQA==================')
        batch_size = tf.shape(query)[0]
        
        # Compute query projections
        Q = self.split_heads(self.w_q(query), batch_size)
        print(f"GQA Query Shape (Q): {Q.shape}")
        
        # Compute key and value projections per group
        K = self.split_heads(self.w_k(key), batch_size)[:, :self.num_groups, :, :]
        V = self.split_heads(self.w_v(value), batch_size)[:, :self.num_groups, :, :]
        print(f"{style.GREEN}GQA Key Shape (K): {K.shape}")
        print(f"GQA Value Shape (V): {V.shape}" + style.RESET)
        
        outputs = []
        print('----------------Groups--------------')
        for i in range(self.num_groups):
            Q_group = Q[:, i * self.group_size:(i + 1) * self.group_size, :, :]
            K_group = K[:, i:i+1, :, :]  # Shared key within the group
            V_group = V[:, i:i+1, :, :]  # Shared value within the group
            print(f"GQA Group {i+1} Query Shape: {Q_group.shape}, Key Shape: {K_group.shape}, Value Shape: {V_group.shape}")
            group_output = scaled_dot_product_attention(Q_group, K_group, V_group)
            outputs.append(group_output)
        
        # Concatenate outputs from all groups
        concat_attention = tf.concat(outputs, axis=1)
        concat_attention = tf.transpose(concat_attention, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(concat_attention, (batch_size, -1, self.num_heads * self.d_k))
        
        # Final linear layer
        output = self.dense(concat_attention)
        print('------------------------------------')
        return output

### Multi-Querry Attetion (MQA)

- Reduces all key and value heads to a single key and value head
- if you have H query, key, and value heads then this will effectively reduce the size of the key-value cache and therefore amount of data that needs to be loaded by a factor of H
- MQA can lead to quality degradation.

In [30]:
class MQA(layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MQA, self).__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Shared key and value layers
        self.w_k = layers.Dense(d_model)
        self.w_v = layers.Dense(d_model)
        
        # Query layer (unique per head)
        self.w_q = layers.Dense(d_model)
        
        # Output linear layer
        self.dense = layers.Dense(d_model)
    
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth)."""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.d_k))
        return tf.transpose(x, perm=[0, 2, 1, 3])  # [batch_size, num_heads, seq_len, d_k]
    
    def call(self, query, key, value):
        print('=================MQA==================')
        batch_size = tf.shape(query)[0]
        
        # Compute query, key, value projections
        Q = self.split_heads(self.w_q(query), batch_size)
        K = self.split_heads(self.w_k(key), batch_size)[:, :1, :, :]  # Shared key across all heads
        V = self.split_heads(self.w_v(value), batch_size)[:, :1, :, :]  # Shared value across all heads
        
        # Print shapes of query, key, and value
        print(f"MQA Query Shape (Q): {Q.shape}")
        print(f"{style.BLUE}MQA Key Shape (K): {K.shape}" +style.RESET)
        print(f"{style.BLUE}MQA Value Shape (V): {V.shape}"+ style.RESET)
        
        # Repeat K and V across heads for multi-query attention
        K = tf.tile(K, [1, self.num_heads, 1, 1])
        V = tf.tile(V, [1, self.num_heads, 1, 1])
        
        attention_output = scaled_dot_product_attention(Q, K, V)
        
        # Concatenate attention output
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention_output, (batch_size, -1, self.num_heads * self.d_k))
        
        # Final linear layer
        output = self.dense(concat_attention)
        return output


In [31]:
# Test both attention mechanisms
input_query = tf.random.normal((batch_size, seq_len, d_model))
input_key = tf.random.normal((batch_size, seq_len, d_model))
input_value = tf.random.normal((batch_size, seq_len, d_model))

In [32]:
# Initialize MQA and GQA models
mha_layer = MHA(d_model, num_heads)
gqa_layer = GQA(d_model, num_heads, group_size)
mqa_layer = MQA(d_model, num_heads)

## Result

### We can see bellow that MHA has the highest memory usage with the longest run-time and MQA has the highest for both. GQA strikes a middle ground between them, allowing for more efficient inference. 

### Note the shapes of Key and Value are [batch_size, num_groups, seq_len, d_k]. We can see that for K and V of GQA the dimensionality of num_groups = num_heads/group_size = 4/2 = 2

In [33]:
# Function to track memory usage and compute time
def track_memory_and_time(layer, query, key, value, name):
    tracemalloc.start()  # Start tracking memory
    start_time = time.time()  # Start tracking time
    layer_output = layer(query, key, value)
    elapsed_time = time.time() - start_time
    current, peak = tracemalloc.get_traced_memory()  # Get memory usage
    tracemalloc.stop()
    #print(f"{name} Output Shape: {layer_output.shape}")
    print(f"Time taken by {name}: {elapsed_time:.6f} seconds")
    print(f"Current memory usage: {current / 10**6:.6f} MB; Peak: {peak / 10**6:.6f} MB\n")

# Track memory and compute time for MHA, MQA, and GQA
track_memory_and_time(mha_layer, input_query, input_key, input_value, "MHA")
track_memory_and_time(gqa_layer, input_query, input_key, input_value, "GQA")
track_memory_and_time(mqa_layer, input_query, input_key, input_value, "MQA")


MHA Query Shape (Q): (8, 4, 5, 4)
[31mMHA Key Shape (K): (8, 4, 5, 4)
MHA Value Shape (V): (8, 4, 5, 4)[0m
Time taken by MHA: 0.023144 seconds
Current memory usage: 2.301817 MB; Peak: 2.775021 MB

GQA Query Shape (Q): (8, 4, 5, 4)
[32mGQA Key Shape (K): (8, 2, 5, 4)
GQA Value Shape (V): (8, 2, 5, 4)[0m
----------------Groups--------------
GQA Group 1 Query Shape: (8, 2, 5, 4), Key Shape: (8, 1, 5, 4), Value Shape: (8, 1, 5, 4)
GQA Group 2 Query Shape: (8, 2, 5, 4), Key Shape: (8, 1, 5, 4), Value Shape: (8, 1, 5, 4)
------------------------------------
Time taken by GQA: 0.037420 seconds
Current memory usage: 0.030244 MB; Peak: 0.039383 MB

MQA Query Shape (Q): (8, 4, 5, 4)
[34mMQA Key Shape (K): (8, 1, 5, 4)[0m
[34mMQA Value Shape (V): (8, 1, 5, 4)[0m
Time taken by MQA: 0.023154 seconds
Current memory usage: 0.025914 MB; Peak: 0.034013 MB



In [35]:
# Compute outputs for MHA, GQA, and MQA
# Measure compute time for MHA

start_time = time.time()
mha_output = mha_layer(input_query, input_key, input_value)
mha_time = time.time() - start_time

# Measure compute time for GQA
start_time = time.time()
gqa_output = gqa_layer(input_query, input_key, input_value)
gqa_time = time.time() - start_time

# Measure compute time for MQA
start_time = time.time()
mqa_output = mqa_layer(input_query, input_key, input_value)
mqa_time = time.time() - start_time


MHA Query Shape (Q): (8, 4, 5, 4)
[31mMHA Key Shape (K): (8, 4, 5, 4)
MHA Value Shape (V): (8, 4, 5, 4)[0m
GQA Query Shape (Q): (8, 4, 5, 4)
[32mGQA Key Shape (K): (8, 2, 5, 4)
GQA Value Shape (V): (8, 2, 5, 4)[0m
----------------Groups--------------
GQA Group 1 Query Shape: (8, 2, 5, 4), Key Shape: (8, 1, 5, 4), Value Shape: (8, 1, 5, 4)
GQA Group 2 Query Shape: (8, 2, 5, 4), Key Shape: (8, 1, 5, 4), Value Shape: (8, 1, 5, 4)
------------------------------------
MQA Query Shape (Q): (8, 4, 5, 4)
[34mMQA Key Shape (K): (8, 1, 5, 4)[0m
[34mMQA Value Shape (V): (8, 1, 5, 4)[0m


In [37]:
print(f"MHA Output Shape: {mha_output.shape}")
print(f"GQA Output Shape: {gqa_output.shape}")
print(f"MQA Output Shape: {mqa_output.shape}")

MHA Output Shape: (8, 5, 16)
GQA Output Shape: (8, 5, 16)
MQA Output Shape: (8, 5, 16)


In [36]:
print("========= Process Time ===========")
print(f"MHA: {mha_time:.6f} seconds")
print(f"GQA: {gqa_time:.6f} seconds")
print(f"MQA: {mqa_time:.6f} seconds")

MHA: 0.009213 seconds
GQA: 0.013903 seconds
MQA: 0.005585 seconds
