# 08.3a: PCA Basis for Trained Embeddings

**Compute principal component basis from final trained embedding matrix**

After training our tiny transformer, we have a 128×64 embedding matrix that evolved from initialization through 5000 training steps. To visualize the geometry (sky maps, density plots, etc.), we need a coordinate system.

This notebook computes all 64 principal components and saves them for use in visualization notebooks.

## Parameters

In [1]:
# Input: final trained embedding matrix
EMBEDDING_PATH = "../data/embeddings_128vocab_qweninit/step_0005000.safetensors"

# Output: PCA basis (saved alongside embeddings)
PCA_BASIS_PATH = "../data/embeddings_128vocab_qweninit/step_0005000_pca_basis.safetensors"

RANDOM_SEED = 42

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

## Load Embedding Matrix

In [3]:
print(f"Loading embeddings from: {EMBEDDING_PATH}\n")

gamma = load_file(EMBEDDING_PATH)['embeddings']
vocab_size, hidden_dim = gamma.shape

print(f"✓ Embeddings loaded")
print(f"Shape: {gamma.shape}")
print(f"Vocabulary: {vocab_size} tokens")
print(f"Dimensions: {hidden_dim}")

Loading embeddings from: ../data/embeddings_128vocab_qweninit/step_0005000.safetensors

✓ Embeddings loaded
Shape: torch.Size([128, 64])
Vocabulary: 128 tokens
Dimensions: 64


## Compute Statistics

In [4]:
# Token norms
norms = torch.norm(gamma, p=2, dim=1)

# Centroid
centroid = gamma.mean(dim=0)
centroid_norm = centroid.norm().item()

print(f"\nToken L2 norms:")
print(f"  Min: {norms.min().item():.6f}")
print(f"  Max: {norms.max().item():.6f}")
print(f"  Mean: {norms.mean().item():.6f}")
print(f"  Median: {norms.median().item():.6f}")

print(f"\nCentroid:")
print(f"  L2 norm: {centroid_norm:.6f}")
print(f"  As fraction of mean token norm: {centroid_norm / norms.mean().item():.4f}")


Token L2 norms:
  Min: 7.527178
  Max: 8.430494
  Mean: 7.749794
  Median: 7.639194

Centroid:
  L2 norm: 7.731306
  As fraction of mean token norm: 0.9976


## Center the Cloud

Compute γ' = γ - μ (center at origin for PCA)

In [5]:
gamma_prime = gamma - centroid

# Verify centering
new_centroid = gamma_prime.mean(dim=0)
print(f"Centered cloud:")
print(f"  New centroid L2 norm: {new_centroid.norm().item():.6e}")
print(f"  (Should be ~0)")

Centered cloud:
  New centroid L2 norm: 5.995414e-07
  (Should be ~0)


## Compute PCA

Use SVD to find principal components: γ' = U Σ V^T

The columns of V are the principal components (eigenvectors of covariance matrix).

In [6]:
print(f"\nComputing PCA via SVD...\n")

# SVD: gamma_prime = U @ diag(S) @ V^T
# V columns are principal components
U, S, Vt = torch.linalg.svd(gamma_prime, full_matrices=False)

# V is Vt transposed
V = Vt.T

# Eigenvalues are singular values squared, divided by (n-1)
eigenvalues = (S ** 2) / (vocab_size - 1)

print(f"✓ SVD complete")
print(f"Principal components: {V.shape}")
print(f"Eigenvalues: {eigenvalues.shape}")

# Show variance explained
total_variance = eigenvalues.sum()
variance_explained = eigenvalues / total_variance
cumulative_variance = torch.cumsum(variance_explained, dim=0)

print(f"\nVariance explained by top components:")
for i in range(min(10, len(eigenvalues))):
    print(f"  PC{i+1:2d}: {100 * variance_explained[i].item():5.2f}%  (cumulative: {100 * cumulative_variance[i].item():5.2f}%)")


Computing PCA via SVD...

✓ SVD complete
Principal components: torch.Size([64, 64])
Eigenvalues: torch.Size([64])

Variance explained by top components:
  PC 1: 66.20%  (cumulative: 66.20%)
  PC 2: 10.48%  (cumulative: 76.69%)
  PC 3:  5.14%  (cumulative: 81.83%)
  PC 4:  3.94%  (cumulative: 85.77%)
  PC 5:  3.53%  (cumulative: 89.29%)
  PC 6:  1.97%  (cumulative: 91.27%)
  PC 7:  1.81%  (cumulative: 93.08%)
  PC 8:  1.61%  (cumulative: 94.68%)
  PC 9:  1.19%  (cumulative: 95.88%)
  PC10:  0.78%  (cumulative: 96.66%)


## Save PCA Basis

Save all principal components with string keys ('1', '2', ..., '64') plus eigenvalues and centroid.

In [7]:
# Build dictionary with string keys
save_dict = {
    'eigenvalues': eigenvalues,
    'centroid': centroid,
}

# Add each principal component (clone to avoid shared memory)
for i in range(hidden_dim):
    save_dict[str(i + 1)] = V[:, i].clone()

# Save
save_file(save_dict, PCA_BASIS_PATH)

print(f"\n✓ Saved PCA basis to: {PCA_BASIS_PATH}")
print(f"\nContents:")
print(f"  eigenvalues: {eigenvalues.shape}")
print(f"  centroid: {centroid.shape}")
print(f"  Principal components '1' through '{hidden_dim}': each {V[:, 0].shape}")


✓ Saved PCA basis to: ../data/embeddings_128vocab_qweninit/step_0005000_pca_basis.safetensors

Contents:
  eigenvalues: torch.Size([64])
  centroid: torch.Size([64])
  Principal components '1' through '64': each torch.Size([64])


## Verify Orthonormality

Check that principal components are orthonormal (Q^T Q = I)

In [8]:
# Check first 3 PCs (the ones we'll use for visualization)
Q = V[:, :3]
gram = Q.T @ Q

print(f"\nGram matrix (Q^T Q) for first 3 PCs:")
print(gram.numpy())
print(f"\n(Should be identity matrix)")

# Check deviation from identity
I = torch.eye(3)
error = torch.norm(gram - I, p='fro').item()
print(f"\nFrobenius norm of (Q^T Q - I): {error:.6e}")
print(f"(Should be ~0)")


Gram matrix (Q^T Q) for first 3 PCs:
[[ 9.9999994e-01 -1.6242453e-08 -4.7340816e-08]
 [-1.6242453e-08  1.0000008e+00 -1.6752810e-07]
 [-4.7340816e-08 -1.6752810e-07  1.0000005e+00]]

(Should be identity matrix)

Frobenius norm of (Q^T Q - I): 9.941829e-07
(Should be ~0)


## Summary

In [9]:
print(f"\n{'='*80}")
print("PCA BASIS COMPUTED")
print(f"{'='*80}")
print(f"Embedding matrix: {vocab_size} tokens × {hidden_dim} dimensions")
print(f"Centroid L2 norm: {centroid_norm:.6f}")
print(f"\nPrincipal components: {hidden_dim}")
print(f"Top 3 explain {100 * cumulative_variance[2].item():.2f}% of variance")
print(f"\nOutput: {PCA_BASIS_PATH}")
print(f"\nReady for visualization with 07.2a (or adapted sky map notebook)")
print(f"{'='*80}")


PCA BASIS COMPUTED
Embedding matrix: 128 tokens × 64 dimensions
Centroid L2 norm: 7.731306

Principal components: 64
Top 3 explain 81.83% of variance

Output: ../data/embeddings_128vocab_qweninit/step_0005000_pca_basis.safetensors

Ready for visualization with 07.2a (or adapted sky map notebook)
