### Additions in llama-2 model compared ot the original transformer:- 

1. Rotary Positional Embeddings
2. SwiGLU activation function
3. RMSProp
4. KV Caching 
5. Grouped Query attention 

In [11]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

## RMS Norm

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        # (bsz, seq_len, dim) * (bsz, seq_len, 1) = (bsz, seq_len, dim)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        # (dim) * (bsz, seq_len, dim) = (bsz, seq_len, dim)
        return self.weight * self._norm(x.float()).type_as(x)

## Rotary Embeddings for positional encoding

In [46]:
import torch


def precompute_theta_pos_freq(head_dim, seq_len, theta = 10000):

    assert head_dim%2 == 0, "Dimension of head must by divissible by 2"

    # theta_i = 10000^(-2(i-1)/dim) for i = [1,2,3.....dim/2]

    theta_numerator = torch.arange(0, head_dim, 2).float()

    theta = 1.0/ (theta **(theta_numerator/head_dim))

    m = torch.arange(seq_len)

    freqs = torch.outer(m,theta).float()

    #                                   magnitude       angle
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs) 

    return freqs_complex


In [47]:
def apply_rotary_embeds(x, freqs_complex):

    # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number
    # Two consecutive values will become a single complex number


    # H -> no.of heads; can be num_heads for Query and num_kv_heads for Key

    # (bsz, seq_len, H , head_dim) -> (bsz, seq_len, H, head_dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))


    # Reshape the freqs_complex tensor to match the shape of the x_complex tensor. 
    # (seq_len, head_dim/2) --> (1, seq_len, 1, head_dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)


    # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
    # Which results in the rotation of the complex number as shown in the Figure 1 of the paper
    # (bsz, seq_len, H, head_dim/2) * (1, seq_len, 1, head_dim/2) = (bsz, seq_len, H, head_dim/2)
    x_rotated = x_complex * freqs_complex


    # Convert the complex number back to the real number
    # (bsz, seq_len, H, head_dim/2) -> (bsz, seq_len, H, head_dim/2, 2)
    x_out = torch.view_as_real(x_rotated)


    # (bsz, seq_len, H, head_dim/2, 2) -> (bsz, seq_len, H, head_dim)
    x_out = x_out.reshape(*x.shape)
    

    return x_out.type_as(x)

In [70]:
freqs_complex = precompute_theta_pos_freq(head_dim = 2, seq_len = 3, theta = 10000)
freqs_complex

# head_dim = embed_dim // n_heads
# here, 8//4 = 2

tensor([[ 1.0000+0.0000j],
        [ 0.5403+0.8415j],
        [-0.4161+0.9093j]])

In [71]:
x = torch.rand(1,3,4,2)

# batch_size, seq_len, n_kv_heads, head_dim
x.shape, x

(torch.Size([1, 3, 4, 2]),
 tensor([[[[0.8417, 0.5510],
           [0.0214, 0.3450],
           [0.3619, 0.2223],
           [0.6088, 0.2576]],
 
          [[0.0726, 0.3607],
           [0.2889, 0.7267],
           [0.5482, 0.7990],
           [0.6860, 0.6757]],
 
          [[0.6538, 0.7353],
           [0.9691, 0.1316],
           [0.7166, 0.7213],
           [0.0201, 0.1098]]]]))

In [72]:
apply_rotary_embeds(x, freqs_complex)

tensor([[[[ 0.8417,  0.5510],
          [ 0.0214,  0.3450],
          [ 0.3619,  0.2223],
          [ 0.6088,  0.2576]],

         [[-0.2643,  0.2560],
          [-0.4555,  0.6357],
          [-0.3761,  0.8930],
          [-0.1979,  0.9423]],

         [[-0.9407,  0.2885],
          [-0.5230,  0.8264],
          [-0.9541,  0.3514],
          [-0.1082, -0.0274]]]])

## Grouped Query Attention

Without KV caching

In [62]:
# Function to expand the vector 'x' for grouped query attention

def repeat_kv(x, n_rep):

    batch_size, seq_len, n_kv_heads, head_dim = x.shape

    if n_rep == 1:
        return x
    
    else:
        # (bsz, seq_len, n_kv_heads, 1, head_dim)
        # --> (bsz, seq_len, n_kv_heads, n_rep, head_dim)
        # --> (bsz, seq_len, n_kv_heads * n_rep, head_dim)
        return (
            x[:, :, :, None, :]
            .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
            .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
        )


# Grouped query attention 
def GQ_attention_fwd(x, n_heads, n_kv_heads, embed_dim):

    n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads

    n_heads_q = n_heads

    n_rep = n_heads_q//n_kv_heads

    head_dim = embed_dim//n_heads

    Wq = nn.Linear(embed_dim, n_heads*head_dim, bias=False)
    Wk = nn.Linear(embed_dim, n_kv_heads*head_dim, bias=False)
    Wv = nn.Linear(embed_dim, n_kv_heads*head_dim, bias=False)
    Wo = nn.Linear(n_heads*head_dim, embed_dim, bias=False)

    batch_size, seq_len, _ = x.shape

    # (bsz, seq_len, embed_dim)
    xq = Wq(x)

    # (bsz, seq_len, h_kv * head_dim)
    xk = Wk(x)
    xv = Wv(x)

    # (bsz, seq_len, n_heads, head_dim)
    xq = xq.view(batch_size, seq_len, n_heads_q, head_dim)

    # (bsz, seq_len, h_kv, head_dim)
    xk = xk.view(batch_size, seq_len, n_kv_heads, head_dim)
    xv = xv.view(batch_size, seq_len, n_kv_heads, head_dim)


    print("Before applying Rotary embeddings :- ")
    print("Q = ", xq)
    print(xq.shape)
    print()
    print("K = ", xk)
    print(xk.shape)
    print()

    ##################################
    ### Applying rotary embeddings ###

    freqs_complex = precompute_theta_pos_freq(head_dim = head_dim, seq_len = seq_len, theta = 10000)

    # (bsz, seq_len, n_heads, head_dim) -> (bsz, seq_len, n_heads, head_dim)
    xq = apply_rotary_embeds(xq, freqs_complex)

    # (bsz, seq_len, n_kv_heads, head_dim) -> (bsz, seq_len, n_kv_heads, head_dim)
    xk = apply_rotary_embeds(xk, freqs_complex)

    #####################################

    print("After applying Rotary embeddings :- ")
    print("Q = ", xq)
    print(xq.shape)
    print()
    print("K = ", xk)
    print(xk.shape)
    print()


    keys = repeat_kv(xk, n_rep)
    values = repeat_kv(xv, n_rep)

    print("Keys and Values after repeating for GQA")
    print("keys = ",keys.shape,keys)
    print()
    print("values = ",values.shape,values)
    print()

    xq = xq.transpose(1, 2)

    # (bsz, n_heads, seq_len, head_dim)
    keys = keys.transpose(1, 2)
    values = values.transpose(1, 2)

    # (bsz, n_heads, seq_len_q, head_dim) MATMUL (bsz, n_heads, head_dim, seq_len) -> (bsz, n_heads, seq_len_q, seq_len)
    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)

    # (bsz, n_heads, seq_len_q, seq_len)
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)
    print("Attention scores = ", scores)
    print()

    # (bsz, n_heads, seq_len_q, seq_len) MATMUL (bsz, n_heads, seq_len, head_dim) -> (bsz, n_heads, seq_len_q, head_dim)
    output = torch.matmul(scores, values)

    # ((bsz, n_heads, seq_len_q, head_dim) -> (bsz, seq_len_q, dim)
    output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
    print("Attention values = ",output)
    print()

    # (bsz, seq_len_q, dim)
    return Wo(output)






In [32]:
x = torch.rand(1,3,8)

# batch_size, seq_len, embedding_dim
x.shape, x

(torch.Size([1, 3, 8]),
 tensor([[[0.3885, 0.6270, 0.3206, 0.4409, 0.3074, 0.2796, 0.3976, 0.3765],
          [0.2976, 0.8556, 0.0485, 0.8291, 0.5157, 0.7897, 0.8054, 0.7738],
          [0.3322, 0.6595, 0.4512, 0.7347, 0.6833, 0.5190, 0.5468, 0.0774]]]))

In [63]:
atten_values = GQ_attention_fwd(x, n_heads = 4, n_kv_heads = 2, embed_dim = 8)

Before applying Rotary embeddings :- 
Q =  tensor([[[[-0.1299,  0.2407],
          [-0.1696,  0.0881],
          [-0.1522, -0.0535],
          [-0.5082, -0.1717]],

         [[ 0.0579,  0.1978],
          [-0.2157,  0.2064],
          [-0.0879,  0.0266],
          [-0.4180,  0.3014]],

         [[-0.0530,  0.4617],
          [-0.1664,  0.1750],
          [-0.4898, -0.0712],
          [-0.6477,  0.2386]]]], grad_fn=<ViewBackward0>)
torch.Size([1, 3, 4, 2])

K =  tensor([[[[-0.1565, -0.1516],
          [-0.3590,  0.3951]],

         [[ 0.1555, -0.3181],
          [-0.1116,  0.2460]],

         [[-0.0209, -0.1935],
          [-0.2571,  0.4130]]]], grad_fn=<ViewBackward0>)
torch.Size([1, 3, 2, 2])

After applying Rotary embeddings :- 
Q =  tensor([[[[-0.1299,  0.2407],
          [-0.1696,  0.0881],
          [-0.1522, -0.0535],
          [-0.5082, -0.1717]],

         [[-0.1351,  0.1556],
          [-0.2903, -0.0700],
          [-0.0698, -0.0596],
          [-0.4795, -0.1889]],

         [

In [53]:
# Projected attetion values
atten_values

tensor([[[ 0.0505,  0.1723, -0.0181, -0.3009, -0.0628, -0.0622, -0.0380,
          -0.1810],
         [ 0.0535,  0.1749, -0.0165, -0.3008, -0.0634, -0.0638, -0.0403,
          -0.1846],
         [ 0.0523,  0.1754, -0.0178, -0.2965, -0.0640, -0.0633, -0.0405,
          -0.1850]]], grad_fn=<UnsafeViewBackward0>)

In [37]:
# repeat_kv :- Operations and intermediate outputs

import torch

batch_size = 1
seq_len = 3
n_kv_heads = 2
n_rep = 2
head_dim = 4

# random 4D tensor
x = torch.rand(batch_size, seq_len, n_kv_heads, head_dim)

result = (
    x[:, :, :, None, :]
    .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
    .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
)

# Displaying shapes and values
print("Original Tensor Shape:", x.shape)
print()

print("Original Tensor :", x)
print()

print("x[:, :, :, None, :] = ", x[:, :, :, None, :])
print()

print(".expand = ", x[:, :, :, None, :].expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim))
print()

print("Result Tensor Values:\n", result)


Original Tensor Shape: torch.Size([1, 3, 2, 4])

Original Tensor : tensor([[[[0.3553, 0.5873, 0.9951, 0.3988],
          [0.3021, 0.7420, 0.4973, 0.3956]],

         [[0.6649, 0.3215, 0.3432, 0.2998],
          [0.2185, 0.9958, 0.5237, 0.6992]],

         [[0.1545, 0.1401, 0.7353, 0.0985],
          [0.5105, 0.2776, 0.4774, 0.7729]]]])

x[:, :, :, None, :] =  tensor([[[[[0.3553, 0.5873, 0.9951, 0.3988]],

          [[0.3021, 0.7420, 0.4973, 0.3956]]],


         [[[0.6649, 0.3215, 0.3432, 0.2998]],

          [[0.2185, 0.9958, 0.5237, 0.6992]]],


         [[[0.1545, 0.1401, 0.7353, 0.0985]],

          [[0.5105, 0.2776, 0.4774, 0.7729]]]]])

.expand =  tensor([[[[[0.3553, 0.5873, 0.9951, 0.3988],
           [0.3553, 0.5873, 0.9951, 0.3988]],

          [[0.3021, 0.7420, 0.4973, 0.3956],
           [0.3021, 0.7420, 0.4973, 0.3956]]],


         [[[0.6649, 0.3215, 0.3432, 0.2998],
           [0.6649, 0.3215, 0.3432, 0.2998]],

          [[0.2185, 0.9958, 0.5237, 0.6992],
           [0.2

In [1]:
class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        if self.config.pretraining_tp > 1:
            slice = self.intermediate_size // self.config.pretraining_tp
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)

            gate_proj = torch.cat(
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
            )
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)

            down_proj = [
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
            ]
            down_proj = sum(down_proj)
            
        else:
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj


In [73]:
class LlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


NameError: name 'LlamaConfig' is not defined

### LLama Configurations 

- vocab_size (int, optional, defaults to 32000) — Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling LlamaModel

- hidden_size (int, optional, defaults to 4096) — Dimension of the hidden representations.

- intermediate_size (int, optional, defaults to 11008) — Dimension of the MLP representations.

- num_hidden_layers (int, optional, defaults to 32) — Number of hidden layers in the Transformer decoder.

- num_attention_heads (int, optional, defaults to 32) — Number of attention heads for each attention layer in the Transformer decoder.

- num_key_value_heads (int, optional) — This is the number of key_value heads that should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, the model will use Multi Head Attention (MHA), if num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to num_attention_heads`.

- hidden_act (str or function, optional, defaults to "silu") — The non-linear activation function (function or string) in the decoder.

- max_position_embeddings (int, optional, defaults to 2048) — The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, Llama 2 up to 4096, CodeLlama up to 16384.

- initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

- rms_norm_eps (float, optional, defaults to 1e-06) — The epsilon used by the rms normalization layers.

- use_cache (bool, optional, defaults to True) — Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

- pad_token_id (int, optional) — Padding token id.
- bos_token_id (int, optional, defaults to 1) — Beginning of stream token id.
- eos_token_id (int, optional, defaults to 2) — End of stream token id.

- pretraining_tp (int, optional, defaults to 1) — Experimental feature. Tensor parallelism rank used during pretraining. Please refer to this document to understand more about it. This value is necessary to ensure exact reproducibility of the pretraining results. Please refer to this issue.

- tie_word_embeddings (bool, optional, defaults to False) — Whether to tie weight embeddings

- rope_theta (float, optional, defaults to 10000.0) — The base period of the RoPE embeddings.
- rope_scaling (Dict, optional) — Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is {"type": strategy name, "factor": scaling factor}. When using this flag, don’t update max_position_embeddings to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions.

- attention_bias (bool, defaults to False, optional, defaults to False) — Whether to use a bias in the query, key, value and output projection layers during self-attention.
- attention_dropout (float, optional, defaults to 0.0) — The dropout ratio for the attention probabilities.
