# Llama 3.2 From Scratch

In [1]:
#Reqs:
#blobfile>=3.0.0
#huggingface_hub>=0.24.7
#ipywidgets>=8.1.2
#safetensors>=0.4.4
#sentencepiece>=0.1.99

In [2]:
!pip install tiktoken



In [3]:
#%load_ext cuml.accel

In [4]:
from importlib.metadata import version

pkgs = [
    "blobfile",         # to download pretrained weights
    "huggingface_hub",  # to download pretrained weights
    "tiktoken",         # to implement the tokenizer
    "torch",            # to implement the model
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

blobfile version: 3.0.0
huggingface_hub version: 0.34.4
tiktoken version: 0.11.0
torch version: 2.8.0+cu126


In [5]:
import torch
import torch.nn as nn
from safetensors.torch import load_file
import os
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe

&nbsp;
# 1. Architecture code

In [6]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [7]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))

    # Frequency adjustments
    if freq_config is not None:
        low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
        high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]

        wavelen = 2 * torch.pi / inv_freq

        inv_freq_llama = torch.where(
            wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
        )

        smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
            freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
        )

        smoothed_inv_freq = (
            (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
        )

        is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
        inv_freq = inv_freq_llama

    # Generate position indices
    positions = torch.arange(context_length, dtype=dtype)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin


def apply_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    # It's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype=x.dtype)

In [8]:
class GroupedQueryAttention(nn.Module):
    def __init__(
            self, d_in, d_out, num_heads,
            num_kv_groups,
            dtype=None
        ):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)

    def forward(self, x, mask, cos, sin):
        b, num_tokens, d_in = x.shape

        queries = self.W_query(x)  # Shape: (b, num_tokens, d_out)
        keys = self.W_key(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)
        values = self.W_value(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)

        # Reshape queries, keys, and values
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)

        # Transpose keys, values, and queries
        keys = keys.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)
        values = values.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)
        queries = queries.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)

        # Apply RoPE
        keys = apply_rope(keys, cos, sin)
        queries = apply_rope(queries, cos, sin)

        # Expand keys and values to match the number of heads
        # Shape: (b, num_heads, num_tokens, head_dim)
        keys = keys.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)
        values = values.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)

        # Compute scaled dot-product attention (self-attention) with a causal mask
        # Shape: (b, num_heads, num_tokens, num_tokens)
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Compute attention scores
        attn_scores = attn_scores.masked_fill(mask, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        assert keys.shape[-1] == self.head_dim

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            num_heads=cfg["n_heads"],
            num_kv_groups=cfg["n_kv_groups"],
            dtype=cfg["dtype"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
        self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])

    def forward(self, x, mask, cos, sin):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x, mask, cos, sin)  # Shape [batch_size, num_tokens, emb_size]
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed-forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = x + shortcut  # Add the original input back

        return x

In [10]:
class Llama3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        # Main model parameters
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])

        self.trf_blocks = nn.ModuleList(  # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )

        self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

        # Reusuable utilities
        cos, sin = compute_rope_params(
            head_dim=cfg["emb_dim"] // cfg["n_heads"],
            theta_base=cfg["rope_base"],
            context_length=cfg["context_length"],
            freq_config=cfg["rope_freq"]
        )
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)
        self.cfg = cfg


    def forward(self, in_idx):
        # Forward pass
        tok_embeds = self.tok_emb(in_idx)
        x = tok_embeds

        num_tokens = x.shape[1]
        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)

        for block in self.trf_blocks:
            x = block(x, mask, self.cos, self.sin)
        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg["dtype"]))
        return logits

&nbsp;
# 2. Initialize model

In [11]:
# Llama 3.2 1B

LLAMA32_CONFIG = {
    "vocab_size": 128_256,           # Vocabulary size
    "context_length": 131_072,       # Context length that was used to train the model
    "emb_dim": 2048,                 # Embedding dimension
    "n_heads": 32,                   # Number of attention heads
    "n_layers": 16,                  # Number of layers
    "hidden_dim": 8192,              # Size of the intermediate dimension in FeedForward
    "n_kv_groups": 8,                # Key-Value groups for grouped-query attention
    "rope_base": 500_000.0,          # The base in RoPE's "theta"
    "dtype": torch.bfloat16,         # Lower-precision dtype to reduce memory usage
    "rope_freq": {                   # RoPE frequency scaling
        "factor": 32.0,
        "low_freq_factor": 1.0,
        "high_freq_factor": 4.0,
        "original_context_length": 8192,
    }
}

# Llama 3.2 3B

# LLAMA32_CONFIG = {
#     "vocab_size": 128_256,           # Vocabulary size
#     "context_length": 131_072,       # Context length that was used to train the model
#     "emb_dim": 3072,                 # Embedding dimension
#     "n_heads": 24,                   # Number of attention heads
#     "n_layers": 28,                  # Number of layers
#     "hidden_dim": 8192,              # Size of the intermediate dimension in FeedForward
#     "n_kv_groups": 8,                # Key-Value groups for grouped-query attention
#     "rope_base": 500_000.0,          # The base in RoPE's "theta"
#     "dtype": torch.bfloat16,         # Lower-precision dtype to reduce memory usage
#     "rope_freq": {                   # RoPE frequency scaling
#         "factor": 32.0,
#         "low_freq_factor": 1.0,
#         "high_freq_factor": 4.0,
#         "original_context_length": 8192,
#     }
# }

LLAMA_SIZE_STR = "1B" if LLAMA32_CONFIG["emb_dim"] == 2048 else "3B"

In [12]:
model = Llama3Model(LLAMA32_CONFIG)

- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:

In [13]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

# Account for weight tying
total_params_normalized = total_params - model.tok_emb.weight.numel()
print(f"\nTotal number of unique parameters: {total_params_normalized:,}")

Total number of parameters: 1,498,482,688

Total number of unique parameters: 1,235,814,400


In [14]:
def model_memory_size(model, input_dtype=torch.float32):
    total_params = 0
    total_grads = 0
    for param in model.parameters():
        # Calculate total number of elements per parameter
        param_size = param.numel()
        total_params += param_size
        # Check if gradients are stored for this parameter
        if param.requires_grad:
            total_grads += param_size

    # Calculate buffer size (non-parameters that require memory)
    total_buffers = sum(buf.numel() for buf in model.buffers())

    # Size in bytes = (Number of elements) * (Size of each element in bytes)
    # We assume parameters and gradients are stored in the same type as input dtype
    element_size = torch.tensor(0, dtype=input_dtype).element_size()
    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size

    # Convert bytes to gigabytes
    total_memory_gb = total_memory_bytes / (1024**3)

    return total_memory_gb

print(f"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB")
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB")
print("If quantization is enabled above, actual memory usage will be lower.")

float32 (PyTorch default): 11.23 GB
bfloat16: 5.61 GB
If quantization is enabled above, actual memory usage will be lower.


In [15]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model.to(device);

&nbsp;
# 3. Load tokenizer

In [16]:
class Tokenizer:
    """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
    def __init__(self, model_path):
        if not os.path.isfile(model_path):
            raise FileNotFoundError(model_path)

        mergeable = load_tiktoken_bpe(model_path)

        # hard-coded from Meta's tokenizer.json
        self.special = {
            "<|begin_of_text|>": 128000,
            "<|end_of_text|>": 128001,
            "<|start_header_id|>": 128006,
            "<|end_header_id|>": 128007,
            "<|eot_id|>": 128009,
        }
        self.special.update({f"<|reserved_{i}|>": 128002 + i
                             for i in range(256)
                             if 128002 + i not in self.special.values()})

        self.model = tiktoken.Encoding(
            name=Path(model_path).name,
            pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
                    r"|[^\r\n\p{L}\p{N}]?\p{L}+"
                    r"|\p{N}{1,3}"
                    r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
                    r"|\s*[\r\n]+"
                    r"|\s+(?!\S)"
                    r"|\s+",
            mergeable_ranks=mergeable,
            special_tokens=self.special,
        )

    def encode(self, text, bos=False, eos=False):
        ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
              + self.model.encode(text)
        if eos:
            ids.append(self.special["<|end_of_text|>"])
        return ids

    def decode(self, ids):
        return self.model.decode(ids)


class ChatFormat:

    def __init__(self, tokenizer: Tokenizer, *,
                 default_system="You are a helpful assistant."):
        self.tok = tokenizer
        self.default_system = default_system

    def _header(self, role):
        """Encode <|start_header_id|>role<|end_header_id|>\n\n"""
        return (
            [self.tok.special["<|start_header_id|>"]]
            + self.tok.encode(role)
            + [self.tok.special["<|end_header_id|>"]]
            + self.tok.encode("\n\n")
        )

    def encode(self, user_message, system_message=None):
        sys_msg = system_message if system_message is not None else self.default_system

        ids = [self.tok.special["<|begin_of_text|>"]]

        # system
        ids += self._header("system")
        ids += self.tok.encode(sys_msg)
        ids += [self.tok.special["<|eot_id|>"]]

        # user
        ids += self._header("user")
        ids += self.tok.encode(user_message)
        ids += [self.tok.special["<|eot_id|>"]]

        # assistant header (no content yet)
        ids += self._header("assistant")

        return ids

In [17]:
#from huggingface_hub import login
#login()

In [18]:
from huggingface_hub import hf_hub_download

tokenizer_file_path = hf_hub_download(
    repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
    filename="original/tokenizer.model",
    local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
)

original/tokenizer.model:   0%|          | 0.00/2.18M [00:00<?, ?B/s]

In [19]:
tokenizer = Tokenizer(tokenizer_file_path)
chat_tokenizer = ChatFormat(tokenizer)

&nbsp;
# 4. Load pretrained weights

In [20]:
def assign(left, right, tensor_name="unknown"):
    if left.shape != right.shape:
        raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")

    if isinstance(right, torch.Tensor):
        return torch.nn.Parameter(right.clone().detach())
    else:
        return torch.nn.Parameter(torch.tensor(right))


def load_weights_into_llama(model, param_config, params):
    model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")

    for l in range(param_config["n_layers"]):

        # Load attention weights
        model.trf_blocks[l].att.W_query.weight = assign(
            model.trf_blocks[l].att.W_query.weight,
            params[f"model.layers.{l}.self_attn.q_proj.weight"],
            f"model.layers.{l}.self_attn.q_proj.weight"
        )
        model.trf_blocks[l].att.W_key.weight = assign(
            model.trf_blocks[l].att.W_key.weight,
            params[f"model.layers.{l}.self_attn.k_proj.weight"],
            f"model.layers.{l}.self_attn.k_proj.weight"
        )
        model.trf_blocks[l].att.W_value.weight = assign(
            model.trf_blocks[l].att.W_value.weight,
            params[f"model.layers.{l}.self_attn.v_proj.weight"],
            f"model.layers.{l}.self_attn.v_proj.weight"
        )
        model.trf_blocks[l].att.out_proj.weight = assign(
            model.trf_blocks[l].att.out_proj.weight,
            params[f"model.layers.{l}.self_attn.o_proj.weight"],
            f"model.layers.{l}.self_attn.o_proj.weight"
        )
        model.trf_blocks[l].norm1.weight = assign(
            model.trf_blocks[l].norm1.weight,
            params[f"model.layers.{l}.input_layernorm.weight"],
            f"model.layers.{l}.input_layernorm.weight"
        )

        # Load FeedForward weights
        model.trf_blocks[l].ff.fc1.weight = assign(
            model.trf_blocks[l].ff.fc1.weight,
            params[f"model.layers.{l}.mlp.gate_proj.weight"],
            f"model.layers.{l}.mlp.gate_proj.weight"
        )
        model.trf_blocks[l].ff.fc2.weight = assign(
            model.trf_blocks[l].ff.fc2.weight,
            params[f"model.layers.{l}.mlp.up_proj.weight"],
            f"model.layers.{l}.mlp.up_proj.weight"
        )
        model.trf_blocks[l].ff.fc3.weight = assign(
            model.trf_blocks[l].ff.fc3.weight,
            params[f"model.layers.{l}.mlp.down_proj.weight"],
            f"model.layers.{l}.mlp.down_proj.weight"
        )
        model.trf_blocks[l].norm2.weight = assign(
            model.trf_blocks[l].norm2.weight,
            params[f"model.layers.{l}.post_attention_layernorm.weight"],
            f"model.layers.{l}.post_attention_layernorm.weight"
        )

    # Load output layer weights
    model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")

    if "lm_head.weight" in params.keys():
        model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
    else:
        model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
        print("Model uses weight tying.")

In [21]:
if LLAMA_SIZE_STR == "1B":
    weights_file = hf_hub_download(
        repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
        filename="model.safetensors",
        local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
    )
    combined_weights = load_file(weights_file)


else:
    combined_weights = {}
    for i in range(1, 3):
        weights_file = hf_hub_download(
            repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
            filename=f"model-0000{i}-of-00002.safetensors",
            local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
        )
        current_weights = load_file(weights_file)
        combined_weights.update(current_weights)


load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)
model.to(device)
del combined_weights  # free up memory

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

Model uses weight tying.


In [22]:
print("Weight tying:", torch.equal(model.tok_emb.weight, model.out_head.weight))

Weight tying: True


&nbsp;
# 5. Generate text

In [23]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text)
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension
    return encoded_tensor


def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0)  # remove batch dimension
    return tokenizer.decode(flat.tolist())


def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):

    # For-loop is the same as before: Get logits, and only focus on last time step
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        # New: Filter logits with top_k sampling
        if top_k is not None:
            # Keep only top_k values
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)

        # New: Apply temperature scaling
        if temperature > 0.0:
            logits = logits / temperature

            # Apply softmax to get probabilities
            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)

            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)

        # Otherwise same as before: get idx of the vocab entry with the highest logits value
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)

        if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified
            break

        # Same as before: append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)

    return idx

In [24]:
import time


PROMPT = "What is quantization?"

torch.manual_seed(123)

start = time.time()

token_ids = generate(
    model=model,
    idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),
    max_new_tokens=150,
    context_size=LLAMA32_CONFIG["context_length"],
    top_k=1,
    temperature=0.
)

print(f"Time: {time.time() - start:.2f} sec")

if torch.cuda.is_available():
    max_mem_bytes = torch.cuda.max_memory_allocated()
    max_mem_gb = max_mem_bytes / (1024 ** 3)
    print(f"Max memory allocated: {max_mem_gb:.2f} GB")

output_text = token_ids_to_text(token_ids, tokenizer)


def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
    # Find the index of the first occurrence of "<|end_header_id|>"
    index = text.find(header_end)

    if index != -1:
        # Return the substring starting after "<|end_header_id|>"
        return text[index + len(header_end):].strip()  # Strip removes leading/trailing whitespace
    else:
        # If the token is not found, return the original text
        return text

print("\n\nOutput text:\n\n", clean_text(output_text))

Time: 21.00 sec
Max memory allocated: 3.10 GB


Output text:

 Quantization is a fundamental concept in physics and engineering that deals with the conversion of continuous signals into discrete values. In other words, it's the process of representing a continuous quantity, such as sound waves, light waves, or electrical signals, as a series of discrete values, known as quanta.

In a continuous signal, such as a sound wave, a light wave, or an electrical signal, there are an infinite number of possible values that can be represented. However, in practice, these values are not always discrete, and it's often necessary to represent them as a series of discrete values, such as 0, 1, 2, 3, etc.

Quantization is used in many areas, including:

1. **Signal


#LoRA fine tunning

In [25]:
!pip install unsloth

Collecting unsloth
  Downloading unsloth-2025.8.9-py3-none-any.whl.metadata (52 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/52.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.3/52.3 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting unsloth_zoo>=2025.8.8 (from unsloth)
  Downloading unsloth_zoo-2025.8.8-py3-none-any.whl.metadata (9.4 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.1 kB)
Collecting bitsandbytes (from unsloth)
  Downloading bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl.metadata (11 kB)
Collecting tyro (from unsloth)
  Downloading tyro-0.9.28-py3-none-any.whl.metadata (11 kB)
Collecting datasets<4.0.0,>=3.4.1 (from unsloth)
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting trl!=0.15.0,!=0.19.0,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,>=0.7.9 (from unsloth)
  Do

In [26]:
!pip install -q -U bitsandbytes
!pip install -q -U peft
!pip install -q -U trl
!pip install -q -U tensorboardX
!pip install -q wandb
!pip install -q -U torchvision
!pip install -q -U transformers

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m504.9/504.9 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.3/11.3 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25h

# Task
Adapt the provided Python code to preprocess the "Jofthomas/hermes-function-calling-thinking-V1" dataset for fine-tuning a Llama 3.2 model using its specific tokenizer and chat template, and then prepare the data for training.

## Install necessary libraries

### Subtask:
Install `transformers` and `datasets` to load and process the dataset.


**Reasoning**:
Install the necessary libraries for loading and processing the dataset.



In [27]:
%pip install transformers datasets



## Define a preprocessing function

### Subtask:
Adapt the `preprocess` function to work with your existing Llama 3.2 tokenizer and chat format.


**Reasoning**:
Define the preprocess function to encode the dataset examples using the chat tokenizer.



In [28]:
def preprocess(example):
    # Use the chat_tokenizer to encode the conversation
    # Assumes the dataset has a 'text' field containing the conversation in a format
    # that chat_tokenizer can handle (e.g., a simple prompt/response structure).
    # If the dataset has a different structure (e.g., 'prompt' and 'response' fields),
    # you would need to adapt this to format the text appropriately before encoding.
    # For this specific dataset 'Jofthomas/hermes-function-calling-thinking-V1',
    # the 'text' field contains the entire conversation in a structured format.
    encoded_text = chat_tokenizer.encode(example['text'])
    return {'input_ids': encoded_text}

## Load and preprocess the dataset

### Subtask:
Load the dataset using `load_dataset` from the `datasets` library and apply the adapted preprocessing function.


**Reasoning**:
Load the dataset, apply the preprocessing function, and remove the original text column.



In [33]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("Jofthomas/hermes-function-calling-thinking-V1")




**Reasoning**:
The previous attempt failed because the dataset does not have a 'text' column. I need to inspect the dataset structure to understand what columns are available and how to construct the input text for the tokenizer from those columns.



In [31]:
print(dataset['train'].column_names)
print(dataset['train'][0])

['conversations']
{'conversations': [{'content': "You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags.You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.Here are the available tools:<tools> [{'type': 'function', 'function': {'name': 'get_stock_price', 'description': 'Get the current stock price of a company', 'parameters': {'type': 'object', 'properties': {'company': {'type': 'string', 'description': 'The name of the company'}}, 'required': ['company']}}}, {'type': 'function', 'function': {'name': 'get_movie_details', 'description': 'Get details about a movie', 'parameters': {'type': 'object', 'properties': {'title': {'type': 'string', 'description': 'The title of the movie'}}, 'required': ['title']}}}] </tools>Use the following pydantic model json schema for each tool call you will make: {'title': 'FunctionCall', 'type': 'object', 'properties': {'ar

**Reasoning**:
The dataset contains a 'conversations' column, which is a list of dictionaries, each representing a turn in the conversation with 'role' and 'content' keys. The previous `preprocess` function expected a 'text' column. I need to modify the `preprocess` function to iterate through the 'conversations' list, format it into a single string that the `chat_tokenizer` can understand, and then encode it.



In [34]:
def preprocess(example):
    conversation = example['conversations']

    # Manually format the conversation turns into a single string, including special tokens
    formatted_text = ""
    for i, turn in enumerate(conversation):
        role = turn['role']
        content = turn['content']

        # Add the beginning of text token at the very start of the first turn
        if i == 0:
             formatted_text += "<|begin_of_text|>"

        # Add header and content based on the role
        if role in ['system', 'human']:
             formatted_text += f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
        elif role == 'model':
             # Map 'model' role to 'assistant' for the tokenizer
             formatted_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
        elif role == 'tool':
             # Include tool responses as content, followed by an EOT token
             # We are assuming the tokenizer understands these tokens within the flow.
             formatted_text += f"{content}<|eot_id|>"

    # Encode the formatted text using the base tokenizer, allowing special tokens
    # We need to explicitly allow the special tokens that we manually added to the formatted_text
    allowed_special_tokens = set(tokenizer.special.keys())
    encoded_text = tokenizer.model.encode(formatted_text, allowed_special=allowed_special_tokens)

    return {'input_ids': encoded_text}

# Load the dataset
dataset = load_dataset("Jofthomas/hermes-function-calling-thinking-V1")

# Apply the preprocess function
processed_dataset = dataset.map(preprocess, batched=False) # Set batched=False as the preprocess function processes one example at a time

# Remove the original conversations column
processed_dataset = processed_dataset.remove_columns('conversations')

# Print some information about the processed dataset
print(processed_dataset)
print(processed_dataset['train'][0])

Map:   0%|          | 0/3570 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 3570
    })
})
{'input_ids': [128000, 128006, 9125, 128007, 271, 2675, 527, 264, 734, 8260, 15592, 1646, 13, 1472, 527, 3984, 449, 734, 33728, 2949, 366, 16297, 1500, 16297, 29, 12138, 9681, 39537, 1253, 1650, 832, 477, 810, 5865, 311, 7945, 449, 279, 1217, 3319, 13, 4418, 956, 1304, 32946, 922, 1148, 2819, 311, 20206, 1139, 5865, 91173, 527, 279, 2561, 7526, 32352, 16297, 29, 62208, 1337, 1232, 364, 1723, 518, 364, 1723, 1232, 5473, 609, 1232, 364, 456, 31641, 9217, 518, 364, 4789, 1232, 364, 1991, 279, 1510, 5708, 3430, 315, 264, 2883, 518, 364, 14105, 1232, 5473, 1337, 1232, 364, 1735, 518, 364, 13495, 1232, 5473, 10348, 1232, 5473, 1337, 1232, 364, 928, 518, 364, 4789, 1232, 364, 791, 836, 315, 279, 2883, 8439, 2186, 364, 6413, 1232, 2570, 10348, 663, 3500, 2186, 5473, 1337, 1232, 364, 1723, 518, 364, 1723, 1232, 5473, 609, 1232, 364, 456, 51829, 13563, 518, 364, 4789, 1232, 364, 1991, 3649, 922, 

## Prepare data for training

### Subtask:
Further process the preprocessed data into a format suitable for training your Llama 3.2 model.


**Reasoning**:
Determine the maximum sequence length, pad or truncate the sequences, create attention masks, and format the data for training.



In [35]:
import torch
from torch.utils.data import Dataset

# 1. Determine the maximum sequence length
# Consider the model's context length and potentially available memory.
# For simplicity, we will use a value that is a power of 2 and less than the model's context length.
# A smaller context length can be used for faster training and less memory usage.
max_sequence_length = 1024 # Example: Using a smaller sequence length for demonstration

# 2. Pad or truncate the sequences
def pad_and_truncate(example, max_length, pad_token_id):
    input_ids = example['input_ids']
    if len(input_ids) > max_length:
        input_ids = input_ids[:max_length]
    elif len(input_ids) < max_length:
        padding = [pad_token_id] * (max_length - len(input_ids))
        input_ids = input_ids + padding
    return {'input_ids': input_ids}

# Use the <|eot_id|> as the padding token
pad_token_id = tokenizer.special["<|eot_id|>"]

# Apply padding and truncation
processed_dataset = processed_dataset.map(
    lambda example: pad_and_truncate(example, max_sequence_length, pad_token_id),
    batched=False
)

# 3. Create attention masks
def create_attention_mask(example, pad_token_id):
    input_ids = example['input_ids']
    # Mask is 1 for non-padding tokens and 0 for padding tokens
    attention_mask = [1 if token_id != pad_token_id else 0 for token_id in input_ids]
    return {'attention_mask': attention_mask}

# Apply attention mask creation
processed_dataset = processed_dataset.map(
    lambda example: create_attention_mask(example, pad_token_id),
    batched=False
)

# 4. Format the processed dataset into a PyTorch tensor dataset
class LlamaDataset(Dataset):
    def __init__(self, dataset):
        self.input_ids = [torch.tensor(x, dtype=torch.long) for x in dataset['input_ids']]
        self.attention_mask = [torch.tensor(x, dtype=torch.long) for x in dataset['attention_mask']]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.input_ids[idx].clone() # For language modeling, labels are the same as input_ids
        }

train_dataset = LlamaDataset(processed_dataset['train'])

# Check if a 'test' split exists before creating eval_dataset
eval_dataset = None
if 'test' in processed_dataset:
    eval_dataset = LlamaDataset(processed_dataset['test'])

print("Processed and formatted datasets:")
print(train_dataset)
if eval_dataset:
    print(eval_dataset)
print("\nSample from training dataset:")
print(train_dataset[0])

Map:   0%|          | 0/3570 [00:00<?, ? examples/s]

Map:   0%|          | 0/3570 [00:00<?, ? examples/s]

Processed and formatted datasets:
<__main__.LlamaDataset object at 0x7fd1ecdc2330>

Sample from training dataset:
{'input_ids': tensor([128000, 128006,   9125,  ..., 128009, 128009, 128009]), 'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0]), 'labels': tensor([128000, 128006,   9125,  ..., 128009, 128009, 128009])}


## Define training loop

### Subtask:
Implement a training loop to fine-tune the Llama 3.2 model using the prepared dataset.


In [39]:
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import gc
from datasets import load_dataset # Need to re-import load_dataset

# Clear memory before retrying
if 'train_dataloader' in locals():
    del train_dataloader
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# 1. Define the training parameters (Reduced sequence length)
learning_rate = 1e-4
num_epochs = 1
batch_size = 1 # Keeping batch size at 1 for now

# Reduce the maximum sequence length
max_sequence_length = 512 # Reduced sequence length

# Load the dataset again
dataset = load_dataset("Jofthomas/hermes-function-calling-thinking-V1")

# Define the preprocess function again (from previous successful step)
def preprocess(example):
    conversation = example['conversations']
    formatted_text = ""
    for i, turn in enumerate(conversation):
        role = turn['role']
        content = turn['content']
        if i == 0:
             formatted_text += "<|begin_of_text|>"
        if role in ['system', 'human']:
             formatted_text += f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
        elif role == 'model':
             formatted_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
        elif role == 'tool':
             formatted_text += f"{content}<|eot_id|>"

    allowed_special_tokens = set(tokenizer.special.keys())
    encoded_text = tokenizer.model.encode(formatted_text, allowed_special=allowed_special_tokens)
    return {'input_ids': encoded_text}

# Apply the preprocess function
processed_dataset = dataset.map(preprocess, batched=False)
processed_dataset = processed_dataset.remove_columns('conversations')


# Define the padding and truncation function again (from previous successful step)
def pad_and_truncate(example, max_length, pad_token_id):
    input_ids = example['input_ids']
    if len(input_ids) > max_length:
        input_ids = input_ids[:max_length]
    elif len(input_ids) < max_length:
        padding = [pad_token_id] * (max_length - len(input_ids))
        input_ids = input_ids + padding
    return {'input_ids': input_ids}

# Use the <|eot_id|> as the padding token
pad_token_id = tokenizer.special["<|eot_id|>"]

# Apply padding and truncation to the processed_dataset
processed_dataset = processed_dataset.map(
    lambda example: pad_and_truncate(example, max_sequence_length, pad_token_id),
    batched=False
)

# Define the attention mask creation function again (from previous successful step)
def create_attention_mask(example, pad_token_id):
    input_ids = example['input_ids']
    attention_mask = [1 if token_id != pad_token_id else 0 for token_id in input_ids]
    return {'attention_mask': attention_mask}

# Re-create attention masks with the new sequence length and apply to processed_dataset
processed_dataset = processed_dataset.map(
    lambda example: create_attention_mask(example, pad_token_id),
    batched=False
)

# Define the LlamaDataset class again (from previous successful step)
class LlamaDataset(Dataset):
    def __init__(self, dataset):
        self.input_ids = [torch.tensor(x, dtype=torch.long) for x in dataset['input_ids']]
        self.attention_mask = [torch.tensor(x, dtype=torch.long) for x in dataset['attention_mask']]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.input_ids[idx].clone()
        }

# Re-create the training dataset with the updated processed data
train_dataset = LlamaDataset(processed_dataset['train'])


# 2. Create a PyTorch DataLoader for the training dataset
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

# Define the loss function (Cross-Entropy Loss)
loss_fn = nn.CrossEntropyLoss()

# 3. Set the model to training mode
model.train()

print(f"Starting training for {num_epochs} epochs with batch size {batch_size} and sequence length {max_sequence_length}...")

# 4. Iterate through the specified number of epochs
for epoch in range(num_epochs):
    total_loss = 0
    # 5. Within each epoch, iterate through the DataLoader to get batches of data
    for i, batch in enumerate(train_dataloader):
        # 6. For each batch:
        # a. Move the input IDs, attention masks, and labels to the appropriate device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # b. Perform a forward pass through the model
        logits = model(input_ids)

        # c. Calculate the loss
        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

        # d. Perform backpropagation
        loss.backward()

        # e. Update the model's parameters
        optimizer.step()

        # f. Zero the gradients
        optimizer.zero_grad()

        total_loss += loss.item()

        # 7. (Optional) Implement periodic logging of the training loss
        if (i + 1) % 100 == 0: # Log every 100 batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] finished, Average Loss: {avg_loss:.4f}")

    # 8. (Optional) Implement saving the model checkpoints periodically
    # You can add code here to save the model state dictionary
    # Example: torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")

print("Training finished.")

Map:   0%|          | 0/3570 [00:00<?, ? examples/s]

Map:   0%|          | 0/3570 [00:00<?, ? examples/s]

Map:   0%|          | 0/3570 [00:00<?, ? examples/s]

Starting training for 1 epochs with batch size 1 and sequence length 512...


OutOfMemoryError: CUDA out of memory. Tried to allocate 502.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 430.12 MiB is free. Process 11793 has 14.32 GiB memory in use. Of the allocated memory 13.68 GiB is allocated by PyTorch, and 523.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Evaluate the fine-tuned model

### Subtask:
Evaluate the performance of the fine-tuned model on a test set.


**Reasoning**:
Evaluate the fine-tuned model on the evaluation dataset if it exists.



In [None]:
# 1. Set the model to evaluation mode
model.eval()

# 2. If an evaluation dataset (eval_dataset) was created in the "Prepare data for training" step, create a PyTorch DataLoader for it.
eval_dataloader = None
if eval_dataset:
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False) # No need to shuffle evaluation data

# 3. Iterate through the evaluation DataLoader (if it exists).
if eval_dataloader:
    print("\nStarting evaluation...")
    total_eval_loss = 0
    with torch.no_grad(): # 4b. Perform a forward pass through the model without calculating gradients
        for i, batch in enumerate(eval_dataloader):
            # 4a. Move the input IDs, attention masks, and labels to the appropriate device.
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # 4b. (Continued)
            logits = model(input_ids)

            # 4c. Calculate the loss using the same loss function (loss_fn) used for training.
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_eval_loss += loss.item()

    # 5. Calculate and print the average loss over the entire evaluation dataset.
    avg_eval_loss = total_eval_loss / len(eval_dataloader)
    print(f"Evaluation finished, Average Loss: {avg_eval_loss:.4f}")
else:
    # 6. If no evaluation dataset exists, print a message indicating that evaluation cannot be performed.
    print("\nNo evaluation dataset found, skipping evaluation.")


In [None]:
import torch
import torch.nn as nn
from safetensors.torch import load_file
import os
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from huggingface_hub import hf_hub_download
from datasets import load_dataset # Need to re-import load_dataset
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import gc

# Re-initialize the model with the defined configuration
model = Llama3Model(LLAMA32_CONFIG)

# Determine the device again
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

# Load the weights file again based on the model size
if LLAMA_SIZE_STR == "1B":
    weights_file = hf_hub_download(
        repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
        filename="model.safetensors",
        local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
    )
    combined_weights = load_file(weights_file)
else:
    combined_weights = {}
    for i in range(1, 3):
        weights_file = hf_hub_download(
            repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
            filename=f"model-0000{i}-of-00002.safetensors",
            local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
        )
        current_weights = load_file(weights_file)
        combined_weights.update(current_weights)

# Load the weights into the re-initialized model
load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)
model.to(device)
del combined_weights # free up memory

# Re-load and preprocess the dataset to recreate eval_dataset
# Define the preprocess function again (from previous successful step)
def preprocess(example):
    conversation = example['conversations']
    formatted_text = ""
    for i, turn in enumerate(conversation):
        role = turn['role']
        content = turn['content']
        if i == 0:
             formatted_text += "<|begin_of_text|>"
        if role in ['system', 'human']:
             formatted_text += f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
        elif role == 'model':
             formatted_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
        elif role == 'tool':
             formatted_text += f"{content}<|eot_id|>"

    allowed_special_tokens = set(tokenizer.special.keys())
    encoded_text = tokenizer.model.encode(formatted_text, allowed_special=allowed_special_tokens)
    return {'input_ids': encoded_text}

# Load the dataset
dataset = load_dataset("Jofthomas/hermes-function-calling-thinking-V1")

# Apply the preprocess function
processed_dataset = dataset.map(preprocess, batched=False)
processed_dataset = processed_dataset.remove_columns('conversations')

# Define the padding and truncation function again (from previous successful step)
def pad_and_truncate(example, max_length, pad_token_id):
    input_ids = example['input_ids']
    if len(input_ids) > max_length:
        input_ids = input_ids[:max_length]
    elif len(input_ids) < max_length:
        padding = [pad_token_id] * (max_length - len(input_ids))
        input_ids = input_ids + padding
    return {'input_ids': input_ids}

# Use the <|eot_id|> as the padding token
pad_token_id = tokenizer.special["<|eot_id|>"]

# Apply padding and truncation to the processed_dataset
max_sequence_length = 512 # Re-set the sequence length
processed_dataset = processed_dataset.map(
    lambda example: pad_and_truncate(example, max_sequence_length, pad_token_id),
    batched=False
)

# Define the attention mask creation function again (from previous successful step)
def create_attention_mask(example, pad_token_id):
    input_ids = example['input_ids']
    attention_mask = [1 if token_id != pad_token_id else 0 for token_id in input_ids]
    return {'attention_mask': attention_mask}

# Re-create attention masks with the new sequence length and apply to processed_dataset
processed_dataset = processed_dataset.map(
    lambda example: create_attention_mask(example, pad_token_id),
    batched=False
)

# Define the LlamaDataset class again (from previous successful step)
class LlamaDataset(Dataset):
    def __init__(self, dataset):
        self.input_ids = [torch.tensor(x, dtype=torch.long) for x in dataset['input_ids']]
        self.attention_mask = [torch.tensor(x, dtype=torch.long) for x in dataset['attention_mask']]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.input_ids[idx].clone()
        }

# Re-create the training dataset with the updated processed data
train_dataset = LlamaDataset(processed_dataset['train'])

# Check if a 'test' split exists before creating eval_dataset
eval_dataset = None
if 'test' in processed_dataset:
    eval_dataset = LlamaDataset(processed_dataset['test'])

# Re-define the loss function
loss_fn = nn.CrossEntropyLoss()

# Now, proceed with evaluation as planned
# 1. Set the model to evaluation mode
model.eval()

# 2. If an evaluation dataset (eval_dataset) was created in the "Prepare data for training" step, create a PyTorch DataLoader for it.
eval_dataloader = None
batch_size = 1 # Re-set batch size for evaluation
if eval_dataset:
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False) # No need to shuffle evaluation data

# 3. Iterate through the evaluation DataLoader (if it exists).
if eval_dataloader:
    print("\nStarting evaluation...")
    total_eval_loss = 0
    with torch.no_grad(): # 4b. Perform a forward pass through the model without calculating gradients
        for i, batch in enumerate(eval_dataloader):
            # 4a. Move the input IDs, attention masks, and labels to the appropriate device.
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # 4b. (Continued)
            logits = model(input_ids)

            # 4c. Calculate the loss using the same loss function (loss_fn) used for training.
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_eval_loss += loss.item()

    # 5. Calculate and print the average loss over the entire evaluation dataset.
    avg_eval_loss = total_eval_loss / len(eval_dataloader)
    print(f"Evaluation finished, Average Loss: {avg_eval_loss:.4f}")
else:
    # 6. If no evaluation dataset exists, print a message indicating that evaluation cannot be performed.
    print("\nNo evaluation dataset found, skipping evaluation.")
