> **DEPRECATED:** This task is outdated and may not reflect current best practices.
> See `causalab/tasks/MCQA/` for an up-to-date example.

# Token Positions for Addition: Llama vs Gemma

This notebook demonstrates how different language models tokenize numbers differently, and why this matters for causal interventions.

## The Critical Question

When we want to intervene on "the tens digit of 23", which tokens do we intervene on?

The answer is **model-dependent**:
- **Llama 3.1 8B**: Tokenizes "23" as a **single token** → we intervene on the entire number
- **Gemma 2 9B**: Tokenizes "23" as **two tokens** ("2" and "3") → we can intervene on individual digits

This notebook shows these differences concretely and explains their implications for experiments.

## Setup

In [1]:
# Autoreload
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("../..")

import torch
from causalab.neural.pipeline import LMPipeline
from causalab.tasks.general_addition.config import (
    create_two_number_two_digit_config,
    create_two_number_three_digit_config,
    create_general_config,
)
from causalab.tasks.general_addition.causal_models import (
    create_basic_addition_model,
    sample_valid_addition_input,
)
from causalab.tasks.general_addition.token_positions import (
    get_digit_token_position,
    create_token_positions,
)
from causalab.tasks.general_addition.experiments.tokenization_config import (
    get_tokens_per_number,
    TOKENIZATION_CONFIG,
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Part 1: Loading the Models

We'll load pipeline objects for both Llama 3.1 8B and Gemma 2 9B. The pipeline includes the tokenizer, which is what we need to see how each model breaks text into tokens.

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading Llama 3.1 8B pipeline...")
llama_pipeline = LMPipeline(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    max_new_tokens=5,
    device=device,
    dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    max_length=256,
)

print("\nLoading Gemma 2 9B pipeline...")
gemma_pipeline = LMPipeline(
    "google/gemma-2-9b",
    max_new_tokens=5,
    device=device,
    dtype=torch.bfloat16 if device == "cuda" else torch.float32,
    max_length=256,
)

print("\n✓ Both models ready!")

Loading Llama 3.1 8B pipeline...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


Loading Gemma 2 9B pipeline...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]


✓ Both models ready!


## Part 2: Create an Example Prompt (2-Digit Addition)

Let's create a simple addition problem: 23 + 45 = 68

In [3]:
# Create configuration and model
config = create_two_number_two_digit_config()
model = create_basic_addition_model(config)

# Create a specific example: 23 + 45
input_sample = {
    "digit_0_0": 2,
    "digit_0_1": 3,  # 23
    "digit_1_0": 4,
    "digit_1_1": 5,  # 45
    "num_addends": 2,
    "num_digits": 2,
    "template": config.templates[0],
}

# Generate the prompt
output = model.new_trace(input_sample)
prompt = output['raw_input']
input_sample["raw_input"] = output['raw_input']

print(f"Example prompt: '{prompt}'")
print(f"Expected answer: '{output['raw_output'].strip()}'")
print("\nBreakdown: We have two numbers: '23' and '45'")
print("  - digit_0_0 = 2 (tens place of first number)")
print("  - digit_0_1 = 3 (ones place of first number)")
print("  - digit_1_0 = 4 (tens place of second number)")
print("  - digit_1_1 = 5 (ones place of second number)")

Example prompt: 'The sum of 23 and 45 is'
Expected answer: '68'

Breakdown: We have two numbers: '23' and '45'
  - digit_0_0 = 2 (tens place of first number)
  - digit_0_1 = 3 (ones place of first number)
  - digit_1_0 = 4 (tens place of second number)
  - digit_1_1 = 5 (ones place of second number)


## Part 3: Llama 3.1 8B Tokenization

Let's see how Llama tokenizes this prompt.

In [4]:
# Tokenize with Llama using pipeline.load() for consistency with get_digit_token_position
llama_token_dict = llama_pipeline.load({"raw_input": prompt}, add_special_tokens=False)
llama_tokens = llama_token_dict["input_ids"][0].tolist()

print("Llama 3.1 8B Tokenization:")
print("=" * 70)
print(f"Total tokens: {len(llama_tokens)}\n")

for i, token_id in enumerate(llama_tokens):
    token_str = llama_pipeline.tokenizer.decode([token_id])
    marker = ""
    if "23" in token_str:
        marker = " <-- FIRST NUMBER (entire '23')"
    elif "45" in token_str:
        marker = " <-- SECOND NUMBER (entire '45')"
    print(f"Token {i}: {repr(token_str):20s} (ID: {token_id}){marker}")

print("\n" + "=" * 70)
print("KEY OBSERVATION: '23' and '45' are each encoded as SINGLE TOKENS")
print("This means we cannot intervene on individual digits!")
print("=" * 70)

Llama 3.1 8B Tokenization:
Total tokens: 256

Token 0: '<|eot_id|>'         (ID: 128009)
Token 1: '<|eot_id|>'         (ID: 128009)
Token 2: '<|eot_id|>'         (ID: 128009)
Token 3: '<|eot_id|>'         (ID: 128009)
Token 4: '<|eot_id|>'         (ID: 128009)
Token 5: '<|eot_id|>'         (ID: 128009)
Token 6: '<|eot_id|>'         (ID: 128009)
Token 7: '<|eot_id|>'         (ID: 128009)
Token 8: '<|eot_id|>'         (ID: 128009)
Token 9: '<|eot_id|>'         (ID: 128009)
Token 10: '<|eot_id|>'         (ID: 128009)
Token 11: '<|eot_id|>'         (ID: 128009)
Token 12: '<|eot_id|>'         (ID: 128009)
Token 13: '<|eot_id|>'         (ID: 128009)
Token 14: '<|eot_id|>'         (ID: 128009)
Token 15: '<|eot_id|>'         (ID: 128009)
Token 16: '<|eot_id|>'         (ID: 128009)
Token 17: '<|eot_id|>'         (ID: 128009)
Token 18: '<|eot_id|>'         (ID: 128009)
Token 19: '<|eot_id|>'         (ID: 128009)
Token 20: '<|eot_id|>'         (ID: 128009)
Token 21: '<|eot_id|>'         (ID: 1280

## Part 4: Gemma 2 9B Tokenization

Now let's see how Gemma tokenizes the same prompt.

In [5]:
# Tokenize with Gemma using pipeline.load() for consistency with get_digit_token_position
gemma_token_dict = gemma_pipeline.load({"raw_input": prompt}, add_special_tokens=False)
gemma_tokens = gemma_token_dict["input_ids"][0].tolist()

print("Gemma 2 9B Tokenization:")
print("=" * 70)
print(f"Total tokens: {len(gemma_tokens)}\n")

for i, token_id in enumerate(gemma_tokens):
    token_str = gemma_pipeline.tokenizer.decode([token_id])
    marker = ""
    if token_str.strip() in ["2", "3"]:
        if i < len(gemma_tokens) // 2:
            marker = f" <-- FIRST NUMBER digit '{token_str.strip()}'"
        else:
            marker = " <-- (part of context or number)"
    elif token_str.strip() in ["4", "5"]:
        marker = f" <-- SECOND NUMBER digit '{token_str.strip()}'"
    print(f"Token {i}: {repr(token_str):20s} (ID: {token_id}){marker}")

print("\n" + "=" * 70)
print("KEY OBSERVATION: Numbers are tokenized DIGIT-BY-DIGIT")
print("'23' becomes two tokens, '45' becomes two tokens")
print("This means we CAN intervene on individual digits!")
print("=" * 70)

Gemma 2 9B Tokenization:
Total tokens: 256

Token 0: '<eos>'              (ID: 1)
Token 1: '<eos>'              (ID: 1)
Token 2: '<eos>'              (ID: 1)
Token 3: '<eos>'              (ID: 1)
Token 4: '<eos>'              (ID: 1)
Token 5: '<eos>'              (ID: 1)
Token 6: '<eos>'              (ID: 1)
Token 7: '<eos>'              (ID: 1)
Token 8: '<eos>'              (ID: 1)
Token 9: '<eos>'              (ID: 1)
Token 10: '<eos>'              (ID: 1)
Token 11: '<eos>'              (ID: 1)
Token 12: '<eos>'              (ID: 1)
Token 13: '<eos>'              (ID: 1)
Token 14: '<eos>'              (ID: 1)
Token 15: '<eos>'              (ID: 1)
Token 16: '<eos>'              (ID: 1)
Token 17: '<eos>'              (ID: 1)
Token 18: '<eos>'              (ID: 1)
Token 19: '<eos>'              (ID: 1)
Token 20: '<eos>'              (ID: 1)
Token 21: '<eos>'              (ID: 1)
Token 22: '<eos>'              (ID: 1)
Token 23: '<eos>'              (ID: 1)
Token 24: '<eos>'             

## Part 5: Side-by-Side Comparison

Let's create a clear visual comparison of the two tokenization schemes.

In [6]:
print("SIDE-BY-SIDE COMPARISON")
print("=" * 70)
print(f"Prompt: '{prompt}'\n")

print("Llama 3.1 8B (whole-number tokens):")
print("-" * 70)
llama_decoded = [llama_pipeline.tokenizer.decode([tid]) for tid in llama_tokens]
for i, tok in enumerate(llama_decoded):
    print(f"  [{i}] {repr(tok)}")
print(f"  Total: {len(llama_tokens)} tokens")

print("\nGemma 2 9B (digit-by-digit tokens):")
print("-" * 70)
gemma_decoded = [gemma_pipeline.tokenizer.decode([tid]) for tid in gemma_tokens]
for i, tok in enumerate(gemma_decoded):
    print(f"  [{i}] {repr(tok)}")
print(f"  Total: {len(gemma_tokens)} tokens")

print("\n" + "=" * 70)
print("Notice: Both models may have padding tokens (shown as empty strings or <eos>)")
print("The actual text tokens come after padding for Gemma (left-padding)")
print("The actual text tokens come first for Llama (right-padding)")
print("=" * 70)

SIDE-BY-SIDE COMPARISON
Prompt: 'The sum of 23 and 45 is'

Llama 3.1 8B (whole-number tokens):
----------------------------------------------------------------------
  [0] '<|eot_id|>'
  [1] '<|eot_id|>'
  [2] '<|eot_id|>'
  [3] '<|eot_id|>'
  [4] '<|eot_id|>'
  [5] '<|eot_id|>'
  [6] '<|eot_id|>'
  [7] '<|eot_id|>'
  [8] '<|eot_id|>'
  [9] '<|eot_id|>'
  [10] '<|eot_id|>'
  [11] '<|eot_id|>'
  [12] '<|eot_id|>'
  [13] '<|eot_id|>'
  [14] '<|eot_id|>'
  [15] '<|eot_id|>'
  [16] '<|eot_id|>'
  [17] '<|eot_id|>'
  [18] '<|eot_id|>'
  [19] '<|eot_id|>'
  [20] '<|eot_id|>'
  [21] '<|eot_id|>'
  [22] '<|eot_id|>'
  [23] '<|eot_id|>'
  [24] '<|eot_id|>'
  [25] '<|eot_id|>'
  [26] '<|eot_id|>'
  [27] '<|eot_id|>'
  [28] '<|eot_id|>'
  [29] '<|eot_id|>'
  [30] '<|eot_id|>'
  [31] '<|eot_id|>'
  [32] '<|eot_id|>'
  [33] '<|eot_id|>'
  [34] '<|eot_id|>'
  [35] '<|eot_id|>'
  [36] '<|eot_id|>'
  [37] '<|eot_id|>'
  [38] '<|eot_id|>'
  [39] '<|eot_id|>'
  [40] '<|eot_id|>'
  [41] '<|eot_id|>'
  [4

## Part 6: Finding Specific Digit Positions

Now let's use the `get_digit_token_position` function to locate specific digits.

In [7]:
print("Finding digit positions using get_digit_token_position()")
print("=" * 70)

# Find the tens digit of the first number ("2" in "23")
llama_digit_0_0 = get_digit_token_position(input_sample, llama_pipeline, 0, 0)
gemma_digit_0_0 = get_digit_token_position(input_sample, gemma_pipeline, 0, 0)

print("\ndigit_0_0 (tens place of first number - the '2' in '23'):")
print(f"  Llama: tokens {llama_digit_0_0}")
llama_decoded_digit = [
    llama_pipeline.tokenizer.decode([llama_tokens[i]]) for i in llama_digit_0_0
]
print(f"         decoded: {llama_decoded_digit}")
print(f"  Gemma: tokens {gemma_digit_0_0}")
gemma_decoded_digit = [
    gemma_pipeline.tokenizer.decode([gemma_tokens[i]]) for i in gemma_digit_0_0
]
print(f"         decoded: {gemma_decoded_digit}")

# Find the ones digit of the first number ("3" in "23")
llama_digit_0_1 = get_digit_token_position(input_sample, llama_pipeline, 0, 1)
gemma_digit_0_1 = get_digit_token_position(input_sample, gemma_pipeline, 0, 1)

print("\ndigit_0_1 (ones place of first number - the '3' in '23'):")
print(f"  Llama: tokens {llama_digit_0_1}")
llama_decoded_digit = [
    llama_pipeline.tokenizer.decode([llama_tokens[i]]) for i in llama_digit_0_1
]
print(f"         decoded: {llama_decoded_digit}")
print(f"  Gemma: tokens {gemma_digit_0_1}")
gemma_decoded_digit = [
    gemma_pipeline.tokenizer.decode([gemma_tokens[i]]) for i in gemma_digit_0_1
]
print(f"         decoded: {gemma_decoded_digit}")

print("\n" + "=" * 70)
print("CRITICAL INSIGHT:")
print("  Llama: Both digit_0_0 and digit_0_1 return the SAME token (the whole number)")
print("  Gemma: digit_0_0 and digit_0_1 return DIFFERENT tokens (individual digits)")
print("=" * 70)

Finding digit positions using get_digit_token_position()

digit_0_0 (tens place of first number - the '2' in '23'):
  Llama: tokens [251]
         decoded: ['23']
  Gemma: tokens [249, 250]
         decoded: ['2', '3']

digit_0_1 (ones place of first number - the '3' in '23'):
  Llama: tokens [251]
         decoded: ['23']
  Gemma: tokens [249, 250]
         decoded: ['2', '3']

CRITICAL INSIGHT:
  Llama: Both digit_0_0 and digit_0_1 return the SAME token (the whole number)
  Gemma: digit_0_0 and digit_0_1 return DIFFERENT tokens (individual digits)


## Part 7: Implications for Interventions

This tokenization difference has major implications for what we can test with interventions.

In [8]:
print("INTERVENTION IMPLICATIONS")
print("=" * 70)

print("\nScenario: We want to test carry propagation by intervening on the ones digit")
print("\nWith Llama 3.1 8B:")
print("  - We request tokens for 'digit_0_1' (ones place)")
print(f"  - We get: {llama_digit_0_1}")
print(
    f"  - This token represents: {llama_pipeline.tokenizer.decode([llama_tokens[llama_digit_0_1[0]]])}"
)
print("  - ⚠️  We're intervening on the ENTIRE NUMBER '23', not just the ones digit!")
print("  - Interpretation: Tests how the model handles whole-number changes")

print("\nWith Gemma 2 9B:")
print("  - We request tokens for 'digit_0_1' (ones place)")
print(f"  - We get: {gemma_digit_0_1}")
print(
    f"  - This token represents: {gemma_pipeline.tokenizer.decode([gemma_tokens[gemma_digit_0_1[0]]])}"
)
print("  - ✓ We're intervening on JUST the ones digit!")
print("  - Interpretation: Tests how the model handles individual digit changes")

print("\n" + "=" * 70)
print("Both are valid experiments, but they test different things!")
print("=" * 70)

INTERVENTION IMPLICATIONS

Scenario: We want to test carry propagation by intervening on the ones digit

With Llama 3.1 8B:
  - We request tokens for 'digit_0_1' (ones place)
  - We get: [251]
  - This token represents: 23
  - ⚠️  We're intervening on the ENTIRE NUMBER '23', not just the ones digit!
  - Interpretation: Tests how the model handles whole-number changes

With Gemma 2 9B:
  - We request tokens for 'digit_0_1' (ones place)
  - We get: [249, 250]
  - This token represents: 2
  - ✓ We're intervening on JUST the ones digit!
  - Interpretation: Tests how the model handles individual digit changes

Both are valid experiments, but they test different things!


## Part 8: 3-Digit Numbers (123 + 456)

Let's see how the models handle larger numbers.

In [9]:
# Create 3-digit example
config_3d = create_two_number_three_digit_config()
model_3d = create_basic_addition_model(config_3d)

input_sample_3d = {
    "digit_0_0": 1,
    "digit_0_1": 2,
    "digit_0_2": 3,  # 123
    "digit_1_0": 4,
    "digit_1_1": 5,
    "digit_1_2": 6,  # 456
    "num_addends": 2,
    "num_digits": 3,
    "template": config_3d.templates[0],
}

output_3d = model_3d.new_trace(input_sample_3d)
prompt_3d = output_3d['raw_input']
input_sample_3d['raw_input'] = prompt_3d  # Add raw_input to sample for later use

print(f"3-Digit Example: '{prompt_3d}'\n")

# Tokenize with both models using pipeline.load()
llama_token_dict_3d = llama_pipeline.load(
    {"raw_input": prompt_3d}, add_special_tokens=False
)
llama_tokens_3d = llama_token_dict_3d["input_ids"][0].tolist()

gemma_token_dict_3d = gemma_pipeline.load(
    {"raw_input": prompt_3d}, add_special_tokens=False
)
gemma_tokens_3d = gemma_token_dict_3d["input_ids"][0].tolist()

print("Llama 3.1 8B (3-digit):")
for i, tok_id in enumerate(llama_tokens_3d):
    tok = llama_pipeline.tokenizer.decode([tok_id])
    marker = " <-- NUMBER" if any(d in tok for d in ["123", "456"]) else ""
    print(f"  [{i}] {repr(tok)}{marker}")

print("\nGemma 2 9B (3-digit):")
for i, tok_id in enumerate(gemma_tokens_3d):
    tok = gemma_pipeline.tokenizer.decode([tok_id])
    marker = " <-- DIGIT" if tok.strip() in ["1", "2", "3", "4", "5", "6"] else ""
    print(f"  [{i}] {repr(tok)}{marker}")

print("\nObservation:")
print("  Llama: STILL uses 1 token per number! ('123' and '456' are single tokens)")
print("  Gemma: Uses 3 tokens per number (digit-by-digit)")

3-Digit Example: 'The sum of 123 and 456 is'

Llama 3.1 8B (3-digit):
  [0] '<|eot_id|>'
  [1] '<|eot_id|>'
  [2] '<|eot_id|>'
  [3] '<|eot_id|>'
  [4] '<|eot_id|>'
  [5] '<|eot_id|>'
  [6] '<|eot_id|>'
  [7] '<|eot_id|>'
  [8] '<|eot_id|>'
  [9] '<|eot_id|>'
  [10] '<|eot_id|>'
  [11] '<|eot_id|>'
  [12] '<|eot_id|>'
  [13] '<|eot_id|>'
  [14] '<|eot_id|>'
  [15] '<|eot_id|>'
  [16] '<|eot_id|>'
  [17] '<|eot_id|>'
  [18] '<|eot_id|>'
  [19] '<|eot_id|>'
  [20] '<|eot_id|>'
  [21] '<|eot_id|>'
  [22] '<|eot_id|>'
  [23] '<|eot_id|>'
  [24] '<|eot_id|>'
  [25] '<|eot_id|>'
  [26] '<|eot_id|>'
  [27] '<|eot_id|>'
  [28] '<|eot_id|>'
  [29] '<|eot_id|>'
  [30] '<|eot_id|>'
  [31] '<|eot_id|>'
  [32] '<|eot_id|>'
  [33] '<|eot_id|>'
  [34] '<|eot_id|>'
  [35] '<|eot_id|>'
  [36] '<|eot_id|>'
  [37] '<|eot_id|>'
  [38] '<|eot_id|>'
  [39] '<|eot_id|>'
  [40] '<|eot_id|>'
  [41] '<|eot_id|>'
  [42] '<|eot_id|>'
  [43] '<|eot_id|>'
  [44] '<|eot_id|>'
  [45] '<|eot_id|>'
  [46] '<|eot_id|>'


## Part 9: 4-Digit Numbers (1234 + 5678)

Now let's push it further with 4-digit numbers.

In [10]:
# Create 4-digit example
config_4d = create_general_config(2, 4)
model_4d = create_basic_addition_model(config_4d)

input_sample_4d = {
    "digit_0_0": 1,
    "digit_0_1": 2,
    "digit_0_2": 3,
    "digit_0_3": 4,  # 1234
    "digit_1_0": 5,
    "digit_1_1": 6,
    "digit_1_2": 7,
    "digit_1_3": 8,  # 5678
    "num_addends": 2,
    "num_digits": 4,
    "template": config_4d.templates[0],
}

output_4d = model_4d.new_trace(input_sample_4d)
prompt_4d = output_4d['raw_input']
input_sample_4d['raw_input'] = prompt_4d  # Add raw_input to sample for later use

print(f"4-Digit Example: '{prompt_4d}'\n")

# Tokenize with both models using pipeline.load()
llama_token_dict_4d = llama_pipeline.load(
    {"raw_input": prompt_4d}, add_special_tokens=False
)
llama_tokens_4d = llama_token_dict_4d["input_ids"][0].tolist()

gemma_token_dict_4d = gemma_pipeline.load(
    {"raw_input": prompt_4d}, add_special_tokens=False
)
gemma_tokens_4d = gemma_token_dict_4d["input_ids"][0].tolist()

print("Llama 3.1 8B (4-digit):")
for i, tok_id in enumerate(llama_tokens_4d):
    tok = llama_pipeline.tokenizer.decode([tok_id])
    marker = (
        " <-- NUMBER TOKEN"
        if any(d in tok for d in ["12", "34", "56", "78", "1234", "5678"])
        else ""
    )
    print(f"  [{i}] {repr(tok)}{marker}")

print("\nGemma 2 9B (4-digit):")
for i, tok_id in enumerate(gemma_tokens_4d):
    tok = gemma_pipeline.tokenizer.decode([tok_id])
    marker = (
        " <-- DIGIT" if tok.strip() in ["1", "2", "3", "4", "5", "6", "7", "8"] else ""
    )
    print(f"  [{i}] {repr(tok)}{marker}")

print("\nObservation:")
print("  Llama: NOW splits into 2 tokens per number (e.g., '12' and '34' for '1234')")
print("  Gemma: Still uses 4 tokens per number (digit-by-digit as always)")
print("\nLlama's vocabulary has limits - very large numbers get split!")

4-Digit Example: 'The sum of 1234 and 5678 is'

Llama 3.1 8B (4-digit):
  [0] '<|eot_id|>'
  [1] '<|eot_id|>'
  [2] '<|eot_id|>'
  [3] '<|eot_id|>'
  [4] '<|eot_id|>'
  [5] '<|eot_id|>'
  [6] '<|eot_id|>'
  [7] '<|eot_id|>'
  [8] '<|eot_id|>'
  [9] '<|eot_id|>'
  [10] '<|eot_id|>'
  [11] '<|eot_id|>'
  [12] '<|eot_id|>'
  [13] '<|eot_id|>'
  [14] '<|eot_id|>'
  [15] '<|eot_id|>'
  [16] '<|eot_id|>'
  [17] '<|eot_id|>'
  [18] '<|eot_id|>'
  [19] '<|eot_id|>'
  [20] '<|eot_id|>'
  [21] '<|eot_id|>'
  [22] '<|eot_id|>'
  [23] '<|eot_id|>'
  [24] '<|eot_id|>'
  [25] '<|eot_id|>'
  [26] '<|eot_id|>'
  [27] '<|eot_id|>'
  [28] '<|eot_id|>'
  [29] '<|eot_id|>'
  [30] '<|eot_id|>'
  [31] '<|eot_id|>'
  [32] '<|eot_id|>'
  [33] '<|eot_id|>'
  [34] '<|eot_id|>'
  [35] '<|eot_id|>'
  [36] '<|eot_id|>'
  [37] '<|eot_id|>'
  [38] '<|eot_id|>'
  [39] '<|eot_id|>'
  [40] '<|eot_id|>'
  [41] '<|eot_id|>'
  [42] '<|eot_id|>'
  [43] '<|eot_id|>'
  [44] '<|eot_id|>'
  [45] '<|eot_id|>'
  [46] '<|eot_id|>

## Part 10: The Configuration Lookup Table

The `tokenization_config.py` module provides pre-computed information about how many tokens each model uses.

In [11]:
print("TOKENIZATION CONFIGURATION LOOKUP")
print("=" * 70)

print("\nLlama configuration:")
print(TOKENIZATION_CONFIG["llama"])
print("  Interpretation: 2-digit → 1 token, 3-digit → 1 token, 4-digit → 2 tokens")

print("\nGemma configuration:")
print(TOKENIZATION_CONFIG["gemma"])
print("  Interpretation: 2-digit → 2 tokens, 3-digit → 3 tokens, 4-digit → 4 tokens")

print("\n" + "=" * 70)
print("Using the helper function:")
print("=" * 70)

for num_digits in [2, 3, 4]:
    llama_count = get_tokens_per_number(
        "meta-llama/Meta-Llama-3.1-8B-Instruct", num_digits
    )
    gemma_count = get_tokens_per_number("google/gemma-2-9b", num_digits)
    print(f"{num_digits}-digit numbers:")
    print(f"  Llama: {llama_count} token(s) per number")
    print(f"  Gemma: {gemma_count} token(s) per number")
    print()

TOKENIZATION CONFIGURATION LOOKUP

Llama configuration:
{2: 1, 3: 1, 4: 2}
  Interpretation: 2-digit → 1 token, 3-digit → 1 token, 4-digit → 2 tokens

Gemma configuration:
{2: 2, 3: 3, 4: 4}
  Interpretation: 2-digit → 2 tokens, 3-digit → 3 tokens, 4-digit → 4 tokens

Using the helper function:
2-digit numbers:
  Llama: 1 token(s) per number
  Gemma: 2 token(s) per number

3-digit numbers:
  Llama: 1 token(s) per number
  Gemma: 3 token(s) per number

4-digit numbers:
  Llama: 2 token(s) per number
  Gemma: 4 token(s) per number



## Part 11: Creating Token Positions for Experiments

The `create_token_positions` function sets up all the token position objects we need for interventions.

In [12]:
# Create token positions for 2-number, 2-digit addition
llama_positions = create_token_positions(llama_pipeline, num_addends=2, num_digits=2)
gemma_positions = create_token_positions(gemma_pipeline, num_addends=2, num_digits=2)

print("Token Positions for 2-Number, 2-Digit Addition")
print("=" * 70)

print("\nAvailable token positions:")
print(list(llama_positions.keys()))

print("\nThese include:")
print("  - digit_0_0, digit_0_1: First number digits")
print("  - digit_1_0, digit_1_1: Second number digits")
print("  - delimiter_and: The ' and ' between numbers")
print("  - delimiter_is: The ' is' before the answer")
print("  - last_token: The last token (where model generates from)")

print("\n" + "=" * 70)
print("These TokenPosition objects are used in intervention experiments!")
print("=" * 70)

Token Positions for 2-Number, 2-Digit Addition

Available token positions:
['digit_0_0', 'digit_0_1', 'digit_1_0', 'digit_1_1', 'delimiter_and', 'delimiter_is', 'last_token']

These include:
  - digit_0_0, digit_0_1: First number digits
  - digit_1_0, digit_1_1: Second number digits
  - delimiter_and: The ' and ' between numbers
  - delimiter_is: The ' is' before the answer
  - last_token: The last token (where model generates from)

These TokenPosition objects are used in intervention experiments!


## Part 12: Testing with Multiple Random Examples

Let's verify the token position functions work correctly across different examples.

In [13]:
print("Testing token position functions across random examples")
print("=" * 70)

for i in range(3):
    # Generate random 2-digit addition problem
    sample = sample_valid_addition_input(config, num_addends=2, num_digits=2)
    output = model.new_trace(sample)
    sample['raw_input'] = output['raw_input']  # Add raw_input to sample
    
    print(f"\nExample {i+1}: {output['raw_input']}")
    
    # Get digit positions
    llama_d00 = get_digit_token_position(sample, llama_pipeline, 0, 0)
    gemma_d00 = get_digit_token_position(sample, gemma_pipeline, 0, 0)

    # Tokenize using pipeline.load() for consistency
    llama_token_dict_test = llama_pipeline.load(
        {"raw_input": output["raw_input"]}, add_special_tokens=False
    )
    llama_toks = llama_token_dict_test["input_ids"][0].tolist()

    gemma_token_dict_test = gemma_pipeline.load(
        {"raw_input": output["raw_input"]}, add_special_tokens=False
    )
    gemma_toks = gemma_token_dict_test["input_ids"][0].tolist()

    # Show what we found
    llama_found = llama_pipeline.tokenizer.decode([llama_toks[llama_d00[0]]])
    gemma_found = gemma_pipeline.tokenizer.decode([gemma_toks[gemma_d00[0]]])

    print("  Looking for digit_0_0 (tens place of first number):")
    print(f"    Llama found: {repr(llama_found)} at position {llama_d00}")
    print(f"    Gemma found: {repr(gemma_found)} at position {gemma_d00}")

print("\n✓ Token position functions work correctly across all examples!")

Testing token position functions across random examples

Example 1: The sum of 42 and 21 is
  Looking for digit_0_0 (tens place of first number):
    Llama found: '42' at position [251]
    Gemma found: '4' at position [249, 250]

Example 2: The sum of 91 and 90 is
  Looking for digit_0_0 (tens place of first number):
    Llama found: '91' at position [251]
    Gemma found: '9' at position [249, 250]

Example 3: The sum of 68 and 31 is
  Looking for digit_0_0 (tens place of first number):
    Llama found: '68' at position [251]
    Gemma found: '6' at position [249, 250]

✓ Token position functions work correctly across all examples!


## Part 13: Visual Token Highlighting with TokenPosition Objects

Now let's use the `highlight_selected_token` method to visualize which tokens are selected for each digit position across different number sizes.

In [14]:
print("=" * 80)
print("GEMMA 2 9B: Highlighted Token Selections")
print("=" * 80)

# 2-digit example
print("\n### 2-Digit Addition: 23 + 45")
print(f"Full prompt: '{prompt}'")
print()

gemma_positions_2d = create_token_positions(gemma_pipeline, num_addends=2, num_digits=2)

for digit_name in ["digit_0_0", "digit_0_1", "digit_1_0", "digit_1_1"]:
    token_pos = gemma_positions_2d[digit_name]
    highlighted = token_pos.highlight_selected_token(input_sample)
    print(f"{digit_name}: {highlighted}")

print(
    f"\ndelimiter_and: {gemma_positions_2d['delimiter_and'].highlight_selected_token(input_sample)}"
)
print(
    f"delimiter_is: {gemma_positions_2d['delimiter_is'].highlight_selected_token(input_sample)}"
)
print(
    f"last_token: {gemma_positions_2d['last_token'].highlight_selected_token(input_sample)}"
)

# 3-digit example
print("\n\n### 3-Digit Addition: 123 + 456")
print(f"Full prompt: '{prompt_3d}'")
print()

gemma_positions_3d = create_token_positions(gemma_pipeline, num_addends=2, num_digits=3)

for digit_name in [
    "digit_0_0",
    "digit_0_1",
    "digit_0_2",
    "digit_1_0",
    "digit_1_1",
    "digit_1_2",
]:
    token_pos = gemma_positions_3d[digit_name]
    highlighted = token_pos.highlight_selected_token(input_sample_3d)
    print(f"{digit_name}: {highlighted}")

# 4-digit example
print("\n\n### 4-Digit Addition: 1234 + 5678")
print(f"Full prompt: '{prompt_4d}'")
print()

gemma_positions_4d = create_token_positions(gemma_pipeline, num_addends=2, num_digits=4)

for digit_name in [
    "digit_0_0",
    "digit_0_1",
    "digit_0_2",
    "digit_0_3",
    "digit_1_0",
    "digit_1_1",
    "digit_1_2",
    "digit_1_3",
]:
    token_pos = gemma_positions_4d[digit_name]
    highlighted = token_pos.highlight_selected_token(input_sample_4d)
    print(f"{digit_name}: {highlighted}")

print("\n" + "=" * 80)
print("Notice how Gemma selects individual digit tokens!")
print("=" * 80)

GEMMA 2 9B: Highlighted Token Selections

### 2-Digit Addition: 23 + 45
Full prompt: 'The sum of 23 and 45 is'

digit_0_0: <bos>The sum of **2****3** and 45 is
digit_0_1: <bos>The sum of **2****3** and 45 is
digit_1_0: <bos>The sum of 23 and **4****5** is
digit_1_1: <bos>The sum of 23 and **4****5** is

delimiter_and: <bos>The sum of 23** and**** **45 is
delimiter_is: <bos>The sum of 23 and 45** is**
last_token: <bos>The sum of 23 and 45** is**


### 3-Digit Addition: 123 + 456
Full prompt: 'The sum of 123 and 456 is'

digit_0_0: <bos>The sum of **1****2****3** and 456 is
digit_0_1: <bos>The sum of **1****2****3** and 456 is
digit_0_2: <bos>The sum of **1****2****3** and 456 is
digit_1_0: <bos>The sum of 123 and **4****5****6** is
digit_1_1: <bos>The sum of 123 and **4****5****6** is
digit_1_2: <bos>The sum of 123 and **4****5****6** is


### 4-Digit Addition: 1234 + 5678
Full prompt: 'The sum of 1234 and 5678 is'

digit_0_0: <bos>The sum of **1****2****3****4** and 5678 is
digit_0_1: 

In [15]:
print("=" * 80)
print("LLAMA 3.1 8B: Highlighted Token Selections")
print("=" * 80)

# 2-digit example
print("\n### 2-Digit Addition: 23 + 45")
print(f"Full prompt: '{prompt}'")
print()

llama_positions_2d = create_token_positions(llama_pipeline, num_addends=2, num_digits=2)

for digit_name in ["digit_0_0", "digit_0_1", "digit_1_0", "digit_1_1"]:
    token_pos = llama_positions_2d[digit_name]
    highlighted = token_pos.highlight_selected_token(input_sample)
    print(f"{digit_name}: {highlighted}")

print(
    f"\ndelimiter_and: {llama_positions_2d['delimiter_and'].highlight_selected_token(input_sample)}"
)
print(
    f"delimiter_is: {llama_positions_2d['delimiter_is'].highlight_selected_token(input_sample)}"
)
print(
    f"last_token: {llama_positions_2d['last_token'].highlight_selected_token(input_sample)}"
)

# 3-digit example
print("\n\n### 3-Digit Addition: 123 + 456")
print(f"Full prompt: '{prompt_3d}'")
print()

llama_positions_3d = create_token_positions(llama_pipeline, num_addends=2, num_digits=3)

for digit_name in [
    "digit_0_0",
    "digit_0_1",
    "digit_0_2",
    "digit_1_0",
    "digit_1_1",
    "digit_1_2",
]:
    token_pos = llama_positions_3d[digit_name]
    highlighted = token_pos.highlight_selected_token(input_sample_3d)
    print(f"{digit_name}: {highlighted}")

# 4-digit example
print("\n\n### 4-Digit Addition: 1234 + 5678")
print(f"Full prompt: '{prompt_4d}'")
print()

llama_positions_4d = create_token_positions(llama_pipeline, num_addends=2, num_digits=4)

for digit_name in [
    "digit_0_0",
    "digit_0_1",
    "digit_0_2",
    "digit_0_3",
    "digit_1_0",
    "digit_1_1",
    "digit_1_2",
    "digit_1_3",
]:
    token_pos = llama_positions_4d[digit_name]
    highlighted = token_pos.highlight_selected_token(input_sample_4d)
    print(f"{digit_name}: {highlighted}")

print("\n" + "=" * 80)
print("Notice how Llama selects the SAME whole-number token for multiple digits!")
print(
    "For 2-digit and 3-digit numbers: all digit positions select the entire number token"
)
print("For 4-digit numbers: digits are split into pairs (e.g., '12' and '34')")
print("=" * 80)

LLAMA 3.1 8B: Highlighted Token Selections

### 2-Digit Addition: 23 + 45
Full prompt: 'The sum of 23 and 45 is'

digit_0_0: <|begin_of_text|>The sum of **23** and 45 is
digit_0_1: <|begin_of_text|>The sum of **23** and 45 is
digit_1_0: <|begin_of_text|>The sum of 23 and **45** is
digit_1_1: <|begin_of_text|>The sum of 23 and **45** is

delimiter_and: <|begin_of_text|>The sum of 23** and**** **45 is
delimiter_is: <|begin_of_text|>The sum of 23 and 45** is**
last_token: <|begin_of_text|>The sum of 23 and 45** is**


### 3-Digit Addition: 123 + 456
Full prompt: 'The sum of 123 and 456 is'

digit_0_0: <|begin_of_text|>The sum of **123** and 456 is
digit_0_1: <|begin_of_text|>The sum of **123** and 456 is
digit_0_2: <|begin_of_text|>The sum of **123** and 456 is
digit_1_0: <|begin_of_text|>The sum of 123 and **456** is
digit_1_1: <|begin_of_text|>The sum of 123 and **456** is
digit_1_2: <|begin_of_text|>The sum of 123 and **456** is


### 4-Digit Addition: 1234 + 5678
Full prompt: 'The sum

## Summary: Key Takeaways

### Tokenization Patterns

| Model | 2-digit | 3-digit | 4-digit | Pattern |
|-------|---------|---------|---------|----------|
| **Llama 3.1 8B** | 1 token | 1 token | 2 tokens | Whole numbers (with limits) |
| **Gemma 2 9B** | 2 tokens | 3 tokens | 4 tokens | Digit-by-digit |

### Implications for Interventions

1. **Llama 3.1 8B**:
   - Uses compact whole-number tokens
   - Interventions affect entire numbers, not individual digits
   - Tests how model handles number-level representations
   - Cannot test fine-grained digit operations

2. **Gemma 2 9B**:
   - Uses digit-by-digit tokenization
   - Interventions can target individual digits precisely
   - Tests how model handles digit-level operations
   - Better for testing carry propagation hypotheses

### Critical Insight

**Both tokenization schemes are valid, but they test different aspects of the models!**

- When we intervene on "digit_0_0" with Llama, we're really testing whole-number processing
- When we intervene on "digit_0_0" with Gemma, we're truly testing individual digit processing
- The `tokenization_config.py` ensures experiments account for these differences

### Next Steps

In the next notebook, we'll use these token positions to:
1. Generate counterfactual datasets
2. Test different computational hypotheses (basic vs intermediate models)
3. Run actual neural interventions to see which variables the models represent