<table style="width:100%">
<tr>
<td style="vertical-align:middle; text-align:left;">
<font size="2">
Supplementary code for the <a href="http://mng.bz/orYv">Build a Large Language Model From Scratch</a> book by <a href="https://sebastianraschka.com">Sebastian Raschka</a><br>
<br>Code repository: <a href="https://github.com/rasbt/LLMs-from-scratch">https://github.com/rasbt/LLMs-from-scratch</a>
</font>
</td>
<td style="vertical-align:middle; text-align:left;">
<a href="http://mng.bz/orYv"><img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp" width="100px"></a>
</td>
</tr>
</table>

# Olmo 3 From Scratch (A Standalone Notebook)

- This notebook is purposefully minimal and focuses on the code to re-implement Olmo 3 7B and 32 models from Allen AI in pure PyTorch without relying on other external LLM libraries; Olmo 3 is interesting because it is currently the leading fully open-source model
- For more information, see the official [Olmo 3 announcement](https://allenai.org/blog/olmo3) and model cards:
  - [Olmo-3-1025-7B](https://huggingface.co/allenai/Olmo-3-1025-7B) (base model)
  - [Olmo-3-7B-Instruct](https://huggingface.co/allenai/Olmo-3-7B-Instruct)
  - [Olmo-3-7B-Think](https://huggingface.co/allenai/Olmo-3-7B-Think)
- Note that there are also 32B versions, which are not listed above for brevity; you can find a complete list [here](https://huggingface.co/collections/allenai/olmo-3-post-training)
- Below is a side-by-side comparison with Qwen3 8B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3)
<br>

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/olmo3/olmo3.webp">
  
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/olmo3/olmo3-pipeline.webp">
  
  
- About the code:
  - all code is my own code, mapping the Olmo 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))

In [1]:
# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt

In [2]:
from importlib.metadata import version

pkgs = [
    "huggingface_hub",  # to download pretrained weights
    "tokenizers",       # to implement the tokenizer
    "torch",            # to implement the model
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

huggingface_hub version: 0.35.0
tokenizers version: 0.22.1
torch version: 2.9.1+cu130


- Note that there are three model types, and each of the four model types comes in a 7B and 32B size:
1. Base (`Olmo-3-1025-7B` and `Olmo-3-1125-32B`)
2. Instruct (`Olmo-3-7B/32B-Think`)
3. Reasoning (`Olmo-3-32B/7B-Think`)

In [3]:
# Select which model to use

# USE_MODEL = "Olmo-3-1025-7B"
# USE_MODEL = "Olmo-3-1125-32B"
USE_MODEL = "Olmo-3-7B-Instruct"
# USE_MODEL = "Olmo-3-32B-Instruct"
# USE_MODEL = "Olmo-3-7B-Think"
# USE_MODEL = "Olmo-3-32B-Think"
# USE_MODEL = "Olmo-3-7B-RLZero-IF"

- In addition to the checkpoints listed above, you can also use the intermediate checkpoints listed [here](https://huggingface.co/collections/allenai/olmo-3-post-training); since they all have the same architecture, they are all compatible with this notebook

&nbsp;
# 1. Architecture code

In [4]:
import torch
import torch.nn as nn


class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
        self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)

    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nn.functional.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [5]:
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(emb_dim))

    def forward(self, x):
        input_dtype = x.dtype
        x_f = x.float()
        var = x_f.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x_f * torch.rsqrt(var + self.eps)
        return (self.weight * x_norm).to(input_dtype)

In [6]:
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type="default", rope_factor=1.0, rope_orig_max=8192, dtype=torch.float32):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute the inverse frequencies
    inv_freq = 1.0 / (
        theta_base ** (
            torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()
            / head_dim
        )
    )

    # Generate position indices
    positions = torch.arange(context_length, dtype=dtype)

    # Optional YaRN scaling
    if rope_type == "yarn":
        positions = positions / rope_factor
        positions = torch.clamp(positions, max=rope_orig_max - 1)

    # Compute the base angles (shape: [context_length, head_dim // 2])
    angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)

    # Expand to full head_dim (shape: [context_length, head_dim])
    angles = torch.cat([angles, angles], dim=1)

    # Precompute sine and cosine
    cos = torch.cos(angles) * attention_factor
    sin = torch.sin(angles) * attention_factor

    return cos, sin


def apply_rope(x, cos, sin, offset=0):
    # x: (batch_size, num_heads, seq_len, head_dim)
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dimension must be even"

    # Split x into first half and second half
    x1 = x[..., : head_dim // 2]  # First half
    x2 = x[..., head_dim // 2 :]  # Second half

    # Adjust sin and cos shapes
    cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)
    sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)

    # Apply the rotary transformation
    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x * cos) + (rotated * sin)

    # It's ok to use lower-precision after applying cos and sin rotation
    return x_rotated.to(dtype=x.dtype)

This is the "atomic trace" you requested. We will simulate a single forward pass through the **OLMo 3** architecture with specific data numbers to make the invisible visible.

### ðŸ§ª **The Laboratory Setup**

To make the tensors readable, we will use a "Mini-Olmo" configuration.

  * **Batch Size ($B$):** 1
  * **Sequence Length ($T$):** 2 (Input tokens: "Hello", "World")
  * **Embedding Dim ($D$):** 8
  * **Query Heads ($H_q$):** 4
  * **KV Heads ($H_{kv}$):** 2 (This creates a **GQA Group Size of 2**)
  * **Head Dim ($D_h$):** 2 (since $4 \text{ heads} \times 2 \text{ dim} = 8 \text{ total dim}$)

-----

### **Phase 1: Input to Embedding**

**Input:** A list of token IDs (integers).

  * `input_ids`: `[101, 205]` (Imagine 101="Hello", 205="World")
  * **Shape:** `[1, 2]`

**Operation:** `self.tok_emb(input_ids)`
The model looks up the vector for ID 101 and ID 205.

**Tensor State (`x`):**
**Shape:** `[1, 2, 8]`

```text
Token 0 ("Hello"): [ 0.1,  0.2, -0.1,  0.5,  1.0,  0.0, -0.5,  0.3]
Token 1 ("World"): [ 0.9, -0.9,  0.2,  0.2,  0.5,  0.5, -0.1,  0.0]
```

-----

### **Phase 2: Inside `GroupedQueryAttention`**

This is the most complex part. We step into `block.att`.

#### **Step 2.1: Projections (The Split)**

The input `x` branches into three paths: Query, Key, Value.

**Code:**

```python
queries = self.W_query(x) # Output size: 4 heads * 2 dim = 8
keys    = self.W_key(x)   # Output size: 2 heads * 2 dim = 4  <-- SMALLER!
values  = self.W_value(x) # Output size: 2 heads * 2 dim = 4  <-- SMALLER!
```

**Tensor State:**

  * `queries`: `[1, 2, 8]` (Full size)
  * `keys`: `[1, 2, 4]` (Compressed)
  * `values`: `[1, 2, 4]` (Compressed)

#### **Step 2.2: Reshape & Transpose**

We split the flat vectors into heads and move "Heads" to the 2nd dimension.

**Code:**

```python
queries = queries.view(1, 2, 4, 2).transpose(1, 2)
keys    = keys.view(1, 2, 2, 2).transpose(1, 2)
```

**Tensor Trace (Queries):**
**Shape:** `[1, 4, 2, 2]` (Batch, **4 Heads**, Seq, HeadDim)

```text
HEAD 0: [[q0_t0_d0, q0_t0_d1], [q0_t1_d0, q0_t1_d1]] (Token 0, Token 1)
HEAD 1: ...
HEAD 2: ...
HEAD 3: ...
```

**Tensor Trace (Keys - Pre-Expansion):**
**Shape:** `[1, 2, 2, 2]` (Batch, **2 Heads**, Seq, HeadDim)

```text
KV_HEAD A: [[k0_t0, k0_t0], [k0_t1, k0_t1]] (Covers Query Heads 0 & 1)
KV_HEAD B: [[k1_t0, k1_t0], [k1_t1, k1_t1]] (Covers Query Heads 2 & 3)
```

#### **Step 2.3: RoPE (Rotary Embeddings)**

We apply the rotation math. Let's visualize what happens to just **Token 1, Head 0**.

  * **Before RoPE:** `[1.0, 0.5]`
  * **RoPE (Rotation):** Rotates vector by angle $\theta$ based on position $m=1$.
  * **After RoPE:** `[0.8, 0.7]` (Numbers changed, magnitude preserved).

**Crucial Note:** `keys` are rotated based on their history position. Since we are processing a sequence of 2, Token 0 gets rotation 0, Token 1 gets rotation 1.

#### **Step 2.4: GQA Expansion (The "Repeat")**

We must match the 2 Key heads to the 4 Query heads.
Group Size = 2.

**Code:**

```python
keys = keys.repeat_interleave(2, dim=1)
```

**Tensor Trace (Keys - Post-Expansion):**
**Shape:** `[1, 4, 2, 2]` (Now matches Query heads\!)

```text
HEAD 0 (Was KV_A): [Data from KV_HEAD A]
HEAD 1 (Was KV_A): [Data from KV_HEAD A]  <-- DUPLICATE
HEAD 2 (Was KV_B): [Data from KV_HEAD B]
HEAD 3 (Was KV_B): [Data from KV_HEAD B]  <-- DUPLICATE
```

*Now we have 4 Key heads, but they are pairwise identical.*

#### **Step 2.5: Attention Scores (MatMul)**

We calculate similarity.
`scores = queries @ keys.transpose(2, 3)`
Shapes: `(1, 4, 2, 2) @ (1, 4, 2, 2)` $\rightarrow$ `(1, 4, 2, 2)` (Seq x Seq)

**Visualization of Score Matrix (Head 0):**

```text
      Key0   Key1
Q0  [ 10.5,  2.1 ]  <- "Hello" matches "Hello" strongly
Q1  [  5.2,  8.8 ]  <- "World" matches "World" strongly
```

#### **Step 2.6: Masking (Causal)**

We apply the mask so Q0 cannot see K1 (future).

```text
      Key0   Key1
Q0  [ 10.5, -inf ]  <- Masked!
Q1  [  5.2,  8.8 ]
```

#### **Step 2.7: Softmax & Context**

1.  **Softmax:** Convert scores to probabilities (0.0 to 1.0).
2.  **Context:** `probs @ values`
      * Q0 (Hello) attends 100% to K0. Context = V0.
      * Q1 (World) attends \~5% to K0, \~95% to K1. Context = 0.05*V0 + 0.95*V1.

**Result (`context`):**
**Shape:** `[1, 4, 2, 2]` (Batch, Heads, Seq, HeadDim)

#### **Step 2.8: Output Projection**

We merge the heads back together.

1.  **Transpose & Reshape:** `[1, 4, 2, 2]` $\rightarrow$ `[1, 2, 8]`
2.  **Linear Proj (`out_proj`):** Mixes the head data.

**Result (`x_attn`):** `[1, 2, 8]`

-----

### **Phase 3: The FeedForward Network (FFN)**

Now the token processes "internally".

1.  **Input:** `x` (Shape: `[1, 2, 8]`)
2.  **Expansion:** Project up to hidden dim (usually 4x larger, e.g., 32).
      * Shape: `[1, 2, 32]`
3.  **Activation:** SwiGLU / Silu.
4.  **Contraction:** Project back down to 8.
      * Shape: `[1, 2, 8]`

**Residual Add:** `x = x_prev + x_ffn`

-----

### **Phase 4: KV Cache (The Atomic "Store")**

The function returns `next_cache`. What exactly is in it?

**Code:**

```python
next_cache = (keys_new, values_new)
```

Remember `keys_new` from Step 2.1 (before expansion)?

  * It was the **UNROTATED** keys.
  * **Shape:** `[1, 2, 2, 2]` (Batch, **KV\_Heads**, Seq, HeadDim)

**Where is it stored?**
It is returned to `Olmo3Model`, which puts it into a Python dictionary or list on the GPU.
`cache = {0: (k_tensor, v_tensor), 1: ...}`

-----

### **Phase 5: Next Token Generation (The "Loop")**

Let's say we generate Token 3 ("\!").
**Input:** `[301]` (ID for "\!")
**Shape:** `[1, 1]` (Seq len is 1 now\!)

1.  **Projections:** We get `k_new` for just this 1 token.
      * Shape: `[1, 2, 1, 2]`
2.  **Cache Retrieval:** We assume we passed the `cache` from Phase 4.
      * `prev_k` Shape: `[1, 2, 2, 2]` (Tokens 0, 1)
3.  **The Concatenation (`torch.cat`):**
    ```python
    keys_cat = torch.cat([prev_k, k_new], dim=2)
    ```
      * `prev_k` (len 2) + `k_new` (len 1) = **Total len 3**.
      * **New Shape:** `[1, 2, 3, 2]`
4.  **RoPE:** Applied to the full sequence of 3.
5.  **Attention:** The new token "\!" attends to all 3 keys (Hello, World, \!).

### **Summary Tensor Flowchart**

```text
Input IDs [1, 2]
    |
Embedding --> [1, 2, 8]
    |
    +--> Query Proj --> [1, 2, 8] --> Reshape [1, 4, 2, 2]
    |
    +--> Key Proj   --> [1, 2, 4] --> Reshape [1, 2, 2, 2] (KV Heads=2)
    |       |
    |       +--> SAVE TO CACHE (Unrotated)
    |
    +--> GQA Expand --> Keys become [1, 4, 2, 2] (Repeated)
    |
Attention (Q @ K.T) --> Scores [1, 4, 2, 2]
    |
Softmax & Multiply V --> Context [1, 4, 2, 2]
    |
Merge Heads --> [1, 2, 8]
    |
Output Proj --> [1, 2, 8]
```

This is the "Molecular Biology" of the Transformer. We are going to zoom in until we see the individual "atoms" of memory and how the code acts as the traffic controller for this massive data structure.

### 1\. The "Filing Cabinet" Architecture

The `cache` variable in your code isn't just a single tensor; it is a **Structured Container** (likely a custom class or a Python Dictionary) that acts like a filing cabinet.

  * **The Cabinet (`cache` object):** Holds the memory for the *entire* model.
  * **The Drawers (`i`):** Each drawer corresponds to one **Transformer Block** (Layer).
  * **The Folders (`k`, `v`):** Inside each drawer, there are exactly two folders: one for **Keys**, one for **Values**.

#### The Data Structure Visualization

If we could X-ray the `cache` object on the GPU during inference, it looks like this:

```python
# Conceptual Structure of the 'cache' object on GPU
cache_storage = {
    # LAYER 0 (The Bottom Drawer)
    0: (
        Tensor_K0,  # Shape: [Batch, KV_Heads, Seq_Len, Head_Dim]
        Tensor_V0   # Shape: [Batch, KV_Heads, Seq_Len, Head_Dim]
    ),
    
    # LAYER 1 (The Middle Drawer)
    1: (
        Tensor_K1,  # Shape: [Batch, KV_Heads, Seq_Len, Head_Dim]
        Tensor_V1   # Shape: [Batch, KV_Heads, Seq_Len, Head_Dim]
    ),
    
    # LAYER 2... (and so on)
    2: ( ... ),
}
```

-----

### 2\. The Tracking Logic: The "Pass-by-Reference" Relay

How does the code ensure Layer 0's data doesn't leak into Layer 1? It uses the **Layer Index `i`** as a strict address key.

Let's trace the **Atomic Flow** of a single forward pass (generating one new token) through the `Olmo3Model.forward` loop.

#### **State 0: The Setup**

  * **Current Token:** "Robot"
  * **Layer Index (`i`):** 0
  * **Cache Object:** Passing into the model.

#### **Step 1: The Retrieval (The `get` call)**

Inside `Olmo3Model.forward`:

```python
for i, block in enumerate(self.blocks):
    # ATOMIC ACTION: Look at Drawer 'i' (0). 
    # Grab the tuple (K0, V0) inside.
    # If this is the first token ever, grab None.
    blk_cache = cache.get(i) 
```

  * **Visual:** The code creates a *pointer* named `blk_cache` that points to the tensors in Drawer 0. It ignores Drawers 1-31.

#### **Step 2: The Block Processing (Inside `TransformerBlock`)**

The `block` takes `blk_cache` (Layer 0's history) and the input `x`.

1.  **Compute:** Generates `k_new_0` and `v_new_0` for "Robot".
2.  **Concat:** Merges `k_new_0` with `blk_cache` (history).
3.  **Return:** Sends back a *new* tuple: `next_cache = (K0_updated, V0_updated)`.

#### **Step 3: The Update (The `update` call)**

Back in `Olmo3Model.forward`:

```python
    # ATOMIC ACTION: Open Drawer 'i' (0).
    # THROW AWAY the old folders.
    # REPLACE them with the new tuple from Step 2.
    if cache is not None:
        cache.update(i, new_blk_cache)
```

  * **Memory Note:** The old `K0` tensor (size $T$) is now garbage collected. The `cache` object now holds the pointer to `K0_updated` (size $T+1$).

#### **Step 4: The Loop Continues**

The loop increments `i` to **1**.

1.  `cache.get(1)` pulls data from Drawer 1.
2.  Block 1 computes attention using *Layer 1's* weights and *Layer 1's* cache.
3.  `cache.update(1, ...)` updates Drawer 1.

**Crucial Insight:** Layer 1 **never** sees Layer 0's cache. The variable `blk_cache` is overwritten in every iteration of the loop, but the persistent `cache` object safely stores the state for every layer in separate slots.

-----

### 3\. Visual Trace: The Memory Address Map

Let's visualize the GPU memory addresses to see "where" the data lives.
**Scenario:** 2 Layers. Generating Token 3.

**Before Forward Pass:**
| Address | Variable | Content |
| :--- | :--- | :--- |
| `0xA00` | `cache[0][0]` | **Layer 0 Key** (Tokens 0-2) |
| `0xB00` | `cache[0][1]` | **Layer 0 Value** (Tokens 0-2) |
| `0xC00` | `cache[1][0]` | **Layer 1 Key** (Tokens 0-2) |
| `0xD00` | `cache[1][1]` | **Layer 1 Value** (Tokens 0-2) |

**Inside Loop `i=0` (Layer 0):**

1.  Model calculates `k_new` (Address `0xE00`).
2.  `torch.cat([0xA00, 0xE00])` allocates **NEW** memory at `0xF00`.
3.  `cache.update(0, ...)` changes the pointer. `cache[0][0]` now points to `0xF00`.
4.  Address `0xA00` (Old L0 Key) is freed.

**Inside Loop `i=1` (Layer 1):**

1.  Model calculates `k_new` for Layer 1 (Address `0x100`).
2.  `torch.cat([0xC00, 0x100])` allocates **NEW** memory at `0x200`.
3.  `cache.update(1, ...)` changes pointer. `cache[1][0]` now points to `0x200`.
4.  Address `0xC00` (Old L1 Key) is freed.

**After Forward Pass:**
| Address | Variable | Content |
| :--- | :--- | :--- |
| `0xF00` | `cache[0][0]` | **Layer 0 Key** (Tokens 0-3) **[UPDATED]** |
| `0xB00` | `cache[0][1]` | **Layer 0 Value** (Tokens 0-2) *Wait, values update too\! Let's assume updated to 0x...* |
| `0x200` | `cache[1][0]` | **Layer 1 Key** (Tokens 0-3) **[UPDATED]** |

-----

### 4\. How `start_pos` Keeps It Synchronized

You might wonder: *How does RoPE know that "Robot" is the 5th word?*

The `Olmo3Model` has a global tracker:

```python
self.current_pos = 0 # Initialized at start
```

In `forward(input_ids, cache=None)`:

```python
if cache is not None:
    pos_start = self.current_pos
    pos_end = pos_start + seq_len
    self.current_pos = pos_end # Increments globally!
```

This `pos_start` integer is passed down:
`Olmo3Model` $\rightarrow$ `TransformerBlock` $\rightarrow$ `GroupedQueryAttention`.

Inside `GroupedQueryAttention`, it is used to calculate the rotation angle:

```python
# If pos_start is 5, we grab the 5th Sin/Cos pair from the pre-computed table.
queries = apply_rope(queries, cos, sin, offset=start_pos)
```

**The "Atomic" Harmony:**

1.  **The Cache** stores the *content* (Tensor data).
2.  **The Index (`i`)** selects the *location* (Layer drawer).
3.  **The Pos (`current_pos`)** defines the *context* (Time step).


In [None]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_in, num_heads, num_kv_groups, head_dim, attention_bias=False, dtype=None, sliding_window=None, attn_type="full_attention"):
        super().__init__()
        # VALIDATION: Ensure we can evenly split heads into groups.
        # If we have 8 heads and 3 groups, this math fails. We need perfect integers.
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        self.num_heads = num_heads          # Total Query Heads (e.g., 32)
        self.num_kv_groups = num_kv_groups  # Total Key/Value Heads (e.g., 8)
        
        # GQA RATIO: How many Query heads share a single Key/Value head?
        # Example: 32 / 8 = 4. So 4 Queries share 1 Key.
        self.group_size = num_heads // num_kv_groups

        self.head_dim = head_dim
        self.d_out = num_heads * head_dim   # Total output dimension
        self.attn_type = attn_type
        
        # SLIDING WINDOW: If enabled, we only look back 'sliding_window' steps.
        self.sliding_window = sliding_window if attn_type == "sliding_attention" else None

        # --- PROJECTIONS (The "Learnable" Part) ---
        # 1. Query Projection: Maps input to Query space. Full size.
        self.W_query = nn.Linear(d_in, self.d_out, bias=attention_bias, dtype=dtype)
        
        # 2. Key/Value Projections: SMALLER than Query!
        # Notice we use 'num_kv_groups' instead of 'num_heads'.
        # This is where GQA saves massive parameters and compute.
        self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)
        self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)
        
        # 3. Output Projection: Mixes the results back to embedding dimension.
        self.out_proj = nn.Linear(self.d_out, d_in, bias=attention_bias, dtype=dtype)

        # --- STABILITY TRICK (Olmo Specific) ---
        # QK-Norm: We normalize Queries and Keys *before* they interact.
        # This prevents attention scores from growing too large (instability).
        self.q_norm = RMSNorm(self.d_out)
        self.k_norm = RMSNorm(num_kv_groups * head_dim)

    def forward(self, x, mask, cos, sin, start_pos=0, cache=None):
        b, num_tokens, _ = x.shape  # b: Batch Size, num_tokens: Seq Length

        # 1. PROJECTION
        # We turn the input vectors into Q, K, and V vectors.
        queries = self.W_query(x)  # Shape: (b, seq, 32 * 128) -> Full Size
        keys = self.W_key(x)       # Shape: (b, seq, 8 * 128)  -> Compressed (GQA)
        values = self.W_value(x)   # Shape: (b, seq, 8 * 128)  -> Compressed (GQA)

        # 2. INTERNAL NORM (Olmo Trick)
        # Normalize the vectors to keep math stable.
        queries = self.q_norm(queries)
        keys_new = self.k_norm(keys) # 'keys_new' means "Key for THIS token only"

        # 3. RESHAPE & TRANSPOSE
        # We need to separate the heads to do parallel attention.
        # .view() cuts the big vector into chunks (heads).
        # .transpose(1, 2) moves 'heads' to dim 1, so 'seq' is dim 2.
        # Result: (Batch, Heads, Seq, Head_Dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
        values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)

        # 4. KV CACHE MANAGEMENT (The Memory Heavy Lifting)
        prev_len = 0
        if cache is not None:
            prev_k, prev_v = cache # Retrieve history from GPU memory
            if prev_k is not None:
                prev_len = prev_k.size(2) # How many tokens have we seen before?
                
                # --- CRITICAL MEMORY OPERATION ---
                # torch.cat creates a NEW block of memory combining old + new.
                # This is "Contiguous Memory Allocation".
                keys_cat_raw = torch.cat([prev_k, keys_new], dim=2)
                values_cat_raw = torch.cat([prev_v, values_new], dim=2)
            else:
                # First token in the sequence (Cold Start)
                keys_cat_raw = keys_new
                values_cat_raw = values_new
        else:
            # Training mode (no cache used)
            keys_cat_raw = keys_new
            values_cat_raw = values_new

        # 5. ROPE (Rotary Positional Embeddings)
        # We rotate the vectors to encode "Order".
        # Note: We rotate 'queries' based on CURRENT position (start_pos).
        # We rotate 'keys' based on HISTORY position (start_pos - prev_len).
        queries = apply_rope(queries, cos, sin, offset=start_pos)
        keys = apply_rope(keys_cat_raw, cos, sin, offset=start_pos - prev_len)

        # 6. GQA EXPANSION
        # We have 32 Query heads but only 8 Key heads. Math requires 32 vs 32.
        # We replicate the 8 Key heads 4 times each to match the Queries.
        if self.group_size > 1:
            # Before: K1, K2
            # After:  K1, K1, K1, K1, K2, K2, K2, K2
            keys = keys.repeat_interleave(self.group_size, dim=1)
            values = values_cat_raw.repeat_interleave(self.group_size, dim=1)
        else:
            values = values_cat_raw

        # 7. SCALING
        # Scale down queries by sqrt(dim) to prevent Softmax from exploding.
        scale = self.head_dim ** -0.5
        queries = queries * scale

        # 8. UPDATE CACHE
        # We must save the UNROTATED keys/values for the next step.
        # Why unrotated? Because RoPE depends on absolute position. 
        # Next step, the positions will change relative to the new token.
        if cache is not None and cache[0] is not None:
            next_cache = (
                torch.cat([cache[0], keys_new], dim=2), # Append UNROTATED new key
                torch.cat([cache[1], values_new], dim=2), # Append UNROTATED new value
            )
        else:
            next_cache = (keys_new, values_new)

        # 9. ATTENTION MECHANISM
        # The dot product: "How much does Query match Key?"
        # Shapes: (B, 32, Seq, Dim) @ (B, 32, Dim, Seq) -> (B, 32, Seq, Seq)
        attn_scores = queries @ keys.transpose(2, 3)
        
        # MASKING
        # Set future positions (or far past positions) to -Infinity.
        # Softmax(-Inf) = 0, so we ignore them.
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask, -torch.inf)

        # PROBABILITIES
        attn_weights = torch.softmax(attn_scores, dim=-1)
        
        # WEIGHTED SUM
        # Combine Values based on attention weights.
        context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
        
        # 10. OUTPUT
        out = self.out_proj(context)

        return out, next_cache

This is the atomic-level breakdown you requested. We will freeze time inside the GPU and look at the exact state of memory, the index counters, and the tensors for **each separate block** as the model generates text.

### 1\. The Global Memory Map

First, you must visualize the **`cache`** object itself. It is not a single tensor. In PyTorch, it is a **List of Tuples** living on the GPU VRAM.

If you have a model with **2 Layers** (Layer 0 and Layer 1), the `cache` object looks like this in Python memory:

```python
cache = [
    # Layer 0 Cache (Tuple of K, V)
    (
        Tensor[Batch, Heads, Seq_Len, Dim],  # Key Cache Layer 0
        Tensor[Batch, Heads, Seq_Len, Dim]   # Value Cache Layer 0
    ),
    
    # Layer 1 Cache (Tuple of K, V)
    (
        Tensor[Batch, Heads, Seq_Len, Dim],  # Key Cache Layer 1
        Tensor[Batch, Heads, Seq_Len, Dim]   # Value Cache Layer 1
    )
]
```

  * **Crucial Insight:** Each layer has its *own independent memory*. Layer 1 cannot see Layer 0's cache.
  * **`current_pos`** is a simple integer counter stored on the CPU that tells us "How many tokens have we processed in total?".

-----

### 2\. The Trace: "Hello World" -\> "\!"

Let's trace the generation of the token "\!" after the prompt "Hello World".

  * **Setup:**
      * **Prompt:** "Hello" (ID 10), "World" (ID 20)
      * **Sliding Window:** 3 (We only keep the last 3 tokens).
      * **Heads:** 1 (Simple visualization).
      * **Dim:** 4.

#### **State 0: The Prefill (Context Phase)**

We process "Hello World" all at once.

  * **Input:** `[10, 20]` (Length 2)
  * **`current_pos`:** Starts at 0. Ends at 2.

**GPU Memory After Prefill:**

```text
LAYER 0 CACHE:
Keys:   [[Hello_K0, World_K0]]  (Shape 1x1x2x4)
Values: [[Hello_V0, World_V0]]

LAYER 1 CACHE:
Keys:   [[Hello_K1, World_K1]]
Values: [[Hello_V1, World_V1]]
```

-----

#### **State 1: Generation Step (The Loop begins)**

We are now generating the **3rd token**. The input is the *result* of the prefill (let's say we predicted token ID 99: "\!").

**A. The Setup**

  * **Input:** `[99]` (Shape `1x1`)
  * **`current_pos`:** 2 (We have 2 tokens in history).
  * **Pos Start:** 2.
  * **Pos End:** 3.

**B. Entering Layer 0**

1.  **Retrieval:** `blk_cache = cache[0]` gets Layer 0's data.
2.  **Projection:** We compute K and V for "\!" (let's call them `!_K0` and `!_V0`).
      * `!_K0` is **UNROTATED**.
3.  **Concatenation (`torch.cat`):**
      * Take `[Hello_K0, World_K0]`
      * Append `[!_K0]`
      * **New Layer 0 Key:** `[Hello_K0, World_K0, !_K0]` (Length 3).
4.  **RoPE Calculation:**
      * We apply RoPE to the *entire* sequence of 3.
      * "Hello" gets angle 0.
      * "World" gets angle 1.
      * "\!" gets angle 2 (derived from `start_pos` passed into forward).
5.  **Attention:**
      * Query "\!" looks at Keys `[Hello, World, !]`.
6.  **Cache Update:**
      * We return the **UNROTATED** tuple `(keys_new, values_new)` (Length 3).
      * `cache.update(0, new_blk_cache)` replaces the old Layer 0 entry.

**C. Entering Layer 1**

  * The input `x` is now the *output* of Layer 0.
  * We repeat the exact same process but using `cache[1]`.
  * Layer 1's cache grows from 2 to 3 tokens independently.

-----

### 3\. The Sliding Window Visualized (The "Chop")

Now, let's see what happens when we generate a **4th token** ("Game"), exceeding our window of 3.

  * **Current Cache (Length 3):** `[Hello, World, !]`
  * **New Input:** "Game"
  * **Window Size:** 3

#### **Step 1: Append (Briefly Length 4)**

Inside the block, `torch.cat` happens first.

```text
Temp Cache: [Hello, World, !, Game]
Indices:      0      1    2    3
```

#### **Step 2: Attention & Masking**

The model computes attention. The mask forces "Game" to ignore "Hello" because $3 - 0 = 3$ (and if window is strictly 3, distance $\ge$ 3 might be masked depending on implementation, let's assume strict local mask).

#### **Step 3: The Truncation Code**

This line runs *before* returning the cache:

```python
if k.size(2) > self.sliding_window:
    k = k[:, :, -self.sliding_window:, :] # Take last 3
```

**Visualizing the Tensor Slice:**

```text
Before Slice: [ Hello, World, !, Game ]  (Shape 4)
                  |      |    |    |
Slice [ -3: ]:    X    [ World, !, Game ]  (Shape 3)
```

**Memory Result:**
"Hello" is dropped from the cache tuple. The memory used by "Hello" is freed (or marked overwriteable).

-----

### 4\. `current_pos` Flow Chart

This integer drives the RoPE rotation. Here is how it transforms.

| Step | Action | Input Shape | `current_pos` (Start) | `current_pos` (End) | RoPE Indices Applied |
| :--- | :--- | :--- | :--- | :--- | :--- |
| **Prefill** | Process "Hello World" | `[1, 2]` | 0 | 2 | 0, 1 |
| **Gen 1** | Process "\!" | `[1, 1]` | 2 | 3 | 2 |
| **Gen 2** | Process "Game" | `[1, 1]` | 3 | 4 | 3 |

**Crucial Logic in Code:**

```python
# In Olmo3Model.forward
if cache is not None:
    pos_start = self.current_pos      # e.g., 2
    pos_end = pos_start + seq_len     # e.g., 2 + 1 = 3
    self.current_pos = pos_end        # Update global tracker to 3
    
    # ... inside GroupedQueryAttention ...
    # We use pos_start to tell RoPE: "This new token is at index 2"
    queries = apply_rope(queries, cos, sin, offset=start_pos) 
```

### 5\. Summary Diagram: The Cycle of a Token

1.  **Born:** Generated as a `logits` prediction at the end of `forward`.
2.  **Re-entry:** Fed back as `input_ids` in the next loop.
3.  **Projected:** Turned into `K` and `V` vectors in Layer 0.
4.  **Stored:** `torch.cat` appends it to Layer 0 Cache (HBM).
5.  **Rotated:** RoPE rotates it based on `current_pos`.
6.  **Attended:** Used to calculate attention for itself.
7.  **Propagated:** The info flows to Layer 1, where it is projected/cached/rotated again.
8.  **Aged:** In future loops, it sits in the cache. It is re-rotated every time (with a shifting relative position if relative PE was used, but here absolute RoPE is used so it just gets re-fetched).
9.  **Death:** Once it falls out of `sliding_window`, the tensor slice excludes it, and it vanishes from VRAM.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg, attn_type):
        super().__init__()
        self.attn_type = attn_type
        self.sliding_window = cfg["sliding_window"]
        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            num_heads=cfg["n_heads"],
            num_kv_groups=cfg["n_kv_heads"],
            head_dim=cfg["head_dim"],
            attention_bias=cfg["attention_bias"],
            dtype=cfg["dtype"],
            sliding_window=cfg["sliding_window"],
            attn_type=attn_type,
        )
        self.ff = FeedForward(cfg)
        self.post_attention_layernorm = RMSNorm(cfg["emb_dim"], eps=cfg["rms_norm_eps"])
        self.post_feedforward_layernorm = RMSNorm(cfg["emb_dim"], eps=cfg["rms_norm_eps"])

    def forward(self, x, mask_global, mask_local, cos, sin, start_pos=0, cache=None):
        shortcut = x  # Save input for residual connection
        
        # --- SELECT MASK ---
        # If this layer is "sliding", use the local mask (band).
        # Otherwise, use the global mask (triangle).
        if self.attn_type == "sliding_attention":
            # (Logic to slice the mask to fit the current window size)
            if cache is not None and isinstance(cache, tuple):
                prev_k, _ = cache
                prev_len = prev_k.size(2) if prev_k is not None else 0
            else:
                prev_len = 0
            eff_kv_len = prev_len + x.size(1)
            attn_mask = mask_local[..., -eff_kv_len:]
        else:
            attn_mask = mask_global

        # 1. ATTENTION BLOCK
        x_attn, next_cache = self.att(x, attn_mask, cos, sin, start_pos=start_pos, cache=cache)
        
        # 2. SLIDING WINDOW CACHE TRUNCATION
        # If the cache gets too big (older than window), chop it off.
        if next_cache is not None and self.attn_type == "sliding_attention":
            k, v = next_cache
            if k.size(2) > self.sliding_window:
                k = k[:, :, -self.sliding_window:, :] # Slice the tensor
                v = v[:, :, -self.sliding_window:, :]
            next_cache = (k, v)

        # 3. NORM & RESIDUAL
        # Norm is applied to the ATTENTION OUTPUT, not the input.
        # Formula: x = x + Norm(Attention(x))
        x_attn = self.post_attention_layernorm(x_attn)
        x = shortcut + x_attn # Add Skip Connection

        # 4. FEEDFORWARD BLOCK
        shortcut = x # Save new state for next residual
        x_ffn = self.ff(x)
        x_ffn = self.post_feedforward_layernorm(x_ffn) # Branch Norm
        x = shortcut + x_ffn # Add Skip Connection
        
        return x, next_cache

This is the **Cartographer's Guide** to the OLMo 3 Masking system.

In Transformer mechanics, the **Mask** is the "Rulebook of Visibility." It dictates exactly which tokens are allowed to "see" (attend to) which other tokens.

We will break down `create_masks` into its atomic geometric operations.

-----

### 1\. First Principles: The Geometry of Attention

Think of Attention as a grid where:

  * **Rows ($i$):** The Current Token (The one "looking").
  * **Columns ($j$):** The Context Tokens (The ones being "looked at").
  * **Value:** `0` = Visible (Clear Line of Sight). `1` = Masked (Wall).

We are building two specific masks here:

1.  **`mask_global`**: A standard "Causal" mask. You can see everything in the past. You cannot see the future.
2.  **`mask_local`**: A "Sliding Window" mask. You can only see the recent past (e.g., last 4 tokens). Everything else is darkness.

-----

### 2\. The Code Execution: Step-by-Step Visualization

Let's simulate the environment variables:

  * **`total_len`** = 8
  * **`sliding_window`** = 4

#### **Step A: The Blank Canvas**

```python
ones = torch.ones((total_len, total_len), dtype=torch.bool, device=device)
```

We start with a square of `True` (which we will treat as 1s).
*(The code logic creates ones, then uses `triu` to keep some, but let's visualize the resulting Boolean logic where 1 means "Masked/Blocked").*

#### **Step B: The Global Mask (No Future)**

```python
mask_global_full = torch.triu(ones, diagonal=1)
```

**`torch.triu`** (Triangular Upper) keeps values *above* the diagonal.
**`diagonal=1`** means we start 1 step above the main diagonal.

**The Logic:** "If Column ($j$) is greater than Row ($i$), Block it."

**Visual Matrix ($8 \times 8$):**

```text
      j (Keys) ->
      0 1 2 3 4 5 6 7
i 0 | 0 1 1 1 1 1 1 1 |  <- Token 0 can only see 0. Future (1-7) is blocked.
  1 | 0 0 1 1 1 1 1 1 |
  2 | 0 0 0 1 1 1 1 1 |
  3 | 0 0 0 0 1 1 1 1 |
  4 | 0 0 0 0 0 1 1 1 |
  5 | 0 0 0 0 0 0 1 1 |
  6 | 0 0 0 0 0 0 0 1 |
  7 | 0 0 0 0 0 0 0 0 |  <- Token 7 sees everything.
```

*This is the standard mask used in almost every GPT model.*

#### **Step C: The Far Past Mask (The Sliding Window)**

This is the trickiest line in the code.

```python
# 1. Create Upper Triangle shifted by 'window' (4)
# 2. Transpose (.T) it to make it Lower Left
far_past_full = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T
```

**Atomic Breakdown:**

1.  **`triu(..., diagonal=4)`**: Keeps ones where $j \ge i + 4$. (Top right corner).
2.  **`.T` (Transpose)**: Flips $i$ and $j$. Now it keeps ones where $i \ge j + 4$. (Bottom left corner).

**The Logic:** "If Row ($i$) is more than 4 steps ahead of Column ($j$), Block it."

**Visual Matrix:**

```text
      j (Keys) ->
      0 1 2 3 4 5 6 7
i 0 | 0 0 0 0 0 0 0 0 |
  1 | 0 0 0 0 0 0 0 0 |
  2 | 0 0 0 0 0 0 0 0 |
  3 | 0 0 0 0 0 0 0 0 |
  4 | 1 0 0 0 0 0 0 0 |  <- Token 4 is >4 steps from 0. BLOCK 0.
  5 | 1 1 0 0 0 0 0 0 |  <- Token 5 blocks 0 and 1.
  6 | 1 1 1 0 0 0 0 0 |
  7 | 1 1 1 1 0 0 0 0 |
```

#### **Step D: The Local Mask (Combination)**

```python
mask_local_full = mask_global_full | far_past_full
```

**The Logic:** "Block if it is the Future **OR** if it is the Far Past."
This creates the **Band Diagonal** matrix.

**Visual Matrix:**

```text
      j (Keys) ->
      0 1 2 3 4 5 6 7
i 0 | 0 1 1 1 1 1 1 1 |
  1 | 0 0 1 1 1 1 1 1 |
  2 | 0 0 0 1 1 1 1 1 |
  3 | 0 0 0 0 1 1 1 1 |
  4 | 1 0 0 0 0 1 1 1 |  <- Visible window: [1, 2, 3, 4]
  5 | 1 1 0 0 0 0 1 1 |  <- Visible window: [2, 3, 4, 5]
  6 | 1 1 1 0 0 0 0 1 |
  7 | 1 1 1 1 0 0 0 0 |
```

-----

### 3\. The "Camera View": Slicing for Inference

The code calculates the *full* matrix (inefficiently, but correctly), but then it uses **slices** to pick only the rows relevant to the *current* tokens being processed.

```python
row_slice = slice(pos_start, pos_end)
mask_local = mask_local_full[row_slice, :pos_end][None, None, :, :]
```

#### **Scenario 1: The "Prefill" (First Prompt)**

**User:** "Hello World"

  * `pos_start` = 0
  * `pos_end` = 2
  * `slice` = `0:2` (Rows 0 and 1)

**Resulting Mask Tensor:**

```text
Row 0: [0, 1] (Hello sees Hello, Block World)
Row 1: [0, 0] (World sees Hello, World)
```

#### **Scenario 2: The "Generation" (Token 5)**

We are generating the 6th token (Index 5). We have cache for indices 0-4.

  * `pos_start` = 5
  * `pos_end` = 6
  * `slice` = `5:6` (Only Row 5\!)

**The Logic:**
The code generates the full 6x6 matrix (indices 0 to 5).
Then it extracts **Row 5** and columns **0 to 5**.

**The View for Token 5:**
Looking at `mask_local_full` Row 5 above:
`[1, 1, 0, 0, 0, 0]`

  * Index 0: `1` (Masked - Too old\!)
  * Index 1: `1` (Masked - Too old\!)
  * Index 2: `0` (Visible)
  * Index 3: `0` (Visible)
  * Index 4: `0` (Visible)
  * Index 5: `0` (Visible - Self)

**Interpretation:** Token 5 can only attend to indices `{2, 3, 4, 5}`. It effectively "forgets" `{0, 1}`.

-----

### 4\. Tensor Trace: The `create_masks` Call

Let's trace the exact shapes during the generation of that Token 5.

1.  **Inputs:** `cur_len=1` (Input is 1 token), `pos_start=5`.
2.  **Calculated:** `pos_end = 5 + 1 = 6`.
3.  **`total_len`** = 6.
4.  **`ones`**: Shape `[6, 6]`.
5.  **`mask_global_full`**: Shape `[6, 6]`.
6.  **`far_past_full`**: Shape `[6, 6]`.
7.  **`mask_local_full`**: Shape `[6, 6]`.
8.  **Slicing:** `mask_local_full[5:6, :6]`.
      * Takes Row 5.
      * Takes Cols 0-6.
      * Result Shape: `[1, 6]`.
9.  **Broadcasting:** `[None, None, :, :]`.
      * Adds Batch and Head dimensions.
      * Final Shape: `[1, 1, 1, 6]`.

**Application in `GroupedQueryAttention`:**

```python
attn_scores = queries @ keys.transpose(2, 3) 
# attn_scores shape: [1, 32, 1, 6] 
# (Batch 1, 32 Heads, 1 Query Token, 6 Key Tokens)

attn_scores = attn_scores.masked_fill(mask, -inf)
```

The mask `[1, 1, 1, 6]` broadcasts perfectly over the 32 heads. It puts `-infinity` on indices 0 and 1.

-----

### 5\. Why `diagonal=window` and `.T`? (Deep Dive)

Why does `triu(ones, diagonal=4).T` create the "Far Past" mask?

**Let's analyze the math of `triu`:**

  * `triu` preserves values where $col \ge row + k$.
  * If $k=4$: $j \ge i + 4$.
  * This means "Column index is 4 greater than Row index". This is **Far Future**.

**Let's analyze the math of Transpose (`.T`):**

  * Transpose swaps $i$ and $j$.
  * So the condition becomes: $i \ge j + 4$.
  * Rearranged: $i - j \ge 4$.
  * "Row index is at least 4 steps ahead of Column index".
  * This is **Far Past**.

**Why use this trick?**
It is extremely concise. Instead of writing a double for-loop in Python (which is slow), or writing complex index arithmetic:

```python
# Slow / Complex
mask[i, j] = (i - j) >= sliding_window
```

They use optimized Linear Algebra primitives (`triu` and `transpose`) which run instantly on the GPU C++ kernel.

### 6\. Summary of the "Mask View"

| Mask Type | Shape (Logical) | Meaning | Visual |
| :--- | :--- | :--- | :--- |
| **Global** | Upper Triangle | "Don't cheat (No Future)" | â—£ (Filled bottom-left) |
| **Far Past** | Lower Triangle | "Don't remember (No Old History)" | â—¢ (Filled top-right\*) <br> *(Technically corner cut off)* |
| **Local** | Band | "Only recent history" | Diagonal Stripe |

**Correction on Visuals:**

  * Global Mask (Triu 1) = **Zeros** on bottom left. **Ones** (Masked) on top right.
  * Far Past (Triu W + T) = **Ones** (Masked) on bottom left corner. **Zeros** on top right.
  * Combined = **Ones** on both corners. **Zeros** only in the middle band.

In [9]:
class Olmo3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        assert cfg["layer_types"] is not None and len(cfg["layer_types"]) == cfg["n_layers"]

        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
        self.blocks = nn.ModuleList([TransformerBlock(cfg, attn_type) for attn_type in cfg["layer_types"]])
        self.final_norm = RMSNorm(cfg["emb_dim"], eps=cfg["rms_norm_eps"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
        self.cfg = cfg
        self.current_pos = 0

        cos, sin = compute_rope_params(
            head_dim=cfg["head_dim"],
            context_length=cfg["context_length"],
            theta_base=cfg["rope_base"],
            attention_factor=cfg["rope_attention_factor"],
            rope_type=cfg["rope_type"],
            rope_factor=cfg["rope_factor"],
            rope_orig_max=cfg["rope_orig_max"],
            dtype=torch.float32,
        )
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)

    def create_masks(self, cur_len, device, pos_start=0, pos_end=None):
        if pos_end is None:
            pos_end = cur_len
        total_len = pos_end

        ones = torch.ones((total_len, total_len), dtype=torch.bool, device=device)
        # mask_global_full (future is masked: j > i)
        #     j:  0 1 2 3 4 5 6 7
        #  i
        #     0:  0 1 1 1 1 1 1 1
        #     1:  0 0 1 1 1 1 1 1
        #     2:  0 0 0 1 1 1 1 1
        #     3:  0 0 0 0 1 1 1 1
        #     4:  0 0 0 0 0 1 1 1
        #     5:  0 0 0 0 0 0 1 1
        #     6:  0 0 0 0 0 0 0 1
        #     7:  0 0 0 0 0 0 0 0
        mask_global_full = torch.triu(ones, diagonal=1)

        # far_past (too far back is masked: i - j >= sliding_window)
        # where sliding_window = 4
        #     j:  0 1 2 3 4 5 6 7
        #  i
        #     0:  0 0 0 0 0 0 0 0
        #     1:  0 0 0 0 0 0 0 0
        #     2:  0 0 0 0 0 0 0 0
        #     3:  0 0 0 0 0 0 0 0
        #     4:  1 0 0 0 0 0 0 0
        #     5:  1 1 0 0 0 0 0 0
        #     6:  1 1 1 0 0 0 0 0
        #     7:  1 1 1 1 0 0 0 0
        far_past_full = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T

        # Local (sliding_window) = future OR far-past
        # mask_local
        #     j:  0 1 2 3 4 5 6 7
        # i
        # 0:      0 1 1 1 1 1 1 1
        # 1:      0 0 1 1 1 1 1 1
        # 2:      0 0 0 1 1 1 1 1
        # 3:      0 0 0 0 1 1 1 1
        # 4:      1 0 0 0 0 1 1 1
        # 5:      1 1 0 0 0 0 1 1
        # 6:      1 1 1 0 0 0 0 1
        # 7:      1 1 1 1 0 0 0 0
        mask_local_full = mask_global_full | far_past_full

        row_slice = slice(pos_start, pos_end)
        mask_global = mask_global_full[row_slice, :pos_end][None, None, :, :]
        mask_local = mask_local_full[row_slice, :pos_end][None, None, :, :]
        return mask_global, mask_local

    def forward(self, input_ids, cache=None):
        b, seq_len = input_ids.shape
        x = self.tok_emb(input_ids)

        if cache is not None:
            pos_start = self.current_pos
            pos_end = pos_start + seq_len
            self.current_pos = pos_end
            mask_global, mask_local = self.create_masks(
                cur_len=seq_len, device=x.device, pos_start=pos_start, pos_end=pos_end
            )
        else:
            pos_start = 0
            mask_global, mask_local = self.create_masks(
                cur_len=seq_len, device=x.device, pos_start=0, pos_end=seq_len
            )

        cos = self.cos
        sin = self.sin

        for i, block in enumerate(self.blocks):
            blk_cache = cache.get(i) if cache is not None else None
            x, new_blk_cache = block(
                x,
                mask_global=mask_global,
                mask_local=mask_local,
                cos=cos,
                sin=sin,
                start_pos=pos_start,
                cache=blk_cache,
            )

            if cache is not None:
                cache.update(i, new_blk_cache)

        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg["dtype"]))
        return logits

    def reset_kv_cache(self):
        self.current_pos = 0

In [10]:
class KVCache:
    def __init__(self, n_layers):
        self.cache = [None] * n_layers

    def get(self, layer_idx):
        return self.cache[layer_idx]

    def update(self, layer_idx, value):
        self.cache[layer_idx] = value

    def get_all(self):
        return self.cache

    def reset(self):
        for i in range(len(self.cache)):
            self.cache[i] = None

&nbsp;
# 2. Initialize model

In [11]:
OLMO3_CONFIG_7B = {
    "vocab_size": 100_278,
    "context_length": 65_536,
    "emb_dim": 4_096,
    "n_heads": 32,
    "n_layers": 32,
    "hidden_dim": 11_008,
    "head_dim": 128,
    "n_kv_heads": 32,
    "attention_bias": False,
    "attention_dropout": 0.0,
    "sliding_window": 4_096,
    "layer_types": [
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
    ],
    "rope_base": 500_000.0,
    "rope_attention_factor": 1.2079441541679836,
    "rope_type": "yarn",
    "rope_factor": 8.0,
    "rope_orig_max": 8_192,
    "rms_norm_eps": 1e-6,
    "dtype": torch.bfloat16,
    "eos_token_id": 100_257,
    "pad_token_id": 100_277,
}

OLMO3_CONFIG_32B = {
    "vocab_size": 100_278,
    "context_length": 65_536,
    "emb_dim": 5_120,
    "n_heads": 40,
    "n_layers": 64,
    "hidden_dim": 27_648,
    "head_dim": 128,
    "n_kv_heads": 8,
    "attention_bias": False,
    "attention_dropout": 0.0,
    "sliding_window": 4_096,
    "layer_types": [
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
        "sliding_attention",
        "sliding_attention",
        "sliding_attention",
        "full_attention",
    ],
    "rope_base": 500_000.0,
    "rope_attention_factor": 1.2079441541679836,
    "rope_type": "yarn",
    "rope_factor": 8.0,
    "rope_orig_max": 8_192,
    "rms_norm_eps": 1e-6,
    "dtype": torch.bfloat16,
    "eos_token_id": 100_257,
    "pad_token_id": 100_277,
}

OLMO3_CONFIG = OLMO3_CONFIG_32B if "32B" in USE_MODEL else OLMO3_CONFIG_7B

In [12]:
torch.manual_seed(123)
model = Olmo3Model(OLMO3_CONFIG)

In [13]:
model

Olmo3Model(
  (tok_emb): Embedding(100278, 4096)
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (att): GroupedQueryAttention(
        (W_query): Linear(in_features=4096, out_features=4096, bias=False)
        (W_key): Linear(in_features=4096, out_features=4096, bias=False)
        (W_value): Linear(in_features=4096, out_features=4096, bias=False)
        (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (q_norm): RMSNorm()
        (k_norm): RMSNorm()
      )
      (ff): FeedForward(
        (fc1): Linear(in_features=4096, out_features=11008, bias=False)
        (fc2): Linear(in_features=4096, out_features=11008, bias=False)
        (fc3): Linear(in_features=11008, out_features=4096, bias=False)
      )
      (post_attention_layernorm): RMSNorm()
      (post_feedforward_layernorm): RMSNorm()
    )
  )
  (final_norm): RMSNorm()
  (out_head): Linear(in_features=4096, out_features=100278, bias=False)
)

- A quick check that the forward pass works before continuing:

In [14]:
model(torch.tensor([1, 2, 3]).unsqueeze(0))

tensor([[[ 0.3594, -0.6289, -0.2754,  ...,  1.1016,  0.4219,  0.0381],
         [ 1.1719,  0.0283,  0.6055,  ...,  0.4863, -0.1953,  0.2246],
         [ 0.4902, -0.0425,  0.6758,  ...,  0.3730, -0.5781, -0.1670]]],
       dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)

In [15]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model.to(device);

    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    


&nbsp;
# 4. Load pretrained weights

In [16]:
def load_weights_into_olmo(model, param_config, params):
    def assign(left, right, tensor_name="unknown"):
        if left.shape != right.shape:
            raise ValueError(
                f"Shape mismatch in tensor '{tensor_name}'. "
                f"Left: {left.shape}, Right: {right.shape}"
            )
        
        with torch.no_grad():
            if isinstance(right, torch.Tensor):
                left.copy_(right)
            else:
                left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))
        
        return left

    # Token embedding
    if "model.embed_tokens.weight" in params:
        model.tok_emb.weight = assign(
            model.tok_emb.weight,
            params["model.embed_tokens.weight"],
            "model.embed_tokens.weight",
        )

    for l in range(param_config["n_layers"]):
        block = model.blocks[l]
        att = block.att

        # Q, K, V projections
        att.W_query.weight = assign(
            att.W_query.weight,
            params[f"model.layers.{l}.self_attn.q_proj.weight"],
            f"model.layers.{l}.self_attn.q_proj.weight",
        )
        att.W_key.weight = assign(
            att.W_key.weight,
            params[f"model.layers.{l}.self_attn.k_proj.weight"],
            f"model.layers.{l}.self_attn.k_proj.weight",
        )
        att.W_value.weight = assign(
            att.W_value.weight,
            params[f"model.layers.{l}.self_attn.v_proj.weight"],
            f"model.layers.{l}.self_attn.v_proj.weight",
        )

        # Output projection
        att.out_proj.weight = assign(
            att.out_proj.weight,
            params[f"model.layers.{l}.self_attn.o_proj.weight"],
            f"model.layers.{l}.self_attn.o_proj.weight",
        )

        # QK norms
        att.q_norm.weight = assign(
            att.q_norm.weight,
            params[f"model.layers.{l}.self_attn.q_norm.weight"],
            f"model.layers.{l}.self_attn.q_norm.weight",
        )
        att.k_norm.weight = assign(
            att.k_norm.weight,
            params[f"model.layers.{l}.self_attn.k_norm.weight"],
            f"model.layers.{l}.self_attn.k_norm.weight",
        )

        # Feedforward weights
        block.ff.fc1.weight = assign(
            block.ff.fc1.weight,
            params[f"model.layers.{l}.mlp.gate_proj.weight"],
            f"model.layers.{l}.mlp.gate_proj.weight",
        )
        block.ff.fc2.weight = assign(
            block.ff.fc2.weight,
            params[f"model.layers.{l}.mlp.up_proj.weight"],
            f"model.layers.{l}.mlp.up_proj.weight",
        )
        block.ff.fc3.weight = assign(
            block.ff.fc3.weight,
            params[f"model.layers.{l}.mlp.down_proj.weight"],
            f"model.layers.{l}.mlp.down_proj.weight",
        )

        # Post-attention and post norms
        block.post_attention_layernorm.weight = assign(
            block.post_attention_layernorm.weight,
            params[f"model.layers.{l}.post_attention_layernorm.weight"],
            f"model.layers.{l}.post_attention_layernorm.weight",
        )
        block.post_feedforward_layernorm.weight = assign(
            block.post_feedforward_layernorm.weight,
            params[f"model.layers.{l}.post_feedforward_layernorm.weight"],
            f"model.layers.{l}.post_feedforward_layernorm.weight",
        )

    # Final normalization and output head
    if "model.norm.weight" in params:
        model.final_norm.weight = assign(
            model.final_norm.weight,
            params["model.norm.weight"],
            "model.norm.weight",
        )

    if "lm_head.weight" in params:
        model.out_head.weight = assign(
            model.out_head.weight,
            params["lm_head.weight"],
            "lm_head.weight",
        )
    else:
        model.out_head.weight = model.tok_emb.weight
        print("Model uses weight tying.")

In [17]:
import json
import os
from pathlib import Path
from safetensors.torch import load_file
from huggingface_hub import snapshot_download

repo_id = f"allenai/{USE_MODEL}"
local_dir = Path(repo_id).parts[-1]

repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)
index_path = os.path.join(repo_dir, "model.safetensors.index.json")
with open(index_path, "r") as f:
    index = json.load(f)

weights_dict = {}
for filename in sorted(set(index["weight_map"].values())):
    shard_path = os.path.join(repo_dir, filename)
    shard = load_file(shard_path)
    weights_dict.update(shard)

load_weights_into_olmo(model, OLMO3_CONFIG, weights_dict)
model.to(device)
del weights_dict

Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

&nbsp;
# 4. Load tokenizer

In [18]:
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download


class OlmoTokenizer:
    def __init__(self, tokenizer_file_path, eos_token_id, pad_token_id):
        tok_file = Path(tokenizer_file_path)
        self._tok = Tokenizer.from_file(str(tok_file))
        eos_from_tok = (
            self._tok.token_to_id("<|endoftext|>")
            or self._tok.token_to_id("<end_of_turn>")
        )
        self.eos_token_id = eos_from_tok if eos_from_tok is not None else eos_token_id
        pad_from_tok = (
            self._tok.token_to_id("<|pad|>")
            or self._tok.token_to_id("<pad>")
        )
        self.pad_token_id = pad_from_tok if pad_from_tok is not None else pad_token_id

    def encode(self, text):
        return self._tok.encode(text).ids

    def decode(self, ids):
        return self._tok.decode(ids, skip_special_tokens=False)


def apply_chat_template(user_text):
    return (
        "<|im_start|>user\n"
        f"{user_text}\n"
        "<|im_end|>\n"
        "<|im_start|>assistant\n"
    )


tokenizer_file_path = os.path.join(local_dir, "tokenizer.json")
if not os.path.exists(tokenizer_file_path):
    try:
        tokenizer_file_path = hf_hub_download(repo_id=repo_id, filename="tokenizer.json", local_dir=local_dir)
    except Exception as e:
        print(f"Warning: failed to download tokenizer.json: {e}")
        tokenizer_file_path = "tokenizer.json"

tokenizer = OlmoTokenizer(
    tokenizer_file_path=tokenizer_file_path,
    eos_token_id=OLMO3_CONFIG["eos_token_id"],
    pad_token_id=OLMO3_CONFIG["pad_token_id"],
)

In [19]:
prompt = apply_chat_template("Give me a short intro to large language models in 3 sentences.")

input_token_ids = tokenizer.encode(prompt)
text = tokenizer.decode(input_token_ids)
text

'<|im_start|>user\nGive me a short intro to large language models in 3 sentences.\n<|im_end|>\n<|im_start|>assistant\n'

&nbsp;
# 5. Generate text

In [20]:
def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):

    model.eval()
    with torch.no_grad():
        cache = KVCache(n_layers=model.cfg["n_layers"])
        model.reset_kv_cache()

        logits = model(token_ids, cache=cache)

        for _ in range(max_new_tokens):
            next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)

            if (eos_token_id is not None
                   and torch.all(next_token == eos_token_id)):
               break

            yield next_token

            token_ids = torch.cat([token_ids, next_token], dim=1)

            logits = model(next_token, cache=cache)

In [21]:
input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)


if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()


for token in generate_text_basic_stream(
    model=model,
    token_ids=input_token_ids_tensor,
    max_new_tokens=500,
    eos_token_id=tokenizer.eos_token_id
):
    token_id = token.squeeze(0).tolist()
    print(
        tokenizer.decode(token_id),
        end="",
        flush=True
    )

if torch.cuda.is_available():
    def gpu_gb(x):
        return f"{x / 1024 / 1024 / 1024:.2f} GB"
    
    print(f"\n\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}")

Sure! Hereâ€™s a brief introduction to large language models:  
Large models are advanced AI systems trained to process vast neural networks capable of understanding and generating text, learning from vast amounts of data, learning language, performing diverse tasks, assisting in many applications, and adapting various tasks.

GPU memory used: 13.71 GB


&nbsp;
# What's next?

- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)

<a href="http://mng.bz/orYv"><img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp" width="100px"></a>