# 1.4a: Cluster Membership by Cosine Resolution

This notebook identifies all tokens that belong to the cluster by finding tokens with identical cosine similarity (in bfloat16) to a reference direction.

## The Question

The telescope view (1.3c) showed ~2,208 tokens packed into a 0.01° field of view where bfloat16 cannot distinguish them. But what's the *exact* count?

We define cluster membership as: **all tokens with identical bfloat16 cosine similarity to the reference direction**.

This is the "circle of confusion" at bfloat16 precision—tokens that look identical from the origin when looking toward the cluster.

## Approach

1. Define reference direction from lat/long (using PCA basis from 1.3c)
2. Compute cosine similarity from all tokens to reference direction (in bfloat16)
3. Find tokens with the same bfloat16 cosine as the reference
4. Count and characterize cluster members

## Parameters

In [11]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

# PCA basis (same as 1.3c)
NORTH_PC = 2      # North pole (+90° latitude)
MERIDIAN_PC = 1   # Prime meridian (0° longitude)
EQUINOX_PC = 3    # Equinox (+90° longitude)

# Reference direction (from telescope view)
REF_LAT = -7.288  # Latitude (degrees)
REF_LON = 6.941   # Longitude (degrees)

## Imports

In [12]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file
from pathlib import Path

## Device Detection

In [13]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load W

In [14]:
# Load W in bfloat16
tensor_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W_bf16 = load_file(tensor_path)["W"]

print(f"Loaded W from {tensor_path}")
print(f"  Shape: {W_bf16.shape}")
print(f"  Dtype: {W_bf16.dtype}")

N, d = W_bf16.shape
print(f"\nToken space: {N:,} tokens in {d:,} dimensions")

Loaded W from ../tensors/Qwen3-4B-Instruct-2507/W.safetensors
  Shape: torch.Size([151936, 2560])
  Dtype: torch.bfloat16

Token space: 151,936 tokens in 2,560 dimensions


## Compute PCA Basis

We need the same PCA basis used in the telescope view to convert lat/long to a direction vector.

In [15]:
print("Computing PCA...")

# Work in float32 for PCA (more stable)
W = W_bf16.to(torch.float32)

# Center the data
W_centered = W - W.mean(dim=0)

# Compute covariance matrix
print("  Computing covariance matrix...")
cov = (W_centered.T @ W_centered) / N

# Eigendecomposition
print("  Computing eigendecomposition...")
eigenvalues, eigenvectors = torch.linalg.eigh(cov)

# Sort by descending eigenvalue
idx = torch.argsort(eigenvalues, descending=True)
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]

print(f"\n✓ PCA computed")

Computing PCA...
  Computing covariance matrix...
  Computing eigendecomposition...

✓ PCA computed


## Define Reference Direction

In [16]:
def get_pc_vector(pcs, index):
    """Get PC vector by index (1-indexed), with sign flip for negative indices."""
    pc_num = abs(index) - 1
    vector = pcs[:, pc_num].clone()
    if index < 0:
        vector = -vector
    return vector


# Extract basis vectors
north = get_pc_vector(eigenvectors, NORTH_PC)
meridian = get_pc_vector(eigenvectors, MERIDIAN_PC)
equinox = get_pc_vector(eigenvectors, EQUINOX_PC)

print("Spherical coordinate basis:")
print(f"  North (+Z):    PC{NORTH_PC}")
print(f"  Meridian (+X): PC{MERIDIAN_PC}")
print(f"  Equinox (+Y):  PC{EQUINOX_PC}")
print()

# Convert lat/long to Cartesian coordinates on unit sphere
lat_rad = np.deg2rad(REF_LAT)
lon_rad = np.deg2rad(REF_LON)

# Spherical to Cartesian: (r=1, lat, lon) -> (x, y, z)
x = np.cos(lat_rad) * np.cos(lon_rad)  # Meridian component
y = np.cos(lat_rad) * np.sin(lon_rad)  # Equinox component
z = np.sin(lat_rad)                     # North component

print(f"Reference direction:")
print(f"  Lat: {REF_LAT:.3f}°, Lon: {REF_LON:.3f}°")
print(f"  Cartesian (on unit sphere): x={x:.6f}, y={y:.6f}, z={z:.6f}")
print()

# Convert to direction in full 2560D space
ref_direction = x * meridian + y * equinox + z * north
ref_direction = ref_direction / ref_direction.norm()  # Normalize

print(f"Reference direction in {d}D space:")
print(f"  Norm: {ref_direction.norm().item():.6f} (should be 1.0)")

Spherical coordinate basis:
  North (+Z):    PC2
  Meridian (+X): PC1
  Equinox (+Y):  PC3

Reference direction:
  Lat: -7.288°, Lon: 6.941°
  Cartesian (on unit sphere): x=0.984651, y=0.119871, z=-0.126857

Reference direction in 2560D space:
  Norm: 1.000000 (should be 1.0)


## Compute Cosine Similarities in Bfloat16

This is the key step: we compute cosines in bfloat16 to match the model's precision.

In [17]:
print("Computing cosine similarities in bfloat16...\n")

# Convert reference to bfloat16
ref_bf16 = ref_direction.to(torch.bfloat16).to(device)

# Normalize W in bfloat16
W_bf16_device = W_bf16.to(device)
W_norm_bf16 = W_bf16_device / W_bf16_device.norm(dim=1, keepdim=True)

# Compute cosines in bfloat16
with torch.no_grad():
    cosines_bf16 = W_norm_bf16 @ ref_bf16

# Move to CPU for analysis
cosines_bf16_cpu = cosines_bf16.cpu()

print(f"✓ Computed cosines for {N:,} tokens")
print()
print(f"Cosine distribution (bfloat16):")
print(f"  Range: [{cosines_bf16_cpu.min():.8f}, {cosines_bf16_cpu.max():.8f}]")
print(f"  Mean: {cosines_bf16_cpu.mean():.8f}")
print(f"  Median: {cosines_bf16_cpu.median():.8f}")

Computing cosine similarities in bfloat16...

✓ Computed cosines for 151,936 tokens

Cosine distribution (bfloat16):
  Range: [-0.68359375, 0.84765625]
  Mean: 0.24023438
  Median: 0.23632812


## Find Cluster Members

Cluster members are tokens with the **most common cosine value** (the mode of the distribution).

This gives us the largest equivalence class—the biggest group of tokens that are indistinguishable by cosine at bfloat16 precision.

In [18]:
print("Identifying cluster members...\n")

# Find the mode (most common cosine value)
unique_cosines, counts = torch.unique(cosines_bf16_cpu, return_counts=True)
mode_idx = counts.argmax()
mode_cosine = unique_cosines[mode_idx].item()
mode_count = counts[mode_idx].item()

print(f"Mode cosine: {mode_cosine:.8f}")
print(f"Mode count: {mode_count:,} tokens")
print()

# Find all tokens with this cosine
cluster_mask = (cosines_bf16_cpu == mode_cosine)
cluster_indices = cluster_mask.nonzero(as_tuple=True)[0]
n_cluster = cluster_indices.numel()

print(f"✓ Found {n_cluster:,} tokens with cosine = {mode_cosine:.8f}")
print(f"  ({n_cluster / N * 100:.2f}% of vocabulary)")
print()
print(f"Cluster token IDs (first 20): {cluster_indices[:20].tolist()}")
print()

# Also report maximum cosine for comparison
max_cosine = cosines_bf16_cpu.max().item()
max_count = (cosines_bf16_cpu == max_cosine).sum().item()
print(f"For comparison:")
print(f"  Maximum cosine: {max_cosine:.8f} ({max_count:,} tokens)")
print(f"  Our reference was slightly off-center from the cluster centroid")

Identifying cluster members...

Mode cosine: 0.84375000
Mode count: 2,251 tokens

✓ Found 2,251 tokens with cosine = 0.84375000
  (1.48% of vocabulary)

Cluster token IDs (first 20): [124, 125, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 27487, 57408, 77150, 78323, 79269, 79270, 80091]

For comparison:
  Maximum cosine: 0.84765625 (3 tokens)
  Our reference was slightly off-center from the cluster centroid


## Distribution of Unique Cosines

How many unique cosine values exist at bfloat16 precision?

In [19]:
print("Analyzing unique cosine values...\n")

unique_cosines, counts = torch.unique(cosines_bf16_cpu, return_counts=True)
n_unique = unique_cosines.numel()

print(f"Total unique cosine values: {n_unique:,}")
print()

# Sort by count (descending)
sorted_indices = torch.argsort(counts, descending=True)
top_cosines = unique_cosines[sorted_indices[:10]]
top_counts = counts[sorted_indices[:10]]

print(f"Top 10 most common cosine values:")
for i in range(10):
    cosine_val = top_cosines[i].item()
    count = top_counts[i].item()
    print(f"  {i+1}. cosine={cosine_val:.8f}: {count:,} tokens")

Analyzing unique cosine values...

Total unique cosine values: 2,113

Top 10 most common cosine values:
  1. cosine=0.84375000: 2,251 tokens
  2. cosine=0.26171875: 1,539 tokens
  3. cosine=0.26562500: 1,528 tokens
  4. cosine=0.25195312: 1,517 tokens
  5. cosine=0.25585938: 1,517 tokens
  6. cosine=0.26757812: 1,507 tokens
  7. cosine=0.25976562: 1,492 tokens
  8. cosine=0.26367188: 1,479 tokens
  9. cosine=0.25781250: 1,470 tokens
  10. cosine=0.25390625: 1,469 tokens


## Save Results

In [20]:
print("\nSaving results...\n")

# Save cluster membership
output_path = Path(f"../tensors/{MODEL_NAME}/1.4a_cluster_members.safetensors")
output_path.parent.mkdir(parents=True, exist_ok=True)

save_file({
    "cluster_mask": cluster_mask.to(torch.uint8),
    "cluster_token_ids": cluster_indices.to(torch.int32),
    "cluster_cosine": torch.tensor([mode_cosine], dtype=torch.float32),
    "n_cluster_members": torch.tensor([n_cluster], dtype=torch.int32),
    "ref_lat": torch.tensor([REF_LAT], dtype=torch.float32),
    "ref_lon": torch.tensor([REF_LON], dtype=torch.float32),
}, str(output_path))

print(f"✓ Saved cluster membership to {output_path}")
print()
print("Saved tensors:")
print(f"  cluster_mask: {cluster_mask.shape} - binary mask (1 = cluster member)")
print(f"  cluster_token_ids: {cluster_indices.shape} - indices of cluster members")
print(f"  cluster_cosine: scalar - shared cosine value ({mode_cosine:.8f})")
print(f"  n_cluster_members: scalar - count ({n_cluster:,})")
print(f"  ref_lat, ref_lon: scalars - reference direction")


Saving results...

✓ Saved cluster membership to ../tensors/Qwen3-4B-Instruct-2507/1.4a_cluster_members.safetensors

Saved tensors:
  cluster_mask: torch.Size([151936]) - binary mask (1 = cluster member)
  cluster_token_ids: torch.Size([2251]) - indices of cluster members
  cluster_cosine: scalar - shared cosine value (0.84375000)
  n_cluster_members: scalar - count (2,251)
  ref_lat, ref_lon: scalars - reference direction


## Summary

This notebook identified cluster members by finding all tokens with identical bfloat16 cosine similarity to the reference direction.

**Key findings:**
- Reference: lat={REF_LAT:.3f}°, lon={REF_LON:.3f}°
- Cluster cosine: (see output above)
- Cluster size: (see output above)

This is the "circle of confusion" at bfloat16 precision—tokens that are indistinguishable from the origin when looking toward this direction in the sky.