# Llama 3.2 From Scratch

For ref:
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/llama32.webp" width="700px">
  

In [5]:
!pip install tiktoken



In [6]:
!pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt



In [7]:
from importlib.metadata import version
pkgs = [
    "blobfile",         # to download pretrained weights
    "huggingface_hub",  # to download pretrained weights
    "tiktoken",         # to implement the tokenizer
    "torch",            # to implement the model
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

blobfile version: 3.0.0
huggingface_hub version: 0.33.1
tiktoken version: 0.9.0
torch version: 2.6.0+cu124


&nbsp;
# 1. Architecture code

In [8]:
import torch
import torch.nn as nn


class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [9]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))

    # Frequency adjustments
    if freq_config is not None:
        low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
        high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]

        wavelen = 2 * torch.pi / inv_freq

        inv_freq_llama = torch.where(
            wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
        )

        smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
            freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
        )

        smoothed_inv_freq = (
            (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
        )

        is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
        inv_freq = inv_freq_llama

    # Generate position indices
    positions = torch.arange(context_length, dtype=dtype)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin


def apply_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    # It's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype=x.dtype)

In [10]:
class GroupedQueryAttention(nn.Module):
    def __init__(
            self, d_in, d_out, num_heads,
            num_kv_groups,
            dtype=None
        ):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)

    def forward(self, x, mask, cos, sin):
        b, num_tokens, d_in = x.shape

        queries = self.W_query(x)  # Shape: (b, num_tokens, d_out)
        keys = self.W_key(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)
        values = self.W_value(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)

        # Reshape queries, keys, and values
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)

        # Transpose keys, values, and queries
        keys = keys.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)
        values = values.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)
        queries = queries.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)

        # Apply RoPE
        keys = apply_rope(keys, cos, sin)
        queries = apply_rope(queries, cos, sin)

        # Expand keys and values to match the number of heads
        # Shape: (b, num_heads, num_tokens, head_dim)
        keys = keys.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)
        values = values.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)
        # For example, before repeat_interleave along dim=1 (query groups):
        #   [K1, K2]
        # After repeat_interleave (each query group is repeated group_size times):
        #   [K1, K1, K2, K2]
        # If we used regular repeat instead of repeat_interleave, we'd get:
        #   [K1, K2, K1, K2]

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        # Shape: (b, num_heads, num_tokens, num_tokens)
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Compute attention scores
        attn_scores = attn_scores.masked_fill(mask, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        assert keys.shape[-1] == self.head_dim

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec

In [11]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            num_heads=cfg["n_heads"],
            num_kv_groups=cfg["n_kv_groups"],
            dtype=cfg["dtype"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
        self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])

    def forward(self, x, mask, cos, sin):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x, mask, cos, sin)  # Shape [batch_size, num_tokens, emb_size]
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed-forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = x + shortcut  # Add the original input back

        return x

In [12]:
class Llama3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        # Main model parameters
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])

        self.trf_blocks = nn.ModuleList(  # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )

        self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

        # Reusuable utilities
        cos, sin = compute_rope_params(
            head_dim=cfg["emb_dim"] // cfg["n_heads"],
            theta_base=cfg["rope_base"],
            context_length=cfg["context_length"],
            freq_config=cfg["rope_freq"]
        )
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)
        self.cfg = cfg


    def forward(self, in_idx):
        # Forward pass
        tok_embeds = self.tok_emb(in_idx)
        x = tok_embeds

        num_tokens = x.shape[1]
        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)

        for block in self.trf_blocks:
            x = block(x, mask, self.cos, self.sin)
        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg["dtype"]))
        return logits

&nbsp;
# 2. Initialize model

In [13]:
# Llama 3.2 1B

LLAMA32_CONFIG = {
    "vocab_size": 128_256,           # Vocabulary size
    "context_length": 131_072,       # Context length that was used to train the model
    "emb_dim": 2048,                 # Embedding dimension
    "n_heads": 32,                   # Number of attention heads
    "n_layers": 16,                  # Number of layers
    "hidden_dim": 8192,              # Size of the intermediate dimension in FeedForward
    "n_kv_groups": 8,                # Key-Value groups for grouped-query attention
    "rope_base": 500_000.0,          # The base in RoPE's "theta"
    "dtype": torch.bfloat16,         # Lower-precision dtype to reduce memory usage
    "rope_freq": {                   # RoPE frequency scaling
        "factor": 32.0,
        "low_freq_factor": 1.0,
        "high_freq_factor": 4.0,
        "original_context_length": 8192,
    }
}

# Llama 3.2 3B

# LLAMA32_CONFIG = {
#     "vocab_size": 128_256,           # Vocabulary size
#     "context_length": 131_072,       # Context length that was used to train the model
#     "emb_dim": 3072,                 # Embedding dimension
#     "n_heads": 24,                   # Number of attention heads
#     "n_layers": 28,                  # Number of layers
#     "hidden_dim": 8192,              # Size of the intermediate dimension in FeedForward
#     "n_kv_groups": 8,                # Key-Value groups for grouped-query attention
#     "rope_base": 500_000.0,          # The base in RoPE's "theta"
#     "dtype": torch.bfloat16,         # Lower-precision dtype to reduce memory usage
#     "rope_freq": {                   # RoPE frequency scaling
#         "factor": 32.0,
#         "low_freq_factor": 1.0,
#         "high_freq_factor": 4.0,
#         "original_context_length": 8192,
#     }
# }

LLAMA_SIZE_STR = "1B" if LLAMA32_CONFIG["emb_dim"] == 2048 else "3B"

In [14]:
model = Llama3Model(LLAMA32_CONFIG)

In [15]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

# Account for weight tying
total_params_normalized = total_params - model.tok_emb.weight.numel()
print(f"\nTotal number of unique parameters: {total_params_normalized:,}")

Total number of parameters: 1,498,482,688

Total number of unique parameters: 1,235,814,400


In [16]:
def model_memory_size(model, input_dtype=torch.float32):
    total_params = 0
    total_grads = 0
    for param in model.parameters():
        # Calculate total number of elements per parameter
        param_size = param.numel()
        total_params += param_size
        # Check if gradients are stored for this parameter
        if param.requires_grad:
            total_grads += param_size

    # Calculate buffer size (non-parameters that require memory)
    total_buffers = sum(buf.numel() for buf in model.buffers())

    # Size in bytes = (Number of elements) * (Size of each element in bytes)
    # We assume parameters and gradients are stored in the same type as input dtype
    element_size = torch.tensor(0, dtype=input_dtype).element_size()
    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size

    # Convert bytes to gigabytes
    total_memory_gb = total_memory_bytes / (1024**3)

    return total_memory_gb

print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB")
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB")
print("If quantization is enabled above, actual memory usage will be lower.")

float32 (PyTorch default): 11.23 GB
bfloat16: 5.61 GB
If quantization is enabled above, actual memory usage will be lower.


In [17]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model.to(device);

&nbsp;
# 3. Load tokenizer

In [18]:
import os
from pathlib import Path

import tiktoken
from tiktoken.load import load_tiktoken_bpe



class Tokenizer:
    """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
    def __init__(self, model_path):
        if not os.path.isfile(model_path):
            raise FileNotFoundError(model_path)

        mergeable = load_tiktoken_bpe(model_path)

        # hard-coded from Meta's tokenizer.json
        self.special = {
            "<|begin_of_text|>": 128000,
            "<|end_of_text|>": 128001,
            "<|start_header_id|>": 128006,
            "<|end_header_id|>": 128007,
            "<|eot_id|>": 128009,
        }
        self.special.update({f"<|reserved_{i}|>": 128002 + i
                             for i in range(256)
                             if 128002 + i not in self.special.values()})

        self.model = tiktoken.Encoding(
            name=Path(model_path).name,
            pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
                    r"|[^\r\n\p{L}\p{N}]?\p{L}+"
                    r"|\p{N}{1,3}"
                    r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
                    r"|\s*[\r\n]+"
                    r"|\s+(?!\S)"
                    r"|\s+",
            mergeable_ranks=mergeable,
            special_tokens=self.special,
        )

    def encode(self, text, bos=False, eos=False):
        ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
              + self.model.encode(text)
        if eos:
            ids.append(self.special["<|end_of_text|>"])
        return ids

    def decode(self, ids):
        return self.model.decode(ids)


class ChatFormat:

    def __init__(self, tokenizer: Tokenizer, *,
                 default_system="You are a helpful assistant."):
        self.tok = tokenizer
        self.default_system = default_system

    def _header(self, role):
        """Encode <|start_header_id|>role<|end_header_id|>\n\n"""
        return (
            [self.tok.special["<|start_header_id|>"]]
            + self.tok.encode(role)
            + [self.tok.special["<|end_header_id|>"]]
            + self.tok.encode("\n\n")
        )

    def encode(self, user_message, system_message=None):
        sys_msg = system_message if system_message is not None else self.default_system

        ids = [self.tok.special["<|begin_of_text|>"]]

        # system
        ids += self._header("system")
        ids += self.tok.encode(sys_msg)
        ids += [self.tok.special["<|eot_id|>"]]

        # user
        ids += self._header("user")
        ids += self.tok.encode(user_message)
        ids += [self.tok.special["<|eot_id|>"]]

        # assistant header (no content yet)
        ids += self._header("assistant")

        return ids

In [19]:
#!pip install ipywidgets==7.7.1

In [20]:
# from google.colab import output
# output.enable_custom_widget_manager()

In [23]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [24]:
from huggingface_hub import hf_hub_download

tokenizer_file_path = hf_hub_download(
    repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
    filename="original/tokenizer.model",
    local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
)

original/tokenizer.model:   0%|          | 0.00/2.18M [00:00<?, ?B/s]

In [29]:
tokenizer = Tokenizer(tokenizer_file_path)
chat_tokenizer = ChatFormat(tokenizer)

&nbsp;
# 4. Load pretrained weights

In [30]:
def assign(left, right, tensor_name="unknown"):
    if left.shape != right.shape:
        raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")

    if isinstance(right, torch.Tensor):
        return torch.nn.Parameter(right.clone().detach())
    else:
        return torch.nn.Parameter(torch.tensor(right))


def load_weights_into_llama(model, param_config, params):
    model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")

    for l in range(param_config["n_layers"]):

        # Load attention weights
        model.trf_blocks[l].att.W_query.weight = assign(
            model.trf_blocks[l].att.W_query.weight,
            params[f"model.layers.{l}.self_attn.q_proj.weight"],
            f"model.layers.{l}.self_attn.q_proj.weight"
        )
        model.trf_blocks[l].att.W_key.weight = assign(
            model.trf_blocks[l].att.W_key.weight,
            params[f"model.layers.{l}.self_attn.k_proj.weight"],
            f"model.layers.{l}.self_attn.k_proj.weight"
        )
        model.trf_blocks[l].att.W_value.weight = assign(
            model.trf_blocks[l].att.W_value.weight,
            params[f"model.layers.{l}.self_attn.v_proj.weight"],
            f"model.layers.{l}.self_attn.v_proj.weight"
        )
        model.trf_blocks[l].att.out_proj.weight = assign(
            model.trf_blocks[l].att.out_proj.weight,
            params[f"model.layers.{l}.self_attn.o_proj.weight"],
            f"model.layers.{l}.self_attn.o_proj.weight"
        )
        model.trf_blocks[l].norm1.weight = assign(
            model.trf_blocks[l].norm1.weight,
            params[f"model.layers.{l}.input_layernorm.weight"],
            f"model.layers.{l}.input_layernorm.weight"
        )

        # Load FeedForward weights
        model.trf_blocks[l].ff.fc1.weight = assign(
            model.trf_blocks[l].ff.fc1.weight,
            params[f"model.layers.{l}.mlp.gate_proj.weight"],
            f"model.layers.{l}.mlp.gate_proj.weight"
        )
        model.trf_blocks[l].ff.fc2.weight = assign(
            model.trf_blocks[l].ff.fc2.weight,
            params[f"model.layers.{l}.mlp.up_proj.weight"],
            f"model.layers.{l}.mlp.up_proj.weight"
        )
        model.trf_blocks[l].ff.fc3.weight = assign(
            model.trf_blocks[l].ff.fc3.weight,
            params[f"model.layers.{l}.mlp.down_proj.weight"],
            f"model.layers.{l}.mlp.down_proj.weight"
        )
        model.trf_blocks[l].norm2.weight = assign(
            model.trf_blocks[l].norm2.weight,
            params[f"model.layers.{l}.post_attention_layernorm.weight"],
            f"model.layers.{l}.post_attention_layernorm.weight"
        )

    # Load output layer weights
    model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")

    if "lm_head.weight" in params.keys():
        model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
    else:
        model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
        print("Model uses weight tying.")

In [31]:
from safetensors.torch import load_file


if LLAMA_SIZE_STR == "1B":
    weights_file = hf_hub_download(
        repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
        filename="model.safetensors",
        local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
    )
    combined_weights = load_file(weights_file)


else:
    combined_weights = {}
    for i in range(1, 3):
        weights_file = hf_hub_download(
            repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
            filename=f"model-0000{i}-of-00002.safetensors",
            local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
        )
        current_weights = load_file(weights_file)
        combined_weights.update(current_weights)


load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)
model.to(device)
del combined_weights  # free up memory

Model uses weight tying.


In [32]:
print("Weight tying:", torch.equal(model.tok_emb.weight, model.out_head.weight))

Weight tying: True


&nbsp;
# 5. Generate text

In [33]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text)
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension
    return encoded_tensor


def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0)  # remove batch dimension
    return tokenizer.decode(flat.tolist())


def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):

    # For-loop is the same as before: Get logits, and only focus on last time step
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        # New: Filter logits with top_k sampling
        if top_k is not None:
            # Keep only top_k values
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)

        # New: Apply temperature scaling
        if temperature > 0.0:
            logits = logits / temperature

            # Apply softmax to get probabilities
            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)

            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)

        # Otherwise same as before: get idx of the vocab entry with the highest logits value
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)

        if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified
            break

        # Same as before: append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)

    return idx

In [None]:
import time


PROMPT = "What is quantization?"

torch.manual_seed(123)

start = time.time()

token_ids = generate(
    model=model,
    idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),
    max_new_tokens=150,
    context_size=LLAMA32_CONFIG["context_length"],
    top_k=1,
    temperature=0.
)

print(f"Time: {time.time() - start:.2f} sec")

if torch.cuda.is_available():
    max_mem_bytes = torch.cuda.max_memory_allocated()
    max_mem_gb = max_mem_bytes / (1024 ** 3)
    print(f"Max memory allocated: {max_mem_gb:.2f} GB")

output_text = token_ids_to_text(token_ids, tokenizer)


def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
    # Find the index of the first occurrence of "<|end_header_id|>"
    index = text.find(header_end)

    if index != -1:
        # Return the substring starting after "<|end_header_id|>"
        return text[index + len(header_end):].strip()  # Strip removes leading/trailing whitespace
    else:
        # If the token is not found, return the original text
        return text

print("\n\nOutput text:\n\n", clean_text(output_text))

#Evaluation(When2Call)

In [None]:
import os
import subprocess
import shutil
from pathlib import Path

def setup_when2call_evaluation():
    """Setup When2Call dataset with LM-Eval-Harness"""

    print("🚀 Setting up When2Call evaluation framework...")

    # Create base directory if it doesn't exist
    base_dir = Path.cwd()

    # Step 1: Clone When2Call repository
    print("\n📦 Cloning When2Call repository...")
    if not os.path.exists("When2Call"):
        result = subprocess.run([
            "git", "clone", "https://github.com/NVIDIA/When2Call.git"
        ], capture_output=True, text=True)
        if result.returncode == 0:
            print("✅ When2Call cloned successfully")
        else:
            print(f"❌ Error cloning When2Call: {result.stderr}")
            return False
    else:
        print("📁 When2Call directory already exists")

    # Change to When2Call directory
    os.chdir("When2Call")

    # Step 2: Clone LM-Evaluation-Harness repository
    print("\n📦 Cloning LM-Evaluation-Harness repository...")
    if not os.path.exists("lm-evaluation-harness"):
        result = subprocess.run([
            "git", "clone", "https://github.com/EleutherAI/lm-evaluation-harness.git"
        ], capture_output=True, text=True)
        if result.returncode == 0:
            print("✅ LM-Evaluation-Harness cloned successfully")
        else:
            print(f"❌ Error cloning LM-Evaluation-Harness: {result.stderr}")
            os.chdir(base_dir)
            return False
    else:
        print("📁 LM-Evaluation-Harness directory already exists")

    # Step 3: Install LM-Evaluation-Harness
    print("\n🔧 Installing LM-Evaluation-Harness...")
    os.chdir("lm-evaluation-harness")

    result = subprocess.run([
        "pip", "install", "-e", "."
    ], capture_output=True, text=True)

    if result.returncode == 0:
        print("✅ LM-Evaluation-Harness installed successfully")
    else:
        print(f"❌ Error installing LM-Evaluation-Harness: {result.stderr}")
        os.chdir(base_dir)
        return False

    # Go back to When2Call directory
    os.chdir("..")

    # Step 4: Copy When2Call evaluation tasks to LM-Eval-Harness
    print("\n📋 Copying When2Call evaluation tasks...")

    source_path = "evaluation/mcq/lm_eval_harness/when2call"
    target_path = "lm-evaluation-harness/lm_eval/tasks/when2call"

    if os.path.exists(source_path):
        if os.path.exists(target_path):
            shutil.rmtree(target_path)

        shutil.copytree(source_path, target_path)
        print("✅ When2Call tasks copied successfully")
    else:
        print(f"❌ Source path not found: {source_path}")
        os.chdir(base_dir)
        return False

    # Step 5: Copy test data
    print("\n📊 Setting up test data...")

    # Copy test data to the when2call task directory
    test_data_source = "data/test/when2call_test_mcq.jsonl"
    test_data_target = "lm-evaluation-harness/lm_eval/tasks/when2call/when2call_test_mcq.jsonl"

    if os.path.exists(test_data_source):
        shutil.copy2(test_data_source, test_data_target)
        print("✅ Test data copied successfully")
    else:
        print(f"⚠️  Test data not found at: {test_data_source}")

    # Go back to base directory
    os.chdir(base_dir)

    # Step 6: Verify setup
    print("\n🔍 Verifying setup...")

    required_files = [
        "When2Call/data/test/when2call_test_mcq.jsonl",
        "When2Call/lm-evaluation-harness/lm_eval/tasks/when2call/when2call-llama3_2.yaml",
        "When2Call/lm-evaluation-harness/lm_eval/tasks/when2call/utils.py"
    ]

    all_good = True
    for file_path in required_files:
        if os.path.exists(file_path):
            print(f"✅ {file_path}")
        else:
            print(f"❌ Missing: {file_path}")
            all_good = False

    if all_good:
        print("\n🎉 Setup completed successfully!")
        print("\nYou can now run When2Call evaluation with:")
        print("cd When2Call/lm-evaluation-harness")
        print("lm_eval --model hf --model_args pretrained=your-model --tasks when2call-llama3_2 --batch_size 8")
        return True
    else:
        print("\n❌ Setup incomplete - some files are missing")
        return False

# Run the setup
if __name__ == "__main__":
    success = setup_when2call_evaluation()
    if success:
        print("\n✅ Ready to evaluate your Llama 3.2 model on When2Call!")
    else:
        print("\n❌ Setup failed. Please check the errors above.")


In [None]:
# Verify the installation worked
import os
import subprocess

def verify_when2call_setup():
    """Verify that When2Call is properly set up with LM-Eval-Harness"""

    print("🔍 Verifying When2Call setup...")

    # Check if we can import lm_eval
    try:
        import lm_eval
        print("✅ lm_eval imported successfully")
    except ImportError as e:
        print(f"❌ Cannot import lm_eval: {e}")
        return False

    # Check if When2Call task exists
    if os.path.exists("When2Call/lm-evaluation-harness"):
        os.chdir("When2Call/lm-evaluation-harness")

        # List available tasks to see if when2call is there
        result = subprocess.run([
            "lm_eval", "--tasks", "list"
        ], capture_output=True, text=True)

        if "when2call" in result.stdout:
            print("✅ When2Call tasks found in LM-Eval-Harness")

            # Show specific When2Call tasks
            when2call_tasks = [line for line in result.stdout.split('\n') if 'when2call' in line.lower()]
            print("\n📋 Available When2Call tasks:")
            for task in when2call_tasks[:10]:  # Show first 10
                if task.strip():
                    print(f"  - {task.strip()}")

            os.chdir("../..")
            return True
        else:
            print("❌ When2Call tasks not found in task list")
            print("Available tasks preview:")
            print(result.stdout[:500] + "..." if len(result.stdout) > 500 else result.stdout)
            os.chdir("../..")
            return False
    else:
        print("❌ LM-Evaluation-Harness directory not found")
        return False

# Run verification
verify_when2call_setup()


In [None]:
%load_ext cuml.accel
import sklearn
import json
import time
from sklearn.metrics import f1_score, accuracy_score

In [None]:
def format_when2call_prompt_v2(tools, question, chat_tokenizer):
    """Ultra-strict prompt that prevents parameter hallucination"""
    tools_json = [json.loads(t) for t in tools] if tools else []

    if tools_json:
        system_message = f"""You are an expert in composing functions. You are given a question and a set of possible functions.

CRITICAL RULES:
1. You can ONLY use the exact parameter names listed in each function's "properties"
2. You MUST NOT invent new parameters, even if they seem logical
3. You MUST NOT add parameters like "size", "temperature", "amount" unless they are explicitly listed
4. If the user asks for modifications that require parameters NOT in the schema, respond: "I cannot help with this request using the available functions."
5. Only make a tool call if ALL user requirements can be satisfied with the available parameters

Function schemas (use ONLY these parameters):
{json.dumps(tools_json, indent=2)}

Valid response formats:
- Tool call: [func_name(param1="value1", param2="value2")] (ONLY using schema parameters)
- Cannot help: "I cannot help with this request using the available functions."
- Need info: "To assist you better, could you please specify the [missing_required_parameter]?"

Check each user request against the schema carefully. Do not add any parameters not explicitly defined."""
    else:
        system_message = "You are a helpful assistant. You do not have access to any external functions or tools."

    token_ids = chat_tokenizer.encode(question, system_message=system_message)
    return torch.tensor(token_ids).unsqueeze(0)


In [None]:
def classify_response(response, sample):
    """Classify response into one of the 4 categories"""
    response_lower = response.lower()
    print(f"Classifying response: {response_lower}")

    # Check for tool calls (function call format)
    if '[' in response and '(' in response and ')' in response:
        return 'tool_call'

    # Check for requests for more information
    request_phrases = ['could you', 'please specify', 'can you provide', 'need more',
                      'clarify', 'which', 'what type', 'more details', 'tell me more',
                      'can you tell me', 'please provide', 'what is the', 'need to know']
    if any(phrase in response_lower for phrase in request_phrases):
        return 'request_for_info'

    # Check for cannot answer responses
    cannot_phrases = ['cannot', "can't", 'unable', 'not possible', 'apologies',
                     "don't have", 'not available', 'insufficient', 'sorry',
                     'i cannot', "i can't", 'not able to', 'unable to provide']
    if any(phrase in response_lower for phrase in cannot_phrases):
        return 'cannot_answer'

    # Default to direct answer
    return 'direct'


In [None]:
def run_when2call_evaluation(test_file_path, limit=None):
    """Run When2Call evaluation with your model"""

    # Load test data
    test_samples = []
    with open(test_file_path, 'r') as f:
        for line in f:
            test_samples.append(json.loads(line))

    if limit:
        test_samples = test_samples[:limit]

    print(f"Evaluating {len(test_samples)} samples...")
    print(f"Sample questions: {[s['question'] for s in test_samples[:2]]}")
    results = []
    start_time = time.time()

    for i, sample in enumerate(test_samples):
        if i % 5 == 0:
            print(f"Progress: {i}/{len(test_samples)}")

        try:
            # Format prompt using your ChatFormat
            token_tensor = format_when2call_prompt_v2(sample['tools'], sample['question'], chat_tokenizer)

            # Generate response using your inference pipeline
            torch.manual_seed(123)
            token_ids = generate(
                model=model,
                idx=token_tensor.to(device),
                max_new_tokens=100,  # Reduced for faster evaluation
                context_size=LLAMA32_CONFIG["context_length"],
                top_k=1,
                temperature=0.
            )

            output_text = token_ids_to_text(token_ids, tokenizer)
            response = clean_text(output_text)

            # Clean response further - remove repeated parts
            response = response.split('<|eot_id|>')[0].strip()

            # Classify and store result
            predicted = classify_response(response, sample)
            correct = sample['correct_answer']

            results.append({
                'uuid': sample['uuid'],
                'predicted': predicted,
                'correct': correct,
                'is_correct': predicted == correct,
                'response': response[:150] + "..." if len(response) > 150 else response,
                'question': sample['question'],
                'has_tools': len(sample['tools']) > 0
            })

        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            results.append({
                'uuid': sample['uuid'],
                'predicted': 'cannot_answer',
                'correct': sample['correct_answer'],
                'is_correct': False,
                'response': f'ERROR: {str(e)}',
                'question': sample['question'],
                'has_tools': False
            })

    end_time = time.time()

    # Calculate metrics
    y_true = [r['correct'] for r in results]
    y_pred = [r['predicted'] for r in results]

    accuracy = sum(1 for true, pred in zip(y_true, y_pred) if true == pred) / len(y_true)

    # Calculate Tool Hallucination Rate (key When2Call metric)
    cannot_answer_samples = [r for r in results if r['correct'] == 'cannot_answer']
    cannot_answer_no_tools = [r for r in cannot_answer_samples if not r['has_tools']]
    hallucinated_tools = [r for r in cannot_answer_no_tools if r['predicted'] == 'tool_call']
    hallucination_rate = len(hallucinated_tools) / max(len(cannot_answer_no_tools), 1)

    # Calculate F1 per category
    categories = ['direct', 'tool_call', 'request_for_info', 'cannot_answer']
    f1_scores = []

    for cat in categories:
        true_pos = sum(1 for t, p in zip(y_true, y_pred) if t == cat and p == cat)
        false_pos = sum(1 for t, p in zip(y_true, y_pred) if t != cat and p == cat)
        false_neg = sum(1 for t, p in zip(y_true, y_pred) if t == cat and p != cat)

        precision = true_pos / (true_pos + false_pos) if true_pos + false_pos > 0 else 0
        recall = true_pos / (true_pos + false_neg) if true_pos + false_neg > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
        f1_scores.append(f1)

    macro_f1 = sum(f1_scores) / len(f1_scores)

    # Print results
    print(f"\n" + "="*60)
    print("WHEN2CALL BENCHMARK RESULTS")
    print("="*60)
    print(f"Model Performance Summary:")
    print(f"├─ Total samples: {len(results)}")
    print(f"├─ Evaluation time: {end_time - start_time:.1f} seconds")
    print(f"├─ Speed: {len(results)/(end_time - start_time):.2f} samples/sec")
    print(f"├─ Accuracy: {accuracy:.3f}")
    print(f"├─ Macro F1: {macro_f1:.3f}")
    print(f"└─ Tool Hallucination Rate: {hallucination_rate:.3f} (lower is better)")

    # Per-category breakdown
    print(f"\nPer-category Performance:")
    for i, cat in enumerate(categories):
        cat_results = [r for r in results if r['correct'] == cat]
        if cat_results:
            cat_correct = sum(1 for r in cat_results if r['is_correct'])
            cat_accuracy = cat_correct / len(cat_results)
            print(f"├─ {cat:15}: {cat_accuracy:.3f} ({cat_correct}/{len(cat_results)}) F1={f1_scores[i]:.3f}")

    # Comparison with reported results
    print(f"\nComparison with Llama 3.2 Baselines:")
    print(f"├─ Llama 3.2 1B: F1=0.217, Acc=0.451, Hall=0.43")
    print(f"├─ Llama 3.2 3B: F1=0.179, Acc=0.465, Hall=0.52")
    print(f"└─ Your model:   F1={macro_f1:.3f}, Acc={accuracy:.3f}, Hall={hallucination_rate:.2f}")

    return results


In [None]:
# Make sure you have the When2Call data file
test_file = "When2Call/data/test/when2call_test_mcq.jsonl"

# Run evaluation on first 10 samples to test
print("Starting When2Call evaluation...")
eval_results = run_when2call_evaluation(
    test_file_path=test_file,
    limit=50  # Start small to test
)

print(f"\nCompleted evaluation of {len(eval_results)} samples")


In [None]:
import os

# Define the directory where you want to save the results
output_dir = "/content"
os.makedirs(output_dir, exist_ok=True) # Create the directory if it doesn't exist

# Save detailed results
results_filename = f'when2call_results_{len(eval_results)}samples.json'
results_file_path = os.path.join(output_dir, results_filename)

with open(results_file_path, 'w') as f:
    json.dump(eval_results, f, indent=2, default=str)

print(f"Results saved to: {os.path.abspath(results_file_path)}")

# Create summary
summary = {
    'model_info': 'Custom Llama 3.2 Model',
    'total_samples': len(eval_results),
    'accuracy': sum(1 for r in eval_results if r['is_correct']) / len(eval_results),
    'macro_f1': sum([
        2 * sum(1 for t, p in zip([r['correct'] for r in eval_results], [r['predicted'] for r in eval_results]) if t == cat and p == cat) /
        (sum(1 for t, p in zip([r['correct'] for r in eval_results], [r['predicted'] for r in eval_results]) if t == cat or p == cat) or 1)
        for cat in ['direct', 'tool_call', 'request_for_info', 'cannot_answer']
    ]) / 4,
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
}

print(f"\nSUMMARY:")
print(f"Accuracy: {summary['accuracy']:.3f}")
print(f"Macro F1: {summary['macro_f1']:.3f}")

In [None]:
# Let's see what the model is actually generating
print("DEBUGGING - Actual model responses:")
print("="*60)

for i, result in enumerate(eval_results[:5]):  # Show first 5
    print(f"\nSample {i+1}:")
    print(f"Question: {result['question']}")
    print(f"Actual Response: '{result['response']}'")
    print(f"Predicted Category: {result['predicted']}")
    print(f"Correct Category: {result['correct']}")
    print(f"Tools Available: {len([s for s in test_samples if s['uuid'] == result['uuid']][0]['tools'])} tools")
    print("-" * 40)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Convert results to DataFrame for easier plotting and saving
results_df = pd.DataFrame(eval_results)

# Define the directory where you want to save the CSV
output_dir = "/content"
os.makedirs(output_dir, exist_ok=True) # Create the directory if it doesn't exist

# Save results to CSV
results_csv_filename = f'when2call_detailed_results_{len(eval_results)}samples.csv'
results_csv_file_path = os.path.join(output_dir, results_csv_filename)
results_df.to_csv(results_csv_file_path, index=False)
print(f"\nDetailed results saved to: {os.path.abspath(results_csv_file_path)}")

# --- Visualization ---

print("\nVisualizing results...")

# Overall accuracy and Macro F1
# Need to recreate summary or use the existing one if available
# Assuming 'summary' variable is available from the previous cell execution
if 'summary' in locals():
    summary_df = pd.DataFrame([summary])

    plt.figure(figsize=(8, 5))
    sns.barplot(x=['Accuracy', 'Macro F1'], y=[summary_df['accuracy'].iloc[0], summary_df['macro_f1'].iloc[0]])
    plt.title('Overall Model Performance on When2Call (Subset)')
    plt.ylabel('Score')
    plt.ylim(0, 1)
    plt.show()
else:
    print("Summary variable not found. Cannot plot overall performance.")


# Per-category accuracy
category_accuracy = results_df.groupby('correct')['is_correct'].mean().reset_index()
plt.figure(figsize=(10, 6))
sns.barplot(x='correct', y='is_correct', data=category_accuracy)
plt.title('Per-Category Accuracy on When2Call (Subset)')
plt.xlabel('Correct Category')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.show()

# Tool Hallucination Rate (for samples with no tools)
if 'has_tools' in results_df.columns:
    cannot_answer_no_tools_df = results_df[(results_df['correct'] == 'cannot_answer') & (~results_df['has_tools'])]
    hallucinated_tools_count = cannot_answer_no_tools_df[cannot_answer_no_tools_df['predicted'] == 'tool_call'].shape[0]
    no_tools_cannot_answer_count = cannot_answer_no_tools_df.shape[0]
    hallucination_rate_viz = hallucinated_tools_count / max(no_tools_cannot_answer_count, 1)

    plt.figure(figsize=(6, 4))
    sns.barplot(x=['Tool Hallucination Rate'], y=[hallucination_rate_viz])
    plt.title('Tool Hallucination Rate (No Tools)')
    plt.ylabel('Rate')
    plt.ylim(0, 1)
    plt.show()

In [None]:
# Quantize model weights to 8-bit (dynamic quantization for linear layers)
# import torch

# Uncomment to enable quantization (recommended for CPU inference)
# model = torch.quantization.quantize_dynamic(
#     model, {torch.nn.Linear}, dtype=torch.qint8
# )

# If using quantization, ensure model is on CPU:
# model.to('cpu')

# --- Remove model from memory and clear cache everywhere ---
# del model
# import gc
# gc.collect()
# if torch.cuda.is_available():
#     torch.cuda.empty_cache()
#     torch.cuda.ipc_collect()
# print("Model deleted and memory/cache cleared.")