# 1.5d: Cluster Members by Dot Product

This notebook identifies cluster membership using a simple criterion:

1. Pick an arbitrary token `h` from the cluster region
2. Compute `W @ h` in bfloat16 (all dot products)
3. Find all tokens `t` where `t @ h == h @ h` in bfloat16

These are the tokens indistinguishable from `h` when the hidden state points at `h`.

## Parameters

In [26]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

# PCA basis selection (from 1.3c)
NORTH_PC = 2      # North pole (+90° latitude)
MERIDIAN_PC = 1   # Prime meridian (0° longitude)
EQUINOX_PC = 3    # Equinox (+90° longitude)

# Telescope pointing (from 1.3c)
CENTER_LAT = -7.2888       # Latitude of center (degrees)
CENTER_LON = 6.9400        # Longitude of center (degrees)
ANGULAR_DIAMETER = 0.0010  # Field of view (degrees)

## Imports

In [27]:
import torch
from safetensors.torch import load_file, save_file
from pathlib import Path

## Device Detection

In [28]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load W

In [29]:
# Load W in bfloat16
tensor_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W_bf16 = load_file(tensor_path)["W"]

print(f"Loaded W from {tensor_path}")
print(f"  Shape: {W_bf16.shape}")
print(f"  Dtype: {W_bf16.dtype}")

# Also keep float32 version for PCA/spherical coords
W_f32 = W_bf16.to(torch.float32)

N, d = W_bf16.shape
print(f"\nToken space: {N:,} tokens in {d:,} dimensions")

Loaded W from ../tensors/Qwen3-4B-Instruct-2507/W.safetensors
  Shape: torch.Size([151936, 2560])
  Dtype: torch.bfloat16

Token space: 151,936 tokens in 2,560 dimensions


## Compute PCA

In [30]:
print("Computing PCA...\n")

# Center the data
W_centered = W_f32 - W_f32.mean(dim=0)

# Compute covariance matrix
cov = (W_centered.T @ W_centered) / N

# Eigendecomposition
eigenvalues, eigenvectors = torch.linalg.eigh(cov)

# Sort by descending eigenvalue
idx = torch.argsort(eigenvalues, descending=True)
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]

print("✓ PCA computed")

Computing PCA...

✓ PCA computed


## Define Spherical Basis

In [31]:
def get_pc_vector(pcs, index):
    """Get PC vector by index (1-indexed), with sign flip for negative indices."""
    pc_num = abs(index) - 1
    vector = pcs[:, pc_num].clone()
    if index < 0:
        vector = -vector
    return vector


# Extract basis vectors
north = get_pc_vector(eigenvectors, NORTH_PC)
meridian = get_pc_vector(eigenvectors, MERIDIAN_PC)
equinox = get_pc_vector(eigenvectors, EQUINOX_PC)

## Project to Spherical Coordinates

In [32]:
print("Projecting to spherical coordinates...\n")

# Project onto basis vectors
x = W_f32 @ meridian
y = W_f32 @ equinox
z = W_f32 @ north

# Compute radius
r = torch.sqrt(x**2 + y**2 + z**2)

# Spherical coordinates
lat_rad = torch.asin(torch.clamp(z / r, -1, 1))
lat_deg = torch.rad2deg(lat_rad)

lon_rad = torch.atan2(y, x)
lon_deg = torch.rad2deg(lon_rad)

print("✓ Spherical coordinates computed")

Projecting to spherical coordinates...

✓ Spherical coordinates computed


## Spatial Filter: Find Reference Token

In [33]:
print(f"\nApplying spatial filter...")
print(f"  Center: ({CENTER_LAT:.4f}°, {CENTER_LON:.4f}°)")
print(f"  Angular diameter: {ANGULAR_DIAMETER:.4f}°")
print()

# Define region bounds
half_width = ANGULAR_DIAMETER / 2
lat_min = CENTER_LAT - half_width
lat_max = CENTER_LAT + half_width
lon_min = CENTER_LON - half_width
lon_max = CENTER_LON + half_width

# Filter tokens
spatial_mask = (
    (lat_deg >= lat_min) & (lat_deg <= lat_max) &
    (lon_deg >= lon_min) & (lon_deg <= lon_max)
)

n_spatial = spatial_mask.sum().item()
spatial_token_ids = spatial_mask.nonzero(as_tuple=True)[0]

print(f"✓ Found {n_spatial:,} tokens in spatial region")

# Pick arbitrary token as reference
ref_token_id = spatial_token_ids[0].item()
print(f"\n✓ Selected token {ref_token_id} as reference h")


Applying spatial filter...
  Center: (-7.2888°, 6.9400°)
  Angular diameter: 0.0010°

✓ Found 2,176 tokens in spatial region

✓ Selected token 124 as reference h


## Compute W @ h in bfloat16

In [34]:
print("\nComputing W @ h in bfloat16...\n")

# Move to device
W_bf16_device = W_bf16.to(device)

# Get reference vector
h = W_bf16_device[ref_token_id]

# Compute all dot products in bfloat16
with torch.no_grad():
    dots = W_bf16_device @ h

# Move to CPU
dots_cpu = dots.cpu()

# Get h @ h (squared norm of h)
h_dot_h = dots_cpu[ref_token_id].item()

print(f"✓ Computed all dot products")
print(f"  h @ h = {h_dot_h:.8f}")


Computing W @ h in bfloat16...

✓ Computed all dot products
  h @ h = 0.13769531


## Find Cluster: t @ h == h @ h

In [35]:
print("\nFinding cluster members...\n")

# Find all tokens where t @ h == h @ h in bfloat16
cluster_mask = (dots_cpu == h_dot_h)

n_cluster = cluster_mask.sum().item()
cluster_token_ids = cluster_mask.nonzero(as_tuple=True)[0]

print(f"✓ Found {n_cluster:,} cluster members")
print(f"  ({n_cluster/N*100:.3f}% of vocabulary)")
print()
print(f"First 10 cluster token IDs: {cluster_token_ids[:10].tolist()}")


Finding cluster members...

✓ Found 3,245 cluster members
  (2.136% of vocabulary)

First 10 cluster token IDs: [124, 125, 141, 177, 178, 179, 180, 181, 182, 183]


## Save Results

In [36]:
print("\nSaving results...\n")

# Convert mask to uint8 for storage
cluster_mask_uint8 = cluster_mask.to(torch.uint8)

# Save to safetensors
output_path = Path(f"../tensors/{MODEL_NAME}/1.5d_cluster_mask.safetensors")
output_path.parent.mkdir(parents=True, exist_ok=True)

save_file({
    "cluster_mask": cluster_mask_uint8,
    "cluster_token_ids": cluster_token_ids.to(torch.int32),
    "ref_token_id": torch.tensor([ref_token_id], dtype=torch.int32),
    "n_cluster": torch.tensor([n_cluster], dtype=torch.int32),
}, str(output_path))

print(f"✓ Saved to {output_path}")
print()
print("Saved tensors:")
print(f"  cluster_mask: ({N},) - binary mask")
print(f"  cluster_token_ids: ({n_cluster},) - token IDs")
print(f"  ref_token_id: scalar - {ref_token_id}")
print(f"  n_cluster: scalar - {n_cluster:,}")


Saving results...

✓ Saved to ../tensors/Qwen3-4B-Instruct-2507/1.5d_cluster_mask.safetensors

Saved tensors:
  cluster_mask: (151936,) - binary mask
  cluster_token_ids: (3245,) - token IDs
  ref_token_id: scalar - 124
  n_cluster: scalar - 3,245
