# 07.3a: Compute PCA Basis Vectors

**Goal:** Compute complete PCA basis for the centered token cloud and save all 2560 principal components.

This notebook:
1. Loads centered embedding matrix γ'
2. Computes covariance matrix and eigendecomposition
3. Saves **all 2560 principal components** in ranked order (descending by variance)
4. Saves eigenvalues for variance analysis

**Output:** `pca_basis_vectors.safetensors` containing:
- Keys '1' through '2560': principal component vectors
- Key 'eigenvalues': all 2560 eigenvalues (variance along each PC)

**Usage:** Load this basis in 07.2a and specify any three PCs to define your coordinate system:
- Standard view: PCs 1, 2, 3 (maximum variance directions)
- Alternative view: PCs 10, 11, 12 (lower variance structure)
- Noise floor: PCs 2550, 2551, 2552 (minimal variance directions)

**Run once:** This is a generator notebook. Run it once to create the basis file, then use 07.2a repeatedly to visualize from different orientations.

## Parameters

In [1]:
TENSOR_DIR = "../data/tensors"

# Input: centered embedding matrix
INPUT_FILE = "gamma_centered_qwen3_4b_instruct_2507.safetensors"
INPUT_KEY = "gamma_centered"

# Output: PCA basis vectors and eigenvalues
OUTPUT_FILE = "pca_basis_vectors.safetensors"

# Random seed (for reproducibility if using approximate methods)
RANDOM_SEED = 42

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file
from pathlib import Path

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print("Imports loaded successfully.")

Imports loaded successfully.


## Step 1: Load Data

In [3]:
data_dir = Path(TENSOR_DIR)

print("Loading centered embedding matrix...")
gamma_prime_data = load_file(data_dir / INPUT_FILE)
gamma_prime = gamma_prime_data[INPUT_KEY]

N, d = gamma_prime.shape

print(f"  Shape: ({N:,}, {d:,})")
print(f"  Dtype: {gamma_prime.dtype}")
print(f"  Device: {gamma_prime.device}")
print()

# Verify it's actually centered
mean_vec = gamma_prime.mean(dim=0)
mean_norm = torch.norm(mean_vec).item()
print(f"Mean vector norm: {mean_norm:.6e} (should be ~0 if centered)")
print()

Loading centered embedding matrix...
  Shape: (151,936, 2,560)
  Dtype: torch.float32
  Device: cpu

Mean vector norm: 3.952082e-08 (should be ~0 if centered)



## Step 2: Compute Covariance Matrix

In [4]:
print("Computing covariance matrix...")
print(f"  This will create a ({d:,} × {d:,}) matrix")
print(f"  Memory required: ~{d*d*4 / 1e9:.2f} GB (float32)")
print()

# Cov(γ') = (γ')ᵀ γ' / (N-1)
Cov = (gamma_prime.T @ gamma_prime) / (N - 1)

print(f"Covariance matrix computed.")
print(f"  Shape: {Cov.shape}")
print(f"  Dtype: {Cov.dtype}")
print()

Computing covariance matrix...
  This will create a (2,560 × 2,560) matrix
  Memory required: ~0.03 GB (float32)

Covariance matrix computed.
  Shape: torch.Size([2560, 2560])
  Dtype: torch.float32



## Step 3: Eigendecomposition

Compute all eigenvalues and eigenvectors of the covariance matrix.

**Note:** This is the most computationally expensive step. For a 2560×2560 matrix, this may take several seconds to a minute depending on hardware.

In [5]:
print("Computing eigendecomposition...")
print("  (This may take 30-60 seconds for 2560×2560 matrix)")
print()

# torch.linalg.eigh is for symmetric/Hermitian matrices (faster and more stable)
eigenvalues, eigenvectors = torch.linalg.eigh(Cov)

print("Eigendecomposition complete.")
print(f"  Eigenvalues shape: {eigenvalues.shape}")
print(f"  Eigenvectors shape: {eigenvectors.shape}")
print()

Computing eigendecomposition...
  (This may take 30-60 seconds for 2560×2560 matrix)

Eigendecomposition complete.
  Eigenvalues shape: torch.Size([2560])
  Eigenvectors shape: torch.Size([2560, 2560])



## Step 4: Sort by Variance (Descending)

Sort eigenvalues and eigenvectors in descending order so PC1 has the highest variance.

In [6]:
print("Sorting by variance (descending)...")

# torch.linalg.eigh returns eigenvalues in ascending order, so flip
eigenvalues = eigenvalues.flip(0)
eigenvectors = eigenvectors.flip(1)  # Flip columns

print("  Sorted.")
print()

# Report statistics
total_variance = eigenvalues.sum()

print("Top 10 principal components:")
print(f"  {'PC':<4} {'Eigenvalue':<15} {'Variance %':<12} {'Cumulative %':<15}")
print("  " + "-" * 50)

cumulative = 0.0
for i in range(10):
    var_explained = eigenvalues[i] / total_variance
    cumulative += var_explained
    print(f"  {i+1:<4} {eigenvalues[i].item():<15.6e} {var_explained.item()*100:<12.2f} {cumulative.item()*100:<15.2f}")

print()

print("Bottom 5 principal components:")
print(f"  {'PC':<4} {'Eigenvalue':<15} {'Variance %':<12}")
print("  " + "-" * 35)

for i in range(5):
    pc_idx = d - 5 + i
    var_explained = eigenvalues[pc_idx] / total_variance
    print(f"  {pc_idx+1:<4} {eigenvalues[pc_idx].item():<15.6e} {var_explained.item()*100:<12.2f}")

print()

Sorting by variance (descending)...
  Sorted.

Top 10 principal components:
  PC   Eigenvalue      Variance %   Cumulative %   
  --------------------------------------------------
  1    1.048719e-02    0.94         0.94           
  2    3.177739e-03    0.28         1.22           
  3    2.791374e-03    0.25         1.47           
  4    2.616169e-03    0.23         1.71           
  5    1.973001e-03    0.18         1.88           
  6    1.805293e-03    0.16         2.04           
  7    1.609086e-03    0.14         2.19           
  8    1.549411e-03    0.14         2.33           
  9    1.468294e-03    0.13         2.46           
  10   1.389096e-03    0.12         2.58           

Bottom 5 principal components:
  PC   Eigenvalue      Variance %  
  -----------------------------------
  2556 5.328412e-05    0.00        
  2557 5.124237e-05    0.00        
  2558 3.992779e-05    0.00        
  2559 1.252859e-05    0.00        
  2560 9.613870e-06    0.00        



## Step 5: Verify Orthonormality

Sanity check: eigenvectors should be orthonormal by construction.

In [7]:
print("Verifying orthonormality of first 3 PCs...")

pc1 = eigenvectors[:, 0]
pc2 = eigenvectors[:, 1]
pc3 = eigenvectors[:, 2]

# Dot products (should be ~0 for orthogonality)
dot_12 = (pc1 @ pc2).item()
dot_13 = (pc1 @ pc3).item()
dot_23 = (pc2 @ pc3).item()

# Norms (should be ~1)
norm_1 = torch.norm(pc1).item()
norm_2 = torch.norm(pc2).item()
norm_3 = torch.norm(pc3).item()

print(f"  Dot products (should be ~0):")
print(f"    PC1 · PC2: {dot_12:.6e}")
print(f"    PC1 · PC3: {dot_13:.6e}")
print(f"    PC2 · PC3: {dot_23:.6e}")
print()
print(f"  Norms (should be ~1):")
print(f"    ||PC1||: {norm_1:.6f}")
print(f"    ||PC2||: {norm_2:.6f}")
print(f"    ||PC3||: {norm_3:.6f}")
print()

# Overall check
max_dot = max(abs(dot_12), abs(dot_13), abs(dot_23))
max_norm_error = max(abs(norm_1 - 1), abs(norm_2 - 1), abs(norm_3 - 1))

if max_dot < 1e-5 and max_norm_error < 1e-5:
    print("✓ Basis is orthonormal.")
else:
    print("⚠ Warning: Basis may not be perfectly orthonormal.")
    print(f"  Max dot product: {max_dot:.6e}")
    print(f"  Max norm error: {max_norm_error:.6e}")

print()

Verifying orthonormality of first 3 PCs...
  Dot products (should be ~0):
    PC1 · PC2: 1.618173e-07
    PC1 · PC3: 3.725290e-09
    PC2 · PC3: 1.713634e-07

  Norms (should be ~1):
    ||PC1||: 1.000000
    ||PC2||: 0.999999
    ||PC3||: 1.000000

✓ Basis is orthonormal.



## Step 6: Save All Principal Components

Save all 2560 principal components to safetensors with string keys '1' through '2560'.

In [8]:
print("Preparing to save all principal components...")

# Build dictionary with string keys
save_dict = {}

for i in range(d):
    pc_idx = i + 1  # 1-indexed
    pc_vector = eigenvectors[:, i]
    save_dict[str(pc_idx)] = pc_vector

# Add eigenvalues
save_dict['eigenvalues'] = eigenvalues

print(f"  Prepared {d:,} principal components + eigenvalues")
print(f"  Total keys: {len(save_dict):,}")
print()

# Save to file
output_path = data_dir / OUTPUT_FILE
print(f"Saving to {output_path}...")

save_file(save_dict, output_path)

print(f"✓ Saved successfully.")
print()

# Report file size
file_size_mb = output_path.stat().st_size / 1e6
print(f"File size: {file_size_mb:.2f} MB")
print()

Preparing to save all principal components...
  Prepared 2,560 principal components + eigenvalues
  Total keys: 2,561

Saving to ../data/tensors/pca_basis_vectors.safetensors...
✓ Saved successfully.

File size: 26.41 MB



## Step 7: Verification - Load and Check

Verify we can load the saved basis successfully.

In [9]:
print("Verifying saved file...")

# Load back
loaded = load_file(output_path)

print(f"  Total keys in file: {len(loaded)}")
print()

# Check a few random PCs
test_indices = [1, 10, 100, 1000, 2560]
print("Checking sample principal components:")
for idx in test_indices:
    key = str(idx)
    if key in loaded:
        vec = loaded[key]
        norm = torch.norm(vec).item()
        print(f"  PC{idx}: shape {vec.shape}, norm {norm:.6f}")
    else:
        print(f"  PC{idx}: ⚠ NOT FOUND")

print()

# Check eigenvalues
if 'eigenvalues' in loaded:
    eigs = loaded['eigenvalues']
    print(f"Eigenvalues: shape {eigs.shape}")
    print(f"  First 3: {eigs[:3].tolist()}")
    print(f"  Last 3: {eigs[-3:].tolist()}")
else:
    print("⚠ Eigenvalues NOT FOUND")

print()
print("✓ Verification complete.")

Verifying saved file...
  Total keys in file: 2561

Checking sample principal components:
  PC1: shape torch.Size([2560]), norm 1.000000
  PC10: shape torch.Size([2560]), norm 1.000000
  PC100: shape torch.Size([2560]), norm 1.000000
  PC1000: shape torch.Size([2560]), norm 1.000000
  PC2560: shape torch.Size([2560]), norm 1.000000

Eigenvalues: shape torch.Size([2560])
  First 3: [0.010487192310392857, 0.0031777385156601667, 0.0027913739904761314]
  Last 3: [3.992778874817304e-05, 1.2528589650173672e-05, 9.613870133762248e-06]

✓ Verification complete.


## Summary

In [10]:
print("=" * 60)
print("PCA BASIS GENERATION COMPLETE")
print("=" * 60)
print()
print(f"Input: {INPUT_FILE}")
print(f"  Tokens: {N:,}")
print(f"  Dimensions: {d:,}")
print()
print(f"Output: {OUTPUT_FILE}")
print(f"  Principal components: {d:,} (keys '1' through '{d}')")
print(f"  Eigenvalues: (key 'eigenvalues')")
print(f"  File size: {file_size_mb:.2f} MB")
print()
print("Variance summary:")
print(f"  Top 3 PCs explain {(eigenvalues[:3].sum() / total_variance * 100).item():.2f}% of variance")
print(f"  Top 10 PCs explain {(eigenvalues[:10].sum() / total_variance * 100).item():.2f}% of variance")
print(f"  Top 100 PCs explain {(eigenvalues[:100].sum() / total_variance * 100).item():.2f}% of variance")
print()
print("Usage in 07.2a:")
print("  BASIS_FILE = 'pca_basis_vectors.safetensors'")
print("  BASIS_KEYS = {")
print("      'north': '1',      # PC1 (highest variance)")
print("      'meridian': '2',   # PC2")
print("      'equinox': '3'     # PC3")
print("  }")
print()
print("  Or try different orientations:")
print("  BASIS_KEYS = {'north': '10', 'meridian': '11', 'equinox': '12'}")
print()
print("=" * 60)

PCA BASIS GENERATION COMPLETE

Input: gamma_centered_qwen3_4b_instruct_2507.safetensors
  Tokens: 151,936
  Dimensions: 2,560

Output: pca_basis_vectors.safetensors
  Principal components: 2,560 (keys '1' through '2560')
  Eigenvalues: (key 'eigenvalues')
  File size: 26.41 MB

Variance summary:
  Top 3 PCs explain 1.47% of variance
  Top 10 PCs explain 2.58% of variance
  Top 100 PCs explain 10.16% of variance

Usage in 07.2a:
  BASIS_FILE = 'pca_basis_vectors.safetensors'
  BASIS_KEYS = {
      'north': '1',      # PC1 (highest variance)
      'meridian': '2',   # PC2
      'equinox': '3'     # PC3
  }

  Or try different orientations:
  BASIS_KEYS = {'north': '10', 'meridian': '11', 'equinox': '12'}

