In [3]:
from typing import Optional, Tuple, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from pathlib import Path
from typing import (
    AbstractSet,
    cast,
    Collection,
    Dict,
    Iterator,
    List,
    Literal,
    Sequence,
    TypedDict,
    Union,
)
import json
import numpy as np
from torch.nn.attention import sdpa_kernel, SDPBackend
import time
from bitsandbytes.nn import Linear8bitLt


Could not detect ROCm GPU architecture: module 'torch' has no attribute 'version'

ROCm GPU architecture detection failed despite ROCm being available.
                


AttributeError: module 'torch' has no attribute 'version'

In [2]:
device = "cuda" if torch.cuda.is_available else "cpu"

# RoPE

In [3]:
def calculate_angles(theeta, dim, seq_len):
    pos = 1/theeta**(torch.arange(0, dim, 2, device=device).float()/dim)
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, pos)
    unit_vecs = torch.polar(torch.ones_like(freqs), freqs).to(device)
    return unit_vecs

def brodcast(unit_vecs, x):
    # print(f"unit_vecs.shape: {unit_vecs.shape}\nX.shape: {x.shape[1], x.shape[-1]}")
    assert unit_vecs.shape == (x.shape[1], x.shape[-1])
    n_dim = x.ndim
    shape = [d if i == 1 or i == n_dim-1 else 1 for i,d in enumerate(x.shape)]
    return unit_vecs.view(*shape)

def RoPE(W_Q, W_K, unit_vecs):
    complex_W_Q = torch.view_as_complex(W_Q.float().reshape(*W_Q.shape[:-1], -1, 2)).to(device)
    complex_W_K = torch.view_as_complex(W_K.float().reshape(*W_K.shape[:-1], -1, 2)).to(device)
    # print(complex_W_Q.shape)
    pos = brodcast(unit_vecs, complex_W_K).to(device)
    embedded_W_Q = torch.view_as_real(complex_W_Q * pos).float().flatten(3)
    embedded_W_K = torch.view_as_real(complex_W_K * pos).float().flatten(3)
    return embedded_W_Q, embedded_W_K

In [4]:
W_Q = torch.rand((1,6,2,4))
complex_W_Q = torch.view_as_complex(W_Q.float().reshape(*W_Q.shape[:-1], -1, 2))
complex_W_Q.shape


torch.Size([1, 6, 2, 2])

In [5]:
CONFIGURATIONS = {
    "DIM": 3072,
    "FFN_DIM": 8192,
    "N_LAYERS": 28,
    "N_HEADS": 24,
    "N_KV_HEADS": 8,
    "VOCAB_SIZE": 128256,
    "NORM_EPS": 1e-5,
    "ROPE_THETA": 500000,
    "MAX_BATCH_SIZE": 4,
    "MAX_SEQ_LEN": 128,
    "N_KV_HEAD_REP": 24 // 8,
    "HEAD_DIM": 3072 // 24
    }

# LAYERS

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, n_heads, head_dim, n_kv_heads, n_kv_heads_reps, max_batch_size, max_seq_len):
        super().__init__()
        self.W_Q = nn.Linear(dim, n_heads * head_dim, bias=False)
        self.W_K = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
        self.W_V = nn.Linear(dim, n_kv_heads * head_dim, bias=False)

        self.CACHE_K = torch.zeros(
            (max_batch_size, max_seq_len, n_kv_heads, head_dim),
            device=device
        )
        self.CACHE_V = torch.zeros(
            (max_batch_size, max_seq_len, n_kv_heads, head_dim),
            device=device
        )

        self.wo = nn.Linear(dim, dim)

        self.n_heads = n_heads
        self.head_dim = head_dim
        self.n_kv_heads = n_kv_heads
        self.n_kv_heads_reps = n_kv_heads_reps


    def forward(self,x, freq=None, start_pos=0, mask=None):
        bhz, seq_len, _ = x.shape

        query = self.W_Q(x).view(bhz, seq_len, self.n_heads, self.head_dim)
        key = self.W_K(x).view(bhz, seq_len, self.n_kv_heads, self.head_dim)
        value = self.W_V(x).view(bhz, seq_len, self.n_kv_heads, self.head_dim)

        query, key = RoPE(query, key, freq)

        # self.CACHE_K.to(device)
        # self.CACHE_V.to(device)

        self.CACHE_K[:bhz, start_pos:start_pos+seq_len] = key
        self.CACHE_V[:bhz, start_pos:start_pos+seq_len] = value

        keys = self.CACHE_K[:bhz, :start_pos+seq_len]
        values = self.CACHE_V[:bhz, :start_pos+seq_len]

        keys = torch.repeat_interleave(input=keys, repeats=self.n_kv_heads_reps, dim=-2)
        values = torch.repeat_interleave(input=values, repeats=self.n_kv_heads_reps, dim=-2)

        queries = query.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        
        out = F.scaled_dot_product_attention(queries, keys, values, attn_mask=mask)
        out = out.transpose(1,2).contiguous().view(bhz, seq_len, -1)

        return self.wo(out)

In [7]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim, norm_eps):
        super().__init__()
        self.norm_eps = norm_eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps)
    
    def forward(self, x):
        out = self._norm(x.float()).type_as(x)
        return out * self.weight

In [8]:
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_dim):
        super().__init__()

        self.w1 = nn.Linear(dim, ffn_dim, bias=False)
        self.w3 = nn.Linear(dim, ffn_dim, bias=False)
        self.w2 = nn.Linear(ffn_dim, dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

In [9]:
class Transformer_Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.Attention_Norm = RMSNorm(dim=config["DIM"], norm_eps=config["NORM_EPS"])
        self.FFN_Norm = RMSNorm(dim=config["DIM"], norm_eps=config["NORM_EPS"])
        self.Attention = Attention(dim=config["DIM"],
                                   n_heads=config["N_HEADS"],
                                   head_dim=config["HEAD_DIM"],
                                   n_kv_heads=config["N_KV_HEADS"],
                                   n_kv_heads_reps=config["N_KV_HEAD_REP"],
                                   max_batch_size=config["MAX_BATCH_SIZE"],
                                   max_seq_len=config["MAX_SEQ_LEN"])
        self.FeedForward = FeedForward(dim=config["DIM"],
                                       ffn_dim=config["FFN_DIM"])
    def forward(self, x, freq, start_pos, mask):
        shortcut = x
        x = self.Attention_Norm(x)
        x = self.Attention(x, freq, start_pos, mask)
        x = x + shortcut

        shortcut = x
        x = self.FFN_Norm(x)
        x = self.FeedForward(x)
        x = x + shortcut

        return x

In [10]:
class LLAMA_3(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.tok_embedding = nn.Embedding(config["VOCAB_SIZE"], config["DIM"])
        self.layers = nn.ModuleList()
        for _ in range(config["N_LAYERS"]):
            self.layers.append(Transformer_Block(config))
        self.norm = RMSNorm(config["DIM"], config["NORM_EPS"])
        self.output = nn.Linear(config["DIM"], config["VOCAB_SIZE"], bias=False)

        self.freq = calculate_angles(
            config["ROPE_THETA"],
            config["HEAD_DIM"],
            config["MAX_SEQ_LEN"] * 2
        )
    
    def forward(self, tokens, start_pos):
        bhz, seq_len = tokens.shape
        x = self.tok_embedding(tokens)
        freq = self.freq[start_pos : start_pos+seq_len]

        mask = None
        if seq_len > 1:
            mask = torch.full((seq_len, seq_len), float("-inf"), device=tokens.device) 
            mask = torch.triu(mask, diagonal=1)

        for layer in self.layers:
            x = layer(x, freq, start_pos, mask )
        x = self.norm(x)
        x = self.output(x).float()
        
        return x
        

In [11]:
llama = LLAMA_3(config=CONFIGURATIONS).to(device)

In [12]:
total_params = sum(p.numel() for p in llama.parameters())
print(f"Total Number of parameters: {total_params:,}")

Total Number of parameters: 3,606,838,272


In [13]:
total_size_bytes = total_params * 4
total_size_mb = total_size_bytes / (1024 * 1024)
print(f"Total size of the model: {total_size_mb/1024} GB")

Total size of the model: 13.436519622802734 GB


In [14]:
def Generate_Text(model, idx, max_tokens, context_size, start_pos):
    for _ in range(max_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.inference_mode():
            logits = model(idx_cond, start_pos)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_idx = torch.argmax(probs, dim=-1, keepdim=True)
            idx = torch.cat((idx, next_idx), dim=-1)
    return idx

# Tokenizer

In [15]:
class Tokenizer:
    no_reserved_special_tokens = 256
    pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501

    def __init__(self, model_path):
        mergeable_ranks = load_tiktoken_bpe("./Weights/Llama-3.2-3B/original/tokenizer.model")
        num_base_tokens = len(mergeable_ranks)
        special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>",
            "<|reserved_special_token_3|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|reserved_special_token_4|>",
            "<|eot_id|>",  # end of turn
        ] + [f"<|reserved_special_token_{i}|>" for i in range(5, self.no_reserved_special_tokens - 5)]
        
        self.special_tokens = {
            token: num_base_tokens + i for i, token in enumerate(special_tokens)
        }

        self.model = tiktoken.Encoding(
            name=Path(model_path).name,
            pat_str=self.pat_str,
            mergeable_ranks=mergeable_ranks,
            special_tokens=self.special_tokens
        )
        self.n_words = self.model.n_vocab
        self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
        self.eos_id: int = self.special_tokens["<|end_of_text|>"]
        self.pad_id: int = -1
        self.stop_tokens = {
            self.special_tokens["<|end_of_text|>"],
            self.special_tokens["<|eot_id|>"],
        }

    def encode(
        self,
        s: str,
        *,
        bos: bool,
        eos: bool,
        allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
        disallowed_special: Union[Literal["all"], Collection[str]] = (),
        ):
        TIKTOKEN_MAX_ENCODE_CHARS = 400_000
        MAX_NO_WHITESPACES_CHARS = 25_000
        substrs = (
            substr
            for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
            for substr in self._split_whitespaces_or_nonwhitespaces(
                s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
            )
        )
        t: List[int] = []
        for substr in substrs:
            t.extend(
                self.model.encode(
                    substr,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            )
        if bos:
            t.insert(0, self.bos_id)
        if eos:
            t.append(self.eos_id)
        return t
    def decode(self, t: Sequence[int]) -> str:
        return self.model.decode(cast(List[int], t))

    def _split_whitespaces_or_nonwhitespaces(self,
        s: str, max_consecutive_slice_len: int
    ):
        current_slice_len = 0
        current_slice_is_space = s[0].isspace() if len(s) > 0 else False
        slice_start = 0

        for i in range(len(s)):
            is_now_space = s[i].isspace()

            if current_slice_is_space ^ is_now_space:
                current_slice_len = 1
                current_slice_is_space = is_now_space
            else:
                current_slice_len += 1
                if current_slice_len > max_consecutive_slice_len:
                    yield s[slice_start:i]
                    slice_start = i
                    current_slice_len = 1
        yield s[slice_start:]     

In [16]:
tok = Tokenizer("Weights/Llama-3.2-3B/original/tokenizer.model")

In [19]:
start_time = time.time()
z = tok.encode(s="my name is sarthak", bos=False, eos=False)
z = torch.tensor([z], dtype=torch.long, device="cuda")
d = Generate_Text(llama, z, 5, 10, 0)
d = d.squeeze(dim=0)
tok.decode(d.tolist())
end_time = time.time()

print(f"Total Time: {end_time - start_time}")

Total Time: 4.8732240200042725


In [None]:
from load_gpt_2_355M import load_llama3_weights_and_settings, nested_state_dict

In [None]:
dir = "./Weights/Llama-3.2-3B/original/"
settings, params = load_llama3_weights_and_settings(dir=dir)

In [None]:
settings

In [None]:
def weight_injector(llama, params):
    device = next(llama.parameters()).device
    dtype = next(llama.parameters()).dtype    

    llama.tok_embedding.weight.data.copy_(
        torch.from_numpy(params["tok_embeddings"]["weight"]).to(device=device, dtype=dtype)
    )

    for b in range(len(params["layers"])):
        llama.layers[b].Attention.W_Q.weight.data.copy_(
            torch.from_numpy(params["layers"][b]["attention"]["wq"]["weight"]).to(device=device, dtype=dtype)
        )
        llama.layers[b].Attention.W_K.weight.data.copy_(
            torch.from_numpy(params["layers"][b]["attention"]["wk"]["weight"]).to(device=device, dtype=dtype)
        )
        llama.layers[b].Attention.W_V.weight.data.copy_(
            torch.from_numpy(params["layers"][b]["attention"]["wv"]["weight"]).to(device=device, dtype=dtype)
        )
        llama.layers[b].Attention.wo.weight.data.copy_(
            torch.from_numpy(params["layers"][b]["attention"]["wo"]["weight"]).to(device=device, dtype=dtype)
        )

        llama.layers[b].FeedForward.w1.weight.data.copy_(
            torch.from_numpy(params["layers"][b]["feed_forward"]["w1"]["weight"]).to(device=device, dtype=dtype)
        )
        llama.layers[b].FeedForward.w3.weight.data.copy_(
            torch.from_numpy(params["layers"][b]["feed_forward"]["w3"]["weight"]).to(device=device, dtype=dtype)
        )
        llama.layers[b].FeedForward.w2.weight.data.copy_(
            torch.from_numpy(params["layers"][b]["feed_forward"]["w2"]["weight"]).to(device=device, dtype=dtype)
        )

        llama.layers[b].Attention_Norm.weight.data.copy_(
            torch.from_numpy(params["layers"][b]["attention_norm"]["weight"]).to(device=device, dtype=dtype)
        )
        llama.layers[b].FFN_Norm.weight.data.copy_(
            torch.from_numpy(params["layers"][b]["ffn_norm"]["weight"]).to(device=device, dtype=dtype)
        )

    llama.norm.weight.data.copy_(
        torch.from_numpy(params["norm"]["weight"]).to(device=device, dtype=dtype)
    )
    llama.output.weight.data.copy_(
        torch.from_numpy(params["output"]["weight"]).to(device=device, dtype=dtype)
    )

    print("Weights have been loaded successfully......")

In [None]:
weight_injector(llama, params)

In [None]:
z = tok.encode(s="Whats i saw there was", bos=False, eos=False)
z = torch.tensor([z], dtype=torch.long, device="cuda")
d = Generate_Text(llama, z, 20, 10, 0)
d = d.squeeze(dim=0)
tok.decode(d.tolist())

# DATASET

In [None]:
DIR = "./DATA/the-verdict.txt"

with open(DIR, "r", encoding="utf-8") as file:
    text = file.read()

len(text)

# Training

In [None]:
def train_model_sample(model, train_loader, val_loader, optimizer, device, num_epochs,
                       eval_freq, eval_iter, start_context, tokenizer):
    train_losses, val_losses, track_token_seen = [], [], []

    token_seen, global_step = 0, -1

    for epoch in range(num_epochs):
        model.train()
        batch_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for input_batch, target_batch in batch_iter:
            optimizer.zero_grad()
            loss = corss_entropy_loss(input_batch, target_batch, device, model)
            loss.backward()
            optimizer.step()
            token_seen += input_batch.numel()
            global_step +=1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_token_seen.append(token_seen)
                print(f"Ep {epoch+1}(Step {global_step:06d}): "
                      f"Train Loss: {train_loss:.3f}, val Loss: {val_loss:.3f}")
                
        generate_and_print_sample(model, tokenizer, device, start_context)

    return train_losses, val_losses, track_token_seen
            