# 1.5a: Logit Equivalence Classes

This notebook identifies tokens that produce **identical logit scores** for a given hidden state due to bfloat16 quantization.

## The Question

We know from 1.4a that 2,251 tokens share identical **cosine** (0.84375) to a reference direction. But cosine only captures direction, not magnitude.

**Logits are computed via dot products:**
```
logit[token_i] = hidden_state @ W[i]
```

The dot product depends on both direction AND magnitude. Due to bfloat16's limited precision, tokens with **different embeddings** may produce **identical dot products** when both are rounded to bfloat16.

This notebook finds the **logit equivalence classes**: sets of tokens that are indistinguishable by dot product at bfloat16 precision.

## Approach

1. Define a simulated hidden state `h` using spherical coordinates (lat/lon/distance)
2. Compute dot products `W @ h` in bfloat16
3. Find tokens with identical dot product values
4. Report the largest equivalence class (mode of dot product distribution)
5. Compare to embedding-based equivalence (tokens with literally identical W[i])

## Parameters

In [48]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

# PCA basis (same as 1.4a)
NORTH_PC = 2      # North pole (+90° latitude)
MERIDIAN_PC = 1   # Prime meridian (0° longitude)
EQUINOX_PC = 3    # Equinox (+90° longitude)

# Simulated hidden state (spherical coordinates)
REF_LAT = -7.288   # Latitude (degrees) - pointing at cluster
REF_LON = 6.941    # Longitude (degrees)
REF_DIST = None    # Distance from origin (gamma units) - None = use cluster mean norm

## Imports

In [49]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file
from pathlib import Path

## Device Detection

In [50]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load W

In [51]:
# Load W in bfloat16
tensor_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W_bf16 = load_file(tensor_path)["W"]

print(f"Loaded W from {tensor_path}")
print(f"  Shape: {W_bf16.shape}")
print(f"  Dtype: {W_bf16.dtype}")

N, d = W_bf16.shape
print(f"\nToken space: {N:,} tokens in {d:,} dimensions")

Loaded W from ../tensors/Qwen3-4B-Instruct-2507/W.safetensors
  Shape: torch.Size([151936, 2560])
  Dtype: torch.bfloat16

Token space: 151,936 tokens in 2,560 dimensions


## Compute PCA Basis

We need the PCA basis to convert lat/lon/distance to a 2560D vector.

In [52]:
print("Computing PCA...")

# Work in float32 for PCA (more stable)
W = W_bf16.to(torch.float32)

# Center the data
W_centered = W - W.mean(dim=0)

# Compute covariance matrix
print("  Computing covariance matrix...")
cov = (W_centered.T @ W_centered) / N

# Eigendecomposition
print("  Computing eigendecomposition...")
eigenvalues, eigenvectors = torch.linalg.eigh(cov)

# Sort by descending eigenvalue
idx = torch.argsort(eigenvalues, descending=True)
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]

print(f"\n✓ PCA computed")

Computing PCA...
  Computing covariance matrix...
  Computing eigendecomposition...

✓ PCA computed


## Define Simulated Hidden State

Convert lat/lon/distance to a 2560D vector representing our simulated hidden state.

In [53]:
def get_pc_vector(pcs, index):
    """Get PC vector by index (1-indexed), with sign flip for negative indices."""
    pc_num = abs(index) - 1
    vector = pcs[:, pc_num].clone()
    if index < 0:
        vector = -vector
    return vector


# Extract basis vectors
north = get_pc_vector(eigenvectors, NORTH_PC)
meridian = get_pc_vector(eigenvectors, MERIDIAN_PC)
equinox = get_pc_vector(eigenvectors, EQUINOX_PC)

print("Spherical coordinate basis:")
print(f"  North (+Z):    PC{NORTH_PC}")
print(f"  Meridian (+X): PC{MERIDIAN_PC}")
print(f"  Equinox (+Y):  PC{EQUINOX_PC}")
print()

# Convert lat/long to Cartesian coordinates on unit sphere
lat_rad = np.deg2rad(REF_LAT)
lon_rad = np.deg2rad(REF_LON)

# Spherical to Cartesian: (r=1, lat, lon) -> (x, y, z)
x = np.cos(lat_rad) * np.cos(lon_rad)  # Meridian component
y = np.cos(lat_rad) * np.sin(lon_rad)  # Equinox component
z = np.sin(lat_rad)                     # North component

print(f"Direction (unit sphere):")
print(f"  Lat: {REF_LAT:.3f}°, Lon: {REF_LON:.3f}°")
print(f"  Cartesian: x={x:.6f}, y={y:.6f}, z={z:.6f}")
print()

# Convert to direction in full 2560D space
direction = x * meridian + y * equinox + z * north
direction = direction / direction.norm()  # Normalize

# Determine distance (default to cluster mean norm if not specified)
if REF_DIST is None:
    # Load cluster membership to get mean norm
    cluster_path = Path(f"../tensors/{MODEL_NAME}/1.4a_cluster_members.safetensors")
    cluster_data = load_file(cluster_path)
    cluster_token_ids = cluster_data["cluster_token_ids"].long()
    
    # Compute mean norm of cluster embeddings
    cluster_norms = W_bf16[cluster_token_ids].norm(dim=1)
    ref_dist = cluster_norms.mean().item()
    print(f"Distance: {ref_dist:.6f} (cluster mean norm)")
else:
    ref_dist = REF_DIST
    print(f"Distance: {ref_dist:.6f} (user specified)")

# Construct hidden state
h = direction * ref_dist

print(f"\nSimulated hidden state in {d}D space:")
print(f"  Norm: {h.norm().item():.6f}")
print(f"  Direction: lat={REF_LAT:.3f}°, lon={REF_LON:.3f}°")

Spherical coordinate basis:
  North (+Z):    PC2
  Meridian (+X): PC1
  Equinox (+Y):  PC3

Direction (unit sphere):
  Lat: -7.288°, Lon: 6.941°
  Cartesian: x=0.984651, y=0.119871, z=-0.126857

Distance: 0.371094 (cluster mean norm)

Simulated hidden state in 2560D space:
  Norm: 0.371094
  Direction: lat=-7.288°, lon=6.941°


## Compute Dot Products in Bfloat16

This is the key step: compute `W @ h` in bfloat16 to match the model's logit computation precision.

In [54]:
print("Computing dot products in bfloat16...\n")

# Convert h to bfloat16
h_bf16 = h.to(torch.bfloat16).to(device)

# Move W to device
W_bf16_device = W_bf16.to(device)

# Compute dot products in bfloat16
with torch.no_grad():
    dot_products_bf16 = W_bf16_device @ h_bf16

# Move to CPU for analysis
dot_products_bf16_cpu = dot_products_bf16.cpu()

print(f"✓ Computed dot products for {N:,} tokens")
print()
print(f"Dot product distribution (bfloat16):")
print(f"  Range: [{dot_products_bf16_cpu.min():.8f}, {dot_products_bf16_cpu.max():.8f}]")
print(f"  Mean: {dot_products_bf16_cpu.mean():.8f}")
print(f"  Median: {dot_products_bf16_cpu.median():.8f}")

Computing dot products in bfloat16...

✓ Computed dot products for 151,936 tokens

Dot product distribution (bfloat16):
  Range: [-0.28320312, 0.25781250]
  Mean: 0.09179688
  Median: 0.09912109


## Find Logit Equivalence Classes

Find tokens with **identical dot product values** at bfloat16 precision.

In [55]:
print("Identifying logit equivalence classes...\n")

# Find the mode (most common dot product value)
unique_dots, counts = torch.unique(dot_products_bf16_cpu, return_counts=True)
mode_idx = counts.argmax()
mode_dot = unique_dots[mode_idx].item()
mode_count = counts[mode_idx].item()

print(f"Mode dot product: {mode_dot:.8f}")
print(f"Mode count: {mode_count:,} tokens")
print()

# Find all tokens with this dot product
equiv_mask = (dot_products_bf16_cpu == mode_dot)
equiv_indices = equiv_mask.nonzero(as_tuple=True)[0]
n_equiv = equiv_indices.numel()

print(f"✓ Found {n_equiv:,} tokens with dot product = {mode_dot:.8f}")
print(f"  ({n_equiv / N * 100:.2f}% of vocabulary)")
print()
print(f"Token IDs (first 20): {equiv_indices[:20].tolist()}")
print()

# Also report maximum dot product for comparison
max_dot = dot_products_bf16_cpu.max().item()
max_count = (dot_products_bf16_cpu == max_dot).sum().item()
print(f"For comparison:")
print(f"  Maximum dot product: {max_dot:.8f} ({max_count:,} tokens)")

Identifying logit equivalence classes...

Mode dot product: 0.11621094
Mode count: 3,281 tokens

✓ Found 3,281 tokens with dot product = 0.11621094
  (2.16% of vocabulary)

Token IDs (first 20): [124, 125, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 3402, 4902, 7392, 7550, 9083, 10092, 12813]

For comparison:
  Maximum dot product: 0.25781250 (1 tokens)


## Distribution of Unique Dot Products

How many unique dot product values exist at bfloat16 precision?

In [56]:
print("Analyzing unique dot product values...\n")

unique_dots, counts = torch.unique(dot_products_bf16_cpu, return_counts=True)
n_unique = unique_dots.numel()

print(f"Total unique dot product values: {n_unique:,}")
print(f"  ({n_unique / N * 100:.2f}% of vocabulary)")
print()

# Sort by count (descending)
sorted_indices = torch.argsort(counts, descending=True)
top_dots = unique_dots[sorted_indices[:10]]
top_counts = counts[sorted_indices[:10]]

print(f"Top 10 most common dot product values:")
for i in range(10):
    dot_val = top_dots[i].item()
    count = top_counts[i].item()
    print(f"  {i+1}. dot={dot_val:.8f}: {count:,} tokens")

Analyzing unique dot product values...

Total unique dot product values: 2,033
  (1.34% of vocabulary)

Top 10 most common dot product values:
  1. dot=0.11621094: 3,281 tokens
  2. dot=0.12597656: 1,410 tokens
  3. dot=0.12792969: 1,395 tokens
  4. dot=0.12695312: 1,325 tokens
  5. dot=0.12890625: 1,174 tokens
  6. dot=0.12988281: 1,174 tokens
  7. dot=0.11279297: 1,173 tokens
  8. dot=0.11132812: 1,145 tokens
  9. dot=0.13085938: 1,140 tokens
  10. dot=0.11083984: 1,126 tokens


## Check for Embedding-Level Duplicates

Do these tokens with identical dot products also have **identical embeddings**?

Or are some truly distinct embeddings that just happen to round to the same dot product?

In [57]:
print("Checking for embedding-level duplicates...\n")

# Extract equivalence class embeddings
W_equiv = W_bf16[equiv_indices]

# Find unique embeddings using torch (works with bfloat16)
# torch.unique with dim=0 finds unique rows (unique 2560D embeddings)
unique_embeddings = torch.unique(W_equiv, dim=0)
n_unique_embeddings = len(unique_embeddings)

print(f"Tokens in equivalence class: {n_equiv:,}")
print(f"Unique embeddings: {n_unique_embeddings:,}")
print()

if n_unique_embeddings < n_equiv:
    print(f"✓ Found duplicates: {n_equiv - n_unique_embeddings:,} tokens share embeddings with others")
    print(f"  Average degeneracy: {n_equiv / n_unique_embeddings:.1f} tokens per unique embedding")
else:
    print("✓ No embedding duplicates: all tokens have distinct embeddings")
    print("  These tokens are indistinguishable by DOT PRODUCT but not by EMBEDDING")
    print("  This is pure bfloat16 quantization!")

Checking for embedding-level duplicates...

Tokens in equivalence class: 3,281
Unique embeddings: 1,194

✓ Found duplicates: 2,087 tokens share embeddings with others
  Average degeneracy: 2.7 tokens per unique embedding


## Comparison with Cosine-Based Cluster

How does this logit-based equivalence class compare to the cosine-based cluster from 1.4a?

In [58]:
print("Comparing with cosine-based cluster (1.4a)...\n")

# Load cosine cluster
cluster_path = Path(f"../tensors/{MODEL_NAME}/1.4a_cluster_members.safetensors")
cluster_data = load_file(cluster_path)
cluster_mask = cluster_data["cluster_mask"].bool()
n_cluster = cluster_data["n_cluster_members"].item()

# Compare overlap
overlap_mask = equiv_mask & cluster_mask
n_overlap = overlap_mask.sum().item()

print(f"Cosine-based cluster (1.4a): {n_cluster:,} tokens")
print(f"Dot-product-based equivalence (1.5a): {n_equiv:,} tokens")
print(f"Overlap: {n_overlap:,} tokens")
print()

if n_overlap == n_equiv and n_overlap == n_cluster:
    print("✓ Perfect match: cosine and dot product identify the same tokens")
elif n_overlap == min(n_equiv, n_cluster):
    if n_equiv > n_cluster:
        print(f"✓ Dot-product class is LARGER: includes {n_equiv - n_cluster:,} additional tokens")
        print("  These tokens have different directions but same dot product (varying magnitudes)")
    else:
        print(f"✓ Cosine class is LARGER: includes {n_cluster - n_equiv:,} additional tokens")
        print("  These tokens have same direction but different magnitudes (different dot products)")
else:
    print(f"✓ Partial overlap: {n_overlap:,} tokens in both")
    print(f"  Cosine-only: {n_cluster - n_overlap:,} tokens")
    print(f"  Dot-product-only: {n_equiv - n_overlap:,} tokens")
    print("  These are different geometric structures!")

Comparing with cosine-based cluster (1.4a)...

Cosine-based cluster (1.4a): 2,251 tokens
Dot-product-based equivalence (1.5a): 3,281 tokens
Overlap: 2,227 tokens

✓ Partial overlap: 2,227 tokens in both
  Cosine-only: 24 tokens
  Dot-product-only: 1,054 tokens
  These are different geometric structures!


## Save Results

In [59]:
print("\nSaving results...\n")

# Save equivalence class membership
output_path = Path(f"../tensors/{MODEL_NAME}/1.5a_logit_equivalence.safetensors")
output_path.parent.mkdir(parents=True, exist_ok=True)

save_file({
    "equiv_mask": equiv_mask.to(torch.uint8),
    "equiv_token_ids": equiv_indices.to(torch.int32),
    "equiv_dot_product": torch.tensor([mode_dot], dtype=torch.float32),
    "n_equiv_members": torch.tensor([n_equiv], dtype=torch.int32),
    "ref_lat": torch.tensor([REF_LAT], dtype=torch.float32),
    "ref_lon": torch.tensor([REF_LON], dtype=torch.float32),
    "ref_dist": torch.tensor([ref_dist], dtype=torch.float32),
}, str(output_path))

print(f"✓ Saved logit equivalence class to {output_path}")
print()
print("Saved tensors:")
print(f"  equiv_mask: {equiv_mask.shape} - binary mask (1 = equivalence class member)")
print(f"  equiv_token_ids: {equiv_indices.shape} - indices of equivalence class members")
print(f"  equiv_dot_product: scalar - shared dot product value ({mode_dot:.8f})")
print(f"  n_equiv_members: scalar - count ({n_equiv:,})")
print(f"  ref_lat, ref_lon, ref_dist: scalars - simulated hidden state")


Saving results...

✓ Saved logit equivalence class to ../tensors/Qwen3-4B-Instruct-2507/1.5a_logit_equivalence.safetensors

Saved tensors:
  equiv_mask: torch.Size([151936]) - binary mask (1 = equivalence class member)
  equiv_token_ids: torch.Size([3281]) - indices of equivalence class members
  equiv_dot_product: scalar - shared dot product value (0.11621094)
  n_equiv_members: scalar - count (3,281)
  ref_lat, ref_lon, ref_dist: scalars - simulated hidden state


## Summary

This notebook identified tokens with **identical logit scores** for a given hidden state due to bfloat16 quantization.

**Key findings:**
- Hidden state: lat={REF_LAT:.3f}°, lon={REF_LON:.3f}°, dist={ref_dist:.3f}
- Equivalence class dot product: (see output above)
- Equivalence class size: (see output above)

These tokens are **indistinguishable by the model** during next-token prediction—they receive identical logit scores and thus identical probabilities.