# 13.3e: What Are You, Stubborn Token?

**Identity crisis of the one live token that refuses to leave the dead core.**

From 13.3d, we know:
- **SGD step 10000:** Main class has 50 dead + 50 live tokens (100 total)
- **Adam step 10000:** Main class has 50 dead + **1 live token** (51 total)

**Questions:**
1. Which live token stays with the dead core in Adam?
2. Which 50 live tokens stay in SGD's core?
3. What's special about them? Low frequency? Rare characters?
4. How many times do they appear in the Gatsby corpus?

**Note:** Lil Gatsby uses pure ASCII (bytes 0-127), no fancy tokenizer. Each token ID = ASCII byte value.

## Parameters

In [5]:
# Data sources
SGD_DATA_PATH = "../tensors/Lil_Gatsby/13.1a_training_data.safetensors"
ADAM_DATA_PATH = "../tensors/Lil_Gatsby/13.1b_training_data.safetensors"

# Gatsby corpus (to count byte occurrences)
CORPUS_PATH = "../data/the_great_gatsby.txt"

# Final step to analyze
FINAL_STEP = 10000

# Equivalence threshold
EQUIVALENCE_THRESHOLD = 1.0

# ASCII vocabulary
VOCAB_SIZE = 128

RANDOM_SEED = 42

## Imports

In [6]:
import torch
import numpy as np
from safetensors.torch import safe_open
from collections import deque, Counter

np.random.seed(RANDOM_SEED)

print("✓ Imports complete")

✓ Imports complete


## Device Detection

In [7]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Training Data

In [8]:
def load_training_data(path):
    """Load embeddings and metadata from safetensors."""
    print(f"Loading: {path}")
    with safe_open(path, framework='pt', device='cpu') as f:
        embeddings_bf16 = f.get_tensor('embeddings')
        dead_token_ids = f.get_tensor('dead_token_ids')
        live_token_ids = f.get_tensor('live_token_ids')
        recorded_steps = f.get_tensor('recorded_steps')
    
    print(f"  Embeddings: {embeddings_bf16.shape}")
    print(f"  Dead: {len(dead_token_ids)}, Live: {len(live_token_ids)}")
    
    return embeddings_bf16, dead_token_ids, live_token_ids, recorded_steps

print("=" * 80)
print("SGD (13.1a)")
print("=" * 80)
sgd_emb, sgd_dead, sgd_live, sgd_steps = load_training_data(SGD_DATA_PATH)

print("\n" + "=" * 80)
print("Adam (13.1b)")
print("=" * 80)
adam_emb, adam_dead, adam_live, adam_steps = load_training_data(ADAM_DATA_PATH)

print("\n✓ Data loaded")

SGD (13.1a)
Loading: ../tensors/Lil_Gatsby/13.1a_training_data.safetensors
  Embeddings: torch.Size([10001, 128, 64])
  Dead: 50, Live: 78

Adam (13.1b)
Loading: ../tensors/Lil_Gatsby/13.1b_training_data.safetensors
  Embeddings: torch.Size([10001, 128, 64])
  Dead: 50, Live: 78

✓ Data loaded


## Find Main Class at Step 10000

In [9]:
def find_equivalence_classes_bfs(cos_sim_matrix, threshold=1.0):
    """
    Find connected components in the equivalence graph.
    
    Returns:
        classes: list of lists (token IDs in each class)
    """
    n = cos_sim_matrix.shape[0]
    adjacency = cos_sim_matrix >= threshold
    
    visited = torch.zeros(n, dtype=torch.bool)
    classes = []
    
    for start in range(n):
        if visited[start]:
            continue
        
        component = []
        queue = deque([start])
        visited[start] = True
        
        while queue:
            node = queue.popleft()
            component.append(node)
            
            neighbors = torch.where(adjacency[node])[0]
            for neighbor in neighbors:
                if not visited[neighbor]:
                    visited[neighbor] = True
                    queue.append(neighbor.item())
        
        classes.append(component)
    
    # Sort by size (largest first)
    classes.sort(key=len, reverse=True)
    
    return classes

def compute_pairwise_cosine_similarity(embeddings_step, device):
    """
    Compute pairwise cosine similarity for one step.
    """
    emb_f32 = embeddings_step.to(torch.float32).to(device)
    emb_norm = emb_f32 / emb_f32.norm(p=2, dim=1, keepdim=True)
    cos_sim = emb_norm @ emb_norm.T
    
    # Round to BF16 and back
    cos_sim_bf16 = cos_sim.to(torch.bfloat16).to(torch.float32).cpu()
    
    return cos_sim_bf16

# Get final step index
sgd_final_idx = (sgd_steps - FINAL_STEP).abs().argmin().item()
adam_final_idx = (adam_steps - FINAL_STEP).abs().argmin().item()

print(f"Analyzing step {sgd_steps[sgd_final_idx].item()} (SGD)")
print(f"Analyzing step {adam_steps[adam_final_idx].item()} (Adam)\n")

# Compute pairwise cosine similarity at final step
print("Computing pairwise cosine similarity...")
sgd_cos_sim_final = compute_pairwise_cosine_similarity(sgd_emb[sgd_final_idx], device)
adam_cos_sim_final = compute_pairwise_cosine_similarity(adam_emb[adam_final_idx], device)

# Find equivalence classes
print("Finding equivalence classes...")
sgd_classes = find_equivalence_classes_bfs(sgd_cos_sim_final, threshold=EQUIVALENCE_THRESHOLD)
adam_classes = find_equivalence_classes_bfs(adam_cos_sim_final, threshold=EQUIVALENCE_THRESHOLD)

print(f"\nSGD: {len(sgd_classes)} classes, main size = {len(sgd_classes[0])}")
print(f"Adam: {len(adam_classes)} classes, main size = {len(adam_classes[0])}")

print("\n✓ Classes identified")

Analyzing step 10000 (SGD)
Analyzing step 10000 (Adam)

Computing pairwise cosine similarity...
Finding equivalence classes...

SGD: 5 classes, main size = 100
Adam: 77 classes, main size = 51

✓ Classes identified


## Identify Live Tokens in Main Classes

In [10]:
# Convert to sets for fast lookup
sgd_dead_set = set(sgd_dead.tolist())
sgd_live_set = set(sgd_live.tolist())
adam_dead_set = set(adam_dead.tolist())
adam_live_set = set(adam_live.tolist())

# SGD main class
sgd_main = sgd_classes[0]
sgd_main_live = sorted([tid for tid in sgd_main if tid in sgd_live_set])
sgd_main_dead = sorted([tid for tid in sgd_main if tid in sgd_dead_set])

print("=" * 80)
print("SGD MAIN CLASS (Step 10000)")
print("=" * 80)
print(f"Total: {len(sgd_main)} tokens")
print(f"  Dead: {len(sgd_main_dead)}")
print(f"  Live: {len(sgd_main_live)}")
print(f"\nLive token IDs in main class (first 20): {sgd_main_live[:20]}")
if len(sgd_main_live) > 20:
    print(f"  ... ({len(sgd_main_live) - 20} more)")

# Adam main class
adam_main = adam_classes[0]
adam_main_live = sorted([tid for tid in adam_main if tid in adam_live_set])
adam_main_dead = sorted([tid for tid in adam_main if tid in adam_dead_set])

print("\n" + "=" * 80)
print("ADAM MAIN CLASS (Step 10000)")
print("=" * 80)
print(f"Total: {len(adam_main)} tokens")
print(f"  Dead: {len(adam_main_dead)}")
print(f"  Live: {len(adam_main_live)}")
print(f"\nLive token ID(s) in main class: {adam_main_live}")

SGD MAIN CLASS (Step 10000)
Total: 100 tokens
  Dead: 50
  Live: 50

Live token IDs in main class (first 20): [13, 33, 36, 40, 41, 42, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 63, 65]
  ... (30 more)

ADAM MAIN CLASS (Step 10000)
Total: 51 tokens
  Dead: 50
  Live: 1

Live token ID(s) in main class: [13]


## Decode ASCII Byte IDs to Characters

In [11]:
def decode_ascii_byte(byte_id):
    """
    Decode ASCII byte ID to human-readable string.
    
    Returns:
        (byte_id, char_repr, description)
    """
    if byte_id < 0 or byte_id >= 128:
        return (byte_id, "<invalid>", "Out of ASCII range")
    
    char = chr(byte_id)
    
    # Representation
    if byte_id < 32:
        # Control characters
        char_repr = f"\\x{byte_id:02x}"
        control_names = {
            0: "NUL", 9: "TAB", 10: "LF", 13: "CR",
        }
        desc = control_names.get(byte_id, "control")
    elif byte_id == 32:
        char_repr = "' '"
        desc = "space"
    elif byte_id == 127:
        char_repr = "\\x7f"
        desc = "DEL"
    else:
        # Printable ASCII
        char_repr = f"'{char}'"
        if char.isdigit():
            desc = "digit"
        elif char.isalpha():
            desc = "letter"
        else:
            desc = "punctuation"
    
    return (byte_id, char_repr, desc)

def print_ascii_tokens(token_ids, title):
    """
    Print decoded ASCII token table.
    """
    print("=" * 80)
    print(title)
    print("=" * 80)
    
    for tid in token_ids[:30]:  # Show first 30
        byte_id, char_repr, desc = decode_ascii_byte(tid)
        print(f"  Byte {byte_id:3d} (0x{byte_id:02x}): {char_repr:>6s}  [{desc}]")
    
    if len(token_ids) > 30:
        print(f"  ... ({len(token_ids) - 30} more)")
    print()

# SGD live tokens in main class
print_ascii_tokens(sgd_main_live, "SGD LIVE TOKENS IN MAIN CLASS (Decoded)")

# Adam live token in main class
print_ascii_tokens(adam_main_live, "ADAM LIVE TOKEN IN MAIN CLASS (Decoded)")

SGD LIVE TOKENS IN MAIN CLASS (Decoded)
  Byte  13 (0x0d):   \x0d  [CR]
  Byte  33 (0x21):    '!'  [punctuation]
  Byte  36 (0x24):    '$'  [punctuation]
  Byte  40 (0x28):    '('  [punctuation]
  Byte  41 (0x29):    ')'  [punctuation]
  Byte  42 (0x2a):    '*'  [punctuation]
  Byte  48 (0x30):    '0'  [digit]
  Byte  49 (0x31):    '1'  [digit]
  Byte  50 (0x32):    '2'  [digit]
  Byte  51 (0x33):    '3'  [digit]
  Byte  52 (0x34):    '4'  [digit]
  Byte  53 (0x35):    '5'  [digit]
  Byte  54 (0x36):    '6'  [digit]
  Byte  55 (0x37):    '7'  [digit]
  Byte  56 (0x38):    '8'  [digit]
  Byte  57 (0x39):    '9'  [digit]
  Byte  58 (0x3a):    ':'  [punctuation]
  Byte  59 (0x3b):    ';'  [punctuation]
  Byte  63 (0x3f):    '?'  [punctuation]
  Byte  65 (0x41):    'A'  [letter]
  Byte  66 (0x42):    'B'  [letter]
  Byte  67 (0x43):    'C'  [letter]
  Byte  68 (0x44):    'D'  [letter]
  Byte  69 (0x45):    'E'  [letter]
  Byte  70 (0x46):    'F'  [letter]
  Byte  71 (0x47):    'G'  [letter

## Count Byte Occurrences in Gatsby Corpus

In [12]:
print(f"Loading corpus: {CORPUS_PATH}")
with open(CORPUS_PATH, 'r', encoding='ascii') as f:
    corpus_text = f.read()

print(f"  Corpus length: {len(corpus_text)} characters")

# Convert to byte array
corpus_bytes = corpus_text.encode('ascii')
print(f"  Byte count: {len(corpus_bytes)}")

# Count occurrences
byte_counts = Counter(corpus_bytes)

print("\n✓ Corpus loaded and counted")

Loading corpus: ../data/the_great_gatsby.txt
  Corpus length: 265905 characters
  Byte count: 265905

✓ Corpus loaded and counted


## Frequency Analysis

In [13]:
def analyze_byte_frequencies(byte_ids, byte_counts, name):
    """
    Print frequency statistics for a list of byte IDs.
    """
    print("=" * 80)
    print(f"{name} - FREQUENCY ANALYSIS")
    print("=" * 80)
    
    freq_data = []
    for byte_id in byte_ids:
        count = byte_counts.get(byte_id, 0)
        _, char_repr, desc = decode_ascii_byte(byte_id)
        freq_data.append((byte_id, char_repr, desc, count))
    
    # Sort by frequency (ascending)
    freq_data.sort(key=lambda x: x[3])
    
    print(f"\nTotal bytes: {len(freq_data)}")
    print(f"Bytes with 0 occurrences: {sum(1 for _, _, _, c in freq_data if c == 0)}")
    print(f"Bytes with 1 occurrence: {sum(1 for _, _, _, c in freq_data if c == 1)}")
    print(f"Bytes with 2-10 occurrences: {sum(1 for _, _, _, c in freq_data if 2 <= c <= 10)}")
    print(f"Bytes with >10 occurrences: {sum(1 for _, _, _, c in freq_data if c > 10)}")
    
    print("\n" + "-" * 80)
    print("Rarest bytes (sorted by frequency):")
    print("-" * 80)
    for byte_id, char_repr, desc, count in freq_data[:20]:
        print(f"  Byte {byte_id:3d} ({char_repr:>6s}, {desc:>12s}): {count:6,} occurrences")
    
    if len(freq_data) > 20:
        print("\n" + "-" * 80)
        print("Most common bytes in this group:")
        print("-" * 80)
        for byte_id, char_repr, desc, count in freq_data[-10:]:
            print(f"  Byte {byte_id:3d} ({char_repr:>6s}, {desc:>12s}): {count:6,} occurrences")
    
    print()

# Analyze SGD live bytes in main class
analyze_byte_frequencies(sgd_main_live, byte_counts, "SGD Main Class Live Bytes")

# Analyze Adam's stubborn byte
analyze_byte_frequencies(adam_main_live, byte_counts, "Adam Main Class Live Byte")

SGD Main Class Live Bytes - FREQUENCY ANALYSIS

Total bytes: 50
Bytes with 0 occurrences: 1
Bytes with 1 occurrence: 1
Bytes with 2-10 occurrences: 16
Bytes with >10 occurrences: 32

--------------------------------------------------------------------------------
Rarest bytes (sorted by frequency):
--------------------------------------------------------------------------------
  Byte  13 (  \x0d,           CR):      0 occurrences
  Byte  90 (   'Z',       letter):      1 occurrences
  Byte  36 (   '$',  punctuation):      2 occurrences
  Byte  52 (   '4',        digit):      2 occurrences
  Byte  55 (   '7',        digit):      2 occurrences
  Byte  88 (   'X',       letter):      2 occurrences
  Byte  91 (   '[',  punctuation):      2 occurrences
  Byte  93 (   ']',  punctuation):      2 occurrences
  Byte  56 (   '8',        digit):      3 occurrences
  Byte  50 (   '2',        digit):      4 occurrences
  Byte  81 (   'Q',       letter):      4 occurrences
  Byte  54 (   '6',      

## The Stubborn Token: Special Analysis

In [14]:
if len(adam_main_live) == 1:
    stubborn_id = adam_main_live[0]
    stubborn_byte_id, stubborn_char, stubborn_desc = decode_ascii_byte(stubborn_id)
    stubborn_count = byte_counts.get(stubborn_id, 0)
    
    print("=" * 80)
    print("THE STUBBORN TOKEN")
    print("=" * 80)
    print(f"\nByte ID: {stubborn_id} (0x{stubborn_id:02x})")
    print(f"Character: {stubborn_char}")
    print(f"Type: {stubborn_desc}")
    print(f"Occurrences in Gatsby: {stubborn_count:,}")
    
    # Compare to dead tokens
    print(f"\n" + "-" * 80)
    print("Context: How does this compare to the dead core?")
    print("-" * 80)
    print(f"Dead tokens in main class: {len(adam_main_dead)}")
    print(f"Dead token occurrence range in Gatsby: 0 (by definition)")
    print(f"\nThe stubborn token has {stubborn_count:,} occurrence(s).")
    
    if stubborn_count == 0:
        print("  → This token is EFFECTIVELY DEAD (zero gradients)!")
        print("  → Must have been miscategorized as 'live' somehow.")
    elif stubborn_count == 1:
        print("  → Near-dead: appears exactly once in the entire corpus.")
        print("  → Gradient updates are negligible.")
    elif stubborn_count <= 10:
        print("  → Very rare: appears only a handful of times.")
        print("  → Gradient signal is weak but present.")
    else:
        print(f"  → Surprisingly common (top {100 * (1 - stubborn_count / len(corpus_bytes)):.1f}% of corpus bytes)")
        print("  → This is weird—strong gradient signal but still frozen!")
    
    # Show a few examples from corpus if count is small
    if 0 < stubborn_count <= 10:
        print(f"\n" + "-" * 80)
        print(f"Occurrences in corpus (showing all {stubborn_count}):")
        print("-" * 80)
        
        # Find positions
        positions = [i for i, b in enumerate(corpus_bytes) if b == stubborn_id]
        
        for i, pos in enumerate(positions, 1):
            # Extract context (±40 chars)
            start = max(0, pos - 40)
            end = min(len(corpus_text), pos + 40)
            context = corpus_text[start:end]
            
            # Highlight the character
            rel_pos = pos - start
            highlighted = context[:rel_pos] + f"[{context[rel_pos]}]" + context[rel_pos+1:]
            
            print(f"\n  Occurrence {i} at position {pos}:")
            print(f"    ...{highlighted}...")

else:
    print(f"\nExpected 1 stubborn token in Adam main class, found {len(adam_main_live)}")

THE STUBBORN TOKEN

Byte ID: 13 (0x0d)
Character: \x0d
Type: CR
Occurrences in Gatsby: 0

--------------------------------------------------------------------------------
Context: How does this compare to the dead core?
--------------------------------------------------------------------------------
Dead tokens in main class: 50
Dead token occurrence range in Gatsby: 0 (by definition)

The stubborn token has 0 occurrence(s).
  → This token is EFFECTIVELY DEAD (zero gradients)!
  → Must have been miscategorized as 'live' somehow.


## Done

In [15]:
print("✓ Stubborn token investigation complete")

✓ Stubborn token investigation complete
