## Llama 2 modules

- **Rotary positional embeddings** -> applied to the computed Q and K vectors in the self_attention part
- **Grouped Query Attention** -> Tradeoff between Mutli-Query attention and MHA, balances memory bandwidth requirements and speedup
- **KV Caching** -> For faster computation and better memory management
- **SwiGLU activation function**
- **RMS Norm**

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

### Rotary positional embeddings

In [143]:
def precompute_theta_pos_freq(head_dim, seq_len, theta = 10000):

    assert head_dim%2 == 0, "Dimension of head must by divisible by 2"

    # theta_i = 10000^(-2(i-1)/dim) for i = [1,2,3.....dim/2]

    theta_numerator = torch.arange(0, head_dim, 2).float()

    theta = 1.0/ (theta **(theta_numerator/head_dim))

    m = torch.arange(seq_len)

    freqs = torch.outer(m,theta).float()

    #                                   magnituda       angle
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs) 

    return freqs_complex


In [144]:
def apply_rotary_embeds(x, freqs_complex):

    # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
    # Two consecutive values will become a single complex number


    # H -> no.of heads; can be num_heads for Query and num_kv_heads for Key

    # (bsz, seq_len, H , head_dim) -> (bsz, seq_len, H, head_dim/2)

    # (bsz, seq_len, H , head_dim) -> (bsz, seq_len, H, head_dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))


    # Reshape the freqs_complex tensor to match the shape of the x_complex tensor. 
    # (seq_len, head_dim/2) --> (1, seq_len, 1, head_dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)


    # (bsz, seq_len, H, head_dim/2) * (1, seq_len, 1, head_dim/2) = (bsz, seq_len, H, head_dim/2)
    x_rotated = x_complex * freqs_complex


    # Convert the complex number back to the real number
    # (bsz, seq_len, H, head_dim/2) -> (bsz, seq_len, H, head_dim/2, 2)
    x_out = torch.view_as_real(x_rotated)


    # (bsz, seq_len, H, head_dim/2, 2) -> (bsz, seq_len, H, head_dim)
    x_out = x_out.reshape(*x.shape)
    

    return x_out.type_as(x)

In [5]:
class SelfAttention(nn.Module):

    def __init__(self, n_heads ,embed_dim, ):
        super().__init__()

        # Indicates the number of heads for the Queries
        self.num_heads = n_heads

        # Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for
        self.head_dim = embed_dim // n_heads

        self.wq = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(embed_dim, n_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, embed_dim, bias=False)


    def precompute_theta_pos_freq(self, head_dim, seq_len, theta = 10000):

        assert head_dim%2 == 0, "Dimension of head must by divisible by 2"

        # theta_i = 10000^(-2(i-1)/dim) for i = [1,2,3.....dim/2]

        theta_numerator = torch.arange(0, head_dim, 2).float()

        theta = 1.0/ (theta **(theta_numerator/head_dim))

        m = torch.arange(seq_len)

        freqs = torch.outer(m,theta).float()

        #                                   magnituda       angle
        freqs_complex = torch.polar(torch.ones_like(freqs), freqs) 

        return freqs_complex
    

    def apply_rotary_embeds(self, x, freqs_complex):

        # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
        # Two consecutive values will become a single complex number


        # H -> no.of heads; can be num_heads for Query and num_kv_heads for Key

        # (bsz, seq_len, H , head_dim) -> (bsz, seq_len, H, head_dim/2)
        x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))


        # (seq_len, head_dim/2) --> (1, seq_len, 1, head_dim/2)
        freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)



        # (bsz, seq_len, H, head_dim/2) * (1, seq_len, 1, head_dim/2) = (bsz, seq_len, H, head_dim/2)
        x_rotated = x_complex * freqs_complex


        # Convert the complex number back to the real number
        # (bsz, seq_len, H, head_dim/2) -> (bsz, seq_len, H, head_dim/2, 2)
        x_out = torch.view_as_real(x_rotated)

        # (bsz, seq_len, H, head_dim/2, 2) -> (bsz, seq_len, H, head_dim)
        x_out = x_out.reshape(*x.shape)
        
        return x_out.type_as(x)



    def forward(
        self,
        x,
        freqs_complex,
        attn_mask
    ):                                # While inferencing:- 
        bsz, seq_len, _ = x.shape   # (bsz, 1, embed_dim)

        Q = self.wq(x)
        K = self.wk(x)
        V = self.wv(x)


        if attn_mask is not None:
        # ensure attn_mask's dim is 3
            if attn_mask.dim() == 2:
                correct_2d_size = (seq_len, seq_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
                attn_mask = attn_mask.unsqueeze(0)
            elif attn_mask.dim() == 3:
                correct_3d_size = (bsz * self.num_heads, seq_len, seq_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
            else:
                raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")




        Q = Q.view(bsz, seq_len, self.n_heads, self.head_dim)
        K = K.view(bsz, seq_len, self.n_heads, self.head_dim)
        V = V.view(bsz, seq_len, self.n_heads, self.head_dim)

        Q = self.apply_rotary_embeddings(Q, freqs_complex)
        K = self.apply_rotary_embeddings(K, freqs_complex)

        # (bsz, 1, H_Q, Head_Dim) -> (bsz, H_Q, 1, Head_Dim)
        Q = Q.transpose(1, 2)
        # (bsz, Seq_Len_KV, H_Q, Head_Dim) -> (bsz, H_Q, Seq_Len_KV, Head_Dim)
        K = K.transpose(1, 2)
        # (bsz, Seq_Len_KV, H_Q, Head_Dim) -> (bsz, H_Q, Seq_Len_KV, Head_Dim)
        V = V.transpose(1, 2)

        # (bsz, H_Q, 1, Head_Dim) @ (bsz, H_Q, Head_Dim, Seq_Len_KV) -> (bsz, H_Q, 1, Seq_Len_KV)
        scores = torch.matmul(Q, K.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_mask is not None:
            scores = scores + attn_mask 

        # (bsz, H_Q, 1, Seq_Len_KV) -> (bsz, H_Q, 1, Seq_Len_KV)
        scores = F.softmax(scores.float(), dim=-1).type_as(Q)

        # (bsz, H_Q, 1, Seq_Len) @ (bsz, H_Q, Seq_Len_KV, Head_Dim) -> (bsz, H_Q, 1, Head_Dim)
        output = torch.matmul(scores, V)
        # (bsz, H_Q, 1, Head_Dim) -> (bsz, 1, H_Q, Head_Dim) -> (bsz, 1, Dim)
        output = (output.transpose(1, 2).contiguous().view(bsz, seq_len, -1))
        return self.wo(output) # (bsz, 1, Dim) -> (bsz, 1, Dim)

        

In [145]:
# freqs_complex = precompute_thaeta_pos_freq(head_dim = 4, seq_len = 4, theta = 10000)

### Huggingface implementation

In [129]:
# def rotate_half(x):
#     """Rotates half the hidden dims of the input."""
#     x1 = x[..., : x.shape[-1] // 2]
#     x2 = x[..., x.shape[-1] // 2 :]
    
#     return torch.cat((-x2, x1), dim=-1)

# def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
#     """Applies Rotary Position Embedding to the query and key tensors.

#     Args:
#         q (`torch.Tensor`): The query tensor.
#         k (`torch.Tensor`): The key tensor.
#         cos (`torch.Tensor`): The cosine part of the rotary embedding.
#         sin (`torch.Tensor`): The sine part of the rotary embedding.
#         position_ids (`torch.Tensor`):
#             The position indices of the tokens corresponding to the query and key tensors. For example, this can be
#             used to pass offsetted position ids when working with a KV-cache.
#         unsqueeze_dim (`int`, *optional*, defaults to 1):
#             The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
#             sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
#             that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
#             k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
#             cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
#             the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
#     Returns:
#         `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
#     """
#     cos = cos[position_ids].unsqueeze(unsqueeze_dim)
#     sin = sin[position_ids].unsqueeze(unsqueeze_dim)

#     q_embed = (q * cos) + (rotate_half(q) * sin)
#     k_embed = (k * cos) + (rotate_half(k) * sin)

#     return q_embed, k_embed

## RoPE implementation (LLama2 implementation)

In [134]:
def get_sin_cos(dim, seq_len, max_seq_len, base = 10000):

    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))

    t = torch.arange(max_seq_len, dtype=torch.int64).type_as(inv_freq)

    freqs = torch.outer(t, inv_freq)
    
    # Uses a different permutation in order to obtain the same calculation
    emb = torch.cat((freqs, freqs), dim=-1)    

    return  emb.cos()[:seq_len], emb.sin()[:seq_len]


In [147]:
dim = 4

max_seq_len = 2048
seq_len = 4

num_heads = 2
num_kv_heads = num_heads

cos, sin = get_sin_cos(dim, seq_len, max_seq_len, base = 10000)

In [148]:
cos, sin, cos.shape, sin.shape

(tensor([[ 1.0000,  1.0000,  1.0000,  1.0000],
         [ 0.5403,  0.9999,  0.5403,  0.9999],
         [-0.4161,  0.9998, -0.4161,  0.9998],
         [-0.9900,  0.9996, -0.9900,  0.9996]]),
 tensor([[0.0000, 0.0000, 0.0000, 0.0000],
         [0.8415, 0.0100, 0.8415, 0.0100],
         [0.9093, 0.0200, 0.9093, 0.0200],
         [0.1411, 0.0300, 0.1411, 0.0300]]),
 torch.Size([4, 4]),
 torch.Size([4, 4]))

In [149]:
def rotate_half(x):

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]

    # print(x1.shape, x2.shape)
    # x1 = x[ : , : x.shape[-1] // 2]
    # x2 = x[ : , x.shape[-1] // 2 :]

    return torch.cat((-x2, x1), dim=-1)

In [150]:
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # print("HLAF ROT SHAPES = ")
    # print((rotate_half(q).shape,   sin.shape))
    # print((rotate_half(k).shape , sin.shape))

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed

In [139]:
import torch

# query =   torch.tensor([[[-0.0437,  0.0253, -0.0413,  0.0103],
#          [-0.0459, -0.0178,  0.0348, -0.0220],
#          [ 0.0150, -0.0052, -0.0289,  0.0082],
#          [ 0.0364, -0.0503,  0.0622, -0.0376]]])

# key =  torch.tensor([[[ 0.0332,  0.0166,  0.0118, -0.0467],
#          [ 0.0217,  0.0474, -0.0396, -0.0136],
#          [ 0.0500,  0.0030,  0.0491, -0.0210],
#          [ 0.0077,  0.0465, -0.0323,  0.0301]]])

# value =  torch.tensor([[[ 0.0046, -0.0319, -0.0447, -0.0426],
#          [-0.0408, -0.0341, -0.0696, -0.0315],
#          [ 0.0233,  0.0008, -0.0235, -0.0119],
#          [-0.0386, -0.0073, -0.0242,  0.0268]]])


query =  torch.tensor([[[ 0.0504,  0.0288,  0.0344,  0.0388, -0.0436, -0.0319, -0.0429,
          -0.0098],
         [ 0.1113,  0.0207,  0.0186, -0.0005, -0.0560, -0.0844,  0.0346,
           0.0186],
         [-0.1298, -0.0652, -0.0048,  0.0499,  0.0185,  0.0948, -0.0852,
          -0.0162],
         [-0.0612,  0.0323, -0.0151,  0.0323,  0.0426,  0.0375, -0.0656,
           0.0103]]])

key = torch.tensor([[[ 2.4886e-02,  7.6303e-02, -5.6346e-02, -2.4803e-02,  1.1542e-01,
          -7.3159e-02, -3.0394e-02, -5.5170e-02],
         [ 1.7044e-02, -5.8096e-02,  1.6804e-02, -2.4497e-03,  8.4145e-02,
          -3.4248e-02, -3.9195e-02, -2.2824e-02],
         [-5.4661e-02, -7.4460e-03,  8.9877e-03,  1.3472e-02, -5.5859e-02,
          -8.4246e-05, -1.3683e-02, -1.8631e-03],
         [-2.3927e-02,  5.0750e-02, -6.1120e-03, -5.3097e-03, -2.8093e-02,
           1.1956e-01, -3.7372e-02, -3.3173e-02]]])

value = torch.tensor([[[-0.0186,  0.0089,  0.0408, -0.0983, -0.0562, -0.1053, -0.0689,
          -0.0754],
         [-0.0603,  0.0383,  0.0621,  0.0291,  0.0002, -0.0553, -0.0520,
          -0.1110],
         [ 0.0389, -0.0035, -0.0284, -0.0129, -0.0393,  0.0243, -0.0281,
           0.0394],
         [-0.0880, -0.1116, -0.0291, -0.1508, -0.0056, -0.0027,  0.0898,
           0.1074]]])



bsz, q_len, _ = query.shape

# num_heads = 1

# head_dim = 4

query = query.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key = key.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value = value.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)


query.shape, key.shape, value.shape


(torch.Size([1, 1, 4, 4]), torch.Size([1, 1, 4, 4]), torch.Size([1, 1, 4, 4]))

In [140]:
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
cos.shape, sin.shape


(torch.Size([1, 4, 4]), torch.Size([1, 4, 4]))

In [141]:
apply_rotary_pos_emb(query, key, cos, sin, unsqueeze_dim=1)

(tensor([[[[-0.0437,  0.0253, -0.0413,  0.0103],
           [-0.0541, -0.0176, -0.0198, -0.0222],
           [ 0.0200, -0.0054,  0.0257,  0.0081],
           [-0.0448, -0.0491, -0.0564, -0.0391]]]]),
 tensor([[[[ 0.0332,  0.0166,  0.0118, -0.0467],
           [ 0.0450,  0.0475, -0.0031, -0.0131],
           [-0.0655,  0.0034,  0.0250, -0.0209],
           [-0.0031,  0.0456,  0.0331,  0.0315]]]]))

In [142]:
def LlamaRMSNorm(hidden_states, wt, variance_epsilon = 1e-6):

    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    return wt * hidden_states



In [None]:
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


In [42]:


class FeedForward(nn.Module):
    def __init__(
        self,
        embed_dim, 

    ):
        super().__init__()

        hidden_dim = 4 * embed_dim
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, embed_dim, bias=False)
        self.w3 = nn.Linear(embed_dim, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor):
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        swish = F.silu(self.w1(x))
        # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim)
        x_V = self.w3(x)
        # (B, Seq_Len, Hidden_Dim) * (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Hidden_Dim)
        x = swish * x_V
        # (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Dim)
        x = self.w2(x)
        return x


In [None]:
class TransformerBlock(nn.Module):

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 bias: bool = True, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
                                            bias=bias, batch_first=batch_first,
                                            **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)

        self.norm_first = norm_first
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)




    def __init__(self, n_heads, embed_dim, norm_eps = 1e-6, ):
        super().__init__()

        self.n_heads = n_heads
        self.embed_dim = embed_dim
        self.head_embed_dim = embed_dim // n_heads

        self.attention = SelfAttention()
        self.feed_forward = FeedForward()

        # Normalization BEFORE the attention block
        self.attention_norm = RMSNorm(embed_dim, eps=norm_eps)
        # Normalization BEFORE the feed forward block
        self.ffn_norm = RMSNorm(embed_dim, eps=norm_eps)
    
    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        # (B, Seq_Len, embed_dim) + (B, Seq_Len, embed_dim) --> (B, Seq_Len, embed_dim)
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_complex
        )
        # (B, Seq_Len, embed_dim) + (B, Seq_Len, embed_dim) --> (B, Seq_Len, embed_dim)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out