# 1.5c: Cluster Logit Equivalence

This notebook identifies tokens with **identical logit scores** when pointing directly at the cluster centroid.

## The Question

We found 2,251 tokens with identical **cosine** in 1.4a. But cosine ignores magnitude—it only captures direction.

For **true logit equivalence**, we need identical **dot products**:
```
logit[token_i] = hidden_state @ W[i]
```

Tokens at the same distance along the same ray will have identical dot products (and thus identical logits). Tokens at different distances will have different logits, even if they're on the same ray.

## Approach

1. **Bootstrap:** Load cosine-based cluster from 1.4a (2,251 tokens)
2. **Compute centroid:** Mean of cluster embeddings in raw 2560D space (no PCA)
3. **Point at centroid:** Use centroid as simulated hidden state `h`
4. **Compute dot products:** `W @ h` in bfloat16 for all tokens
5. **Find equivalence class:** Tokens with identical dot products

This gives us the **logit-based cluster**—tokens truly indistinguishable when pointing directly at them.

## Parameters

In [1]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file
from pathlib import Path

## Device Detection

In [3]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load W

In [4]:
# Load W in bfloat16
tensor_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W_bf16 = load_file(tensor_path)["W"]

print(f"Loaded W from {tensor_path}")
print(f"  Shape: {W_bf16.shape}")
print(f"  Dtype: {W_bf16.dtype}")

N, d = W_bf16.shape
print(f"\nToken space: {N:,} tokens in {d:,} dimensions")

Loaded W from ../tensors/Qwen3-4B-Instruct-2507/W.safetensors
  Shape: torch.Size([151936, 2560])
  Dtype: torch.bfloat16

Token space: 151,936 tokens in 2,560 dimensions


## Load Cosine-Based Cluster (Bootstrap)

We'll use the cosine-based cluster from 1.4a as a starting point to compute the centroid.

In [5]:
# Load cluster membership from 1.4a
cluster_path = Path(f"../tensors/{MODEL_NAME}/1.4a_cluster_members.safetensors")
cluster_data = load_file(cluster_path)

cosine_cluster_token_ids = cluster_data["cluster_token_ids"].long()
n_cosine_cluster = cluster_data["n_cluster_members"].item()
cluster_cosine = cluster_data["cluster_cosine"].item()

print(f"Loaded cosine-based cluster (1.4a):")
print(f"  Cluster size: {n_cosine_cluster:,} tokens")
print(f"  Shared cosine: {cluster_cosine:.8f}")
print(f"  Token IDs (first 10): {cosine_cluster_token_ids[:10].tolist()}")

Loaded cosine-based cluster (1.4a):
  Cluster size: 2,251 tokens
  Shared cosine: 0.84375000
  Token IDs (first 10): [124, 125, 177, 178, 179, 180, 181, 182, 183, 184]


## Compute Cluster Centroid

Compute the mean of cluster embeddings in raw 2560D space (no PCA, no normalization).

In [6]:
print("\nComputing cluster centroid...\n")

# Extract cluster embeddings
W_cluster = W_bf16[cosine_cluster_token_ids]

# Compute centroid (mean in bfloat16)
centroid_bf16 = W_cluster.mean(dim=0)

print(f"Cluster centroid:")
print(f"  Shape: {centroid_bf16.shape}")
print(f"  Norm: {centroid_bf16.norm().item():.6f}")
print(f"  Dtype: {centroid_bf16.dtype}")
print()

# For comparison: individual cluster token norms
cluster_norms = W_cluster.norm(dim=1)
print(f"Cluster token norms:")
print(f"  Min: {cluster_norms.min().item():.6f}")
print(f"  Max: {cluster_norms.max().item():.6f}")
print(f"  Mean: {cluster_norms.mean().item():.6f}")
print(f"  Centroid norm: {centroid_bf16.norm().item():.6f}")


Computing cluster centroid...

Cluster centroid:
  Shape: torch.Size([2560])
  Norm: 0.371094
  Dtype: torch.bfloat16

Cluster token norms:
  Min: 0.359375
  Max: 0.373047
  Mean: 0.371094
  Centroid norm: 0.371094


## Compute Dot Products in Bfloat16

Point our simulated hidden state directly at the cluster centroid and compute dot products for all tokens.

In [7]:
print("\nComputing dot products in bfloat16...\n")

# Use centroid as hidden state
h_bf16 = centroid_bf16.to(device)

# Move W to device
W_bf16_device = W_bf16.to(device)

# Compute dot products in bfloat16
with torch.no_grad():
    dot_products_bf16 = W_bf16_device @ h_bf16

# Move to CPU for analysis
dot_products_bf16_cpu = dot_products_bf16.cpu()

print(f"✓ Computed dot products for {N:,} tokens")
print()
print(f"Dot product distribution (bfloat16):")
print(f"  Range: [{dot_products_bf16_cpu.min():.8f}, {dot_products_bf16_cpu.max():.8f}]")
print(f"  Mean: {dot_products_bf16_cpu.mean():.8f}")
print(f"  Median: {dot_products_bf16_cpu.median():.8f}")


Computing dot products in bfloat16...

✓ Computed dot products for 151,936 tokens

Dot product distribution (bfloat16):
  Range: [-0.25390625, 0.23242188]
  Mean: 0.10156250
  Median: 0.10693359


## Find Logit Equivalence Class

Find tokens with **identical dot product values** at bfloat16 precision.

In [8]:
print("\nIdentifying logit equivalence class...\n")

# Find the mode (most common dot product value)
unique_dots, counts = torch.unique(dot_products_bf16_cpu, return_counts=True)
mode_idx = counts.argmax()
mode_dot = unique_dots[mode_idx].item()
mode_count = counts[mode_idx].item()

print(f"Mode dot product: {mode_dot:.8f}")
print(f"Mode count: {mode_count:,} tokens")
print()

# Find all tokens with this dot product
logit_equiv_mask = (dot_products_bf16_cpu == mode_dot)
logit_equiv_indices = logit_equiv_mask.nonzero(as_tuple=True)[0]
n_logit_equiv = logit_equiv_indices.numel()

print(f"✓ Found {n_logit_equiv:,} tokens with dot product = {mode_dot:.8f}")
print(f"  ({n_logit_equiv / N * 100:.2f}% of vocabulary)")
print()
print(f"Token IDs (first 20): {logit_equiv_indices[:20].tolist()}")
print()

# Also report maximum dot product for comparison
max_dot = dot_products_bf16_cpu.max().item()
max_count = (dot_products_bf16_cpu == max_dot).sum().item()
print(f"For comparison:")
print(f"  Maximum dot product: {max_dot:.8f} ({max_count:,} tokens)")


Identifying logit equivalence class...

Mode dot product: 0.13769531
Mode count: 3,248 tokens

✓ Found 3,248 tokens with dot product = 0.13769531
  (2.14% of vocabulary)

Token IDs (first 20): [124, 125, 141, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 1680, 3864, 4732, 7267, 8054, 12370]

For comparison:
  Maximum dot product: 0.23242188 (1 tokens)


## Compare with Cosine-Based Cluster

How does the logit-based equivalence class compare to the cosine-based cluster?

In [9]:
print("\nComparing with cosine-based cluster (1.4a)...\n")

# Create mask for cosine cluster
cosine_cluster_mask = torch.zeros(N, dtype=torch.bool)
cosine_cluster_mask[cosine_cluster_token_ids] = True

# Compare overlap
overlap_mask = logit_equiv_mask & cosine_cluster_mask
n_overlap = overlap_mask.sum().item()

print(f"Cosine-based cluster (1.4a): {n_cosine_cluster:,} tokens")
print(f"Logit-based equivalence (1.5c): {n_logit_equiv:,} tokens")
print(f"Overlap: {n_overlap:,} tokens")
print()

if n_overlap == n_logit_equiv and n_overlap == n_cosine_cluster:
    print("✓ Perfect match: cosine and logit clusters are identical")
elif n_overlap == min(n_logit_equiv, n_cosine_cluster):
    if n_logit_equiv > n_cosine_cluster:
        print(f"✓ Logit class is LARGER: includes {n_logit_equiv - n_cosine_cluster:,} additional tokens")
        print(f"  Logit-only tokens: {(logit_equiv_mask & ~cosine_cluster_mask).sum().item():,}")
    else:
        print(f"✓ Cosine class is LARGER: includes {n_cosine_cluster - n_logit_equiv:,} additional tokens")
        print(f"  Cosine-only tokens: {(cosine_cluster_mask & ~logit_equiv_mask).sum().item():,}")
else:
    print(f"✓ Partial overlap")
    print(f"  Cosine-only: {(cosine_cluster_mask & ~logit_equiv_mask).sum().item():,} tokens")
    print(f"  Logit-only: {(logit_equiv_mask & ~cosine_cluster_mask).sum().item():,} tokens")
    print(f"  Both: {n_overlap:,} tokens")

print()
print(f"Overlap percentage:")
print(f"  Of cosine cluster in logit class: {n_overlap / n_cosine_cluster * 100:.1f}%")
print(f"  Of logit class in cosine cluster: {n_overlap / n_logit_equiv * 100:.1f}%")


Comparing with cosine-based cluster (1.4a)...

Cosine-based cluster (1.4a): 2,251 tokens
Logit-based equivalence (1.5c): 3,248 tokens
Overlap: 2,226 tokens

✓ Partial overlap
  Cosine-only: 25 tokens
  Logit-only: 1,022 tokens
  Both: 2,226 tokens

Overlap percentage:
  Of cosine cluster in logit class: 98.9%
  Of logit class in cosine cluster: 68.5%


## Check for Embedding-Level Duplicates

Do tokens in the logit equivalence class have **identical embeddings**, or just identical dot products?

In [10]:
print("\nChecking for embedding-level duplicates...\n")

# Extract logit equivalence class embeddings
W_logit_equiv = W_bf16[logit_equiv_indices]

# Find unique embeddings using torch (works with bfloat16)
unique_embeddings = torch.unique(W_logit_equiv, dim=0)
n_unique_embeddings = len(unique_embeddings)

print(f"Tokens in logit equivalence class: {n_logit_equiv:,}")
print(f"Unique embeddings: {n_unique_embeddings:,}")
print()

if n_unique_embeddings < n_logit_equiv:
    print(f"✓ Found duplicates: {n_logit_equiv - n_unique_embeddings:,} tokens share embeddings with others")
    print(f"  Average degeneracy: {n_logit_equiv / n_unique_embeddings:.1f} tokens per unique embedding")
else:
    print("✓ No embedding duplicates: all tokens have distinct embeddings")
    print("  These tokens are indistinguishable by DOT PRODUCT but not by EMBEDDING")
    print("  This is pure bfloat16 quantization!")


Checking for embedding-level duplicates...

Tokens in logit equivalence class: 3,248
Unique embeddings: 1,161

✓ Found duplicates: 2,087 tokens share embeddings with others
  Average degeneracy: 2.8 tokens per unique embedding


## Distribution of Unique Dot Products

In [11]:
print("\nAnalyzing unique dot product values...\n")

unique_dots, counts = torch.unique(dot_products_bf16_cpu, return_counts=True)
n_unique = unique_dots.numel()

print(f"Total unique dot product values: {n_unique:,}")
print(f"  ({n_unique / N * 100:.2f}% of vocabulary)")
print()

# Sort by count (descending)
sorted_indices = torch.argsort(counts, descending=True)
top_dots = unique_dots[sorted_indices[:10]]
top_counts = counts[sorted_indices[:10]]

print(f"Top 10 most common dot product values:")
for i in range(10):
    dot_val = top_dots[i].item()
    count = top_counts[i].item()
    mode_flag = "← MODE (LOGIT CLUSTER)" if abs(dot_val - mode_dot) < 1e-6 else ""
    print(f"  {i+1}. dot={dot_val:.8f}: {count:,} tokens {mode_flag}")


Analyzing unique dot product values...

Total unique dot product values: 1,576
  (1.04% of vocabulary)

Top 10 most common dot product values:
  1. dot=0.13769531: 3,248 tokens ← MODE (LOGIT CLUSTER)
  2. dot=0.12597656: 2,044 tokens 
  3. dot=0.12695312: 1,983 tokens 
  4. dot=0.12792969: 1,963 tokens 
  5. dot=0.12890625: 1,919 tokens 
  6. dot=0.12988281: 1,785 tokens 
  7. dot=0.13085938: 1,757 tokens 
  8. dot=0.12500000: 1,635 tokens 
  9. dot=0.13183594: 1,605 tokens 
  10. dot=0.13281250: 1,548 tokens 


## Save Results

In [12]:
print("\nSaving results...\n")

# Save logit equivalence class membership
output_path = Path(f"../tensors/{MODEL_NAME}/1.5c_cluster_logits.safetensors")
output_path.parent.mkdir(parents=True, exist_ok=True)

save_file({
    "logit_equiv_mask": logit_equiv_mask.to(torch.uint8),
    "logit_equiv_token_ids": logit_equiv_indices.to(torch.int32),
    "logit_equiv_dot_product": torch.tensor([mode_dot], dtype=torch.float32),
    "n_logit_equiv_members": torch.tensor([n_logit_equiv], dtype=torch.int32),
    "cluster_centroid": centroid_bf16,
}, str(output_path))

print(f"✓ Saved logit equivalence class to {output_path}")
print()
print("Saved tensors:")
print(f"  logit_equiv_mask: {logit_equiv_mask.shape} - binary mask (1 = logit cluster member)")
print(f"  logit_equiv_token_ids: {logit_equiv_indices.shape} - indices of logit cluster members")
print(f"  logit_equiv_dot_product: scalar - shared dot product value ({mode_dot:.8f})")
print(f"  n_logit_equiv_members: scalar - count ({n_logit_equiv:,})")
print(f"  cluster_centroid: {centroid_bf16.shape} - cluster centroid in 2560D space")


Saving results...

✓ Saved logit equivalence class to ../tensors/Qwen3-4B-Instruct-2507/1.5c_cluster_logits.safetensors

Saved tensors:
  logit_equiv_mask: torch.Size([151936]) - binary mask (1 = logit cluster member)
  logit_equiv_token_ids: torch.Size([3248]) - indices of logit cluster members
  logit_equiv_dot_product: scalar - shared dot product value (0.13769531)
  n_logit_equiv_members: scalar - count (3,248)
  cluster_centroid: torch.Size([2560]) - cluster centroid in 2560D space


## Summary

This notebook identified tokens with **identical logit scores** when pointing directly at the cluster centroid.

**Key findings:**
- Bootstrapped from cosine-based cluster (1.4a) with {n_cosine_cluster:,} tokens
- Computed cluster centroid in raw 2560D space
- Found logit equivalence class: (see output above)
- Overlap with cosine cluster: (see output above)

These tokens receive **identical logit scores** when the hidden state points at the cluster centroid—they are truly indistinguishable by dot product in bfloat16.