In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from pathlib import Path
from typing import (
    AbstractSet,
    cast,
    Collection,
    Dict,
    Iterator,
    List,
    Literal,
    Sequence,
    TypedDict,
    Union,
)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# RoPE

In [None]:
def calculate_angles(theeta, dim, seq_len):
    pos = 1/theeta**(torch.arange(0, dim, 2, device=device).float()/dim)
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, pos)
    unit_vecs = torch.polar(torch.ones_like(freqs), freqs)
    return unit_vecs

def brodcast(unit_vecs, x):
    # print(f"unit_vecs.shape: {unit_vecs.shape}\nX.shape: {x.shape[1], x.shape[-1]}")
    assert unit_vecs.shape == (x.shape[1], x.shape[-1])
    n_dim = x.ndim
    shape = [d if i == 1 or i == n_dim-1 else 1 for i,d in enumerate(x.shape)]
    return unit_vecs.view(*shape)

def RoPE(W_Q, W_K, unit_vecs):
    complex_W_Q = torch.view_as_complex(W_Q.float().reshape(*W_Q.shape[:-1], -1, 2))
    complex_W_K = torch.view_as_complex(W_K.float().reshape(*W_K.shape[:-1], -1, 2))
    # print(complex_W_Q.shape)
    pos = brodcast(unit_vecs, complex_W_K)
    embedded_W_Q = torch.view_as_real(complex_W_Q * pos).float().flatten(3)
    embedded_W_K = torch.view_as_real(complex_W_K * pos).float().flatten(3)
    return embedded_W_Q, embedded_W_K

In [None]:
CONFIGURATIONS = {
  "DIM": 3072,
  "FFN_DIM": 8192,
  "N_LAYERS": 28,
  "N_HEADS": 24,
  "N_KV_HEADS": 8,
  "VOCAB_SIZE": 128256,
  "NORM_EPS": 1e-5,
  "ROPE_THETA": 500000,
  "MAX_BATCH_SIZE": 4,
  "MAX_SEQ_LEN": 6000,
  "N_KV_HEAD_REP": 24 // 8,
  "HEAD_DIM": 128
}

# LAYERS

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, n_heads, head_dim, n_kv_heads, n_kv_heads_reps, max_batch_size, max_seq_len):
        super().__init__()
        self.W_Q = nn.Linear(dim, n_heads * head_dim, bias=False)
        self.W_K = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.W_V = nn.Linear(dim, n_kv_heads * head_dim, bias=False)

        self.register_buffer('CACHE_K', torch.zeros(
            (max_batch_size, max_seq_len, n_kv_heads, head_dim))
        )
        self.register_buffer('CACHE_V', torch.zeros(
            (max_batch_size, max_seq_len, n_kv_heads, head_dim))
        )

        self.wo = nn.Linear(dim, dim)

        self.n_heads = n_heads
        self.head_dim = head_dim
        self.n_kv_heads = n_kv_heads
        self.n_kv_heads_reps = n_kv_heads_reps


    def forward(self,x, freq=None, start_pos=0, mask=None):
        bhz, seq_len, _ = x.shape

        query = self.W_Q(x).view(bhz, seq_len, self.n_heads, self.head_dim)
        key = self.W_K(x).view(bhz, seq_len, self.n_kv_heads, self.head_dim)
        value = self.W_V(x).view(bhz, seq_len, self.n_kv_heads, self.head_dim)

        query, key = RoPE(query, key, freq)

        self.CACHE_K[:bhz, start_pos:start_pos+seq_len] = key
        self.CACHE_V[:bhz, start_pos:start_pos+seq_len] = value

        keys = self.CACHE_K[:bhz, :start_pos+seq_len]
        values = self.CACHE_V[:bhz, :start_pos+seq_len]

        keys = torch.repeat_interleave(input=keys, repeats=self.n_kv_heads_reps, dim=-2)
        values = torch.repeat_interleave(input=values, repeats=self.n_kv_heads_reps, dim=-2)

        queries = query.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        
        out = F.scaled_dot_product_attention(queries, keys, values, attn_mask=mask)
        out = out.transpose(1,2).contiguous().view(bhz, seq_len, -1)

        return self.wo(out)

In [None]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim, norm_eps):
        super().__init__()
        self.norm_eps = norm_eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps)
    
    def forward(self, x):
        out = self._norm(x.float()).type_as(x)
        return out * self.weight

In [None]:
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_dim):
        super().__init__()

        self.w1 = nn.Linear(dim, ffn_dim, bias=False)
        self.w3 = nn.Linear(dim, ffn_dim, bias=False)
        self.w2 = nn.Linear(ffn_dim, dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

In [None]:
class Transformer_Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.Attention_Norm = RMSNorm(dim=config["DIM"], norm_eps=config["NORM_EPS"])
        self.FFN_Norm = RMSNorm(dim=config["DIM"], norm_eps=config["NORM_EPS"])
        self.Attention = Attention(dim=config["DIM"],
                                   n_heads=config["N_HEADS"],
                                   head_dim=config["HEAD_DIM"],
                                   n_kv_heads=config["N_KV_HEADS"],
                                   n_kv_heads_reps=config["N_KV_HEAD_REP"],
                                   max_batch_size=config["MAX_BATCH_SIZE"],
                                   max_seq_len=config["MAX_SEQ_LEN"])
        self.FeedForward = FeedForward(dim=config["DIM"],
                                       ffn_dim=config["FFN_DIM"])
    def forward(self, x, freq, start_pos, mask):
        shortcut = x
        x = self.Attention_Norm(x)
        x = self.Attention(x, freq, start_pos, mask)
        x = x + shortcut

        shortcut = x
        x = self.FFN_Norm(x)
        x = self.FeedForward(x)
        x = x + shortcut

        return x

In [None]:
class LLAMA_3(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.tok_embedding = nn.Embedding(config["VOCAB_SIZE"], config["DIM"])
        self.layers = nn.ModuleList()
        for _ in range(config["N_LAYERS"]):
            self.layers.append(Transformer_Block(config))
        self.norm = RMSNorm(config["DIM"], config["NORM_EPS"])
        self.output = nn.Linear(config["DIM"], config["VOCAB_SIZE"], bias=False)

        self.register_buffer(
            'freq',
            calculate_angles(
                config["ROPE_THETA"],
                config["HEAD_DIM"],
                config["MAX_SEQ_LEN"] * 2
            )
        )
    
    def reset_cache(self):
        for name, module in self.named_modules():
            if hasattr(module, "CACHE_K"):
                module.CACHE_K.zero_()
            if hasattr(module, "CACHE_V"):
                module.CACHE_V.zero_()

    
    def forward(self, tokens, start_pos):
        bhz, seq_len = tokens.shape
        x = self.tok_embedding(tokens)
        freq = self.freq[start_pos : start_pos+seq_len]

        mask = None
        if seq_len > 1:
            mask = torch.full((seq_len, seq_len), float("-inf"), device=tokens.device) 
            mask = torch.triu(mask, diagonal=1)

        for layer in self.layers:
            x = layer(x, freq, start_pos, mask )
        x = self.norm(x)
        x = self.output(x).float()
        
        return x

In [None]:
llama = LLAMA_3(config=CONFIGURATIONS).to(device)

In [None]:
total_params = sum(p.numel() for p in llama.parameters())
print(f"Total Number of parameters: {total_params:,}")

In [None]:
total_size_bytes = total_params * 4
total_size_mb = total_size_bytes / (1024 * 1024)
print(f"Total size of the model: {total_size_mb/1024} GB")

In [None]:
def Generate_Text(model, idx, max_tokens, context_size, start_pos):
    for _ in range(max_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.inference_mode():
            logits = model(idx_cond, start_pos)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_idx = torch.argmax(probs, dim=-1, keepdim=True)
            idx = torch.cat((idx, next_idx), dim=-1)
    return idx

# Tokenizer

In [None]:
class Tokenizer:
    no_reserved_special_tokens = 256
    pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501

    def __init__(self, model_path):
        mergeable_ranks = load_tiktoken_bpe(model_path)
        num_base_tokens = len(mergeable_ranks)
        special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>",
            "<|reserved_special_token_3|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|reserved_special_token_4|>",
            "<|eot_id|>",  # end of turn
        ] + [f"<|reserved_special_token_{i}|>" for i in range(5, self.no_reserved_special_tokens - 5)]
        
        self.special_tokens = {
            token: num_base_tokens + i for i, token in enumerate(special_tokens)
        }

        self.model = tiktoken.Encoding(
            name=Path(model_path).name,
            pat_str=self.pat_str,
            mergeable_ranks=mergeable_ranks,
            special_tokens=self.special_tokens
        )
        self.n_words = self.model.n_vocab
        self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
        self.eos_id: int = self.special_tokens["<|end_of_text|>"]
        self.pad_id: int = -1
        self.stop_tokens = {
            self.special_tokens["<|end_of_text|>"],
            self.special_tokens["<|eot_id|>"],
        }

    def encode(
        self,
        s: str,
        *,
        bos: bool,
        eos: bool,
        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
        disallowed_special: Union[Literal["all"], Collection[str]] = (),
        ):
        TIKTOKEN_MAX_ENCODE_CHARS = 400_000
        MAX_NO_WHITESPACES_CHARS = 25_000
        substrs = (
            substr
            for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
            for substr in self._split_whitespaces_or_nonwhitespaces(
                s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
            )
        )
        t: List[int] = []
        for substr in substrs:
            t.extend(
                self.model.encode(
                    substr,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            )
        if bos:
            t.insert(0, self.bos_id)
        if eos:
            t.append(self.eos_id)
        return t
    def decode(self, t: Sequence[int]) -> str:
        return self.model.decode(cast(List[int], t))

    def _split_whitespaces_or_nonwhitespaces(self,
        s: str, max_consecutive_slice_len: int
    ):
        current_slice_len = 0
        current_slice_is_space = s[0].isspace() if len(s) > 0 else False
        slice_start = 0

        for i in range(len(s)):
            is_now_space = s[i].isspace()

            if current_slice_is_space ^ is_now_space:
                current_slice_len = 1
                current_slice_is_space = is_now_space
            else:
                current_slice_len += 1
                if current_slice_len > max_consecutive_slice_len:
                    yield s[slice_start:i]
                    slice_start = i
                    current_slice_len = 1
        yield s[slice_start:]     

In [None]:
tok = Tokenizer("Weights/3B-instruct/original/tokenizer.model")

In [None]:
z = tok.encode(s="my name is sarthak", bos=False, eos=False)
z = torch.tensor([z], dtype=torch.long, device=device)
d = Generate_Text(llama, z, 5, 10, 0)
d = d.squeeze(dim=0)
tok.decode(d.tolist())

In [None]:
from helper_functions import load_consolidated_pth_weights, fetch_context_from_web, clean_serp_data

In [None]:
dir = "./Weights/3B-instruct/original/consolidated.00.pth"
params = load_consolidated_pth_weights(dir)

In [None]:
print(list(params.keys())[:])

In [None]:
import torch

def weight_injector(my_model, params):
    device = next(my_model.parameters()).device
    dtype = next(my_model.parameters()).dtype

    my_model.tok_embedding.weight.data.copy_(
        params["tok_embeddings.weight"].to(device=device, dtype=dtype)
    )

    num_layers = len(my_model.layers)
    for i in range(num_layers):
        layer = my_model.layers[i]

        layer.Attention.W_Q.weight.data.copy_(
            params[f"layers.{i}.attention.wq.weight"].to(device=device, dtype=dtype)
        )
        layer.Attention.W_K.weight.data.copy_(
            params[f"layers.{i}.attention.wk.weight"].to(device=device, dtype=dtype)
        )
        layer.Attention.W_V.weight.data.copy_(
            params[f"layers.{i}.attention.wv.weight"].to(device=device, dtype=dtype)
        )
        layer.Attention.wo.weight.data.copy_(
            params[f"layers.{i}.attention.wo.weight"].to(device=device, dtype=dtype)
        )

        if hasattr(layer.Attention.wo, "bias") and f"layers.{i}.attention.wo.bias" in params:
            layer.Attention.wo.bias.data.copy_(
                params[f"layers.{i}.attention.wo.bias"].to(device=device, dtype=dtype)
            )

        layer.FeedForward.w1.weight.data.copy_(
            params[f"layers.{i}.feed_forward.w1.weight"].to(device=device, dtype=dtype)
        )
        layer.FeedForward.w3.weight.data.copy_(
            params[f"layers.{i}.feed_forward.w3.weight"].to(device=device, dtype=dtype)
        )
        layer.FeedForward.w2.weight.data.copy_(
            params[f"layers.{i}.feed_forward.w2.weight"].to(device=device, dtype=dtype)
        )

        layer.Attention_Norm.weight.data.copy_(
            params[f"layers.{i}.attention_norm.weight"].to(device=device, dtype=dtype)
        )
        layer.FFN_Norm.weight.data.copy_(
            params[f"layers.{i}.ffn_norm.weight"].to(device=device, dtype=dtype)
        )

    my_model.norm.weight.data.copy_(
        params["norm.weight"].to(device=device, dtype=dtype)
    )

    out_w = params["output.weight"]
    if my_model.output.weight.shape == out_w.shape:
        my_model.output.weight.data.copy_(out_w.to(device=device, dtype=dtype))
    elif my_model.output.weight.shape[::-1] == out_w.shape:
        my_model.output.weight.data.copy_(out_w.T.to(device=device, dtype=dtype))
    else:
        raise ValueError(f"Output weight shape mismatch: model={my_model.output.weight.shape}, ckpt={out_w.shape}")

    print("All weights successfully injected!")

In [None]:
weight_injector(llama, params)

In [None]:
def generate(
    model, idx, context_len, max_new_tok, top_k,
    temp=0.1,
    eos_id=128001,
    eot_id=128009,
    eom_id=128008,
):
    model.eval()
    model.reset_cache()
    for tok_no in range(max_new_tok):
        if tok_no == 0:
            idx_cond = idx[:, -context_len:] if idx.shape[1] > context_len else idx
            start_pos = 0
        else:
            idx_cond = idx[:, -1:]
            start_pos = idx.shape[1] - 1
        print(f"\rToken {tok_no}: seq_len={idx_cond.shape[1]}, start_pos={start_pos}", end="")
        with torch.inference_mode():
            logits = model(idx_cond, start_pos)
        logits = logits[:, -1, :]

        if top_k > 0:
            top_k_logits, _ = torch.topk(logits, top_k)
            min_val = top_k_logits[:, -1:]
            logits = torch.where(logits < min_val, float("-inf"), logits)

        logits /= max(temp, 1e-8)
        probs = torch.softmax(logits, dim=-1)
        preds = torch.multinomial(probs, 1)

        next_tok = preds.item()

        if next_tok in {eos_id, eot_id, eom_id}:
            print(f"\n[Stop token {next_tok} reached at step {tok_no}]")
            break

        idx = torch.cat([idx, preds], dim=1)
    return idx


In [None]:
def format_and_tokenize(system_prompt, user_prompt, context=False, device="cuda"):
    if context == True:
        cntx = clean_serp_data(fetch_context_from_web(user_prompt, n_results=1))
        print(cntx)
        user_prompt = f"Use the following context to answer:\n{cntx.strip()}\n\nQuestion: {user_prompt.strip()}"

    prompt = (
        "<|start_header_id|>system<|end_header_id|>\n"
        f"{system_prompt.strip()}\n"
        "<|eot_id|>\n"
        "<|start_header_id|>user<|end_header_id|>\n"
        f"{user_prompt.strip()}\n"
        "<|eot_id|>\n"
        "<|start_header_id|>assistant<|end_header_id|>\n"
    )

    input_ids = tok.encode(prompt, bos=True, eos=False, allowed_special="all")
    return torch.tensor([input_ids], device=device)
    # return tok.decode(input_ids)

In [None]:
#When context is provided, summarize it in your own words rather than copying.

In [None]:
system_prompt = "You are a concise and factual AI assistant, and your name is Blue. You have to answer the Question asked by user."
user_prompt = "hi how are you"
idx = format_and_tokenize(system_prompt, user_prompt, context=False)

# Decoding Block

In [None]:
out = generate(llama, idx, context_len=2048, max_new_tok=50, top_k=50, temp=0.3)

In [None]:
def extract_assistant_output(out, tokenizer):
    text = tokenizer.decode(out.squeeze().tolist())
    segments = text.split("<|start_header_id|>assistant<|end_header_id|>")

    if len(segments) <= 1:
        return text.strip()

    assistants = []
    for seg in segments[1:]:
        assistants.append(seg.split("<|eot_id|>")[0].strip())

    return assistants[0]


In [None]:
assistant_response = extract_assistant_output(out, tok)
print(assistant_response)

In [None]:
out = out.squeeze()
x = tok.decode(out.tolist())
x

In [None]:
out

In [None]:
r = [78191]
print(tok.decode(r))


In [None]:
y = "<|start_header_id|>"
print(tok.encode(y, bos=False, eos=False, allowed_special="all"))