# Word2Vec Test: Causal Geometry

**Goal:** Test the canonical word2vec arithmetic in **causal metric space**.

**Hypothesis:** `'woman' - 'man' + 'king' ≈ 'queen'` (measured with causal metric M)

**Why this matters:** This tests whether the **causal metric** (Park et al. 2024) preserves linear semantic relationships. We compare with 07.59a (Euclidean) to see if the probability geometry helps or hurts compositional semantics.

**Method:**
1. Compute synthetic vector: `v = 'woman' - 'man' + 'king'`
2. Find nearest neighbors by **causal distance** d_M(u,v) = sqrt((u-v)^T M (u-v))
3. Find nearest neighbors by **causal cosine similarity** (u^T M v) / (||u||_M ||v||_M)
4. Check where 'queen' ranks

**Note:** We search the **full vocabulary** (151k tokens), not just our 32k sample.

## Configuration

In [33]:
# Model
MODEL_NAME = 'Qwen/Qwen3-4B-Instruct-2507'

# Data paths
METRIC_TENSOR_PATH = '../data/vectors/causal_metric_tensor_qwen3_4b.pt'

# Tokens for arithmetic
TOKENS = {
    'man': None,    # Will tokenize
    'woman': None,  # Will tokenize
    'king': None,   # Will tokenize
    'queen': None,  # Will tokenize (for checking)
}

# Analysis parameters
TOP_N = 20  # How many neighbors to show

print(f"Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Arithmetic: 'woman' - 'man' + 'king' = ?")
print(f"  Metric: CAUSAL (using M tensor)")
print(f"  Search space: Full vocabulary (151k tokens)")
print(f"  Top N results: {TOP_N}")

Configuration:
  Model: Qwen/Qwen3-4B-Instruct-2507
  Arithmetic: 'woman' - 'man' + 'king' = ?
  Metric: CAUSAL (using M tensor)
  Search space: Full vocabulary (151k tokens)
  Top N results: 20


## Setup

In [34]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

print("✓ Imports complete")

✓ Imports complete


## Load Model, Tokenizer, and Metric Tensor

In [35]:
print("Loading model, tokenizer, and metric tensor...\n")

# Tokenizer
print(f"Loading tokenizer from {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Model (for full unembedding matrix)
print(f"\nLoading model (for unembedding matrix)...")
print("  This will take a minute...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map='cpu',
)

# Extract FULL unembedding matrix (all vocab)
gamma = model.lm_head.weight.data.to(torch.float32).cpu()  # [vocab_size, hidden_dim]

# Load metric tensor
print(f"\nLoading causal metric tensor from {METRIC_TENSOR_PATH}...")
metric_data = torch.load(METRIC_TENSOR_PATH, weights_only=False)
M = metric_data['M'].to(torch.float32).cpu()  # [hidden_dim, hidden_dim]

print(f"\n✓ All data loaded")
print(f"  Vocab size: {tokenizer.vocab_size:,}")
print(f"  Unembedding matrix shape: {gamma.shape}")
print(f"  Metric tensor shape: {M.shape}")
print(f"  Memory: {(gamma.element_size() * gamma.nelement() + M.element_size() * M.nelement()) / 1e9:.2f} GB")

Loading model, tokenizer, and metric tensor...

Loading tokenizer from Qwen/Qwen3-4B-Instruct-2507...

Loading model (for unembedding matrix)...
  This will take a minute...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


Loading causal metric tensor from ../data/vectors/causal_metric_tensor_qwen3_4b.pt...

✓ All data loaded
  Vocab size: 151,643
  Unembedding matrix shape: torch.Size([151936, 2560])
  Metric tensor shape: torch.Size([2560, 2560])
  Memory: 1.58 GB


## Tokenize and Verify

In [36]:
print("Tokenizing target words...\n")

words = ['man', 'woman', 'king', 'queen']
all_single_tokens = True

for word in words:
    tokens = tokenizer.encode(word, add_special_tokens=False)
    
    if len(tokens) == 1:
        token_id = tokens[0]
        TOKENS[word] = token_id
        text = tokenizer.decode([token_id])
        print(f"✓ '{word}' → token {token_id}: '{text}'")
    else:
        print(f"✗ '{word}' → {len(tokens)} tokens: {tokens}")
        all_single_tokens = False

if not all_single_tokens:
    print("\n⚠️  Not all words are single tokens!")
    print("    This may affect results, but we'll proceed anyway.")
else:
    print("\n✓ All words are single tokens")

Tokenizing target words...

✓ 'man' → token 1515: 'man'
✓ 'woman' → token 22028: 'woman'
✓ 'king' → token 10566: 'king'
✓ 'queen' → token 93114: 'queen'

✓ All words are single tokens


## Compute Synthetic Vector: 'woman' - 'man' + 'king'

In [37]:
print("\nComputing synthetic vector...\n")

# Get embeddings
v_man = gamma[TOKENS['man']]
v_woman = gamma[TOKENS['woman']]
v_king = gamma[TOKENS['king']]

# Arithmetic
v_synthetic = v_woman - v_man + v_king

print(f"v_synthetic = 'woman' - 'man' + 'king'")
print(f"\nSynthetic vector properties:")
print(f"  Shape: {v_synthetic.shape}")
print(f"  Euclidean norm: {torch.norm(v_synthetic).item():.2f}")

# Compute causal norm
causal_norm_synthetic = torch.sqrt(v_synthetic @ M @ v_synthetic).item()
print(f"  Causal norm: {causal_norm_synthetic:.2f} logometers")


Computing synthetic vector...

v_synthetic = 'woman' - 'man' + 'king'

Synthetic vector properties:
  Shape: torch.Size([2560])
  Euclidean norm: 1.81
  Causal norm: 86.75 logometers


## Direct Vector Comparison: Are the Transformations Parallel?

**Simple test:** Do `woman - man` and `queen - king` point in the same direction (in causal space)?

If word2vec arithmetic works, these should be **parallel** under the causal metric.

In [38]:
print("\n" + "=" * 80)
print("DIRECT VECTOR COMPARISON (CAUSAL METRIC)")
print("=" * 80)

# Get queen vector
v_queen = gamma[TOKENS['queen']]

# Compute the two displacement vectors
v_gender_shift = v_woman - v_man      # man → woman
v_royalty_shift = v_queen - v_king    # king → queen

# Causal norms: ||v||_M = sqrt(v^T M v)
causal_norm_gender = torch.sqrt(v_gender_shift @ M @ v_gender_shift).item()
causal_norm_royalty = torch.sqrt(v_royalty_shift @ M @ v_royalty_shift).item()

# Causal dot product and cosine: (u^T M v) / (||u||_M ||v||_M)
causal_dot_product = (v_gender_shift @ M @ v_royalty_shift).item()
causal_cosine_similarity = causal_dot_product / (causal_norm_gender * causal_norm_royalty)
causal_angle_rad = np.arccos(np.clip(causal_cosine_similarity, -1.0, 1.0))
causal_angle_deg = np.degrees(causal_angle_rad)

print(f"\n(woman - man) properties:")
print(f"  Causal norm: {causal_norm_gender:.4f} logometers")

print(f"\n(queen - king) properties:")
print(f"  Causal norm: {causal_norm_royalty:.4f} logometers")

print(f"\nAlignment between the two vectors (causal metric):")
print(f"  Causal dot product: {causal_dot_product:.4f}")
print(f"  Causal cosine similarity: {causal_cosine_similarity:.4f}")
print(f"  Causal angle: {causal_angle_deg:.2f}°")

print(f"\n" + "=" * 80)
print("INTERPRETATION:")
print("=" * 80)

if abs(causal_cosine_similarity) > 0.9:
    print(f"✅ Vectors are HIGHLY ALIGNED in causal space (cos={causal_cosine_similarity:.3f})")
    print(f"   Word2vec arithmetic should work in causal geometry!")
elif abs(causal_cosine_similarity) > 0.7:
    print(f"✓ Vectors are moderately aligned in causal space (cos={causal_cosine_similarity:.3f})")
    print(f"  Some linear structure exists under the causal metric")
elif abs(causal_cosine_similarity) > 0.3:
    print(f"⚠️  Vectors are weakly aligned in causal space (cos={causal_cosine_similarity:.3f})")
    print(f"   Minimal linear relationship under the causal metric")
else:
    print(f"❌ Vectors are nearly ORTHOGONAL in causal space (cos={causal_cosine_similarity:.3f}, angle={causal_angle_deg:.1f}°)")
    print(f"   No meaningful linear relationship under the causal metric")
    print(f"   Word2vec arithmetic will NOT work in causal geometry")

print(f"\nFor reference:")
print(f"  • Parallel vectors: cos=1.0, angle=0°")
print(f"  • Orthogonal vectors: cos=0.0, angle=90°")
print(f"  • Opposite vectors: cos=-1.0, angle=180°")


DIRECT VECTOR COMPARISON (CAUSAL METRIC)

(woman - man) properties:
  Causal norm: 60.7513 logometers

(queen - king) properties:
  Causal norm: 68.3953 logometers

Alignment between the two vectors (causal metric):
  Causal dot product: 253.7651
  Causal cosine similarity: 0.0611
  Causal angle: 86.50°

INTERPRETATION:
❌ Vectors are nearly ORTHOGONAL in causal space (cos=0.061, angle=86.5°)
   No meaningful linear relationship under the causal metric
   Word2vec arithmetic will NOT work in causal geometry

For reference:
  • Parallel vectors: cos=1.0, angle=0°
  • Orthogonal vectors: cos=0.0, angle=90°
  • Opposite vectors: cos=-1.0, angle=180°


## Precompute Causal Norms for All Tokens

We need these for cosine similarity calculations.

In [27]:
print("\nPrecomputing causal norms for all tokens...")
print("  (This may take a minute for 151k tokens...)\n")

# ||v||_M = sqrt(v^T M v)
causal_norms = torch.sqrt(torch.sum(gamma @ M * gamma, dim=1)).numpy()

print(f"✓ Causal norms computed")
print(f"  Range: [{causal_norms.min():.2f}, {causal_norms.max():.2f}] logometers")
print(f"  Mean: {causal_norms.mean():.2f} logometers")


Precomputing causal norms for all tokens...
  (This may take a minute for 151k tokens...)

✓ Causal norms computed
  Range: [21.35, 85.29] logometers
  Mean: 54.13 logometers


## Testing Scaled Vector Arithmetic (Causal Metric)

**Question:** Is the failure due to **magnitude** or **direction** in causal space?

**Hypothesis:** Maybe `queen = king + α * (woman - man)` for some optimal α ≠ 1 (measured with M)

**Method:** Solve for optimal α via least-squares projection **in causal metric**:

```
α = (queen - king)^T M (woman - man) / [(woman - man)^T M (woman - man)]
```

This finds the best scaling in the **causal geometry**, not Euclidean.

In [28]:
print("\nComputing causal distances to ALL tokens in vocabulary...\n")
print("  (This may take a few minutes for 151k tokens...)\n")

# Compute causal distances: d_M(u, v) = sqrt((u-v)^T M (u-v))
vocab_size = gamma.shape[0]
causal_distances = np.zeros(vocab_size)

for i in range(vocab_size):
    if i % 10000 == 0:
        print(f"  Progress: {i:,} / {vocab_size:,} tokens...")
    diff = gamma[i] - v_synthetic
    causal_distances[i] = torch.sqrt(diff @ M @ diff).item()

print(f"\n✓ Causal distances computed")

# Sort by distance (ascending)
sorted_indices = np.argsort(causal_distances)
top_indices = sorted_indices[:TOP_N]

print(f"\nTop {TOP_N} tokens by CAUSAL DISTANCE:")
print(f"{'Rank':<6} {'Distance':<12} {'Token ID':<10} {'Text':<40}")
print("=" * 80)

for rank, idx in enumerate(top_indices, 1):
    dist = causal_distances[idx]
    text = tokenizer.decode([idx])
    print(f"{rank:<6} {dist:<12.2f} {idx:<10} {text:<40}")

print("\n" + "=" * 80)


Computing causal distances to ALL tokens in vocabulary...

  (This may take a few minutes for 151k tokens...)

  Progress: 0 / 151,936 tokens...
  Progress: 10,000 / 151,936 tokens...
  Progress: 20,000 / 151,936 tokens...
  Progress: 30,000 / 151,936 tokens...
  Progress: 40,000 / 151,936 tokens...
  Progress: 50,000 / 151,936 tokens...
  Progress: 60,000 / 151,936 tokens...
  Progress: 70,000 / 151,936 tokens...
  Progress: 80,000 / 151,936 tokens...
  Progress: 90,000 / 151,936 tokens...
  Progress: 100,000 / 151,936 tokens...
  Progress: 110,000 / 151,936 tokens...
  Progress: 120,000 / 151,936 tokens...
  Progress: 130,000 / 151,936 tokens...
  Progress: 140,000 / 151,936 tokens...
  Progress: 150,000 / 151,936 tokens...

✓ Causal distances computed

Top 20 tokens by CAUSAL DISTANCE:
Rank   Distance     Token ID   Text                                    
1      60.75        10566      king                                    
2      79.26        22028      woman                   

## Find Nearest Neighbors by Causal Cosine Similarity

In [29]:
print("\nComputing causal cosine similarities to ALL tokens...\n")

# Compute causal cosine: cos(θ_M) = (u^T M v) / (||u||_M · ||v||_M)
causal_dot_products = (gamma @ M @ v_synthetic).numpy()  # [vocab_size]
causal_cosines = causal_dot_products / (causal_norms * causal_norm_synthetic)

# Sort by cosine (descending)
sorted_indices_cos = np.argsort(-causal_cosines)
top_indices_cos = sorted_indices_cos[:TOP_N]

print(f"Top {TOP_N} tokens by CAUSAL COSINE SIMILARITY:")
print(f"{'Rank':<6} {'Cosine':<12} {'Distance':<12} {'Token ID':<10} {'Text':<40}")
print("=" * 80)

for rank, idx in enumerate(top_indices_cos, 1):
    cos_sim = causal_cosines[idx]
    dist = causal_distances[idx]
    text = tokenizer.decode([idx])
    print(f"{rank:<6} {cos_sim:<12.4f} {dist:<12.2f} {idx:<10} {text:<40}")

print("\n" + "=" * 80)


Computing causal cosine similarities to ALL tokens...

Top 20 tokens by CAUSAL COSINE SIMILARITY:
Rank   Cosine       Distance     Token ID   Text                                    
1      0.7139       60.75        10566      king                                    
2      0.4474       79.26        22028      woman                                   
3      0.4105       81.09        33555      King                                    
4      0.3983       81.89        11477       king                                   
5      0.3612       83.86        6210        King                                   
6      0.3440       84.82        64662      women                                   
7      0.3440       85.14        73811       KING                                   
8      0.3278       85.43        27906       queen                                  
9      0.3227       84.65        3198        women                                  
10     0.3176       86.02        25079       kingdo

## Check for 'queen' Specifically

In [30]:
if TOKENS['queen'] is not None:
    idx_queen = TOKENS['queen']
    
    dist_to_queen = causal_distances[idx_queen]
    cos_to_queen = causal_cosines[idx_queen]
    
    # Find rank
    rank_by_distance = np.where(sorted_indices == idx_queen)[0][0] + 1
    rank_by_cosine = np.where(sorted_indices_cos == idx_queen)[0][0] + 1
    
    print("\n" + "=" * 80)
    print(f"CHECKING FOR 'queen' (token {idx_queen}):")
    print("=" * 80)
    print(f"\nCausal distance: {dist_to_queen:.2f} logometers (rank {rank_by_distance})")
    print(f"Causal cosine: {cos_to_queen:.4f} (rank {rank_by_cosine})")
    
    if rank_by_distance <= 5:
        print(f"\n🎉🎉🎉 'queen' is in the TOP 5 by causal distance!")
        print(f"         WORD2VEC MAGIC WORKS IN CAUSAL SPACE!")
    elif rank_by_distance <= 10:
        print(f"\n🎉 'queen' is in the TOP 10 by causal distance!")
        print(f"    Linear semantics confirmed in causal geometry!")
    elif rank_by_distance <= 20:
        print(f"\n✓ 'queen' is in the top 20 by causal distance")
        print(f"  Moderate evidence for linear semantics in causal space")
    elif rank_by_distance <= 100:
        print(f"\n'queen' is rank {rank_by_distance} by causal distance")
        print(f"  Weak signal, but detectable")
    else:
        print(f"\n❌ 'queen' is rank {rank_by_distance} by causal distance")
        print(f"   Word2vec arithmetic does NOT work in causal space")
    
    if rank_by_cosine <= 5:
        print(f"\n🎉🎉🎉 'queen' is in the TOP 5 by causal cosine!")
    elif rank_by_cosine <= 10:
        print(f"🎉 'queen' is in the TOP 10 by causal cosine!")
    elif rank_by_cosine <= 20:
        print(f"✓ 'queen' is in the top 20 by causal cosine")
    elif rank_by_cosine <= 100:
        print(f"'queen' is rank {rank_by_cosine} by causal cosine")
    else:
        print(f"❌ 'queen' is rank {rank_by_cosine} by causal cosine")
else:
    print("\n⚠️  Token 'queen' not found or not a single token")


CHECKING FOR 'queen' (token 93114):

Causal distance: 88.66 logometers (rank 5860)
Causal cosine: 0.2876 (rank 16)

❌ 'queen' is rank 5860 by causal distance
   Word2vec arithmetic does NOT work in causal space
✓ 'queen' is in the top 20 by causal cosine


In [31]:
if TOKENS['queen'] is not None:
    print("\n" + "=" * 80)
    print("TESTING SCALED VECTOR ARITHMETIC (CAUSAL METRIC)")
    print("=" * 80)
    
    # Get queen embedding
    v_queen = gamma[TOKENS['queen']]
    
    # Compute the gender shift vector
    v_gender_shift = v_woman - v_man
    
    # Compute the target displacement
    v_target_displacement = v_queen - v_king
    
    # Solve for α via least-squares projection in CAUSAL metric
    # α = (target^T M shift) / (shift^T M shift)
    numerator = (v_target_displacement @ M @ v_gender_shift).item()
    denominator = (v_gender_shift @ M @ v_gender_shift).item()
    alpha_optimal = numerator / denominator
    
    print(f"\nOptimal scaling factor α (causal) = {alpha_optimal:.6f}")
    print(f"\nInterpretation:")
    if abs(alpha_optimal) < 0.01:
        print(f"  α ≈ 0 → Gender shift is ORTHOGONAL to king→queen in causal space")
        print(f"  The vectors have no meaningful relationship")
    elif 0.8 <= alpha_optimal <= 1.2:
        print(f"  α ≈ 1 → Gender shift aligns well with king→queen in causal geometry!")
        print(f"  Word2vec arithmetic works (just needed right magnitude)")
    else:
        print(f"  α = {alpha_optimal:.2f} → Vectors are aligned but need rescaling")
        print(f"  Partial compositional structure in causal space")
    
    # Construct the scaled synthetic vector
    v_scaled = v_king + alpha_optimal * v_gender_shift
    
    # === CAUSAL DISTANCE ===
    # Unscaled synthetic to queen (already computed)
    causal_dist_synthetic_to_queen = causal_distances[TOKENS['queen']]
    
    # Scaled synthetic to queen
    diff_scaled_queen = v_scaled - v_queen
    causal_dist_scaled_to_queen = torch.sqrt(diff_scaled_queen @ M @ diff_scaled_queen).item()
    
    # === CAUSAL ANGLE ===
    # Compute causal norms
    causal_norm_synthetic = torch.sqrt(v_synthetic @ M @ v_synthetic).item()
    causal_norm_scaled = torch.sqrt(v_scaled @ M @ v_scaled).item()
    causal_norm_queen = causal_norms[TOKENS['queen']]
    
    # Compute causal cosine and convert to degrees
    causal_dot_synthetic_queen = (v_synthetic @ M @ v_queen).item()
    cos_causal_synthetic_queen = causal_dot_synthetic_queen / (causal_norm_synthetic * causal_norm_queen)
    angle_causal_synthetic_queen_rad = np.arccos(np.clip(cos_causal_synthetic_queen, -1.0, 1.0))
    angle_causal_synthetic_queen_deg = np.degrees(angle_causal_synthetic_queen_rad)
    
    causal_dot_scaled_queen = (v_scaled @ M @ v_queen).item()
    cos_causal_scaled_queen = causal_dot_scaled_queen / (causal_norm_scaled * causal_norm_queen)
    angle_causal_scaled_queen_rad = np.arccos(np.clip(cos_causal_scaled_queen, -1.0, 1.0))
    angle_causal_scaled_queen_deg = np.degrees(angle_causal_scaled_queen_rad)
    
    print(f"\n{'Method':<30} {'Distance (logo)':<20} {'Angle (degrees)':<15}")
    print("=" * 65)
    print(f"{'Original (α=1.0)':<30} {causal_dist_synthetic_to_queen:<20.4f} {angle_causal_synthetic_queen_deg:<15.2f}")
    print(f"{'Scaled (α=' + f'{alpha_optimal:.4f})':<30} {causal_dist_scaled_to_queen:<20.4f} {angle_causal_scaled_queen_deg:<15.2f}")
    
    dist_improvement = (1 - causal_dist_scaled_to_queen / causal_dist_synthetic_to_queen) * 100
    angle_improvement = (1 - angle_causal_scaled_queen_deg / angle_causal_synthetic_queen_deg) * 100
    
    print(f"\nDistance improvement: {dist_improvement:.1f}%")
    print(f"Angular improvement: {angle_improvement:.1f}%")
    
    if dist_improvement > 50:
        print("  → LARGE distance improvement! The issue was magnitude, not direction")
    elif dist_improvement > 10:
        print("  → Moderate distance improvement. Direction is partially correct")
    else:
        print("  → Minimal distance improvement. The vectors are not aligned in causal space")


TESTING SCALED VECTOR ARITHMETIC (CAUSAL METRIC)

Optimal scaling factor α (causal) = 0.068758

Interpretation:
  α = 0.07 → Vectors are aligned but need rescaling
  Partial compositional structure in causal space

Method                         Distance (logo)      Angle (degrees)
Original (α=1.0)               88.6629              73.29          
Scaled (α=0.0688)              68.2676              70.20          

Distance improvement: 23.0%
Angular improvement: 4.2%
  → Moderate distance improvement. Direction is partially correct


## Overlap Analysis

In [32]:
print("\n" + "=" * 80)
print("OVERLAP ANALYSIS")
print("=" * 80)

# Find tokens that appear in BOTH top N lists
top_by_distance = set(sorted_indices[:TOP_N])
top_by_cosine = set(sorted_indices_cos[:TOP_N])

overlap = top_by_distance & top_by_cosine

print(f"\nTokens in top {TOP_N} by BOTH metrics:")
print(f"  Count: {len(overlap)}")

if overlap:
    print(f"\nTokens appearing in both lists:")
    for tid in sorted(overlap):
        text = tokenizer.decode([tid])
        dist = causal_distances[tid]
        cos = causal_cosines[tid]
        print(f"  • {tid}: '{text}' (distance={dist:.2f}, cosine={cos:.4f})")


OVERLAP ANALYSIS

Tokens in top 20 by BOTH metrics:
  Count: 5

Tokens appearing in both lists:
  • 6210: ' King' (distance=83.86, cosine=0.3612)
  • 10566: 'king' (distance=60.75, cosine=0.7139)
  • 11477: ' king' (distance=81.89, cosine=0.3983)
  • 22028: 'woman' (distance=79.26, cosine=0.4474)
  • 33555: 'King' (distance=81.09, cosine=0.4105)


## Summary

**What we tested:** The classic word2vec arithmetic `'woman' - 'man' + 'king' ≈ 'queen'` in **causal metric geometry**.

**Why this matters:** 
- If it works here but not in Euclidean (07.59a), the causal metric **reveals** hidden semantic structure
- If it works in Euclidean but not here, the causal metric **destroys** linear semantics
- If it works in both, the causal metric **preserves** compositional structure
- If it works in neither, unembedding space lacks linear semantic relationships

**Comparison:** Check 07.59a results to determine which scenario holds.