# Pairwise Distance Matrix Extraction (Causal Metric)

Compute pairwise causal distances for a random sample of tokens.

**Why this approach:**
- Configurable sample size N (tune for your time/memory budget)
- Full N×N distance matrix stored once, reused forever for UMAP experiments
- Upper-triangle compression saves 50% storage
- fp16 precision: 15× faster via tensor cores, adequate precision for visualization

**Method:**
1. Load unembedding matrix γ and metric tensor M
2. Randomly sample N tokens from vocabulary
3. Compute full N×N pairwise causal distances using fp16
4. Save upper triangle (symmetric matrix)

**Expected runtime (H100):**
- N=16,000: ~2-3 minutes
- N=50,000: ~15-20 minutes
- N=152,936 (full vocab): ~2-3 hours

**Output size:**
- N=16,000: ~256 MB (upper triangle)
- N=50,000: ~2.5 GB
- N=152,936: ~23 GB

## Configuration

In [1]:
# Sample size: tune this to fit your time/memory budget
N_TOKENS = 64000  # Start here, scale up to 50k or 152k

# Model configuration
MODEL_NAME = 'Qwen/Qwen3-4B-Instruct-2507'
DEVICE = 'cuda'  # 'cuda' for cloud GPU, 'mps' for Mac, 'cpu' for fallback

# Precision (use fp16 for speed via tensor cores)
DTYPE = 'float16'  # 'float16' or 'float32'

# Batching (tuned for H100 80GB VRAM)
BATCH_SIZE_I = 500   # Query batch size
BATCH_SIZE_J = 10000 # Target batch size
# Memory: 500 × 10,000 × 2560 × 2 bytes ≈ 25 GB (safe)

# Input paths
METRIC_TENSOR_PATH = '../data/vectors/causal_metric_tensor_qwen3_4b.pt'

# Output paths
OUTPUT_DISTANCES = f'../data/vectors/distances_causal_{N_TOKENS}.pt'

# Random seed
RANDOM_SEED = 42

print(f"Configuration:")
print(f"  Sample size: {N_TOKENS:,} tokens")
print(f"  Precision: {DTYPE}")
print(f"  Batch sizes: {BATCH_SIZE_I} × {BATCH_SIZE_J}")
print(f"  Expected memory: ~{BATCH_SIZE_I * BATCH_SIZE_J * 2560 * (2 if DTYPE == 'float16' else 4) / 1e9:.1f} GB peak")
print(f"  Expected output: ~{N_TOKENS * (N_TOKENS - 1) // 2 * (2 if DTYPE == 'float16' else 4) / 1e9:.2f} GB (upper triangle)")

Configuration:
  Sample size: 64,000 tokens
  Precision: float16
  Batch sizes: 500 × 10000
  Expected memory: ~25.6 GB peak
  Expected output: ~4.10 GB (upper triangle)


## Setup

In [2]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from datetime import datetime
from tqdm.auto import tqdm

device = torch.device(DEVICE)
dtype = torch.float16 if DTYPE == 'float16' else torch.float32

print(f"✓ Using device: {device}")
print(f"✓ Using dtype: {dtype}")

if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

✓ Using device: cuda
✓ Using dtype: torch.float16
  GPU: NVIDIA H200
  VRAM: 150.1 GB


## Load Unembedding Matrix (γ)

In [3]:
print(f"Loading model to extract unembedding matrix...")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=dtype,
    device_map=device,
)

# Extract unembedding matrix γ
gamma = model.lm_head.weight.data.clone()  # [vocab_size, hidden_dim]
vocab_size, hidden_dim = gamma.shape

print(f"✓ Extracted unembedding matrix")
print(f"  Vocabulary size: {vocab_size:,}")
print(f"  Hidden dim: {hidden_dim:,}")
print(f"  Memory: {gamma.element_size() * gamma.nelement() / 1e9:.2f} GB")

# Free model memory
del model
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print(f"✓ Freed model memory")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading model to extract unembedding matrix...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

✓ Extracted unembedding matrix
  Vocabulary size: 151,936
  Hidden dim: 2,560
  Memory: 0.78 GB
✓ Freed model memory


## Load Causal Metric Tensor (M)

In [4]:
print(f"Loading causal metric tensor from {METRIC_TENSOR_PATH}...")
metric_data = torch.load(METRIC_TENSOR_PATH, map_location=device, weights_only=False)

M = metric_data['M'].to(dtype=dtype)  # [hidden_dim, hidden_dim]

print(f"✓ Loaded metric tensor")
print(f"  Shape: {M.shape}")
print(f"  Dtype: {M.dtype}")
print(f"  Memory: {M.element_size() * M.nelement() / 1e6:.1f} MB")

Loading causal metric tensor from ../data/vectors/causal_metric_tensor_qwen3_4b.pt...
✓ Loaded metric tensor
  Shape: torch.Size([2560, 2560])
  Dtype: torch.float16
  Memory: 13.1 MB


## Sample N Tokens Randomly

In [5]:
print(f"\nSampling {N_TOKENS:,} random tokens...")
print("(Random sampling avoids vocabulary ordering bias)")

torch.manual_seed(RANDOM_SEED)
sample_indices = torch.randperm(vocab_size)[:N_TOKENS].to(device)
sample_embeddings = gamma[sample_indices]  # [N_TOKENS, hidden_dim]

print(f"✓ Sampled embeddings shape: {sample_embeddings.shape}")
print(f"  Memory: {sample_embeddings.element_size() * sample_embeddings.nelement() / 1e6:.1f} MB")


Sampling 64,000 random tokens...
(Random sampling avoids vocabulary ordering bias)
✓ Sampled embeddings shape: torch.Size([64000, 2560])
  Memory: 327.7 MB


## Compute Pairwise Distances

Uses batched computation optimized for tensor cores:
- Broadcasts to [batch_i, batch_j, hidden_dim] tensors
- Matmul with M: ~6.5 billion ops per batch (excellent for tensor cores)
- Memory-safe batching: ~25 GB peak on H100

In [6]:
def compute_pairwise_distances(embeddings, M, batch_i=500, batch_j=10000):
    """
    Compute full pairwise causal distance matrix.
    
    Distance formula: d(i,j) = sqrt((γ_i - γ_j)^T M (γ_i - γ_j))
    
    Args:
        embeddings: [N, hidden_dim] - Token embeddings
        M: [hidden_dim, hidden_dim] - Causal metric tensor
        batch_i: Query batch size
        batch_j: Target batch size
    
    Returns:
        distances: [N, N] - Pairwise causal distances
    """
    N = embeddings.shape[0]
    distances = torch.zeros(N, N, dtype=embeddings.dtype, device=embeddings.device)
    
    n_batches_i = int(np.ceil(N / batch_i))
    n_batches_j = int(np.ceil(N / batch_j))
    total_batches = n_batches_i * n_batches_j
    
    with tqdm(total=total_batches, desc="Computing distances") as pbar:
        for i in range(0, N, batch_i):
            i_end = min(i + batch_i, N)
            tokens_i = embeddings[i:i_end]  # [batch_i, hidden_dim]
            
            for j in range(0, N, batch_j):
                j_end = min(j + batch_j, N)
                tokens_j = embeddings[j:j_end]  # [batch_j, hidden_dim]
                
                # Broadcasting: [batch_i, 1, hidden_dim] - [1, batch_j, hidden_dim]
                diff = tokens_i[:, None, :] - tokens_j[None, :, :]  # [batch_i, batch_j, hidden_dim]
                
                # Matmul with metric tensor (tensor cores activate here!)
                # einsum handles the batch dimensions automatically
                M_delta = torch.einsum('ijk,kl->ijl', diff, M)  # [batch_i, batch_j, hidden_dim]
                
                # Inner product and sqrt
                squared_dist = (diff * M_delta).sum(dim=-1)  # [batch_i, batch_j]
                distances[i:i_end, j:j_end] = torch.sqrt(torch.clamp(squared_dist, min=0))
                
                pbar.update(1)
    
    return distances

print("\nComputing pairwise distance matrix...")
print(f"  This will compute {N_TOKENS * N_TOKENS:,} distances")
print(f"  Estimated time: {N_TOKENS**2 / 16000**2 * 2:.1f} minutes on H100\n")

distances = compute_pairwise_distances(sample_embeddings, M, BATCH_SIZE_I, BATCH_SIZE_J)

print(f"\n✓ Distance matrix computed!")
print(f"  Shape: {distances.shape}")
print(f"  Memory: {distances.element_size() * distances.nelement() / 1e9:.2f} GB")


Computing pairwise distance matrix...
  This will compute 4,096,000,000 distances
  Estimated time: 32.0 minutes on H100



Computing distances:   0%|          | 0/896 [00:00<?, ?it/s]


✓ Distance matrix computed!
  Shape: torch.Size([64000, 64000])
  Memory: 8.19 GB


## Validation

In [7]:
print("\nValidation checks:")

# Check 1: Diagonal should be zero (self-distances)
diag = torch.diagonal(distances)
print(f"  Self-distances (diagonal, should be ~0):")
print(f"    Mean: {diag.mean().item():.6f}")
print(f"    Max: {diag.max().item():.6f}")

# Check 2: Matrix should be symmetric
asymmetry = (distances - distances.T).abs().max()
print(f"  Symmetry check (should be ~0): {asymmetry.item():.6f}")

# Check 3: All distances should be non-negative
min_dist = distances.min()
print(f"  Minimum distance (should be ≥0): {min_dist.item():.6f}")

# Distance statistics
# Exclude diagonal for statistics
mask = ~torch.eye(N_TOKENS, dtype=torch.bool, device=distances.device)
off_diag = distances[mask]

print(f"\nDistance statistics (excluding diagonal):")
print(f"  Min: {off_diag.min().item():.2f}")
print(f"  Max: {off_diag.max().item():.2f}")
print(f"  Mean: {off_diag.mean().item():.2f}")

# Median computation - handle INT_MAX limitation for large arrays
if off_diag.numel() < 2_000_000_000:  # Safe threshold below INT_MAX
    median_val = off_diag.median().item()
    print(f"  Median: {median_val:.2f}")
else:
    # Array too large - use reservoir sampling for median approximation
    sample_size = 100_000_000  # 100M samples
    # Use random integers directly (no sorting required)
    sample_indices = torch.randint(0, off_diag.numel(), (sample_size,), device=off_diag.device)
    median_val = off_diag[sample_indices].median().item()
    print(f"  Median (approx, n={sample_size:,}): {median_val:.2f}")

print(f"  Std: {off_diag.std().item():.2f}")


Validation checks:
  Self-distances (diagonal, should be ~0):
    Mean: 0.000000
    Max: 0.000000
  Symmetry check (should be ~0): 0.000000
  Minimum distance (should be ≥0): 0.000000

Distance statistics (excluding diagonal):
  Min: 0.00
  Max: 112.69
  Mean: 71.00
  Median (approx, n=100,000,000): 72.19
  Std: 8.03


## Save Distance Matrix (Upper Triangle)

Since the matrix is symmetric, we only need to store the upper triangle.

In [8]:
print("\nExtracting upper triangle...")

# Get upper triangle indices (excluding diagonal)
triu_indices = torch.triu_indices(N_TOKENS, N_TOKENS, offset=1, device=device)
triu_values = distances[triu_indices[0], triu_indices[1]]

print(f"✓ Upper triangle extracted")
print(f"  Values: {triu_values.shape[0]:,}")
print(f"  Memory: {triu_values.element_size() * triu_values.nelement() / 1e9:.2f} GB")

# Save
print(f"\nSaving to {OUTPUT_DISTANCES}...")
torch.save({
    'triu_values': triu_values.cpu(),
    'token_indices': sample_indices.cpu(),
    'N': N_TOKENS,
    'metadata': {
        'model': MODEL_NAME,
        'metric_tensor_path': METRIC_TENSOR_PATH,
        'vocab_size': vocab_size,
        'hidden_dim': hidden_dim,
        'dtype': str(dtype),
        'random_seed': RANDOM_SEED,
        'distance_stats': {
            'min': off_diag.min().item(),
            'max': off_diag.max().item(),
            'mean': off_diag.mean().item(),
            'median': median_val,  # Use pre-computed median from validation cell
            'std': off_diag.std().item(),
        },
        'timestamp': datetime.now().isoformat(),
    }
}, OUTPUT_DISTANCES)

import os
file_size = os.path.getsize(OUTPUT_DISTANCES) / 1e9
print(f"✓ Saved!")
print(f"  File size: {file_size:.2f} GB")


Extracting upper triangle...
✓ Upper triangle extracted
  Values: 2,047,968,000
  Memory: 4.10 GB

Saving to ../data/vectors/distances_causal_64000.pt...
✓ Saved!
  File size: 4.90 GB


## How to Load and Reconstruct

```python
# Load
data = torch.load('distances_causal_16000.pt')
triu_values = data['triu_values']
token_indices = data['token_indices']
N = data['N']

# Reconstruct full symmetric matrix
distances = torch.zeros(N, N, dtype=triu_values.dtype)
triu_indices = torch.triu_indices(N, N, offset=1)
distances[triu_indices[0], triu_indices[1]] = triu_values
distances = distances + distances.T  # Make symmetric

# Now run UMAP with different hyperparameters
from umap import UMAP

umap_2d = UMAP(n_components=2, metric='precomputed', n_neighbors=15, min_dist=0.1)
embedding_2d = umap_2d.fit_transform(distances.numpy())

umap_3d = UMAP(n_components=3, metric='precomputed', n_neighbors=50, min_dist=0.01)
embedding_3d = umap_3d.fit_transform(distances.numpy())

# etc.
```

## Summary

✓ Computed pairwise causal distances for {N_TOKENS:,} tokens

✓ Saved upper triangle (symmetric matrix compression)

**Next steps:**
1. Download the `.pt` file to your local machine
2. Load and reconstruct the full distance matrix
3. Run UMAP with different hyperparameters (n_neighbors, min_dist, n_components)
4. Visualize semantic space! 🌌

**To scale up:**
- Increase N_TOKENS to 50,000 or 152,936 (full vocab)
- Rent more VRAM if needed (H200 has 141 GB)
- Same code works at any scale!