In [1]:
import math
from typing import Optional, Tuple, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer

In [2]:
device = "cuda" if torch.cuda.is_available else "cpu"

# RoPE

In [3]:
def calculate_angles(theeta, dim, seq_len):
    pos = 1/theeta**(torch.arange(0, dim, 2, device=device).float()/dim)
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, pos)
    unit_vecs = torch.polar(torch.ones_like(freqs), freqs).to(device)
    return unit_vecs

def brodcast(unit_vecs, x):
    # print(f"unit_vecs.shape: {unit_vecs.shape}\nX.shape: {x.shape[1], x.shape[-1]}")
    assert unit_vecs.shape == (x.shape[1], x.shape[-1])
    n_dim = x.ndim
    shape = [d if i == 1 or i == n_dim-1 else 1 for i,d in enumerate(x.shape)]
    return unit_vecs.view(*shape)

def RoPE(W_Q, W_K, unit_vecs):
    complex_W_Q = torch.view_as_complex(W_Q.float().reshape(*W_Q.shape[:-1], -1, 2)).to(device)
    complex_W_K = torch.view_as_complex(W_K.float().reshape(*W_K.shape[:-1], -1, 2)).to(device)
    # print(complex_W_Q.shape)
    pos = brodcast(unit_vecs, complex_W_K).to(device)
    embedded_W_Q = torch.view_as_real(complex_W_Q * pos).float().flatten(3)
    embedded_W_K = torch.view_as_real(complex_W_K * pos).float().flatten(3)
    return embedded_W_Q, embedded_W_K

In [4]:
W_Q = torch.rand((1,6,2,4))
complex_W_Q = torch.view_as_complex(W_Q.float().reshape(*W_Q.shape[:-1], -1, 2))
complex_W_Q.shape


torch.Size([1, 6, 2, 2])

In [5]:
DIM = 3072 # Keep as 3072 to match checkpoint
FFN_DIM = 8192 # Change from 14336 to 8192 to match checkpoint feed-forward layers
N_LAYERS = 28 # This should match your checkpoint's number of layers (0-27 = 28 layers)
N_HEADS = 24 # Change from 32 - calculated as DIM/HEAD_DIM where HEAD_DIM=128
N_KV_HEADS = 8 # Keep as 8 (matches checkpoint key-value head dimension 1024/128)
VOCAB_SIZE = 128256 # Keep same - matches checkpoint
NORM_EPS = 1e-5 # Keep same
ROPE_THETA = 500000 # Keep same
MAX_BATCH_SIZE = 4 # Keep same
MAX_SEQ_LEN = 128 # Keep same
N_KV_HEAD_REP = N_HEADS // N_KV_HEADS # Now 24 // 8 = 3
HEAD_DIM = DIM // N_HEADS # Now 3072 // 24 = 128

# ATTENTION

In [6]:
freq = calculate_angles(10000, HEAD_DIM, 2).to(device)
freq.shape

torch.Size([2, 64])

In [7]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.W_Q = nn.Linear(DIM, N_HEADS * HEAD_DIM, bias=False).to(device)
        self.W_K = nn.Linear(DIM, N_KV_HEADS * HEAD_DIM, bias=False).to(device)
        self.W_V = nn.Linear(DIM, N_KV_HEADS * HEAD_DIM, bias=False).to(device)

        self.CACHE_K = torch.zeros(
            (MAX_BATCH_SIZE, MAX_SEQ_LEN, N_KV_HEADS, HEAD_DIM)
        )
        self.CACHE_V = torch.zeros(
            (MAX_BATCH_SIZE, MAX_SEQ_LEN, N_KV_HEADS, HEAD_DIM)
        )

        self.wo = nn.Linear(DIM, DIM).to(device)


    def forward(self,x, freq=None, start_pos=0, mask=None):
        bhz, seq_len, _ = x.shape

        query = self.W_Q(x).view(bhz, seq_len, N_HEADS, HEAD_DIM).to(device)
        key = self.W_K(x).view(bhz, seq_len, N_KV_HEADS, HEAD_DIM)
        value = self.W_V(x).view(bhz, seq_len, N_KV_HEADS, HEAD_DIM)

        query, key = RoPE(query, key, freq)

        self.CACHE_K.to(device)
        self.CACHE_V.to(device)

        self.CACHE_K[:bhz, start_pos:start_pos+seq_len] = key
        self.CACHE_V[:bhz, start_pos:start_pos+seq_len] = value

        keys = self.CACHE_K[:bhz, :start_pos+seq_len]
        values = self.CACHE_V[:bhz, :start_pos+seq_len]

        keys = torch.repeat_interleave(input=keys, repeats=N_KV_HEAD_REP, dim=-2)
        values = torch.repeat_interleave(input=values, repeats=N_KV_HEAD_REP, dim=-2)

        queries = query.transpose(1,2).to(device)
        keys = keys.transpose(1,2).to(device)
        values = values.transpose(1,2).to(device)
        
        out = F.scaled_dot_product_attention(queries, keys, values, attn_mask=mask)
        out = out.transpose(1,2).contiguous().view(bhz, seq_len, -1)

        return self.wo(out)

In [8]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim, norm_eps):
        super().__init__()
        self.norm_eps = norm_eps
        self.weight = nn.Parameter(torch.ones(dim)).to(device)

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps)
    
    def forward(self, x):
        out = self._norm(x.float()).type_as(x).to(device)
        return out * self.weight

In [9]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()

        self.w1 = nn.Linear(DIM, FFN_DIM, bias=False).to(device)
        self.w3 = nn.Linear(DIM, FFN_DIM, bias=False).to(device)
        self.w2 = nn.Linear(FFN_DIM, DIM, bias=False).to(device)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

In [10]:
class Transformer_Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.Attention_Norm = RMSNorm(dim=DIM, norm_eps=NORM_EPS)
        self.FFN_Norm = RMSNorm(dim=DIM, norm_eps=NORM_EPS)
        self.Attention = Attention()
        self.FeedForward = FeedForward()
    def forward(self, x, freq, start_pos, mask):
        shortcut = x
        x = self.Attention_Norm(x)
        x = self.Attention(x, freq, start_pos, mask)
        x = x.to(shortcut.device) + shortcut

        shortcut = x
        x = self.FFN_Norm(x.to(device))
        x = self.FeedForward(x)
        x = x.to(shortcut.device) + shortcut

        return x.to(device)

In [11]:
class LLAMA_3(nn.Module):
    def __init__(self):
        super().__init__()

        self.tok_embedding = nn.Embedding(VOCAB_SIZE, DIM)
        self.layers = nn.ModuleList()
        for _ in range(N_LAYERS):
            self.layers.append(Transformer_Block())
        self.norm = RMSNorm(DIM, NORM_EPS)
        self.output = nn.Linear(DIM, VOCAB_SIZE, bias=False).to(device)

        self.freq = calculate_angles(
            ROPE_THETA,
            HEAD_DIM,
            MAX_SEQ_LEN * 2
        )
    
    def forward(self, tokens, start_pos):
        bhz, seq_len = tokens.shape
        x = self.tok_embedding(tokens)
        freq = self.freq[start_pos : start_pos+seq_len]

        mask = None
        if seq_len > 1:
            mask = torch.full((seq_len, seq_len), float("-inf"), device=tokens.device) 
            mask = torch.triu(mask, diagonal=1).to(device)

        for layer in self.layers:
            x = layer(x, freq, start_pos, mask )
        x = self.norm(x)
        x = self.output(x).float()
        
        return x
        

In [12]:
llama = LLAMA_3().to(device)

In [13]:
total_params = sum(p.numel() for p in llama.parameters())
print(f"Total Number of parameters: {total_params:,}")

Total Number of parameters: 3,606,663,168


In [14]:
total_size_bytes = total_params * 4 #A
total_size_mb = total_size_bytes / (1024 * 1024) #B
print(f"Total size of the model: {total_size_mb/1024} GB")

Total size of the model: 13.435867309570312 GB


In [17]:
def Generate_Text(model, idx, max_tokens, context_size, start_pos):
    for _ in range(max_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.inference_mode():
            logits = model(idx_cond, start_pos)

            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_idx = torch.argmax(probs, dim=-1, keepdim=True)
            idx = torch.cat((idx, next_idx), dim=-1)
    return idx

In [18]:
inp = torch.rand((1,4)).long().to(device)

f = Generate_Text(llama, inp, 4, 4, 0)
f

tensor([[    0,     0,     0,     0, 41203, 11803, 23389, 12947]],
       device='cuda:0')