# 1.9b: Core Definition (bfloat16 corrected)

**The Error:** In 1.8a, we loaded W and immediately converted to float32. This merged the 13 real black holes into 4 apparent ones.

**The Fix:** Keep W in **native bfloat16** throughout. Only convert to float32 for operations that need precision (like centering), but do all equality/grouping in bfloat16.

**Expected Results:**
- 13 unique vectors (not 5)
- Populations matching 1.9a: 814, 704, 306, 228, 11, 10, 6, 5, 4, 4, 3, 3, 2
- Total: 2,100 degenerate tokens

This is the **corrected** version of 1.8a.

## Parameters

In [1]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

# Magic radius that defines the core (from 1.8a)
MAGIC_RADIUS = 0.00007553
RADIUS_TOLERANCE = 1e-7

# Visualization
DPI = 200
N_BINS = 100

## Imports

In [2]:
import torch
import ml_dtypes
import numpy as np
import matplotlib.pyplot as plt
from safetensors.torch import load_file, save_file
from pathlib import Path
from collections import defaultdict

## Helper Functions

In [3]:
def torch_bf16_to_numpy_bf16(tensor):
    """Convert PyTorch bfloat16 tensor to numpy array with ml_dtypes.bfloat16 dtype."""
    return tensor.cpu().view(torch.uint16).numpy().view(ml_dtypes.bfloat16)

## Load Data

In [4]:
# Load W in NATIVE bfloat16 (DO NOT CONVERT!)
W_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W_bf16 = load_file(W_path)["W"]

print(f"Loaded W from {W_path}")
print(f"  Shape: {W_bf16.shape}")
print(f"  Dtype: {W_bf16.dtype} ← NATIVE bfloat16 (not converted!)")

# Load cluster data
cluster_path = Path(f"../tensors/{MODEL_NAME}/1.6a_cluster_mask.safetensors")
cluster_data = load_file(cluster_path)
cluster_mask = cluster_data["cluster_mask"].to(torch.bool)
cluster_token_ids = cluster_data["cluster_token_ids"].to(torch.int64)

print(f"\nLoaded cluster from {cluster_path}")
print(f"  Cluster size: {cluster_mask.sum().item():,} tokens")

# Load cluster spherical coords
spherical_path = Path(f"../tensors/{MODEL_NAME}/1.7c_cluster_spherical.safetensors")
spherical_data = load_file(spherical_path)
r_cluster = spherical_data["r"]

print(f"\nLoaded cluster spherical coords from {spherical_path}")

Loaded W from ../tensors/Qwen3-4B-Instruct-2507/W.safetensors
  Shape: torch.Size([151936, 2560])
  Dtype: torch.bfloat16 ← NATIVE bfloat16 (not converted!)

Loaded cluster from ../tensors/Qwen3-4B-Instruct-2507/1.6a_cluster_mask.safetensors
  Cluster size: 2,248 tokens

Loaded cluster spherical coords from ../tensors/Qwen3-4B-Instruct-2507/1.7c_cluster_spherical.safetensors


## Define the Core

In [5]:
print("\nDefining the core...\n")

# Core = tokens at magic radius within cluster
core_mask_within_cluster = torch.abs(r_cluster - MAGIC_RADIUS) < RADIUS_TOLERANCE
n_core = core_mask_within_cluster.sum().item()

# Get core token IDs
core_token_ids = cluster_token_ids[core_mask_within_cluster]

print(f"Core definition:")
print(f"  Radius: {MAGIC_RADIUS} ± {RADIUS_TOLERANCE}")
print(f"  Core size: {n_core:,} tokens")
print(f"  Percentage of cluster: {n_core/len(cluster_token_ids)*100:.1f}%")
print(f"\nCore token ID range:")
print(f"  Min: {core_token_ids.min().item():,}")
print(f"  Max: {core_token_ids.max().item():,}")


Defining the core...

Core definition:
  Radius: 7.553e-05 ± 1e-07
  Core size: 2,179 tokens
  Percentage of cluster: 96.9%

Core token ID range:
  Min: 124
  Max: 151,935


## Extract Core Embeddings (bfloat16)

In [6]:
print("\nExtracting core embeddings in bfloat16...\n")

# Create global mask (in full vocabulary space)
core_mask_global = torch.zeros(W_bf16.shape[0], dtype=torch.bool)
core_mask_global[core_token_ids] = True

# Extract embeddings IN BFLOAT16
W_core_bf16 = W_bf16[core_mask_global]

print(f"✓ Extracted {W_core_bf16.shape[0]:,} core embeddings")
print(f"  Dimensionality: {W_core_bf16.shape[1]:,}")
print(f"  Dtype: {W_core_bf16.dtype} ← Still bfloat16!")


Extracting core embeddings in bfloat16...

✓ Extracted 2,179 core embeddings
  Dimensionality: 2,560
  Dtype: torch.bfloat16 ← Still bfloat16!


## Find Black Holes via Exact Equality (bfloat16)

In [7]:
print("\nGrouping core tokens by exact vector equality (bfloat16)...\n")

# Convert to numpy bfloat16 for hashing
W_core_np_bf16 = torch_bf16_to_numpy_bf16(W_core_bf16)

# Group by vector
vector_groups = defaultdict(list)
for i in range(len(W_core_np_bf16)):
    vector = W_core_np_bf16[i]
    vector_key = tuple(vector)  # Hashable
    vector_groups[vector_key].append(i)  # Local index within core

n_unique = len(vector_groups)

print(f"Unique vectors in core: {n_unique}")
print(f"Total core tokens: {n_core}")
print(f"Degenerate tokens: {n_core - n_unique}")

# Find black holes (groups with >1 token)
black_holes = [(vector_key, indices) for vector_key, indices in vector_groups.items() 
               if len(indices) > 1]
black_holes.sort(key=lambda x: len(x[1]), reverse=True)

print(f"\nBlack holes found: {len(black_holes)}")
print(f"\nTop 20 black holes:")
print("Rank  Population  Sample Local Indices")
print("-" * 60)
for i, (vector_key, indices) in enumerate(black_holes[:20], 1):
    sample = indices[:5]
    sample_str = ", ".join(str(idx) for idx in sample)
    if len(indices) > 5:
        sample_str += ", ..."
    print(f"{i:4d}  {len(indices):10,}  {sample_str}")

print(f"\n✓ Black hole grouping complete")


Grouping core tokens by exact vector equality (bfloat16)...

Unique vectors in core: 92
Total core tokens: 2179
Degenerate tokens: 2087

Black holes found: 13

Top 20 black holes:
Rank  Population  Sample Local Indices
------------------------------------------------------------
   1         814  14, 16, 18, 20, 21, ...
   2         704  1, 2, 3, 4, 6, ...
   3         306  0, 24, 26, 33, 38, ...
   4         228  51, 85, 136, 160, 264, ...
   5          11  25, 668, 670, 885, 886, ...
   6          10  19, 127, 276, 1121, 1147, ...
   7           6  244, 724, 1134, 1695, 1742, ...
   8           5  698, 700, 1234, 1545, 1546
   9           4  940, 1188, 1385, 1627
  10           4  1037, 1141, 1142, 1256
  11           3  5, 1224, 1448
  12           3  295, 1340, 1710
  13           2  300, 1739

✓ Black hole grouping complete


## Verification Against 1.9a Global Results

In [8]:
print("\n" + "=" * 80)
print("VERIFICATION: COMPARING TO 1.9a GLOBAL RESULTS")
print("=" * 80)
print()

# Expected from 1.9a
expected_n_bh = 13
expected_populations = [814, 704, 306, 228, 11, 10, 6, 5, 4, 4, 3, 3, 2]
expected_total = 2100

print("Expected (from 1.9a global check):")
print(f"  Black holes: {expected_n_bh}")
print(f"  Populations: {expected_populations}")
print(f"  Total degenerate: {expected_total:,}")
print()

# Actual from this notebook
actual_n_bh = len(black_holes)
actual_populations = [len(indices) for _, indices in black_holes]
actual_total = sum(actual_populations)

print("Actual (from bfloat16 core analysis):")
print(f"  Black holes: {actual_n_bh}")
print(f"  Populations: {actual_populations}")
print(f"  Total degenerate: {actual_total:,}")
print()

# Check
if actual_n_bh == expected_n_bh:
    print(f"✓ Number of black holes MATCHES ({actual_n_bh})")
else:
    print(f"✗ Number MISMATCH (expected {expected_n_bh}, found {actual_n_bh})")

if actual_populations == expected_populations:
    print(f"✓ Populations EXACTLY MATCH")
else:
    print(f"⚠ Populations differ (might be ordering or subset)")

if actual_total == expected_total:
    print(f"✓ Total degenerate MATCHES ({actual_total:,})")
else:
    print(f"✗ Total MISMATCH (expected {expected_total:,}, found {actual_total:,})")

print()
if actual_n_bh == expected_n_bh and actual_total == expected_total:
    print("=" * 80)
    print("VERDICT: ✓✓✓ CORRECTION SUCCESSFUL ✓✓✓")
    print("=" * 80)
    print()
    print("The bfloat16-native analysis correctly identifies all 13 black holes.")
    print("The 1.8a error has been corrected.")
else:
    print("=" * 80)
    print("VERDICT: ⚠ STILL DISCREPANCY ⚠")
    print("=" * 80)

print()
print("=" * 80)


VERIFICATION: COMPARING TO 1.9a GLOBAL RESULTS

Expected (from 1.9a global check):
  Black holes: 13
  Populations: [814, 704, 306, 228, 11, 10, 6, 5, 4, 4, 3, 3, 2]
  Total degenerate: 2,100

Actual (from bfloat16 core analysis):
  Black holes: 13
  Populations: [814, 704, 306, 228, 11, 10, 6, 5, 4, 4, 3, 3, 2]
  Total degenerate: 2,100

✓ Number of black holes MATCHES (13)
✓ Populations EXACTLY MATCH
✓ Total degenerate MATCHES (2,100)

VERDICT: ✓✓✓ CORRECTION SUCCESSFUL ✓✓✓

The bfloat16-native analysis correctly identifies all 13 black holes.
The 1.8a error has been corrected.



## Compute Centered Coordinates (for geometry, not grouping)

In [9]:
print("\nComputing centered coordinates for geometric analysis...\n")

# NOW we can convert to float32 for centering (for geometry only, not grouping!)
W_core_f32 = W_core_bf16.to(torch.float32)
core_centroid_f32 = W_core_f32.mean(dim=0)
W_core_centered_f32 = W_core_f32 - core_centroid_f32

# Verify centering
mean_norm = W_core_centered_f32.mean(dim=0).norm().item()

print(f"✓ Computed centered coordinates in float32 (for geometry)")
print(f"  Mean of centered core: {mean_norm:.2e} (should be ~0)")
print()
print("NOTE: Centering is ONLY for geometric analysis (PCA, spherical coords).")
print("      Black hole grouping was done in native bfloat16 before centering.")


Computing centered coordinates for geometric analysis...

✓ Computed centered coordinates in float32 (for geometry)
  Mean of centered core: 9.86e-10 (should be ~0)

NOTE: Centering is ONLY for geometric analysis (PCA, spherical coords).
      Black hole grouping was done in native bfloat16 before centering.


## Find High-Variance Dimensions

In [10]:
print("\nFinding high-variance dimensions...\n")

# Compute variance per dimension
variances = W_core_centered_f32.var(dim=0)

# Find top 3 dimensions
top_indices = variances.argsort(descending=True)[:3]

# Assign to spherical basis
north_idx = top_indices[1].item()
meridian_idx = top_indices[0].item()
equinox_idx = top_indices[2].item()

print(f"Top 3 dimensions by variance (in core):")
print(f"  Meridian (1st): dimension {meridian_idx}, variance = {variances[meridian_idx].item():.2e}")
print(f"  North (2nd):    dimension {north_idx}, variance = {variances[north_idx].item():.2e}")
print(f"  Equinox (3rd):  dimension {equinox_idx}, variance = {variances[equinox_idx].item():.2e}")

print(f"\n✓ Identified basis dimensions")


Finding high-variance dimensions...

Top 3 dimensions by variance (in core):
  Meridian (1st): dimension 322, variance = 2.33e-10
  North (2nd):    dimension 163, variance = 1.09e-10
  Equinox (3rd):  dimension 1564, variance = 5.77e-11

✓ Identified basis dimensions


## Project to Spherical Coordinates

In [11]:
print("\nProjecting to spherical coordinates...\n")

# Extract Cartesian coordinates
x = W_core_centered_f32[:, meridian_idx]
y = W_core_centered_f32[:, equinox_idx]
z = W_core_centered_f32[:, north_idx]

# Compute radius
r_core = torch.sqrt(x**2 + y**2 + z**2)

# Latitude
lat_rad = torch.asin(torch.clamp(z / (r_core + 1e-10), -1, 1))
lat_deg = torch.rad2deg(lat_rad)

# Longitude
lon_rad = torch.atan2(y, x)
lon_deg = torch.rad2deg(lon_rad)

print(f"✓ Computed spherical coordinates")
print(f"\nRadius statistics:")
print(f"  Min: {r_core.min().item():.2e}")
print(f"  Max: {r_core.max().item():.2e}")
print(f"  Median: {r_core.median().item():.2e}")
print(f"  Mean: {r_core.mean().item():.2e}")


Projecting to spherical coordinates...

✓ Computed spherical coordinates

Radius statistics:
  Min: 1.64e-05
  Max: 4.88e-04
  Median: 1.71e-05
  Mean: 1.72e-05


## Save Corrected Core Data

In [12]:
print("\nSaving corrected core data...\n")

output_path = Path(f"../tensors/{MODEL_NAME}/1.9b_core_bfloat16.safetensors")
output_path.parent.mkdir(parents=True, exist_ok=True)

# Create black hole labels (which BH does each token belong to?)
bh_labels = torch.full((n_core,), -1, dtype=torch.int64)  # -1 = unique
for bh_id, (vector_key, indices) in enumerate(black_holes):
    for local_idx in indices:
        bh_labels[local_idx] = bh_id

save_file({
    # Core definition
    "core_mask": core_mask_global.to(torch.uint8),
    "core_token_ids": core_token_ids.to(torch.int32),
    "n_core": torch.tensor([n_core], dtype=torch.int32),
    
    # Black hole assignments
    "bh_labels": bh_labels.to(torch.int16),  # Which BH each token belongs to
    "n_black_holes": torch.tensor([len(black_holes)], dtype=torch.int32),
    
    # Core centroid (float32, for geometry)
    "core_centroid": core_centroid_f32,
    
    # Spherical coordinates (centered at core centroid, float32)
    "r": r_core,
    "lat_deg": lat_deg,
    "lon_deg": lon_deg,
    
    # Basis indices
    "north_idx": torch.tensor([north_idx], dtype=torch.int32),
    "meridian_idx": torch.tensor([meridian_idx], dtype=torch.int32),
    "equinox_idx": torch.tensor([equinox_idx], dtype=torch.int32),
}, str(output_path))

print(f"✓ Saved to {output_path}")
print()
print("Saved tensors:")
print(f"  core_mask, core_token_ids, n_core")
print(f"  bh_labels: ({n_core},) - which black hole each token belongs to")
print(f"  n_black_holes: {len(black_holes)}")
print(f"  core_centroid, r, lat_deg, lon_deg (for geometry)")
print(f"  basis indices: {meridian_idx}, {north_idx}, {equinox_idx}")


Saving corrected core data...

✓ Saved to ../tensors/Qwen3-4B-Instruct-2507/1.9b_core_bfloat16.safetensors

Saved tensors:
  core_mask, core_token_ids, n_core
  bh_labels: (2179,) - which black hole each token belongs to
  n_black_holes: 13
  core_centroid, r, lat_deg, lon_deg (for geometry)
  basis indices: 322, 163, 1564
