# 1.4b: Pairwise Distances

This notebook computes pairwise L2 distances between tokens in the overdensity to reveal internal structure.

## The Question

We identified ~20,000 tokens in a spatial bounding box (1.4a), but this includes both:
- **The tight cluster** we're interested in (the "spike")
- **Normal tokens** that happen to pass through that region

To separate them, we need to look at **density**—not just where tokens are, but how close they are to each other.

## Method

We'll:
1. Extract vectors for all tokens in the bounded region
2. Compute pairwise L2 distances between them
3. Save the distance matrix for further analysis

This distance matrix will reveal tokens that are extremely close together (or identical), indicating a truly dense cluster rather than normal token separation.

## Parameters

In [1]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

# PCA basis (must match 1.4a)
NORTH_PC = 2
MERIDIAN_PC = 1
EQUINOX_PC = 3

# Bounding box (must match 1.4a)
LAT_MIN = -15
LAT_MAX = 5
LON_MIN = -10
LON_MAX = 20
R_MIN = 0.2
R_MAX = 0.5

# Distance computation
CHUNK_SIZE = 40  # Process this many rows at a time

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file
from pathlib import Path
from tqdm.auto import tqdm

## Detect Device

In [3]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load W and Compute Spherical Coordinates

Same as 1.4a—we need to identify the same tokens.

In [4]:
# Load W
tensor_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W_bf16 = load_file(tensor_path)["W"]
W = W_bf16.to(torch.float32)
N, d = W.shape

print(f"Loaded W: {W.shape}")

Loaded W: torch.Size([151936, 2560])


In [5]:
# PCA
print("Computing PCA...")
W_centered = W - W.mean(dim=0)
cov = (W_centered.T @ W_centered) / N
eigenvalues, eigenvectors = torch.linalg.eigh(cov)
idx = torch.argsort(eigenvalues, descending=True)
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]
print("✓ PCA computed")

Computing PCA...
✓ PCA computed


In [6]:
# Define basis
def get_pc_vector(pcs, index):
    pc_num = abs(index) - 1
    vector = pcs[:, pc_num].clone()
    if index < 0:
        vector = -vector
    return vector

north = get_pc_vector(eigenvectors, NORTH_PC)
meridian = get_pc_vector(eigenvectors, MERIDIAN_PC)
equinox = get_pc_vector(eigenvectors, EQUINOX_PC)

In [7]:
# Spherical coordinates
print("Computing spherical coordinates...")
x = W @ meridian
y = W @ equinox
z = W @ north

r = torch.sqrt(x**2 + y**2 + z**2)
lat_rad = torch.asin(torch.clamp(z / r, -1, 1))
lat_deg = torch.rad2deg(lat_rad)
lon_rad = torch.atan2(y, x)
lon_deg = torch.rad2deg(lon_rad)
print("✓ Spherical coordinates computed")

Computing spherical coordinates...
✓ Spherical coordinates computed


In [8]:
# Filter by bounding box
mask = (
    (lat_deg >= LAT_MIN) & (lat_deg <= LAT_MAX) &
    (lon_deg >= LON_MIN) & (lon_deg <= LON_MAX) &
    (r >= R_MIN) & (r <= R_MAX)
)

spike_token_ids = torch.where(mask)[0]

print(f"\n✓ Found {len(spike_token_ids):,} tokens in bounding box")


✓ Found 20,373 tokens in bounding box


## Extract Spike Vectors

In [9]:
# Extract vectors for spike tokens and move to device
spike_vecs = W[spike_token_ids].to(device)

print(f"Spike vectors: {spike_vecs.shape}")
print(f"  {len(spike_token_ids):,} tokens")
print(f"  {d:,} dimensions")
print(f"  Device: {spike_vecs.device}")

Spike vectors: torch.Size([20373, 2560])
  20,373 tokens
  2,560 dimensions
  Device: mps:0


## Compute Pairwise Distances

We'll compute exact L2 distances using chunked processing to manage memory. The distance matrix will be NxN where N is the number of spike tokens (~20,000).

This computation is expensive but necessary—we don't yet know that there are duplicate vectors, so we need to examine all pairwise distances.

In [10]:
print(f"Computing exact pairwise L2 distances...\n")
print(f"Using chunked algorithm with chunk_size={CHUNK_SIZE}\n")

# Exact distance computation using chunked processing
N_spike = len(spike_vecs)
dists = torch.zeros((N_spike, N_spike), dtype=torch.float32, device='cpu')

with torch.no_grad():
    for i in tqdm(range(0, N_spike, CHUNK_SIZE), desc="Computing distances"):
        end_i = min(i + CHUNK_SIZE, N_spike)
        chunk = spike_vecs[i:end_i]
        
        # Compute differences: (chunk_size, N_spike, d)
        diffs = chunk.unsqueeze(1) - spike_vecs.unsqueeze(0)
        
        # Compute L2 norm along dimension 2, move to CPU
        dists[i:end_i] = torch.linalg.vector_norm(diffs, ord=2, dim=2).cpu()

print(f"\n✓ Distance matrix computed")
print(f"  Shape: {dists.shape}")
print(f"  Memory: {dists.element_size() * dists.nelement() / 1024**3:.2f} GB")

Computing exact pairwise L2 distances...

Using chunked algorithm with chunk_size=40



Computing distances:   0%|          | 0/510 [00:00<?, ?it/s]


✓ Distance matrix computed
  Shape: torch.Size([20373, 20373])
  Memory: 1.55 GB


## Distance Statistics

In [11]:
print(f"Distance statistics (excluding diagonal):\n")

# Get upper triangle (excluding diagonal)
upper_tri_mask = torch.triu(torch.ones_like(dists, dtype=torch.bool), diagonal=1)
dists_upper = dists[upper_tri_mask]

print(f"  Number of pairs: {len(dists_upper):,}")
print(f"  Min: {dists_upper.min():.10f}")
print(f"  Median: {dists_upper.median():.10f}")
print(f"  Mean: {dists_upper.mean():.10f}")
print(f"  Max: {dists_upper.max():.10f}")
print()

# Count exact zeros
n_zeros = (dists_upper == 0).sum().item()
print(f"Exact zeros (off-diagonal): {n_zeros:,}")
if n_zeros > 0:
    print(f"  ⚠️  Found {n_zeros:,} token pairs with distance = 0")
    print(f"     Multiple tokens occupy the same point in space!")
else:
    print(f"  ✓ No exact duplicates found")

Distance statistics (excluding diagonal):

  Number of pairs: 207,519,378
  Min: 0.0000000000
  Median: 1.2463527918
  Mean: 1.1983120441
  Max: 2.0540900230

Exact zeros (off-diagonal): 651,034
  ⚠️  Found 651,034 token pairs with distance = 0
     Multiple tokens occupy the same point in space!


## Save Distance Matrix

In [12]:
# Save distance matrix and metadata
output_dir = Path(f"../tensors/{MODEL_NAME}")
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "1.4b_overdensity_distances.safetensors"

save_file({
    'distances': dists,
    'spike_token_ids': spike_token_ids.cpu(),
}, output_path)

print(f"\n✓ Saved to {output_path}")
print(f"  Size: {output_path.stat().st_size / 1024**2:.1f} MB")


✓ Saved to ../tensors/Qwen3-4B-Instruct-2507/1.4b_overdensity_distances.safetensors
  Size: 1583.5 MB


## Summary

We've computed the full pairwise distance matrix for all ~20,000 tokens in the overdensity. This matrix reveals:

- **Minimum distances:** How close are the nearest neighbors?
- **Exact zeros:** Are there duplicate vectors (multiple tokens at the same point)?
- **Distance distribution:** Is there a tight cluster separated from normal tokens?

The distance matrix is saved for further analysis—we can now visualize the distribution, identify clusters, and separate the true "spike" from tokens that just happen to pass through the region.