# Overview

Llama2, like the original Llama model, is based on the Google transformer architecture, with improvements. Llama's improvements include

* RMSNorm pre-normalization, inspired by GPT-3
* SwiGLU activation function, inspired by Google's PaLM
* [Multi-query attention](https://arxiv.org/abs/1911.02150) instead of multi-head attention
* Rotary positional embeddings(RoPE), inspired by GPT Neo

Llama training used the [AdamW](https://arxiv.org/abs/1711.05101) optimizer. Llama2's primary differences from Llama are increased context length (4096 vs. 2048 tokens) and [grouped-query attention(GQA)](https://arxiv.org/abs/2305.13245) instead of [multi-query attention(MQA)](https://arxiv.org/abs/1911.02150) in the two larger models.


# The components are need to implement

* RMSNorm
* SwiGLU
* RoPE
* Transformer architecture with Multi-Query-Attention

In [None]:
%%capture
!pip install transformers==4.37.2
!pip install datasets==2.17.0
!pip install sentencepiece==0.1.99

# Root Mean Square Layer Normalization(RMSNorm)

LLaMA2 normalizes the input of each transformer sub-layer, instead of normalizing the output. RMSNorm is extension of Layer Normalization(LayerNorm). Reason behind using RMSNorm is the computational overhead in LayerNorm. This makes improvements slow and expensive. RMSNorm achieves comparable performance against LayerNorm but reduces the runing time. For the LayerNorm, it has two properties.

**Re-centring**

It makes model insensitive to shift noises on both input and weights.

**Re-scaling**

It keeps the output representations intact when both inputs and weighs are randomly scaled. RMSNorm claims that most of the benefits comes from re-scaling.

RMSNorm does re-scaling invariance and regularizes the summed inputs simply according to the root mean square(RMS) statistic.

In [None]:
import torch
from torch.nn import nn

class RMSnorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps=eps
        self.weight=nn.Parameter(torch.ones(dim))
    
    def _norm(self, x:torch.Tensor):
        # (m, seq_len, dim)*(m.seq_len,1)