In [1]:
# importing libraries
import torch
import torch.nn as nn
import math
from typing import Tuple, Optional

  cpu = _conversion_method_template(device=torch.device("cpu"))


# Config settings

#### Architecural Dimensions
$\underline{\text{Grouped-Query Attention (GQA)}}$  
This is a technique in which we use fewer number of Key/Value heads than the Query heads.  This method requires significantly less memory, and can generate text much faster with a very small impact on the overall accuracy.  
> __note:__ The number of Query heads must be perfectly divisble by the number of Key/Value heads.  

*In this case, we use one Key and Value heads per 4 Query heads*

In [2]:

hidden_size = 128   # synonymous with embeddings dimension
num_attention_heads = 16    # The no. of attention query heads
num_key_value_heads = 4     # The no. of Key & Value heads [Grouped-Query Attention (GQA)]
head_dim = hidden_size // num_attention_heads   # Dimension of each atention head

#### Positional Embedding Parameters

Instead of traditional positional embeddings, this model uses RoPE to encode the order of tokens. RoPE modifies the Query and Key vectors using rotations, which elegantly injects relative positional information directly into the self-attention calculation.  


$\underline{\text{Rotary Positional Encodings (RoPE)}}$  
* **Relative Position:** The attention score between two tokens becomes sensitive to their relative distance, not their absolute positions.
* **No Trainable Parameters:** Positional information is added via a deterministic function, requiring no extra parameters to be learned.
* **Long Sequence Extrapolation:** RoPE has been shown to be effective at handling sequences longer than the model was trained on.

In [3]:
max_positional_embeddings = 256  # max no. of positions to be calculated by RoPE\
rope_theta = 10000  # base for the formula to calculate frequencies for RoPE, controlling the timescale

#### Normalization and Regularization

In [4]:
rms_norm_eps = 1e-5 # to normalise the vector embeddings
attention_bias = 0  # 0 to keep it as a Linear Layer without an extra bias vector
attention_dropout = 0  # Dropout probability for attention weights to prevent overfitting, for simplicity, we won't use that
use_qk_norm = True  # To apply L2 normalization on Q & K before attention

#### Sample Input

In [5]:
torch.manual_seed(77)  # for reproducibility

<torch._C.Generator at 0x2b2c4751a90>

In [6]:
batch_size = 2  # two independent sequences of text
sequence_length = 10  # length of each sequence

hidden_states = torch.randn(batch_size, sequence_length, hidden_size)  # creating sample input token embeddings

In [7]:
print(hidden_states)
print(hidden_states.shape)

tensor([[[-0.5043, -0.4161, -0.1364,  ..., -1.5856, -0.4089, -2.8163],
         [ 1.0667, -0.0923,  0.3463,  ...,  0.5123,  1.9678, -1.6733],
         [ 1.2775,  0.2651, -0.5682,  ..., -0.2129, -1.4258, -1.2878],
         ...,
         [ 1.4049, -0.0547, -0.4749,  ...,  2.6301, -0.4774,  0.3909],
         [-0.5966,  0.7187, -0.3401,  ..., -0.5780,  0.9983,  0.6903],
         [-0.4571,  0.7204,  0.3816,  ...,  1.9020, -0.6863,  0.4856]],

        [[-1.8869,  2.0450, -0.3714,  ..., -0.0561,  1.2780, -0.0363],
         [ 0.2985,  1.5429,  1.3085,  ...,  0.2492,  0.6134,  0.5383],
         [-0.2063, -2.8666, -1.4368,  ..., -0.6156, -0.6485,  0.1808],
         ...,
         [-0.2064,  2.1962,  1.2381,  ...,  1.1080, -0.6104,  0.7092],
         [-1.1046, -0.1936,  0.0943,  ..., -0.0681,  0.0745,  1.0041],
         [ 0.8959,  0.1819,  1.3658,  ...,  0.7530, -0.9845, -0.2993]]])
torch.Size([2, 10, 128])


The job of `position_ids` is to tell the RoPE function the position of each token (is it the 1st, 2nd, 3rd, etc. token?).

In [8]:
# to create positional ids for the tokens (very imp for RoPE)
position_ids = torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1)

In [9]:
# understanding what goes on in each step
print(f'torch.arange(0, sequence_length) : \n {torch.arange(0, sequence_length)}')  # creates a simple indexing sequence of numbers
print()
print(f'torch.arange(0, sequence_length).unsqueeze(0) : \n {torch.arange(0, sequence_length).unsqueeze(0)}')  # Adds a new dimension of size 1 at the specified position (0-row, 1-column)
print()
print(f'torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1) : \n {torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1)}')  # copies the sequence for each item in the batch

torch.arange(0, sequence_length) : 
 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

torch.arange(0, sequence_length).unsqueeze(0) : 
 tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1) : 
 tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])


#### Attention Mask
To prevent the model from attending to the tokens after the current token, it should be able to see only the previous tokens.

In [10]:
# create a square and assign -inf to all the upper triangle positions so the softmax functino will make it 0
# diagonal = 1 specifies that the digonal right above the principal diagonal
attention_mask = torch.triu(torch.ones(sequence_length, sequence_length) * -torch.inf, diagonal=1)
print(attention_mask)
print(attention_mask.shape)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
torch.Size([10, 10])


* `.unsqueeze(0)` changes the dimension from `[sequence_length, sequence_length]` to `[1, sequence_length, sequence_length]`  

* second `.unsqueeze(0)` changes the dimension from `[1, sequence_length, sequence_length]` to `[1, 1, sequence_length, sequence_length]`   

We do this to match the dimensions with the attention_weights, which has a 4D shape `[batch_size, num_attention_heads, sequence_length, sequence_length]`

In [11]:
attention_mask =attention_mask.unsqueeze(0).unsqueeze(0)
print(attention_mask)
print(attention_mask.shape)

tensor([[[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
torch.Size([1, 1, 10, 10])


* `1` indicates that the size of the second dimension is 1. We apply the same attention mask across all attention heads. 
* `-1` indicates to not change the third and fourth dimensions, to remain unchanged

In [12]:
attention_mask = attention_mask.expand(batch_size, 1, -1, -1)
print(attention_mask)
print(attention_mask.shape)

tensor([[[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
          [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -in

### Final Config Check

In [13]:
print('Configuration')
print(f'hidden_size: {hidden_size}')
print(f'num_attention_heads: {num_attention_heads}')
print(f'num_key_value_heads: {num_key_value_heads}')
print(f'head_dim: {head_dim}')
print()
print('Sample Input Shapes')
print(f'hidden_states: {hidden_states.shape}')
print(f'position_ids: {position_ids.shape}')
print(f'attention_mask: {attention_mask.shape}')

Configuration
hidden_size: 128
num_attention_heads: 16
num_key_value_heads: 4
head_dim: 8

Sample Input Shapes
hidden_states: torch.Size([2, 10, 128])
position_ids: torch.Size([2, 10])
attention_mask: torch.Size([2, 1, 10, 10])


# Q, K, V Projections

#### Define Projection Layers
here, we define the matrices W<sup>q</sup>, W<sup>k</sup>, and W<sup>v</sup>

In [14]:
q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias= attention_bias)
k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias= attention_bias)
v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias= attention_bias)

# contains the learned weight matrix often refferd to as W_o on paper
o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias= attention_bias)

In [15]:
print('Projection layers')
print(f'q_proj: {q_proj}')
print(f'k_proj: {k_proj}')
print(f'v_proj: {v_proj}')
print(f'o_proj: {o_proj}')

Projection layers
q_proj: Linear(in_features=128, out_features=128, bias=False)
k_proj: Linear(in_features=128, out_features=32, bias=False)
v_proj: Linear(in_features=128, out_features=32, bias=False)
o_proj: Linear(in_features=128, out_features=128, bias=False)


#### Project the input on these matrices

In [16]:
query_states = q_proj(hidden_states)
key_states = k_proj(hidden_states)
value_states = v_proj(hidden_states)

In [17]:
print('Projections shape')
print(f'query_states: {query_states.shape}')
print(f'key_states: {key_states.shape}')
print(f'value_states: {value_states.shape}')

Projections shape
query_states: torch.Size([2, 10, 128])
key_states: torch.Size([2, 10, 32])
value_states: torch.Size([2, 10, 32])


* Query: We have 2 sequences, 10 tokens in each sequence, and 128 values to represent each single token.  
* Key & Value: We have 2 sequences, 10 tokens in each sequence, and 32 values to represent each single token.

#### Creating Individual Heads
We have have Q, K, and V. We must divide them into individual heads for multi-head attention.  

Target shape: [batch_size, num_heads, sequence_length, head_size]

`view()` function is used to reshape the tensor *(works only on contiguous tensors)*  
*we used it to split `hidden_size` dimension into two new dimensions `(num_attention_heads, head_dim)`*

In [18]:
query_states = query_states.view(batch_size, sequence_length, num_attention_heads, head_dim).transpose(1,2)
key_states = key_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1,2)
value_states = value_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1,2)

In [19]:
print('The individual heads shapes:')
print(f'query_states: {query_states.shape}')
print(f'key_states: {key_states.shape}')
print(f'value_states: {value_states.shape}')

The individual heads shapes:
query_states: torch.Size([2, 16, 10, 8])
key_states: torch.Size([2, 4, 10, 8])
value_states: torch.Size([2, 4, 10, 8])


#### Calculating the number of Query heads per Key-Value head

In [20]:
num_key_value_groups = num_attention_heads // num_key_value_heads
print(f'Number of Key-Value groups (Q heads per K-V head): {num_key_value_groups}')

Number of Key-Value groups (Q heads per K-V head): 4


# Rotary Position Embeddings (RoPE)

#### Defining Rotation Calculation Function - `simple_rope_calculation()`

In [21]:
def simple_rope_calculation(dim, max_seq_len, base=10000.0, device=None):
    
    # Please find the breakdown of whats happening in each line below in the next cell

    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))  # LINE 1
    t = torch.arange(max_seq_len, device=device).type_as(inv_freq)  # LINE 2
    freqs = torch.outer(t, inv_freq)  # LINE 3
    emb = torch.cat((freqs, freqs), dim=1)  # LINE 4 
    
    # To create rotators for "cos(theta) + i*sin(theta)"
    freqs_cos = emb.cos()  # real part
    freqs_sin = emb.sin()  # imaginery part
    freqs_cis = torch.complex(freqs_cos, freqs_sin)  # LINE 5

    return freqs_cis
    

#### Breakdown of `simple_rope_calculation()`

In [22]:
dim = 8
device = None
base=10000.0
max_seq_len = 10
print(f'variables- dim:{dim}, device:{device}, base:{base}, max_seq_len:{max_seq_len}')

variables- dim:8, device:None, base:10000.0, max_seq_len:10


LINE 1

In [23]:
print('\nCreate a vector')
print(torch.arange(0, dim, 2, device=device).float())

# To normalise, we divide by the dimension
print('\nTo normalise, we divide by the dimension')
print(torch.arange(0, dim, 2, device=device).float() / dim)

# raising base to the power
print('\nRaising base to the power of the normalized vector')
print(base ** (torch.arange(0, dim, 2, device=device).float() / dim))

# divide 1 by all these
print('\nDivide 1 by all these (reciprocal)')
print(1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)))

# full line
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))


Create a vector
tensor([0., 2., 4., 6.])

To normalise, we divide by the dimension
tensor([0.0000, 0.2500, 0.5000, 0.7500])

Raising base to the power of the normalized vector
tensor([   1.,   10.,  100., 1000.])

Divide 1 by all these (reciprocal)
tensor([1.0000, 0.1000, 0.0100, 0.0010])


LINE 2

In [24]:
t = torch.arange(max_seq_len, device=device).type_as(inv_freq)
print(f'\'t\' shows the position for each token: {t}')

't' shows the position for each token: tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])


LINE 3  

We find the outer product of `k` and `inv_freq`

In [25]:
freqs = torch.outer(t, inv_freq)
print(freqs)
print(freqs.shape)

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03],
        [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03],
        [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03],
        [4.0000e+00, 4.0000e-01, 4.0000e-02, 4.0000e-03],
        [5.0000e+00, 5.0000e-01, 5.0000e-02, 5.0000e-03],
        [6.0000e+00, 6.0000e-01, 6.0000e-02, 6.0000e-03],
        [7.0000e+00, 7.0000e-01, 7.0000e-02, 7.0000e-03],
        [8.0000e+00, 8.0000e-01, 8.0000e-02, 8.0000e-03],
        [9.0000e+00, 9.0000e-01, 9.0000e-02, 9.0000e-03]])
torch.Size([10, 4])


Rows represents the tokens, and the columns represents the pairs of embeddings for each token.  
*Example:*  
* *First token: (No rotation ) 0.0000e+00*
* *The second token:*
    * *first pair will rotate by 1.0000e+00*
    * *Second pair will rotate by 1.0000e-01*
    * ...

LINE 4  

We calulated only 4 frequencies for the rotations, but in reality we have `dim=8` embeddings so we need 8 angles. Hence, we achieve this by replicating the `freqs` matrix

In [26]:
emb = torch.cat((freqs, freqs), dim=1)
print(emb.shape)

torch.Size([10, 8])


LINE 5  

`freqs_cis` frequencies in the form of cosine + i*sine

In [27]:
# To create rotators for "cos(theta) + i*sin(theta)"
freqs_cos = emb.cos()  # real part
freqs_sin = emb.sin()  # imaginery part
freqs_cis = torch.complex(freqs_cos, freqs_sin)
print(freqs_cis)
print(freqs_cis.shape)

tensor([[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,
          1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j,
          0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j],
        [-0.4161+0.9093j,  0.9801+0.1987j,  0.9998+0.0200j,  1.0000+0.0020j,
         -0.4161+0.9093j,  0.9801+0.1987j,  0.9998+0.0200j,  1.0000+0.0020j],
        [-0.9900+0.1411j,  0.9553+0.2955j,  0.9996+0.0300j,  1.0000+0.0030j,
         -0.9900+0.1411j,  0.9553+0.2955j,  0.9996+0.0300j,  1.0000+0.0030j],
        [-0.6536-0.7568j,  0.9211+0.3894j,  0.9992+0.0400j,  1.0000+0.0040j,
         -0.6536-0.7568j,  0.9211+0.3894j,  0.9992+0.0400j,  1.0000+0.0040j],
        [ 0.2837-0.9589j,  0.8776+0.4794j,  0.9988+0.0500j,  1.0000+0.0050j,
          0.2837-0.9589j,  0.8776+0.4794j,  0.9988+0.0500j,  1.0000+0.0050j],
        [ 0.9602-0.2794j,  0.8253+0.5646j,  0.9982+0.0600j,  1.0000+0.

#### Defining Rotations Function - `apply_rotary_emb_torch()`

In [28]:
def apply_rotary_emb_torch(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    
    # making sure "freqs_cis" is on the right device
    freqs_cis = freqs_cis.to(xq.device)

    # intead of having all the frequencies, we filter out the ones for the correspondings tokens we need.  
    freqs_cis = freqs_cis[position_ids]  # LINE 6

    freqs_cis = freqs_cis[:, None, :, :]  # LINE 7

    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  # LINE 8
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    freqs_cis_broadcast = freqs_cis[..., :xq_.shape[-1]]

    # Applying the rotations
    xq_rotated = xq_ * freqs_cis_broadcast
    xk_rotated = xq_ * freqs_cis_broadcast

    # Convert back to real representatinos
    # This flattens the last 2 dimensions back into one.
    # input: [batch, num_heads, seq_length, head_dim / 2 (4), 2] (complex)
    # output: [batch, num_heads, seq_length, 8]
    xq_out = torch.view_as_real(xq_rotated).flatten(3) 
    xk_out = torch.view_as_real(xk_rotated).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)


#### Breakdown of `apply_rotary_emb_torch()`

LINE 6

In [29]:
print(f'freqs_cis shape: {freqs_cis.shape}')
print(f'position_ids shape: {position_ids.shape}')
print(f'freqs_cis[position_ids] shape: {freqs_cis[position_ids].shape}')
print(f'final shape: \n{freqs_cis[position_ids][0]}')

freqs_cis shape: torch.Size([10, 8])
position_ids shape: torch.Size([2, 10])
freqs_cis[position_ids] shape: torch.Size([2, 10, 8])
final shape: 
tensor([[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,
          1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
        [ 0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j,
          0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j],
        [-0.4161+0.9093j,  0.9801+0.1987j,  0.9998+0.0200j,  1.0000+0.0020j,
         -0.4161+0.9093j,  0.9801+0.1987j,  0.9998+0.0200j,  1.0000+0.0020j],
        [-0.9900+0.1411j,  0.9553+0.2955j,  0.9996+0.0300j,  1.0000+0.0030j,
         -0.9900+0.1411j,  0.9553+0.2955j,  0.9996+0.0300j,  1.0000+0.0030j],
        [-0.6536-0.7568j,  0.9211+0.3894j,  0.9992+0.0400j,  1.0000+0.0040j,
         -0.6536-0.7568j,  0.9211+0.3894j,  0.9992+0.0400j,  1.0000+0.0040j],
        [ 0.2837-0.9589j,  0.8776+0.4794j,  0.9988+0.0500j,  1.0000+0.0050j,
   

LINE 7
  
The Query and Key tensors have a shape of `[batch, num_heads, seq_len, head_dim]`. So, we add a dimension of `1` to `freqs_cis` to align with Q & K so we can broadcast them with this angles tensor.  

Can use `unsqueeze(1)` too

In [30]:
freqs_cis = freqs_cis[position_ids]
print(freqs_cis.shape)
freqs_cis = freqs_cis[:, None, :, :]
print(freqs_cis.shape)

torch.Size([2, 10, 8])
torch.Size([2, 1, 10, 8])


LINE 8  

Now, we reshapre the Query and Key tensors so that the pairs of numbers are treated as complex numbers **(a+ib)**.  
*Same process for "xk"*

In [31]:
# random xq for testing
xq = torch.randn(batch_size, num_attention_heads, sequence_length, head_dim)
xk = torch.randn(batch_size, num_attention_heads, sequence_length, head_dim)
print(xq)
print(f'{xq.shape}\n')


xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  # keeps dimensions execpt the last one same, then splits the last one into two
print(xq)
print(xq.shape)

xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  # keeps dimensions execpt the last one same, then splits the last one into two


tensor([[[[ 0.8011, -0.5517,  1.2441,  ..., -0.8717, -0.2942,  0.3337],
          [ 0.1762, -0.1274, -0.8008,  ..., -1.1549, -0.4158, -1.7951],
          [-2.3783, -0.9613, -2.1015,  ...,  1.3007, -1.0858,  0.3678],
          ...,
          [-0.5342,  0.1280, -0.4700,  ...,  0.9925,  0.6600,  1.5047],
          [ 0.6374,  1.0656,  1.0083,  ...,  1.0599,  1.4468, -0.8652],
          [ 0.0680,  0.1107, -1.6959,  ...,  0.9235,  1.0702,  0.6899]],

         [[ 1.5018,  0.9565, -1.0583,  ...,  0.8433, -0.3213,  1.2631],
          [ 0.7975, -0.1302, -0.2862,  ...,  1.8829,  0.0470,  1.5126],
          [-0.1961,  0.0146,  0.3894,  ..., -0.9640,  0.0291, -0.8071],
          ...,
          [-0.4551, -0.5375,  1.8866,  ...,  0.1935, -0.1347, -0.1596],
          [ 1.0800, -2.1215, -1.3935,  ..., -0.1580, -0.9938, -1.1936],
          [ 0.4865,  0.2801, -0.1601,  ...,  0.4223,  0.5414,  0.1586]],

         [[ 0.5211,  0.2147,  1.4402,  ...,  1.8647, -1.8323,  1.3069],
          [ 0.3001,  0.2398, -

*8 has changed to 4 because we divided the deimension into pairs*

LINE 9  

Previsously, we calculated the `freqs_cis` tensor by using `head_dim` sines and cosines. However, xq's last dimension is now 4 (*Because each pair of numbers become a single complex number*).  
To multiply them, we need freq_cis in the same dimesnions too.

In [32]:
print(f'initially: {freqs_cis.shape}')

freqs_cis_broadcast = freqs_cis[..., :xq_.shape[-1]]
print(freqs_cis_broadcast)
print(freqs_cis_broadcast.shape)

initially: torch.Size([2, 1, 10, 8])
tensor([[[[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
          [ 0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j],
          [-0.4161+0.9093j,  0.9801+0.1987j,  0.9998+0.0200j,  1.0000+0.0020j],
          [-0.9900+0.1411j,  0.9553+0.2955j,  0.9996+0.0300j,  1.0000+0.0030j],
          [-0.6536-0.7568j,  0.9211+0.3894j,  0.9992+0.0400j,  1.0000+0.0040j],
          [ 0.2837-0.9589j,  0.8776+0.4794j,  0.9988+0.0500j,  1.0000+0.0050j],
          [ 0.9602-0.2794j,  0.8253+0.5646j,  0.9982+0.0600j,  1.0000+0.0060j],
          [ 0.7539+0.6570j,  0.7648+0.6442j,  0.9976+0.0699j,  1.0000+0.0070j],
          [-0.1455+0.9894j,  0.6967+0.7174j,  0.9968+0.0799j,  1.0000+0.0080j],
          [-0.9111+0.4121j,  0.6216+0.7833j,  0.9960+0.0899j,  1.0000+0.0090j]]],


        [[[ 1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j,  1.0000+0.0000j],
          [ 0.5403+0.8415j,  0.9950+0.0998j,  0.9999+0.0100j,  1.0000+0.0010j],

Now, we have the rotations for 4 pairs in each token.  

**NOTE:** *The slicing step is necessary because we duplicated the frequencies in the `simple_rope_calculation()` function. We do this to understand the existing implementation of transformers.*

#### Calculating RoPE Frequencies

In [33]:
# pre-computed frequencies depending solely on the position, and not the token.
freqs_cis = simple_rope_calculation(head_dim, max_positional_embeddings, base=rope_theta, device=hidden_states.device)
print(f'Calculated freqs_cis shape: {freqs_cis.shape}')

Calculated freqs_cis shape: torch.Size([256, 8])


#### Applying RoPE Rotations

In [34]:
query_states_rope, key_states_rope = apply_rotary_emb_torch(query_states, key_states, freqs_cis)

print('Shapes after RoPE:')
print(f'  query_states_rope: {query_states_rope.shape}')
print(f'  key_states_rope: {key_states_rope.shape}')

Shapes after RoPE:
  query_states_rope: torch.Size([2, 16, 10, 8])
  key_states_rope: torch.Size([2, 16, 10, 8])


# Normalization

In [35]:
class SimpleL2Norm(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps # to avoid div by 0 during normalization

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)