## Llama 2 modules

- **Rotary positional embeddings** -> applied to the computed Q and K vectors in the self_attention part
- **Grouped Query Attention** -> Tradeoff between Mutli-Query attention and MHA, balances memory bandwidth requirements and speedup
- **KV Caching** -> For faster computation and better memory management
- **SwiGLU activation function**
- **RMS Norm**

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

### Rotary positional embeddings

In [2]:
def precompute_theta_pos_freq(head_dim, seq_len, theta = 10000):

    assert head_dim%2 == 0, "Dimension of head must by divisible by 2"

    # theta_i = 10000^(-2(i-1)/dim) for i = [1,2,3.....dim/2]

    theta_numerator = torch.arange(0, head_dim, 2).float()

    theta = 1.0/ (theta **(theta_numerator/head_dim))

    m = torch.arange(seq_len)

    freqs = torch.outer(m,theta).float()

    #                                   magnitude       angle
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs) 

    return freqs_complex


In [3]:
def apply_rotary_embeds(x, freqs_complex):

    # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
    # Two consecutive values will become a single complex number


    # H -> no.of heads; can be num_heads for Query and num_kv_heads for Key

    # (bsz, seq_len, H , head_dim) -> (bsz, seq_len, H, head_dim/2)

    # (bsz, seq_len, H , head_dim) -> (bsz, seq_len, H, head_dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))


    # Reshape the freqs_complex tensor to match the shape of the x_complex tensor. 
    # (seq_len, head_dim/2) --> (1, seq_len, 1, head_dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)


    # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
    # Which results in the rotation of the complex number as shown in the Figure 1 of the paper

    # (bsz, seq_len, H, head_dim/2) * (1, seq_len, 1, head_dim/2) = (bsz, seq_len, H, head_dim/2)
    x_rotated = x_complex * freqs_complex


    # Convert the complex number back to the real number
    # (bsz, seq_len, H, head_dim/2) -> (bsz, seq_len, H, head_dim/2, 2)
    x_out = torch.view_as_real(x_rotated)


    # (bsz, seq_len, H, head_dim/2, 2) -> (bsz, seq_len, H, head_dim)
    x_out = x_out.reshape(*x.shape)
    

    return x_out.type_as(x)