# 1.4a: Identify Spike Tokens

This notebook identifies which tokens belong to the low-norm overdensity ("the spike") by filtering based on spherical coordinates.

## The Question

We've established that there's a concentrated overdensity of tokens near the origin at specific coordinates. But **which tokens** are in this cluster?

Once we have the token IDs, we can:
- Decode them to see what text they represent
- Look for patterns (language scripts, special characters, rare tokens)
- Understand why these tokens are clustered together

## Method

We'll:
1. Compute spherical coordinates (same as 1.3a/1.3b)
2. Define a bounding box around the overdensity in (latitude, longitude, radius) space
3. Filter tokens that fall within this box
4. Extract their token IDs and decode them

## Parameters

In [1]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"
MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"  # For tokenizer

# PCA basis selection (match 1.3a/1.3b)
NORTH_PC = 2
MERIDIAN_PC = 1
EQUINOX_PC = 3

# Overdensity bounding box (adjust based on visual inspection)
LAT_MIN = -15    # degrees
LAT_MAX = 5      # degrees
LON_MIN = -10    # degrees
LON_MAX = 20     # degrees
R_MIN = 0.2      # distance units
R_MAX = 0.5      # distance units

# How many example tokens to decode
NUM_EXAMPLES = 50

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file
from transformers import AutoTokenizer
from pathlib import Path
from collections import Counter

## Load W

In [3]:
# Load W in bfloat16
tensor_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W_bf16 = load_file(tensor_path)["W"]

print(f"Loaded W from {tensor_path}")
print(f"  Shape: {W_bf16.shape}")
print(f"  Dtype: {W_bf16.dtype}")

# Convert to float32
W = W_bf16.to(torch.float32)

N, d = W.shape
print(f"\nToken space: {N:,} tokens in {d:,} dimensions")

Loaded W from ../tensors/Qwen3-4B-Instruct-2507/W.safetensors
  Shape: torch.Size([151936, 2560])
  Dtype: torch.bfloat16

Token space: 151,936 tokens in 2,560 dimensions


## Compute PCA

In [4]:
print("Computing PCA...")

W_centered = W - W.mean(dim=0)
cov = (W_centered.T @ W_centered) / N
eigenvalues, eigenvectors = torch.linalg.eigh(cov)

# Sort descending
idx = torch.argsort(eigenvalues, descending=True)
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]

print("✓ PCA computed")

Computing PCA...
✓ PCA computed


## Define Spherical Basis

In [5]:
def get_pc_vector(pcs, index):
    """Get PC vector by index (1-indexed), with sign flip for negative indices."""
    pc_num = abs(index) - 1
    vector = pcs[:, pc_num].clone()
    if index < 0:
        vector = -vector
    return vector

north = get_pc_vector(eigenvectors, NORTH_PC)
meridian = get_pc_vector(eigenvectors, MERIDIAN_PC)
equinox = get_pc_vector(eigenvectors, EQUINOX_PC)

print("Spherical basis defined:")
print(f"  North: PC{NORTH_PC}")
print(f"  Meridian: PC{MERIDIAN_PC}")
print(f"  Equinox: PC{EQUINOX_PC}")

Spherical basis defined:
  North: PC2
  Meridian: PC1
  Equinox: PC3


## Project to Spherical Coordinates

In [6]:
print("Projecting to spherical coordinates...\n")

# Project onto basis
x = W @ meridian
y = W @ equinox
z = W @ north

# Spherical coordinates
r = torch.sqrt(x**2 + y**2 + z**2)
lat_rad = torch.asin(torch.clamp(z / r, -1, 1))
lat_deg = torch.rad2deg(lat_rad)
lon_rad = torch.atan2(y, x)
lon_deg = torch.rad2deg(lon_rad)

print("✓ Spherical coordinates computed")

Projecting to spherical coordinates...

✓ Spherical coordinates computed


## Filter by Bounding Box

In [7]:
print(f"Filtering tokens in bounding box:")
print(f"  Latitude: [{LAT_MIN}°, {LAT_MAX}°]")
print(f"  Longitude: [{LON_MIN}°, {LON_MAX}°]")
print(f"  Radius: [{R_MIN}, {R_MAX}]")
print()

# Create mask
mask = (
    (lat_deg >= LAT_MIN) & (lat_deg <= LAT_MAX) &
    (lon_deg >= LON_MIN) & (lon_deg <= LON_MAX) &
    (r >= R_MIN) & (r <= R_MAX)
)

# Extract token IDs
spike_token_ids = torch.where(mask)[0]

print(f"✓ Found {len(spike_token_ids):,} tokens in overdensity")
print(f"  ({len(spike_token_ids)/N*100:.2f}% of vocabulary)")
print()
print(f"First 20 token IDs: {spike_token_ids[:20].tolist()}")

Filtering tokens in bounding box:
  Latitude: [-15°, 5°]
  Longitude: [-10°, 20°]
  Radius: [0.2, 0.5]

✓ Found 20,373 tokens in overdensity
  (13.41% of vocabulary)

First 20 token IDs: [124, 125, 127, 128, 129, 138, 140, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187]


## Load Tokenizer and Decode

In [8]:
print(f"Loading tokenizer: {MODEL_ID}\n")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

print(f"✓ Tokenizer loaded (vocab size: {len(tokenizer):,})")

Loading tokenizer: Qwen/Qwen3-4B-Instruct-2507

✓ Tokenizer loaded (vocab size: 151,669)


## Examine Sample Tokens

In [9]:
print(f"\nDecoding first {NUM_EXAMPLES} spike tokens:\n")
print(f"{'Token ID':<10} {'Repr':<30} {'Description'}")
print("-" * 70)

for i, tid in enumerate(spike_token_ids[:NUM_EXAMPLES]):
    tid_int = tid.item()
    
    # Decode token
    try:
        token_str = tokenizer.decode([tid_int])
        token_repr = repr(token_str)[:28]  # Truncate for display
        
        # Classify
        if token_str.strip() == '':
            desc = "(whitespace/empty)"
        elif all(ord(c) < 128 for c in token_str):
            desc = "ASCII"
        elif any('\u0e00' <= c <= '\u0e7f' for c in token_str):
            desc = "Thai"
        elif any('\u4e00' <= c <= '\u9fff' for c in token_str):
            desc = "CJK"
        else:
            desc = "Other Unicode"
            
    except Exception as e:
        token_repr = f"(decode error: {e})"
        desc = "Error"
    
    print(f"{tid_int:<10} {token_repr:<30} {desc}")


Decoding first 50 spike tokens:

Token ID   Repr                           Description
----------------------------------------------------------------------
124        '�'                            Other Unicode
125        '�'                            Other Unicode
127        '�'                            Other Unicode
128        '�'                            Other Unicode
129        '�'                            Other Unicode
138        '�'                            Other Unicode
140        '�'                            Other Unicode
175        '�'                            Other Unicode
176        '�'                            Other Unicode
177        '�'                            Other Unicode
178        '�'                            Other Unicode
179        '�'                            Other Unicode
180        '�'                            Other Unicode
181        '�'                            Other Unicode
182        '�'                            Other Unicode
1

## Character Type Statistics

In [10]:
print("\nAnalyzing character types across all spike tokens...\n")

categories = []

for tid in spike_token_ids:
    tid_int = tid.item()
    try:
        token_str = tokenizer.decode([tid_int])
        
        if token_str.strip() == '':
            categories.append("Whitespace/Empty")
        elif all(ord(c) < 128 for c in token_str):
            categories.append("ASCII")
        elif any('\u0e00' <= c <= '\u0e7f' for c in token_str):
            categories.append("Thai")
        elif any('\u4e00' <= c <= '\u9fff' for c in token_str):
            categories.append("CJK")
        elif any('\u0600' <= c <= '\u06ff' for c in token_str):
            categories.append("Arabic")
        elif any('\u0400' <= c <= '\u04ff' for c in token_str):
            categories.append("Cyrillic")
        else:
            categories.append("Other Unicode")
    except:
        categories.append("Decode Error")

# Count and display
category_counts = Counter(categories)

print(f"{'Category':<20} {'Count':<10} {'Percentage'}")
print("-" * 50)
for category, count in category_counts.most_common():
    pct = count / len(spike_token_ids) * 100
    print(f"{category:<20} {count:<10} {pct:>6.2f}%")

print()
print(f"Total spike tokens: {len(spike_token_ids):,}")


Analyzing character types across all spike tokens...

Category             Count      Percentage
--------------------------------------------------
ASCII                9337        45.83%
Other Unicode        6848        33.61%
Thai                 1666         8.18%
CJK                  1600         7.85%
Whitespace/Empty     362          1.78%
Arabic               329          1.61%
Cyrillic             231          1.13%

Total spike tokens: 20,373


## Summary

We've identified the tokens in the low-norm overdensity by filtering based on spherical coordinates.

**Key observations to look for:**
- Are these tokens from a specific language or script?
- Are they rare/unused tokens?
- Do they have something in common semantically?

This bridges from geometric analysis ("there's a cluster") to semantic analysis ("the cluster contains Thai script tokens").