# 08.4b: Pairwise Distance Matrices

**Compute pairwise distance matrices across all training snapshots**

With only 128 tokens, pairwise distances are computationally cheap. We compute distance matrices for all 5,001 snapshots and save them for reuse in multiple analysis notebooks.

## Distance Metrics

For each snapshot, we compute two distance metrics:

1. **Euclidean (L2)**: Standard geometric distance
   $$d_2(u, v) = \sqrt{\sum_i (u_i - v_i)^2}$$

2. **Chebyshev (L∞)**: Maximum absolute difference across dimensions
   $$d_\infty(u, v) = \max_i |u_i - v_i|$$

Chebyshev is particularly useful for detecting quantization neighbors—tokens that differ by less than the bfloat16 quantization threshold in every dimension are distinguishable in Chebyshev space.

## Output

Saves two tensors:
- `euclidean_distances`: (5001, 128, 128) float16
- `chebyshev_distances`: (5001, 128, 128) float16

Total storage: ~328 MB

## Parameters

In [1]:
# Input: consolidated embedding history tensor
EMBEDDING_DIR = "../data/embeddings_128vocab_qweninit"  # or embeddings_128vocab_qweninit
EMBEDDING_FILE = "embedding_evolution.safetensors"
EMBEDDING_KEY = "embedding_history"

# Output: distance matrices
OUTPUT_FILE = "pairwise_distances.safetensors"

# Training run parameters (for validation)
EXPECTED_STEPS = 5001
VOCAB_SIZE = 128
HIDDEN_DIM = 64

RANDOM_SEED = 42

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file
from pathlib import Path
from tqdm.auto import tqdm

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

## Load Embedding History

In [3]:
embedding_dir = Path(EMBEDDING_DIR)
embedding_path = embedding_dir / EMBEDDING_FILE

print(f"Loading embedding history from: {embedding_path}")

data = load_file(embedding_path)
embedding_history = data[EMBEDDING_KEY]
n_snapshots, vocab_size, hidden_dim = embedding_history.shape

print(f"\n✓ Embedding history loaded")
print(f"Shape: {embedding_history.shape}")
print(f"Snapshots: {n_snapshots:,}")
print(f"Vocabulary size: {vocab_size:,}")
print(f"Hidden dimension: {hidden_dim:,}")
print(f"Memory footprint: {embedding_history.element_size() * embedding_history.numel() / 1e6:.2f} MB")

# Validate dimensions
if n_snapshots != EXPECTED_STEPS:
    print(f"\n⚠ WARNING: Snapshot count mismatch! Expected {EXPECTED_STEPS:,}, got {n_snapshots:,}")
else:
    print(f"\n✓ Snapshot count matches expectation")

if (vocab_size, hidden_dim) != (VOCAB_SIZE, HIDDEN_DIM):
    raise ValueError(f"Unexpected dimensions: expected ({VOCAB_SIZE}, {HIDDEN_DIM}), got ({vocab_size}, {hidden_dim})")

Loading embedding history from: ../data/embeddings_128vocab_qweninit/embedding_evolution.safetensors

✓ Embedding history loaded
Shape: torch.Size([5001, 128, 64])
Snapshots: 5,001
Vocabulary size: 128
Hidden dimension: 64
Memory footprint: 81.94 MB

✓ Snapshot count matches expectation


## Allocate Distance Matrices

Pre-allocate tensors for both distance metrics. We use float16 to save space—distance precision doesn't need float32.

In [4]:
# Allocate distance matrices (float16 for storage efficiency)
euclidean_distances = torch.zeros((n_snapshots, vocab_size, vocab_size), dtype=torch.float32)
chebyshev_distances = torch.zeros((n_snapshots, vocab_size, vocab_size), dtype=torch.float32)

print(f"Allocated distance matrices:")
print(f"  Euclidean: {euclidean_distances.shape} ({euclidean_distances.element_size() * euclidean_distances.numel() / 1e6:.2f} MB)")
print(f"  Chebyshev: {chebyshev_distances.shape} ({chebyshev_distances.element_size() * chebyshev_distances.numel() / 1e6:.2f} MB)")
print(f"  Total: {(euclidean_distances.numel() + chebyshev_distances.numel()) * 2 / 1e6:.2f} MB")

Allocated distance matrices:
  Euclidean: torch.Size([5001, 128, 128]) (327.75 MB)
  Chebyshev: torch.Size([5001, 128, 128]) (327.75 MB)
  Total: 327.75 MB


## Compute Distance Matrices

For each snapshot, compute pairwise distances using PyTorch's efficient broadcasting.

In [5]:
print(f"\nComputing pairwise distances...\n")

for i in tqdm(range(n_snapshots), desc="Processing snapshots"):
    # Extract embedding matrix for this step
    gamma = embedding_history[i].float()  # Convert to float32 for computation
    
    # Compute pairwise differences using broadcasting
    # gamma: (vocab_size, hidden_dim)
    # gamma.unsqueeze(0): (1, vocab_size, hidden_dim)
    # gamma.unsqueeze(1): (vocab_size, 1, hidden_dim)
    # diff: (vocab_size, vocab_size, hidden_dim)
    diff = gamma.unsqueeze(0) - gamma.unsqueeze(1)
    
    # Euclidean distance: L2 norm across dimension axis
    # sqrt(sum(diff^2, dim=2))
    euclidean = torch.norm(diff, p=2, dim=2)
    euclidean_distances[i] = euclidean
    
    # Chebyshev distance: max absolute difference across dimensions
    # max(|diff|, dim=2)
    chebyshev = torch.abs(diff).max(dim=2)[0]
    chebyshev_distances[i] = chebyshev

print(f"\n✓ Computed distance matrices for {n_snapshots:,} snapshots")


Computing pairwise distances...



Processing snapshots:   0%|          | 0/5001 [00:00<?, ?it/s]


✓ Computed distance matrices for 5,001 snapshots


## Summary Statistics

In [6]:
print(f"\n{'='*80}")
print("DISTANCE MATRIX SUMMARY")
print(f"{'='*80}\n")

print(f"Snapshots: {n_snapshots:,}")
print(f"Tokens: {vocab_size:,}")
print(f"Distance pairs per snapshot: {vocab_size * (vocab_size - 1) // 2:,} (excluding diagonal)")

# Analyze initial distances (step 0)
print(f"\nInitial distances (step 0):")
# Mask out diagonal (self-distances = 0)
mask = ~torch.eye(vocab_size, dtype=torch.bool)
euclidean_init = euclidean_distances[0][mask].float()
chebyshev_init = chebyshev_distances[0][mask].float()

print(f"  Euclidean: min={euclidean_init.min().item():.6f}, max={euclidean_init.max().item():.6f}, mean={euclidean_init.mean().item():.6f}")
print(f"  Chebyshev: min={chebyshev_init.min().item():.6f}, max={chebyshev_init.max().item():.6f}, mean={chebyshev_init.mean().item():.6f}")

# Analyze final distances (step 5000)
print(f"\nFinal distances (step {n_snapshots - 1}):")
euclidean_final = euclidean_distances[-1][mask].float()
chebyshev_final = chebyshev_distances[-1][mask].float()

print(f"  Euclidean: min={euclidean_final.min().item():.6f}, max={euclidean_final.max().item():.6f}, mean={euclidean_final.mean().item():.6f}")
print(f"  Chebyshev: min={chebyshev_final.min().item():.6f}, max={chebyshev_final.max().item():.6f}, mean={chebyshev_final.mean().item():.6f}")

print(f"\n{'='*80}")


DISTANCE MATRIX SUMMARY

Snapshots: 5,001
Tokens: 128
Distance pairs per snapshot: 8,128 (excluding diagonal)

Initial distances (step 0):
  Euclidean: min=0.000000, max=0.000000, mean=0.000000
  Chebyshev: min=0.000000, max=0.000000, mean=0.000000

Final distances (step 5000):
  Euclidean: min=0.000000, max=1.593674, mean=0.672264
  Chebyshev: min=0.000000, max=0.597656, mean=0.204198



## Save Distance Matrices

In [7]:
# Build save dictionary
save_dict = {
    'euclidean_distances': euclidean_distances,
    'chebyshev_distances': chebyshev_distances,
}

# Save
output_path = embedding_dir / OUTPUT_FILE
output_path.parent.mkdir(parents=True, exist_ok=True)
save_file(save_dict, output_path)

print(f"✓ Saved distance matrices to: {output_path}")
print(f"\nFile size: {output_path.stat().st_size / 1e6:.2f} MB")
print(f"\nSaved tensors:")
for key, tensor in save_dict.items():
    print(f"  {key}: {tensor.shape} ({tensor.dtype})")

✓ Saved distance matrices to: ../data/embeddings_128vocab_qweninit/pairwise_distances.safetensors

File size: 655.49 MB

Saved tensors:
  euclidean_distances: torch.Size([5001, 128, 128]) (torch.float32)
  chebyshev_distances: torch.Size([5001, 128, 128]) (torch.float32)
