In [1]:
# importing libraries
import torch
import torch.nn as nn
import math
from typing import Tuple, Optional

  cpu = _conversion_method_template(device=torch.device("cpu"))


# Config settings

#### Architecural Dimensions
$\underline{\text{Grouped-Query Attention (GQA)}}$  
This is a technique in which we use fewer number of Key/Value heads than the Query heads.  This method requires significantly less memory, and can generate text much faster with a very small impact on the overall accuracy.  
> __note:__ The number of Query heads must be perfectly divisble by the number of Key/Value heads.  

*In this case, we use one Key and Value heads per 4 Query heads*

In [2]:

hidden_size = 128   # synonymous with embeddings dimension
num_attention_heads = 16    # The no. of attention query heads
num_key_value_heads = 4     # The no. of Key & Value heads [Grouped-Query Attention (GQA)]
head_dim = hidden_size // num_attention_heads   # Dimension of each atention head

#### Positional Embedding Parameters

Instead of traditional positional embeddings, this model uses RoPE to encode the order of tokens. RoPE modifies the Query and Key vectors using rotations, which elegantly injects relative positional information directly into the self-attention calculation.  


$\underline{\text{Rotary Positional Encodings (RoPE)}}$  
* **Relative Position:** The attention score between two tokens becomes sensitive to their relative distance, not their absolute positions.
* **No Trainable Parameters:** Positional information is added via a deterministic function, requiring no extra parameters to be learned.
* **Long Sequence Extrapolation:** RoPE has been shown to be effective at handling sequences longer than the model was trained on.

In [3]:
max_positional_embeddings = 256  # max no. of positions to be calculated by RoPE\
rope_theta = 10000  # base for the formula to calculate frequencies for RoPE, controlling the timescale

#### Normalization and Regularization

In [4]:
rms_norm_eps = 1e-5 # to normalise the vector embeddings
attention_bias = 0  # 0 to keep it as a Linear Layer without an extra bias vector
attention_dropout = 0  # Dropout probability for attention weights to prevent overfitting, for simplicity, we won't use that
use_qk_norm = True  # To apply L2 normalization on Q & K before attention

#### Sample Input

In [5]:
torch.manual_seed(77)  # for reproducibility

<torch._C.Generator at 0x1d9a42b1a50>

In [6]:
batch_size = 2  # two independent sequences of text
sequence_length = 10  # length of each sequence

hidden_states = torch.randn(batch_size, sequence_length, hidden_size)  # creating sample input token embeddings

In [7]:
print(hidden_states)
print(hidden_states.shape)

tensor([[[-0.5043, -0.4161, -0.1364,  ..., -1.5856, -0.4089, -2.8163],
         [ 1.0667, -0.0923,  0.3463,  ...,  0.5123,  1.9678, -1.6733],
         [ 1.2775,  0.2651, -0.5682,  ..., -0.2129, -1.4258, -1.2878],
         ...,
         [ 1.4049, -0.0547, -0.4749,  ...,  2.6301, -0.4774,  0.3909],
         [-0.5966,  0.7187, -0.3401,  ..., -0.5780,  0.9983,  0.6903],
         [-0.4571,  0.7204,  0.3816,  ...,  1.9020, -0.6863,  0.4856]],

        [[-1.8869,  2.0450, -0.3714,  ..., -0.0561,  1.2780, -0.0363],
         [ 0.2985,  1.5429,  1.3085,  ...,  0.2492,  0.6134,  0.5383],
         [-0.2063, -2.8666, -1.4368,  ..., -0.6156, -0.6485,  0.1808],
         ...,
         [-0.2064,  2.1962,  1.2381,  ...,  1.1080, -0.6104,  0.7092],
         [-1.1046, -0.1936,  0.0943,  ..., -0.0681,  0.0745,  1.0041],
         [ 0.8959,  0.1819,  1.3658,  ...,  0.7530, -0.9845, -0.2993]]])
torch.Size([2, 10, 128])


In [8]:
# to create positional ids for the tokens (very imp for RoPE)
position_ids = torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1)

In [9]:
# understanding what goes on in each step
print(f'torch.arange(0, sequence_length) : \n {torch.arange(0, sequence_length)}')  # creates a simple indexing sequence of numbers
print()
print(f'torch.arange(0, sequence_length).unsqueeze(0) : \n {torch.arange(0, sequence_length).unsqueeze(0)}')  # Adds a new dimension of size 1 at the specified position (0-row, 1-column)
print()
print(f'torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1) : \n {torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1)}')  # copies the sequence for each item in the batch

torch.arange(0, sequence_length) : 
 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

torch.arange(0, sequence_length).unsqueeze(0) : 
 tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1) : 
 tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])


#### Attention Mask
To prevent the model from attending to the tokens after the current token, it should be able to see only the previous tokens.

In [10]:
# create a square and assign -inf to all the upper triangle positions so the softmax functino will make it 0
# diagonal = 1 specifies that the digonal right above the principal diagonal
attention_mask = torch.triu(torch.ones(sequence_length, sequence_length) * -torch.inf, diagonal=1)
print(attention_mask)
print(attention_mask.shape)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
torch.Size([10, 10])


* `.unsqueeze(0)` changes the dimension from `[sequence_length, sequence_length]` to `[1, sequence_length, sequence_length]`  

* second `.unsqueeze(0)` changes the dimension from `[1, sequence_length, sequence_length]` to `[1, 1, sequence_length, sequence_length]`   

We do this to match the dimensions with the attention_weights, which has a 4D shape `[batch_size, num_attention_heads, sequence_length, sequence_length]`

In [11]:
attention_mask =attention_mask.unsqueeze(0).unsqueeze(0)
print(attention_mask)
print(attention_mask.shape)

tensor([[[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
torch.Size([1, 1, 10, 10])


* `1` indicates that the size of the second dimension is 1. We apply the same attention mask across all attention heads. 
* `-1` indicates to not change the third and fourth dimensions, to remain unchanged

In [12]:
attention_mask = attention_mask.expand(batch_size, 1, -1, -1)
print(attention_mask)
print(attention_mask.shape)

tensor([[[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -in

### Final Config Check

In [13]:
print('Configuration')
print(f'hidden_size: {hidden_size}')
print(f'num_attention_heads: {num_attention_heads}')
print(f'num_key_value_heads: {num_key_value_heads}')
print(f'head_dim: {head_dim}')
print()
print('Sample Input Shapes')
print(f'hidden_states: {hidden_states.shape}')
print(f'position_ids: {position_ids.shape}')
print(f'attention_mask: {attention_mask.shape}')

Configuration
hidden_size: 128
num_attention_heads: 16
num_key_value_heads: 4
head_dim: 8

Sample Input Shapes
hidden_states: torch.Size([2, 10, 128])
position_ids: torch.Size([2, 10])
attention_mask: torch.Size([2, 1, 10, 10])


# Q, K, V Projections

#### Define Projection Layers
here, we define the matrices W<sup>q</sup>, W<sup>k</sup>, and W<sup>v</sup>

In [14]:
q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias= attention_bias)
k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias= attention_bias)
v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias= attention_bias)

# contains the learned weight matrix often refferd to as W_o on paper
o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias= attention_bias)

In [15]:
print('Projection layers')
print(f'q_proj: {q_proj}')
print(f'k_proj: {k_proj}')
print(f'v_proj: {v_proj}')
print(f'o_proj: {o_proj}')

Projection layers
q_proj: Linear(in_features=128, out_features=128, bias=False)
k_proj: Linear(in_features=128, out_features=32, bias=False)
v_proj: Linear(in_features=128, out_features=32, bias=False)
o_proj: Linear(in_features=128, out_features=128, bias=False)


#### Project the input on these matrices

In [16]:
query_states = q_proj(hidden_states)
key_states = k_proj(hidden_states)
value_states = v_proj(hidden_states)

In [17]:
print('Projections shape')
print(f'query_states: {query_states.shape}')
print(f'key_states: {key_states.shape}')
print(f'value_states: {value_states.shape}')

Projections shape
query_states: torch.Size([2, 10, 128])
key_states: torch.Size([2, 10, 32])
value_states: torch.Size([2, 10, 32])


* Query: We have 2 sequences, 10 tokens in each sequence, and 128 values to represent each single token.  
* Key & Value: We have 2 sequences, 10 tokens in each sequence, and 32 values to represent each single token.

#### Creating Individual Heads
We have have Q, K, and V. We must divide them into individual heads for multi-head attention.  

Target shape: [batch_size, num_heads, sequence_length, head_size]

`view()` function is used to reshape the tensor *(works only on contiguous tensors)*  
*we used it to split `hidden_size` dimension into two new dimensions `(num_attention_heads, head_dim)`*

In [18]:
query_states = query_states.view(batch_size, sequence_length, num_attention_heads, head_dim).transpose(1,2)
key_states = key_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1,2)
value_states = value_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1,2)

In [19]:
print('The individual heads shapes:')
print(f'query_states: {query_states.shape}')
print(f'key_states: {key_states.shape}')
print(f'value_states: {value_states.shape}')

The individual heads shapes:
query_states: torch.Size([2, 16, 10, 8])
key_states: torch.Size([2, 4, 10, 8])
value_states: torch.Size([2, 4, 10, 8])


#### Calculating the number of Query heads per Key-Value head

In [21]:
num_key_value_groups = num_attention_heads // num_key_value_heads
print(f'Number of Key-Value groups (Q heads per K-V head): {num_key_value_groups}')

Number of Key-Value groups (Q heads per K-V head): 4


#### Rotary Positional Embeddings (RoPE)