# Eigenbasis Precomputation

**Goal:** Compute and save expensive matrices needed for eigenbasis analysis.

**Why separate notebook:** These computations are heavy (minutes to run) but only need to be done **once**. Future analysis notebooks (09.2+) can load the precomputed results instantly.

**What we compute:**
1. Full eigendecomposition of M (eigenvalues + eigenvectors)
2. Token projections onto all 2,560 eigenvectors (spherical coordinates)
3. [Future: Add more precomputations as needed]

**Output files:**
- `data/vectors/eigenbasis_qwen3_4b.pt` - Eigenvalues and eigenvectors
- `data/vectors/token_eigenbasis_projections_qwen3_4b.pt` - Full projection matrix

**Run time:** ~2-5 minutes on CPU

## Configuration

In [1]:
import sys
sys.path.append('..')

from azimuth.config import RANDOM_SEED

# Model
MODEL_NAME = 'Qwen/Qwen3-4B-Instruct-2507'

# Input paths
METRIC_TENSOR_PATH = '../data/vectors/causal_metric_tensor_qwen3_4b.pt'

# Output paths
EIGENBASIS_PATH = '../data/vectors/eigenbasis_qwen3_4b.pt'
PROJECTIONS_PATH = '../data/vectors/token_eigenbasis_projections_qwen3_4b.pt'

print(f"Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Random seed: {RANDOM_SEED}")
print(f"  Output: {EIGENBASIS_PATH}")
print(f"  Output: {PROJECTIONS_PATH}")

Configuration:
  Model: Qwen/Qwen3-4B-Instruct-2507
  Random seed: 42
  Output: ../data/vectors/eigenbasis_qwen3_4b.pt
  Output: ../data/vectors/token_eigenbasis_projections_qwen3_4b.pt


## Setup

In [2]:
import numpy as np
import torch
from transformers import AutoModelForCausalLM
from pathlib import Path
from datetime import datetime

print("✓ Imports complete")

✓ Imports complete


## Load Model and Metric Tensor

In [3]:
print("Loading model and metric tensor...\n")

# Load model (for unembedding matrix)
print(f"Loading model from {MODEL_NAME}...")
print("  This will take a minute...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map='cpu',
)

# Extract FULL unembedding matrix (all vocab)
gamma = model.lm_head.weight.data.to(torch.float32).cpu()  # [vocab_size, hidden_dim]
vocab_size, hidden_dim = gamma.shape

# Load metric tensor
print(f"\nLoading causal metric tensor from {METRIC_TENSOR_PATH}...")
metric_data = torch.load(METRIC_TENSOR_PATH, weights_only=False)
M = metric_data['M'].to(torch.float32).cpu()  # [hidden_dim, hidden_dim]

print(f"\n✓ All data loaded")
print(f"  Vocab size: {vocab_size:,}")
print(f"  Hidden dim: {hidden_dim:,}")
print(f"  Unembedding matrix shape: {gamma.shape}")
print(f"  Metric tensor shape: {M.shape}")
print(f"  Memory usage: {(gamma.element_size() * gamma.nelement() + M.element_size() * M.nelement()) / 1e9:.2f} GB")

Loading model and metric tensor...

Loading model from Qwen/Qwen3-4B-Instruct-2507...
  This will take a minute...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


Loading causal metric tensor from ../data/vectors/causal_metric_tensor_qwen3_4b.pt...

✓ All data loaded
  Vocab size: 151,936
  Hidden dim: 2,560
  Unembedding matrix shape: torch.Size([151936, 2560])
  Metric tensor shape: torch.Size([2560, 2560])
  Memory usage: 1.58 GB


---

# Eigendecomposition of M

Compute the full eigendecomposition: M = V Λ V^T

- **Eigenvalues (Λ):** Variance along each principal axis
- **Eigenvectors (V):** The axes themselves (columns of V)

This defines the natural coordinate system for the causal metric.

In [4]:
print("\n" + "=" * 80)
print("COMPUTING EIGENDECOMPOSITION OF M")
print("=" * 80)

print(f"\nComputing eigenvalues and eigenvectors...")
print(f"  Matrix size: {M.shape}")
print(f"  This will take 1-2 minutes...\n")

# Compute eigendecomposition
# eigh returns eigenvalues in ascending order
eigenvalues, eigenvectors = torch.linalg.eigh(M)

print(f"✓ Eigendecomposition complete\n")

# Summary statistics
print(f"Eigenvalue statistics:")
print(f"  Min: {eigenvalues.min().item():.2f}")
print(f"  Max: {eigenvalues.max().item():.2f}")
print(f"  Mean: {eigenvalues.mean().item():.2f}")
print(f"  Median: {eigenvalues.median().item():.2f}")

# Verify eigenvectors are orthonormal
identity_check = eigenvectors.T @ eigenvectors
off_diagonal_max = (identity_check - torch.eye(hidden_dim)).abs().max().item()
print(f"\nOrthonormality check:")
print(f"  Max off-diagonal: {off_diagonal_max:.2e} (should be ~0)")

if off_diagonal_max < 1e-5:
    print(f"  ✓ Eigenvectors are orthonormal")
else:
    print(f"  ⚠️  Eigenvectors may have numerical issues")


COMPUTING EIGENDECOMPOSITION OF M

Computing eigenvalues and eigenvectors...
  Matrix size: torch.Size([2560, 2560])
  This will take 1-2 minutes...

✓ Eigendecomposition complete

Eigenvalue statistics:
  Min: 95.35
  Max: 94217.94
  Mean: 2713.64
  Median: 2498.26

Orthonormality check:
  Max off-diagonal: 2.74e-06 (should be ~0)
  ✓ Eigenvectors are orthonormal


## Save Eigenbasis

In [5]:
print(f"\nSaving eigenbasis to {EIGENBASIS_PATH}...")

Path(EIGENBASIS_PATH).parent.mkdir(parents=True, exist_ok=True)

torch.save({
    'eigenvalues': eigenvalues,  # [hidden_dim]
    'eigenvectors': eigenvectors,  # [hidden_dim, hidden_dim]
    'metadata': {
        'model': MODEL_NAME,
        'hidden_dim': hidden_dim,
        'source_metric_tensor': METRIC_TENSOR_PATH,
        'computation_date': datetime.now().isoformat(),
        'description': 'Eigendecomposition of causal metric tensor M = V Λ V^T',
        'note': 'Eigenvalues are in ascending order. Eigenvectors are columns of the matrix.',
    }
}, EIGENBASIS_PATH)

file_size = Path(EIGENBASIS_PATH).stat().st_size / 1e6
print(f"✓ Saved ({file_size:.1f} MB)")


Saving eigenbasis to ../data/vectors/eigenbasis_qwen3_4b.pt...
✓ Saved (26.2 MB)


---

# Token Projections onto Eigenbasis

**Compute spherical coordinates for all tokens.**

Project each token vector onto all 2,560 eigenvectors:

```
projections[i, j] = gamma[i] · eigenvectors[:, j]
```

This gives us the "coordinates" of each token in the eigenbasis.

**Matrix form:** `projections = gamma @ eigenvectors`

**Cost:** 151,936 × 2,560 × 2,560 ≈ 1 trillion FLOPs (takes ~30-60 seconds on CPU)

In [6]:
print("\n" + "=" * 80)
print("COMPUTING TOKEN PROJECTIONS ONTO EIGENBASIS")
print("=" * 80)

print(f"\nMatrix multiplication: gamma @ eigenvectors")
print(f"  gamma shape: {gamma.shape}")
print(f"  eigenvectors shape: {eigenvectors.shape}")
print(f"  Result shape: [{vocab_size}, {hidden_dim}]")
print(f"  Total operations: ~1 trillion FLOPs")
print(f"  Estimated time: 30-60 seconds...\n")

# Compute projections
projections = gamma @ eigenvectors

print(f"✓ Projections computed\n")

# Summary statistics
print(f"Projection matrix properties:")
print(f"  Shape: {projections.shape}")
print(f"  Memory: {projections.element_size() * projections.nelement() / 1e9:.2f} GB")
print(f"  Value range: [{projections.min().item():.2f}, {projections.max().item():.2f}]")
print(f"  Mean: {projections.mean().item():.4f}")
print(f"  Std: {projections.std().item():.4f}")


COMPUTING TOKEN PROJECTIONS ONTO EIGENBASIS

Matrix multiplication: gamma @ eigenvectors
  gamma shape: torch.Size([151936, 2560])
  eigenvectors shape: torch.Size([2560, 2560])
  Result shape: [151936, 2560]
  Total operations: ~1 trillion FLOPs
  Estimated time: 30-60 seconds...

✓ Projections computed

Projection matrix properties:
  Shape: torch.Size([151936, 2560])
  Memory: 1.56 GB
  Value range: [-0.80, 0.73]
  Mean: 0.0001
  Std: 0.0217


## Verify: Reconstruct Token Norms from Projections

The eigenvectors form an orthonormal basis, so:

```
||token||² = Σ (projection onto eigenvector_i)²
```

This is a sanity check that our projections are correct.

In [7]:
print("\nVerifying projections...\n")

# Compute norms two ways
# Method 1: Direct from gamma
norms_direct = torch.norm(gamma, dim=1)

# Method 2: From projections (since eigenvectors are orthonormal)
norms_from_projections = torch.norm(projections, dim=1)

# Compare
max_diff = (norms_direct - norms_from_projections).abs().max().item()
mean_diff = (norms_direct - norms_from_projections).abs().mean().item()

print(f"Norm reconstruction check:")
print(f"  Max difference: {max_diff:.2e}")
print(f"  Mean difference: {mean_diff:.2e}")

if max_diff < 1e-4:
    print(f"\n  ✓ Projections are correct!")
    print(f"    Norms reconstructed from eigenbasis match original norms.")
else:
    print(f"\n  ⚠️  Large discrepancy detected - possible numerical issues")


Verifying projections...

Norm reconstruction check:
  Max difference: 1.79e-06
  Mean difference: 6.27e-07

  ✓ Projections are correct!
    Norms reconstructed from eigenbasis match original norms.


## Save Token Projections

In [8]:
print(f"\nSaving token projections to {PROJECTIONS_PATH}...")

Path(PROJECTIONS_PATH).parent.mkdir(parents=True, exist_ok=True)

torch.save({
    'projections': projections,  # [vocab_size, hidden_dim]
    'metadata': {
        'model': MODEL_NAME,
        'vocab_size': vocab_size,
        'hidden_dim': hidden_dim,
        'eigenbasis_source': EIGENBASIS_PATH,
        'computation_date': datetime.now().isoformat(),
        'description': 'Token projections onto eigenvectors of causal metric tensor M',
        'note': 'projections[i, j] = dot product of token i with eigenvector j',
    }
}, PROJECTIONS_PATH)

file_size = Path(PROJECTIONS_PATH).stat().st_size / 1e9
print(f"✓ Saved ({file_size:.2f} GB)")


Saving token projections to ../data/vectors/token_eigenbasis_projections_qwen3_4b.pt...
✓ Saved (1.56 GB)


---

# Summary

**What we computed:**

1. **Eigendecomposition of M:**
   - 2,560 eigenvalues (variance along each principal axis)
   - 2,560 eigenvectors (the principal axes themselves)
   - Saved to: `data/vectors/eigenbasis_qwen3_4b.pt`

2. **Token projections onto eigenbasis:**
   - 151,936 tokens × 2,560 eigenvectors = spherical coordinates
   - Saved to: `data/vectors/token_eigenbasis_projections_qwen3_4b.pt`

**Next notebooks (09.2+) can now:**
- Load these precomputed matrices instantly
- Analyze token distribution along any eigenvector
- Find geometric structure in eigenbasis coordinates
- Identify which eigenspaces contain semantic information

**Ready for eigenbasis analysis!** 🚀