# Experiment 040: Zipf-Spectral Mapping

**Question:** Do token embeddings organize by Zipf rank in spectral space?

**Hypothesis:** Common tokens have energy in low spectral bands; rare tokens in high bands.

**Key Prediction:** Correlation r > 0.5 between log(Zipf rank) and spectral centroid.

In [None]:
!pip install transformers torch scipy matplotlib -q

In [None]:
import torch
import numpy as np
from scipy.fft import fft
from scipy.stats import spearmanr, pearsonr
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load Model

In [None]:
MODEL_NAME = "gpt2"
N_BANDS = 7  # Match AKIRA's spectral structure

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
model.eval()

# Get embedding matrix
embeddings = model.get_input_embeddings().weight.detach().cpu()
vocab_size = embeddings.shape[0]
embed_dim = embeddings.shape[1]

print(f"Loaded {MODEL_NAME}")
print(f"Vocabulary size: {vocab_size}")
print(f"Embedding dimension: {embed_dim}")

## 2. Spectral Decomposition Functions

In [None]:
def spectral_decomposition(embedding: torch.Tensor, n_bands: int = N_BANDS) -> dict:
    """Decompose embedding into spectral bands via FFT."""
    emb_np = embedding.numpy()
    
    # FFT
    fft_result = fft(emb_np)
    magnitudes = np.abs(fft_result)
    
    # Split into bands
    n_freq = len(magnitudes) // 2
    band_size = n_freq // n_bands
    
    band_energies = {}
    for band in range(n_bands):
        start = band * band_size
        end = (band + 1) * band_size if band < n_bands - 1 else n_freq
        band_energies[band] = float(np.sum(magnitudes[start:end] ** 2))
    
    # Normalize
    total = sum(band_energies.values())
    if total > 0:
        band_energies = {k: v/total for k, v in band_energies.items()}
    
    return band_energies

def spectral_centroid(band_energies: dict) -> float:
    """Compute energy-weighted centroid."""
    total = sum(band_energies.values())
    if total == 0:
        return 0
    return sum(band * energy for band, energy in band_energies.items()) / total

# Test on first token
test_bands = spectral_decomposition(embeddings[0])
test_centroid = spectral_centroid(test_bands)
print(f"Test token bands: {test_bands}")
print(f"Test centroid: {test_centroid:.4f}")

## 3. Analyze Vocabulary

In [None]:
# Analyze sample of tokens (stratified by rank)
N_SAMPLES = 2000

# Stratified sampling: equal samples from each rank decile
n_per_bucket = N_SAMPLES // 10
bucket_size = vocab_size // 10

sample_ids = []
for i in range(10):
    start = i * bucket_size
    end = min((i + 1) * bucket_size, vocab_size)
    sampled = np.random.choice(range(start, end), min(n_per_bucket, end-start), replace=False)
    sample_ids.extend(sampled)

print(f"Analyzing {len(sample_ids)} tokens (stratified sample)")

# Analyze each token
results = []
for token_id in sample_ids:
    emb = embeddings[token_id]
    bands = spectral_decomposition(emb)
    centroid = spectral_centroid(bands)
    
    # Token ID as proxy for Zipf rank (lower ID = more common)
    zipf_rank = token_id + 1
    log_rank = np.log10(zipf_rank)
    
    try:
        token_str = tokenizer.decode([token_id])
    except:
        token_str = f"[{token_id}]"
    
    results.append({
        'token_id': token_id,
        'token_str': token_str,
        'zipf_rank': zipf_rank,
        'log_rank': log_rank,
        'bands': bands,
        'centroid': centroid
    })

print(f"Analyzed {len(results)} tokens")

## 4. Compute Correlations

In [None]:
# Extract data
log_ranks = [r['log_rank'] for r in results]
centroids = [r['centroid'] for r in results]

# Overall correlation
spearman_r, spearman_p = spearmanr(log_ranks, centroids)
pearson_r, pearson_p = pearsonr(log_ranks, centroids)

print("="*50)
print("OVERALL CORRELATION")
print("="*50)
print(f"Spearman r = {spearman_r:.4f} (p = {spearman_p:.2e})")
print(f"Pearson r = {pearson_r:.4f} (p = {pearson_p:.2e})")

# Band-specific correlations
print("\nBAND-SPECIFIC CORRELATIONS:")
band_correlations = {}
for band in range(N_BANDS):
    band_energies = [r['bands'][band] for r in results]
    corr, p = spearmanr(log_ranks, band_energies)
    band_correlations[band] = corr
    direction = "+" if corr > 0 else "-"
    print(f"  Band {band}: r = {corr:+.4f} ({direction})")

## 5. Analyze Extreme Tokens

In [None]:
# Sort by rank
sorted_results = sorted(results, key=lambda x: x['zipf_rank'])

# Most common (low rank)
common_100 = sorted_results[:100]
common_centroid = np.mean([r['centroid'] for r in common_100])

# Rarest (high rank)
rare_100 = sorted_results[-100:]
rare_centroid = np.mean([r['centroid'] for r in rare_100])

print("="*50)
print("EXTREME TOKEN ANALYSIS")
print("="*50)

print(f"\n100 Most Common tokens:")
print(f"  Mean centroid: {common_centroid:.4f}")
print(f"  Examples: {[r['token_str'][:10] for r in common_100[:5]]}")

print(f"\n100 Rarest tokens:")
print(f"  Mean centroid: {rare_centroid:.4f}")
print(f"  Examples: {[r['token_str'][:10] for r in rare_100[:5]]}")

print(f"\nCentroid Separation: {rare_centroid - common_centroid:.4f}")

## 6. Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Plot 1: Scatter - Log Rank vs Centroid
ax1 = axes[0, 0]
ax1.scatter(log_ranks, centroids, alpha=0.3, s=10)
# Trend line
z = np.polyfit(log_ranks, centroids, 1)
p = np.poly1d(z)
x_line = np.linspace(min(log_ranks), max(log_ranks), 100)
ax1.plot(x_line, p(x_line), 'r-', linewidth=2, label=f'Trend (r={spearman_r:.3f})')
ax1.set_xlabel('Log10(Zipf Rank)', fontsize=12)
ax1.set_ylabel('Spectral Centroid', fontsize=12)
ax1.set_title('Zipf Rank vs Spectral Centroid', fontsize=14)
ax1.legend()

# Plot 2: Band Correlations
ax2 = axes[0, 1]
bands = list(band_correlations.keys())
corrs = list(band_correlations.values())
colors = ['red' if c < 0 else 'green' for c in corrs]
ax2.bar(bands, corrs, color=colors, alpha=0.7)
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5)
ax2.set_xlabel('Spectral Band', fontsize=12)
ax2.set_ylabel('Correlation with Log(Rank)', fontsize=12)
ax2.set_title('Band-Specific Correlations', fontsize=14)

# Plot 3: Heatmap by Rank Bucket
ax3 = axes[1, 0]
n_buckets = 10
bucket_size = len(sorted_results) // n_buckets

heatmap_data = []
for i in range(n_buckets):
    start = i * bucket_size
    end = (i+1) * bucket_size if i < n_buckets - 1 else len(sorted_results)
    bucket = sorted_results[start:end]
    band_means = [np.mean([r['bands'][b] for r in bucket]) for b in range(N_BANDS)]
    heatmap_data.append(band_means)

im = ax3.imshow(heatmap_data, aspect='auto', cmap='viridis')
ax3.set_xlabel('Spectral Band', fontsize=12)
ax3.set_ylabel('Rank Bucket (Common -> Rare)', fontsize=12)
ax3.set_title('Band Energy by Zipf Bucket', fontsize=14)
ax3.set_xticks(range(N_BANDS))
ax3.set_yticks(range(n_buckets))
ax3.set_yticklabels([f'{i+1}' for i in range(n_buckets)])
plt.colorbar(im, ax=ax3, label='Energy')

# Plot 4: Centroid Distribution by Bucket
ax4 = axes[1, 1]
bucket_centroids = []
for i in range(n_buckets):
    start = i * bucket_size
    end = (i+1) * bucket_size if i < n_buckets - 1 else len(sorted_results)
    bucket = sorted_results[start:end]
    bucket_centroids.append([r['centroid'] for r in bucket])

ax4.boxplot(bucket_centroids, labels=[f'{i+1}' for i in range(n_buckets)])
ax4.set_xlabel('Rank Bucket (Common -> Rare)', fontsize=12)
ax4.set_ylabel('Spectral Centroid', fontsize=12)
ax4.set_title('Centroid Distribution by Bucket', fontsize=14)

plt.tight_layout()
plt.show()

## 7. Summary

In [None]:
print("="*60)
print("EXPERIMENT 040 SUMMARY")
print("="*60)

print("\n1. CORRELATION RESULTS:")
print(f"   Spearman r = {spearman_r:.4f}")
print(f"   p-value = {spearman_p:.2e}")

print("\n2. BAND STRUCTURE:")
for band, corr in band_correlations.items():
    if corr < -0.1:
        print(f"   Band {band}: NEGATIVE (common tokens have more energy here)")
    elif corr > 0.1:
        print(f"   Band {band}: POSITIVE (rare tokens have more energy here)")
    else:
        print(f"   Band {band}: NEUTRAL")

print("\n3. CENTROID SEPARATION:")
centroid_sep = rare_centroid - common_centroid
print(f"   Common tokens: {common_centroid:.4f}")
print(f"   Rare tokens: {rare_centroid:.4f}")
print(f"   Separation: {centroid_sep:.4f}")

print("\n4. VERDICT:")
if spearman_r > 0.3 and spearman_p < 0.001 and centroid_sep > 0.3:
    print("   HYPOTHESIS SUPPORTED")
    print("   - Significant positive correlation")
    print("   - Common tokens -> Low bands (DC-like)")
    print("   - Rare tokens -> High bands (detail)")
    print("   - Zipf's Law maps to spectral structure")
else:
    print("   HYPOTHESIS NOT SUPPORTED")
    if spearman_r <= 0.3:
        print(f"   - Correlation too weak: r = {spearman_r:.3f}")
    if centroid_sep <= 0.3:
        print(f"   - Separation too small: {centroid_sep:.3f}")