# GPT to Llama Implementation

Implementing the conversion from GPT architecture to Llama models using template by Sebastian Rascha and "Building a Large Language Model from Scratch." Based on the context from the standalone Llama 3.2 notebook, this implementation would include the key architectural differences between GPT and Llama models, such as the transition from standard multi-head attention to Grouped Query Attention (GQA), the implementation of RoPE (Rotary Position Embedding) instead of absolute positional encodings, the use of SwiGLU activation functions in the feed-forward networks, and RMSNorm instead of LayerNorm for normalization.

The implementation serves as a practical guide for understanding how modern language models like Llama 3.2 evolved from the original GPT architecture, providing hands-on code examples of each component transformation. This would be particularly valuable for researchers and developers looking to understand the technical improvements that make Llama models more efficient and performant compared to earlier. 

## Code Overview

### Overview
Converting GPT architecture to Llama 3.2 with modern improvements: GQA, RoPE, SwiGLU, RMSNorm

### Architecture Components

#### 1. FeedForward (SwiGLU)
- **Purpose**: Gated feed-forward network with SiLU activation
- **Key**: `silu(fc1(x)) * fc2(x)` → fc3
- **Improvement**: More efficient than ReLU-based FFN

#### 2. RoPE Functions
- **compute_rope_params()**: Precompute sin/cos for position encoding
- **apply_rope()**: Apply rotary transformation to Q/K
- **Improvement**: Relative vs absolute positional encoding

#### 3. GroupedQueryAttention
- **Purpose**: Memory-efficient attention mechanism
- **Key**: Multiple query heads share K/V pairs
- **Memory**: 32 heads, 8 KV groups → 4:1 query-to-KV ratio

#### 4. TransformerBlock
- **Components**: GQA + SwiGLU + RMSNorm + residuals
- **Flow**: norm → attention → residual → norm → ffn → residual

#### 5. Llama3Model
- **Complete**: Embedding → 16 blocks → norm → output
- **Buffers**: Precomputed RoPE values registered

### Configuration (1B Model)
```python
LLAMA32_CONFIG = {
    "vocab_size": 128_256,      # Vocabulary
    "context_length": 131_072,  # Max sequence
    "emb_dim": 2048,           # Embedding dim
    "n_heads": 32,             # Attention heads
    "n_layers": 16,            # Transformer layers
    "hidden_dim": 8192,        # FFN intermediate
    "n_kv_groups": 8,          # GQA groups
    "rope_base": 500_000.0,    # RoPE theta
    "dtype": torch.bfloat16,   # Memory efficiency
    "rope_freq": {...}         # Frequency scaling
}

### Packages needed for Llama implementation

In [11]:
# Install packages individually for better error handling
packages = [
    "blobfile",
    "huggingface_hub",
    "ipywidgets",
    "safetensors",
    "sentencepiece"
]

for package in packages:
    %pip install {package}

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [12]:
from importlib.metadata import version

pkgs = [
    "blobfile",         #required to download pretrained weights
    "huggingface_hub",   #required to download pretrained weights
    "tiktoken",          #required to implement the tokenizer
    "torch",             #required to implement models
]

for packages in pkgs:
    print(f"{packages} version: {version(packages)}")

blobfile version: 3.1.0
huggingface_hub version: 0.27.1
tiktoken version: 0.11.0
torch version: 2.8.0


### Architecture of Llama 3.2 including SwiGLU feed-forward network

In [1]:
#The architecture including a SwiGLU feed-forward network in LLama 3.2

import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [None]:
# RoPE (Rotary Position Embedding) implementation to understand token relationships based on relative rather than absolute positions.

def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))

    # Frequency adjustments
    if freq_config is not None:
        low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
        high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]

        wavelen = 2 * torch.pi / inv_freq

        inv_freq_llama = torch.where(
            wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
        )

        smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
            freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
        )

        smoothed_inv_freq = (
            (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
        )

        is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
        inv_freq = inv_freq_llama

    # Generate position indices
    positions = torch.arange(context_length, dtype=dtype)

    # Compute the angles
    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)

    # Expand angles to match the head_dim
    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)

    # Precompute sine and cosine
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin


def apply_rope(x, cos, sin):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    # It's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype=x.dtype)

In [None]:
#Grouped Querty Attention (GQA): implementation by performing core attention mechanism computation and generating final output. GQA is a memory-efficient alternative to standard MHA where multiple query heads share the same key-value pairs, reducing computational overhead while maintaining performance.

class GroupedQueryAttention(nn.Module):
    def __init__(
            self, d_in, d_out, num_heads,
            num_kv_groups,
            dtype=None
        ):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)

    def forward(self, x, mask, cos, sin):
        b, num_tokens, d_in = x.shape

        queries = self.W_query(x)  # Shape: (b, num_tokens, d_out)
        keys = self.W_key(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)
        values = self.W_value(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)

        # Reshape queries, keys, and values
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)

        # Transpose keys, values, and queries
        keys = keys.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)
        values = values.transpose(1, 2)  # Shape: (b, num_kv_groups, num_tokens, head_dim)
        queries = queries.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)

        # Apply RoPE
        keys = apply_rope(keys, cos, sin)
        queries = apply_rope(queries, cos, sin)

        # Expand keys and values to match the number of heads
        # Shape: (b, num_heads, num_tokens, head_dim)
        keys = keys.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)
        values = values.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)
        # For example, before repeat_interleave along dim=1 (query groups):
        #   [K1, K2]
        # After repeat_interleave (each query group is repeated group_size times):
        #   [K1, K1, K2, K2]
        # If we used regular repeat instead of repeat_interleave, we'd get:
        #   [K1, K2, K1, K2]

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        # Shape: (b, num_heads, num_tokens, num_tokens)
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Compute attention scores
        attn_scores = attn_scores.masked_fill(mask, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        assert keys.shape[-1] == self.head_dim

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec

In [7]:
#This code implements a complete Transformer block used in Llama models, combining all the key architectural improvements including Grouped Query Attention, SwiGLU feed-forward networks and RMSNorm normalization. It represents one complete layer of the LLama transformer stack, processing input through attention and feed-forward steps with proper residual connections.

class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            num_heads=cfg["n_heads"],
            num_kv_groups=cfg["n_kv_groups"],
            dtype=cfg["dtype"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
        self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])

    def forward(self, x, mask, cos, sin):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x, mask, cos, sin)  # Shape [batch_size, num_tokens, emb_size]
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed-forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = x + shortcut  # Add the original input back
        return x

In [None]:
#A top-level class that combines all previously defined components (ex: GroupedQueryAttention, SwiGLU FeedForward, ROPE, RMSNorm) into a fully functional transformer model. It represents the entire neural network from token input to vocab predictions, incorporating all key architectural improvements that distinguish Llama from prior versions of GPT.

class Llama3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        # Main model parameters
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])

        self.trf_blocks = nn.ModuleList(  # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )

        self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

        # Reusuable utilities
        cos, sin = compute_rope_params(
            head_dim=cfg["emb_dim"] // cfg["n_heads"],
            theta_base=cfg["rope_base"],
            context_length=cfg["context_length"],
            freq_config=cfg["rope_freq"]
        )
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)
        self.cfg = cfg


    def forward(self, in_idx):
        # Forward pass
        tok_embeds = self.tok_emb(in_idx)
        x = tok_embeds

        num_tokens = x.shape[1]
        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)

        for block in self.trf_blocks:
            x = block(x, mask, self.cos, self.sin)
        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg["dtype"]))
        return logits

## Initialization of the Llama 3.2 model

In [9]:
#Llama 3.2 1B

LLAMA32_CONFIG = {
    "vocab_size": 128_256,          # Vocabulary size
    "context_length": 131_072,      # Context length to train
    "emb_dim": 2048,                # Embedding dimension
    "n_heads": 32,                  # Number of attention heads
    "n_layers": 16,                 # Number of layers
    "hidden_dim": 8192,             # Size of intermediate dim
    "n_kv_groups": 8,               # Key-Value groups for GQA
    "rope_base": 500_000.0,          # The base in RoPE's "theta"
    "dtype": torch.bfloat16,        # dytpe for memory usage
    "rope_freq": {
        "factor": 32.0,
        "low_freq_factor": 1.0,
        "high_freq_factor": 4.0,
        "original_context_length": 8192,
    }
}

# Llama 3.2  model config information

# LLAMA32_CONFIG = {
#     "vocab_size": 128_256,           # Vocabulary size
#     "context_length": 131_072,       # Context length that was used to train the model
#     "emb_dim": 3072,                 # Embedding dimension
#     "n_heads": 24,                   # Number of attention heads
#     "n_layers": 28,                  # Number of layers
#     "hidden_dim": 8192,              # Size of the intermediate dimension in FeedForward
#     "n_kv_groups": 8,                # Key-Value groups for grouped-query attention
#     "rope_base": 500_000.0,          # The base in RoPE's "theta"
#     "dtype": torch.bfloat16,         # Lower-precision dtype to reduce memory usage
#     "rope_freq": {                   # RoPE frequency scaling
#         "factor": 32.0,
#         "low_freq_factor": 1.0,
#         "high_freq_factor": 4.0,
#         "original_context_length": 8192,
#     }
# }

LLAMA_SIZE_STR = "1B" if LLAMA32_CONFIG["emb_dim"] == 2048 else "3B"

In [8]:
model = Llama3Model(LLAMA32_CONFIG)

In [14]:
#Print model parameter size which includes every single parameter
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")

#Account for weight tying (i.e. same parameter matrices are used in multiple places in the model architecture)
total_params_normalized = total_params - model.tok_emb.weight.numel()

#Unique parameters in the model
print(f"\nTotal number of unique parameters: {total_params_normalized}")


Total number of parameters: 1498482688

Total number of unique parameters: 1235814400


In [None]:
#This function calculates the total memory footprint of the Llama model by accounting for three memory component types: model parameters, parameter gradients & model buffers.

def model_memory_size(model, input_dtype=torch.float32):
    total_params = 0
    total_grads = 0
    for param in model.parameters():
        # Calculate total number of elements per parameter
        param_size=param.numel()
        total_params += param_size
        #Check if gradients are stored for this parameter
        if param.requires_grad:
            total_grads += param_size

    # Calculate buffer size (non-parameters that require memory)
    total_buffers = sum(buf.numel() for buf in model.buffers())

    #Size in bytes = (Number of elements) * (Size of each element)
    element_size = torch.tensor(0, dtype=input_dtype).element_size()
    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size

    #Convert bytes to gigabytes
    total_memory_gb = total_memory_bytes / (1024**3)

    return total_memory_gb

print(f"float32: {model_memory_size(model,input_dtype=torch.float32):.2f} GB")
print(f"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB")




float32: 11.23 GB
bfloat16: 5.61 GB


In [18]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model.to(device);

## Load Tokenizer

In [19]:
import os
from pathlib import Path

import tiktoken
from tiktoken.load import load_tiktoken_bpe



class Tokenizer:
    """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
    def __init__(self, model_path):
        if not os.path.isfile(model_path):
            raise FileNotFoundError(model_path)

        mergeable = load_tiktoken_bpe(model_path)

        # hard-coded from Meta's tokenizer.json
        self.special = {
            "<|begin_of_text|>": 128000,
            "<|end_of_text|>": 128001,
            "<|start_header_id|>": 128006,
            "<|end_header_id|>": 128007,
            "<|eot_id|>": 128009,
        }
        self.special.update({f"<|reserved_{i}|>": 128002 + i
                             for i in range(256)
                             if 128002 + i not in self.special.values()})

        self.model = tiktoken.Encoding(
            name=Path(model_path).name,
            pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
                    r"|[^\r\n\p{L}\p{N}]?\p{L}+"
                    r"|\p{N}{1,3}"
                    r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
                    r"|\s*[\r\n]+"
                    r"|\s+(?!\S)"
                    r"|\s+",
            mergeable_ranks=mergeable,
            special_tokens=self.special,
        )

    def encode(self, text, bos=False, eos=False):
        ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
              + self.model.encode(text)
        if eos:
            ids.append(self.special["<|end_of_text|>"])
        return ids

    def decode(self, ids):
        return self.model.decode(ids)


class ChatFormat:

    def __init__(self, tokenizer: Tokenizer, *,
                 default_system="You are a helpful assistant."):
        self.tok = tokenizer
        self.default_system = default_system

    def _header(self, role):
        """Encode <|start_header_id|>role<|end_header_id|>\n\n"""
        return (
            [self.tok.special["<|start_header_id|>"]]
            + self.tok.encode(role)
            + [self.tok.special["<|end_header_id|>"]]
            + self.tok.encode("\n\n")
        )

    def encode(self, user_message, system_message=None):
        sys_msg = system_message if system_message is not None else self.default_system

        ids = [self.tok.special["<|begin_of_text|>"]]

        # system
        ids += self._header("system")
        ids += self.tok.encode(sys_msg)
        ids += [self.tok.special["<|eot_id|>"]]

        # user
        ids += self._header("user")
        ids += self.tok.encode(user_message)
        ids += [self.tok.special["<|eot_id|>"]]

        # assistant header (no content yet)
        ids += self._header("assistant")

        return ids

In [None]:
#HuggingFace Account Authentication
#Uncomment and run the following code if you are executing the notebook for the first time. It will ask for an access token from huggingface as part of the output.

from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [36]:
#Downloads the official Llama 3.2 tokenizer file from Hugging Face hub to local machine (requires permission)
from huggingface_hub import hf_hub_download

tokenizer_file_path = hf_hub_download(
    repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
    filename="original/tokenizer.model",
    local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
)

tokenizer.model:   0%|          | 0.00/2.18M [00:00<?, ?B/s]

In [37]:
# Initialize Llama 3.2 tokenizer and chat formatting for encoding user/system messages and special tokens.

tokenizer = Tokenizer(tokenizer_file_path)
chat_tokenizer = ChatFormat(tokenizer)

## Load pretrained weights

Purpose: 
- This section defines the critical functions needed to bridge Meta's official Llama 3.2 checkpoint format with our custom implementation architecture. The mapping strategy ensures dimensional compatibility and proper parameter assignment between HuggingFace's naming conventions and our step-by-step educational implementation.

Core Architecture Mapping Functions: 
- assign(): Safe tensor assignment with shape validation to ensure dimensional compatibility between model layers and pretrained weights, handling both torch.Tensor and numpy array inputs with detailed error reporting for debugging shape mismatches 
- load_weights_into_llama(): Complete weight mapping pipeline that systematically maps HuggingFace checkpoint parameter names to this custom architecture, processes all 16 transformer layers (attention + feed-forward + normalization), and handles weight tying scenarios between embedding and output layers

Key Architectural Differences Handled: 
- GQA Compatibility: K/V projections are 4x smaller (8 groups vs 32 heads) requiring careful dimension mapping 
- SwiGLU Structure: Maps 3 linear layers (gate + up + down) instead of traditional 2-layer feed-forward networks 
- RMSNorm Integration: Ensures proper normalization layer compatibility across pre-attention and pre-FFN positions 
- Weight Tying Logic: Supports both tied and untied output head configurations depending on checkpoint format

The detailed weight mapping strategy provides the complete reference for understanding how each component of our educational implementation corresponds to Meta's production architecture, enabling seamless transfer of billion-parameter pretrained knowledge into our custom Llama framework.

In [38]:
def assign(left, right, tensor_name="unknown"):
    if left.shape != right.shape:
        raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")

    if isinstance(right, torch.Tensor):
        return torch.nn.Parameter(right.clone().detach())
    else:
        return torch.nn.Parameter(torch.tensor(right))


def load_weights_into_llama(model, param_config, params):
    model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")

    for l in range(param_config["n_layers"]):

        # Load attention weights
        model.trf_blocks[l].att.W_query.weight = assign(
            model.trf_blocks[l].att.W_query.weight,
            params[f"model.layers.{l}.self_attn.q_proj.weight"],
            f"model.layers.{l}.self_attn.q_proj.weight"
        )
        model.trf_blocks[l].att.W_key.weight = assign(
            model.trf_blocks[l].att.W_key.weight,
            params[f"model.layers.{l}.self_attn.k_proj.weight"],
            f"model.layers.{l}.self_attn.k_proj.weight"
        )
        model.trf_blocks[l].att.W_value.weight = assign(
            model.trf_blocks[l].att.W_value.weight,
            params[f"model.layers.{l}.self_attn.v_proj.weight"],
            f"model.layers.{l}.self_attn.v_proj.weight"
        )
        model.trf_blocks[l].att.out_proj.weight = assign(
            model.trf_blocks[l].att.out_proj.weight,
            params[f"model.layers.{l}.self_attn.o_proj.weight"],
            f"model.layers.{l}.self_attn.o_proj.weight"
        )
        model.trf_blocks[l].norm1.weight = assign(
            model.trf_blocks[l].norm1.weight,
            params[f"model.layers.{l}.input_layernorm.weight"],
            f"model.layers.{l}.input_layernorm.weight"
        )

        # Load FeedForward weights
        model.trf_blocks[l].ff.fc1.weight = assign(
            model.trf_blocks[l].ff.fc1.weight,
            params[f"model.layers.{l}.mlp.gate_proj.weight"],
            f"model.layers.{l}.mlp.gate_proj.weight"
        )
        model.trf_blocks[l].ff.fc2.weight = assign(
            model.trf_blocks[l].ff.fc2.weight,
            params[f"model.layers.{l}.mlp.up_proj.weight"],
            f"model.layers.{l}.mlp.up_proj.weight"
        )
        model.trf_blocks[l].ff.fc3.weight = assign(
            model.trf_blocks[l].ff.fc3.weight,
            params[f"model.layers.{l}.mlp.down_proj.weight"],
            f"model.layers.{l}.mlp.down_proj.weight"
        )
        model.trf_blocks[l].norm2.weight = assign(
            model.trf_blocks[l].norm2.weight,
            params[f"model.layers.{l}.post_attention_layernorm.weight"],
            f"model.layers.{l}.post_attention_layernorm.weight"
        )

    # Load output layer weights
    model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight")

    if "lm_head.weight" in params.keys():
        model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
    else:
        model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
        print("Model uses weight tying.")

## **Pretrained Weight Download and Model Initialization with SafeTensors**

**Purpose:** This code downloads Meta's official Llama 3.2 pretrained weights from Hugging Face Hub and loads them into this custom model implementation, transforming it from a randomly initialized neural network into a fully functional language model. The process uses SafeTensors format for secure weight loading and handles both 1B and 3B model variants with different file splitting strategies.

**Key Operations and Logic:**
• **SafeTensors Import:** Imports the `load_file` function from safetensors library for memory-efficient and secure tensor loading
• **Model Size Detection:** Uses conditional logic to check `LLAMA_SIZE_STR` variable - if "1B", downloads single weight file; otherwise assumes 3B model requiring multiple files
• **1B Model Path:** Downloads `model.safetensors` file (~2.5GB) directly using `hf_hub_download()` from Meta's repository and loads it with `load_file()`
• **3B Model Path:** Iterates through files `model-00001-of-00002.safetensors` and `model-00002-of-00002.safetensors`, downloading each separately and combining dictionaries with `update()`
• **Weight Integration:** Calls `load_weights_into_llama()` function to map HuggingFace parameter names to our custom architecture layer names
• **Device Transfer:** Moves the now-pretrained model to the appropriate computational device (GPU, MPS, or CPU) for inference
• **Memory Cleanup:** Deletes the `combined_weights` dictionary to free up system memory after weight transfer is complete

The SafeTensors format provides memory mapping, corruption detection, and faster loading compared to traditional pickle files, while the conditional logic elegantly handles different model sizes without code duplication.

In [39]:
from safetensors.torch import load_file


if LLAMA_SIZE_STR == "1B":
    weights_file = hf_hub_download(
        repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
        filename="model.safetensors",
        local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
    )
    combined_weights = load_file(weights_file)


else:
    combined_weights = {}
    for i in range(1, 3):
        weights_file = hf_hub_download(
            repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
            filename=f"model-0000{i}-of-00002.safetensors",
            local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
        )
        current_weights = load_file(weights_file)
        combined_weights.update(current_weights)


load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)
model.to(device)
del combined_weights  # free up memory

model.safetensors:  35%|###4      | 860M/2.47G [00:00<?, ?B/s]

Model uses weight tying.


## Generate text


**Purpose:** This code implements the complete text generation pipeline for the Llama 3.2 model, including utility functions for tokenization conversion and an advanced sampling function that supports multiple generation strategies. The implementation enables the pretrained model to generate coherent text responses with configurable randomness and quality controls through temperature scaling and top-k sampling techniques.

**Key Functions and Operations:**
• **text_to_token_ids():** Converts input text strings into tensor format by encoding with the tokenizer and adding batch dimension for model processing
• **token_ids_to_text():** Converts model output token tensors back to readable text by removing batch dimension and decoding through tokenizer
• **generate():** Core autoregressive generation function that iteratively produces new tokens using the model's forward pass with gradient computation disabled for efficiency
• **Context Window Management:** Truncates input sequence to fit within model's context size limit using `idx[:, -context_size:]` slicing
• **Top-k Sampling:** Filters vocabulary to keep only the k most probable tokens, setting others to negative infinity to prevent selection
• **Temperature Scaling:** Controls generation randomness by dividing logits by temperature value - lower values increase determinism, higher values increase creativity
• **Sampling Strategies:** Supports both deterministic generation (argmax) when temperature=0 and probabilistic sampling (multinomial) when temperature>0
• **Early Stopping:** Implements optional end-of-sequence token detection to halt generation when model signals completion

**Technical Implementation:** The function operates through an autoregressive loop where each iteration generates one new token based on the current sequence, appends it to the growing output, and continues until reaching the maximum token limit or encountering a stop condition. The torch.no_grad() context ensures memory efficiency during inference by preventing gradient computation.

In [40]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text)
    encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension
    return encoded_tensor

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0)  # remove batch dimension
    return tokenizer.decode(flat.tolist())


def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):

    # For-loop is the same as before: Get logits, and only focus on last time step
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        # New: Filter logits with top_k sampling
        if top_k is not None:
            # Keep only top_k values
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)

        # New: Apply temperature scaling
        if temperature > 0.0:
            logits = logits / temperature

            # Apply softmax to get probabilities
            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)

            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)

        # Otherwise same as before: get idx of the vocab entry with the highest logits value
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)

        if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified
            break

        # Same as before: append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)

    return idx

## **Llama 3.2 Text Generation Demo with Performance Monitoring**

**Purpose:** This code demonstrates the complete text generation workflow using the pretrained Llama 3.2 model with a sample prompt, while measuring inference performance and memory usage. It showcases the model's ability to generate coherent responses using deterministic sampling and includes post-processing to extract clean output from the chat format.

**Key Operations and Workflow:**
• **Prompt Setup:** Defines a simple question "What do llamas eat?" as the input prompt for demonstration
• **Reproducibility:** Sets torch random seed to 123 for consistent generation results across runs
• **Text Generation:** Calls the generate() function with deterministic settings (top_k=1, temperature=0.0) to produce the most probable token sequence for 150 new tokens maximum
• **Input Processing:** Converts the prompt text to token IDs using chat_tokenizer and transfers to appropriate device (GPU/MPS/CPU)
• **Performance Measurement:** Records generation time using time.time() to benchmark inference speed
• **Memory Monitoring:** Tracks maximum GPU memory allocation if CUDA is available for resource usage analysis
• **Output Conversion:** Transforms generated token IDs back to readable text using the tokenizer
• **Text Cleaning:** Applies clean_text() function to extract the assistant's response by removing chat format headers and special tokens

**Technical Configuration:** The generation uses greedy decoding (temperature=0.0) with top-k=1 sampling for deterministic output, processes up to the full context length of 131,072 tokens, and leverages the chat tokenizer format to properly structure the conversation with system/user/assistant roles. The performance metrics help evaluate the model's efficiency on the current hardware setup.

In [43]:
import time


PROMPT = "What are the core symptoms of depression?"

torch.manual_seed(123)

start = time.time()

token_ids = generate(
    model=model,
    idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),
    max_new_tokens=150,
    context_size=LLAMA32_CONFIG["context_length"],
    top_k=1,
    temperature=0.
)

print(f"Time: {time.time() - start:.2f} sec")

if torch.cuda.is_available():
    max_mem_bytes = torch.cuda.max_memory_allocated()
    max_mem_gb = max_mem_bytes / (1024 ** 3)
    print(f"Max memory allocated: {max_mem_gb:.2f} GB")

output_text = token_ids_to_text(token_ids, tokenizer)


def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
    # Find the index of the first occurrence of "<|end_header_id|>"
    index = text.find(header_end)

    if index != -1:
        # Return the substring starting after "<|end_header_id|>"
        return text[index + len(header_end):].strip()  # Strip removes leading/trailing whitespace
    else:
        # If the token is not found, return the original text
        return text

print("\n\nOutput text:\n\n", clean_text(output_text))

Time: 6.47 sec


Output text:

 The core symptoms of depression can vary from person to person, but here are some common ones:

**Primary Symptoms:**

1. **Feeling sad, empty, or hopeless**: A persistent feeling of sadness, emptiness, or hopelessness that interferes with daily life.
2. **Loss of interest in activities**: A lack of interest or pleasure in activities that once brought joy, such as hobbies, socializing, or sex.
3. **Changes in appetite or sleep**: Significant changes in appetite or sleep patterns, such as overeating or insomnia.
4. **Fatigue or low energy**: Feeling tired, sluggish, or lacking the energy to perform daily tasks.
5. **Difficulty concentrating or making decisions**: Trouble focusing, making decisions, or
