# Llama2 Model PyTorch Implementation
This is a notebook implementing a mini version of Llama2 model with only 6 transformer layers 


## 1. RMSNorm

reference: \
https://github.com/meta-llama/llama/blob/8fac8befd776bc03242fe7bc2236cdb41b6c609c/llama/model.py#L34\
https://arxiv.org/abs/1910.07467

In [None]:
import torch
import torch.nn as nn

class rmsnorm(nn.Module):
    
    def __init__(self, d_model):
        self.eps = 1e-8
        self.scale = nn.Parameter(torch.zeros(d_model))
        
    def forward(self, x):
        # x = (batch, seq_len, d_model)
        _, _, d_model = x.size() 
        rms = torch.norm(x, dim=-1) / d_model
        x = x / (rms + self.eps)
        x = torch.einsum("ijk,k->ijk", x, self.scale)
        return x

## 2. Rotary Position Embedding
reference:\
https://arxiv.org/abs/2104.09864


In [None]:
class rotary_position(nn.Module):
    def __init__(self, seq_len, d_model):
        super.__init__(self)
        
        self.theta = torch.zeros(d_model)
        self.theta[::2] = (1/10000)**(2*torch.arange(d_model/2)/d_model)
        self.theta[1::2] = (1/10000)**(2*torch.arange(d_model/2)/d_model)
        
        self.cos = torch.cos(torch.einsum("n,k->nk", torch.arange(seq_len), self.theta))
        self.sin = torch.sin(torch.einsum("n,k->nk", torch.arange(seq_len), 
                                          torch.einsum("k,k->k", (-1)**torch.arange(d_model), self.theta)))
        
    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = torch.einsum("nk,bnk->bnk", self.cos, x) +\
            torch.einsum("nk,bnk->bnk", self.sin, x)
        return x

## 3. Group Query Attention
References:\
https://github.com/fkodom/grouped-query-attention-pytorch/blob/main/grouped_query_attention_pytorch/attention.py#L203\
https://arxiv.org/abs/2305.13245v3

In [None]:
class group_query_attention(nn.torch):
    def __init__(self, seq_len, query_heads, kv_heads, d_model, mask=None):
        """
        MHA:
        
        | (q, q, ..., q)  | (q, q, ..., q) | ... | (q, q, ..., q) | -> query_dim = num_heads * head_dim
        | (k, k, ..., k)  | (k, k, ..., k) | ... | (k, k, ..., k) | ->   key_heads = num_heads * head_dim
        | (v, v, ..., v)  | (v, v, ..., v) | ... | (v, v, ..., v) | -> value_heads = num_heads * head_dim
        
        =====================================================================
        GQA:
        
        | (q, q, ..., q) (q, q, ..., q)...  |  ... | (q, q, ..., q) (q, q, ..., q)|    ->    query_dim = query_heads * head_dim
        |       (k/v, k/v, ..., k/v)       |   ... |       (k/v, k/v, ..., k/v)      | ->    kv_dim = kv_heads * head_dim
                                                                                             group_nums = query_heads / kv_heads
        
        """
        super.__init__(self)
        
        self.seq_len = seq_len
        self.query_heads = query_heads
        self.kv_heads = kv_heads
        self.d_model = d_model
        self.group_nums = self.query_heads / self.kv_heads
        self.head_dim = self.d_model / self.query_heads   
    
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, self.head_dim * self.kv_heads)
        self.WV = nn.Linear(d_model, self.head_dim * self.kv_heads)
        
        self.linear = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax()
        
        self.rope = rotary_position(seq_len, d_model)    
            
    def split_group_head(self, x, group_nums, head_dim, type="q"):
        # input q, k, v: (batch, seq_len, x_dim)
        # output q: (batch, group_nums, group_dim, seq_len, head_dim)
        # output k: (batch, group_nums, seq_len, head_dim)
        # output v: (batch, group_nums, seq_len, head_dim)
        
        batch, seq_len, x_dim = x.size()
        if type == "q":
            group_dim = int(x_dim / head_dim / group_nums)
            x = x.view(batch, seq_len, group_nums, group_dim, head_dim)
            x = torch.permute(x, [0, 2, 3, 1, 4])
        else:
            x = x.view(batch, seq_len, group_nums, head_dim).transpose(1, 2)
        
        return x
    
    def concat_head(self, x):
        # input x: (batch, num_heads, seq_len, d_model // num_heads)
        # output x: (batch, seq_len, d_model)
        batch, num_heads, seq_len, d_head = x.size()
        x = x.transpose(1, 2).view(batch, seq_len, num_heads*d_head)
        
        return x
    
    def group_dot_attention(self, q, k, v):
        # q - (batch, group_nums, group_dim, seq_len, head_dim)
        # k - (batch, group_nums, seq_len, head_dim)
        # v - (batch, group_nums, seq_len, head_dim)
        
        attention = torch.einsum("bgdsh,bgsh->bgdsh", q, k)
        attention = self.softmax(attention/torch.sqrt(self.d_model))
        output = torch.einsum("bgdsh,bgsh->bgdsh", attention, v)
        return output, attention
        
    def forward(self, q, k, v):
        # q: (batch, seq_len, d_model)
        # k: (batch, seq_len, kv_dim)
        # v: (batch, seq_len, kv_dim)
        
        q = self.rope(self.WQ(q))
        k = self.rope(self.WK(k))
        v = self.WV(v)
        
        q = self.split_group_head(q, self.group_nums, self.head_dim, "q")
        k = self.split_group_head(k, self.group_nums, self.head_dim, "k")
        v = self.split_group_head(v, self.group_nums, self.head_dim, "v")
        
        # q - (batch, group_nums, group_dim, seq_len, head_dim)
        # k - (batch, group_nums, seq_len, head_dim)
        # v - (batch, group_nums, seq_len, head_dim)
        
        output, attention = self.group_dot_attention(q, k, v)
        
        output = self.concat_head(output)
        
        output = self.linear(output)
        
        return output
        

## 4. Llama model
![Llama model architecture](images/llama_architecture.png)

For simplicity, we set Nx=6

In [None]:
class transformer_layer(nn.Module):
    def __init__(self, d_model):
        super.__init__(self)
        
        self.rmsnorm = rmsnorm()
        self.gqa = group_query_attention()
        self.ffw = nn.Sequential([nn.Linear(d_model, 2*d_model),
                                  nn.SiLU(),
                                  nn.Linear(2*d_model, d_model),
        ])

    def forward(self, x):
        x = self.rmsnorm(x)
        x, _ = self.gqa(x, x, x)
        x = self.ffw(x)
        return x

In [None]:
import torch.nn as nn

        
class llama_model(nn.Module):
    def __init__(self, d_model):
        super.__init__(self)
        
        self.transformer_layers = nn.Sequential([
            transformer_layer() for _ in range(6)
        ])
        self.rmsnorm = rmsnorm()
        self.linear = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        x = self.transformer_layers(x)
        x = self.rmsnorm(x)
        x = self.linear(x)
        x = self.softmax(x)
        
        return x
        
        
        
            