<h3>LLAMA 3.1 8B</h3>
entirely in pytorch

In [2]:

import torch
import torch.nn as nn
from torch.nn import functional as F




In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Entire Architecture 

n_layer = 32



# Input customizations

vocab_sz  = 128256  # 100k for english bpe tokens and 28k for multilingual tokens
n_embd = 4096
n_pos_emb = 8096  # same as block size 
RoPE_theta = 50000.0


# Transformer Block - FFN

batch_size = 128  # sequences going in parallel
block_size = 8192 # sequece length (context length)
ffn_multplier = 256 # we need layers having matrices with dimensions that are multiple of 256




# Transformer Block - Attention - Grouped multi query

n_head = 32
n_head_kv = 64
num_key_value_heads = 8




In [None]:

# Each layer of transformer


class GMQAttention(nn.module):
    def __init__(self):
        super().__init__()
        self.n_kv_heads = n_head_kv
        self.n_q_heads = n_head
        # to complete


class FeedForward(nn.module):
    def __init__(self):
        super().__init__()
        
        # layer dimensions
        ffn_temp = 4*(int(2*n_embd/3))
        hidden_embd = ffn_multplier *((ffn_temp + n_embd-1) //ffn_multplier)  # rounded to nearest 256 multiple

        
        self.l1_a = nn.Linear(n_embd,hidden_embd,bias=False)
        self.l2 = nn.Linear(hidden_embd,n_embd,bias=False)
        self.l1_l = nn.Linear(n_embd,hidden_embd,bias=False)
        
    def forward(self,x:torch.Tensor):
        
        # creating 2 seperate forms of x
        x_swish = nn.SiLU(self.l1_a(x)) # 1 with activation
        x_lin = self.l1_l(x) # 2 without activation
        x = x_swish * x_lin # multiplying it for fianl input 
        
        x = self.l2(x) # final linear layer with the input

        return x


class Layer(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Attention 
        head_sz = n_embd // n_head
        self.norm1 = nn.RMSNorm(n_embd)
        self.atten = GMQAttention(n_head,head_sz)
        
        # MLP
        self.norm2 = nn.RMSNorm(n_embd)
        self.line = FeedForward(n_embd)
    

    def forward(self,idx):
        # applying attention
        atn = self.atten(self.norm1(idx))

        # applying feed forward
        lin = self.line(self.norm2(idx))
        
        # Adding them instead of replacing in order to keep the original input as context
        idx = idx + atn
        idx = idx + lin

        return idx



# Entire Model husk

class Llama(nn.Module):
    def __init__(self):
        super().__init__()

        # embedding table for all tokens in vocab along n_embd dimensions
        self.token_embedding = nn.Embedding(vocab_sz, n_embd)  
        
        #  RoPE embedding table for all positions in context length along n_embd dimensions 
        # self.position_embedding = RoPE(n_pos_emb, )    # Will do later
        self.pos_embedding = nn.Embedding(block_size, n_embd)

        # Blocks into a sequential model
        self.Layers = nn.Sequential(*[Layer(n_embd, n_head=n_head) for _ in range(n_layer)])  # 32 layers

        # Nomalization
        self.finalnorm = nn.RMSNorm(n_embd)

        # Linear
        self.finallin = nn.Linear(n_embd,vocab_sz)

        # for applying residues
        self.apply(self._init_weights)
    
    
    def _init_weights(self, module):

        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        

    # The application layer
    def forward(self, idx, targets=None):
        B, T = idx.shape  # idx is a random set of examples chosen from dataset
        # B number of examples of T length each

        tok_embds = self.token_embedding(idx) # [B,T] -> [B,T,C]
        pos_embds = self.pos_embedding(torch.arange(T, device=device))   # [T] -> [T,C]
        # Since position embedding does not depend upon the input, it is added equally to all B's

        x = tok_embds + pos_embds
        # Embedding done

        x = self.Layers(x)  # [B,T,C] -> [B,T,C]
        # Repeating transformer done

        x = self.finalnorm(x)  # [B,T,C] -> [B,T,C]
        # Final RMS norm done

        x = self.finallin(x)   # [B,T,C] -> [B,T,vocab_sz] 
        # final linear, this converts embeddings to vocab

        return x

 






