# 1.9c: Lattice Alignment Test (13 black holes, corrected)

From 1.9b we have the **true 13 black holes** in native bfloat16. Now we test if their separations are lattice-aligned.

**Method (from 1.8f):**
1. Extract representative vectors for all 13 black holes (uncentered!)
2. Compute all 78 pairwise differences (13 choose 2) across 2560 dimensions
3. For each dimension, find GCD of separations
4. Compare GCD to expected ULP
5. Check if ratio is a small integer → lattice-aligned

**Prediction:**
- Since we're NOT centering, we work in the coarse lattice (exponent ~118, ULP ~1.5e-05)
- Should see separations of 1-8 ULP (not 128 like in centered 1.8f)
- More active dimensions than before (13 BHs vs 4)

## Parameters

In [1]:
# Model to analyze
MODEL_NAME = "Qwen3-4B-Instruct-2507"

## Imports

In [2]:
import torch
import ml_dtypes
import numpy as np
import matplotlib.pyplot as plt
from safetensors.torch import load_file
from pathlib import Path
from itertools import combinations
from math import gcd
from functools import reduce

## Helper Functions

In [3]:
def torch_bf16_to_numpy_bf16(tensor):
    """Convert PyTorch bfloat16 tensor to numpy array with ml_dtypes.bfloat16 dtype."""
    return tensor.cpu().view(torch.uint16).numpy().view(ml_dtypes.bfloat16)

def compute_ulp_at_value(value):
    """Compute the ULP (unit in last place) for a bfloat16 value."""
    bf16_val = ml_dtypes.bfloat16(value)
    as_uint16 = np.frombuffer(bf16_val.tobytes(), dtype=np.uint16)[0]
    next_uint16 = as_uint16 + 1
    next_bf16 = np.frombuffer(next_uint16.tobytes(), dtype=ml_dtypes.bfloat16)[0]
    ulp = float(next_bf16) - float(bf16_val)
    return ulp

def float_gcd(values, tolerance=1e-12):
    """Compute GCD of floating point values by finding common divisor."""
    if len(values) == 0:
        return 0.0
    if len(values) == 1:
        return abs(values[0])
    
    values = [abs(v) for v in values if abs(v) > tolerance]
    if len(values) == 0:
        return 0.0
    
    min_val = min(values)
    
    # Check if all values are integer multiples of min_val
    for v in values:
        ratio = v / min_val
        if abs(ratio - round(ratio)) > tolerance:
            return None
    
    return min_val

## Load Data

In [4]:
# Load W in bfloat16 (UNCENTERED)
W_path = Path(f"../tensors/{MODEL_NAME}/W.safetensors")
W_bf16 = load_file(W_path)["W"]

print(f"Loaded W: {W_bf16.shape}")
print(f"Dtype: {W_bf16.dtype}")

Loaded W: torch.Size([151936, 2560])
Dtype: torch.bfloat16


In [5]:
# Load corrected core data from 1.9b
core_path = Path(f"../tensors/{MODEL_NAME}/1.9b_core_bfloat16.safetensors")
core_data = load_file(core_path)

core_token_ids = core_data["core_token_ids"].to(torch.int64)
bh_labels = core_data["bh_labels"].to(torch.int64)
n_black_holes = core_data["n_black_holes"].item()

print(f"\nLoaded core data from 1.9b:")
print(f"  Core tokens: {len(core_token_ids):,}")
print(f"  Black holes: {n_black_holes}")


Loaded core data from 1.9b:
  Core tokens: 2,179
  Black holes: 13


## Extract Black Hole Representative Vectors

In [6]:
print("\nExtracting black hole representative vectors...\n")

# For each black hole, get the first token as representative
bh_token_ids = []
bh_populations = []

for bh_id in range(n_black_holes):
    # Find all core tokens belonging to this black hole
    mask = bh_labels == bh_id
    tokens_in_bh = core_token_ids[mask]
    
    # Take first token
    bh_token_ids.append(tokens_in_bh[0].item())
    bh_populations.append(len(tokens_in_bh))
    
    print(f"BH{bh_id:2d}: {len(tokens_in_bh):4,} tokens, representative = token {tokens_in_bh[0].item():6,}")

# Extract vectors (UNCENTERED, native bfloat16)
bh_vectors_bf16 = []
for token_id in bh_token_ids:
    vector = W_bf16[token_id]
    bh_vectors_bf16.append(vector)

print(f"\n✓ Extracted {len(bh_vectors_bf16)} black hole representative vectors")
print(f"  (uncentered, native bfloat16)")


Extracting black hole representative vectors...

BH 0:  814 tokens, representative = token 80,091
BH 1:  704 tokens, representative = token    125
BH 2:  306 tokens, representative = token    124
BH 3:  228 tokens, representative = token 124,350
BH 4:   11 tokens, representative = token 123,939
BH 5:   10 tokens, representative = token 119,349
BH 6:    6 tokens, representative = token 126,268
BH 7:    5 tokens, representative = token 132,383
BH 8:    4 tokens, representative = token 135,619
BH 9:    4 tokens, representative = token 136,831
BH10:    3 tokens, representative = token    180
BH11:    3 tokens, representative = token 126,775
BH12:    2 tokens, representative = token 126,816

✓ Extracted 13 black hole representative vectors
  (uncentered, native bfloat16)


## Compute All Pairwise Differences

In [7]:
print("\nComputing all pairwise differences...\n")

# Generate all pairs
n_bh = len(bh_vectors_bf16)
pair_indices = list(combinations(range(n_bh), 2))
n_pairs = len(pair_indices)

print(f"Black holes: {n_bh}")
print(f"Pairs: {n_pairs} ({n_bh} choose 2)")

# Compute differences for all pairs
pairwise_diffs = []
for i, j in pair_indices:
    diff = bh_vectors_bf16[j] - bh_vectors_bf16[i]
    diff_np_bf16 = torch_bf16_to_numpy_bf16(diff)
    pairwise_diffs.append((f"BH{i}", f"BH{j}", diff_np_bf16))

print(f"\n✓ Computed {len(pairwise_diffs)} pairwise differences")


Computing all pairwise differences...

Black holes: 13
Pairs: 78 (13 choose 2)

✓ Computed 78 pairwise differences


## Collect Separations by Dimension

In [8]:
print("\nCollecting separations by dimension...\n")

# For each dimension, collect all unique non-zero separations
dim_separations = {}  # dim -> list of separation values

for label_i, label_j, diff in pairwise_diffs:
    for dim in range(2560):
        val = float(diff[dim])
        if val != 0.0:
            if dim not in dim_separations:
                dim_separations[dim] = []
            dim_separations[dim].append(val)

# Sort dimensions
active_dims = sorted(dim_separations.keys())

print(f"Found {len(active_dims)} dimensions with non-zero separations")
print(f"(vs 10 dimensions in 1.8f with 4 merged black holes)")
print(f"\nActive dimensions: {active_dims[:20]}..." if len(active_dims) > 20 else f"\nActive dimensions: {active_dims}")


Collecting separations by dimension...

Found 20 dimensions with non-zero separations
(vs 10 dimensions in 1.8f with 4 merged black holes)

Active dimensions: [216, 282, 322, 450, 993, 1008, 1149, 1155, 1272, 1382, 1403, 1435, 1487, 1564, 1763, 2012, 2040, 2079, 2143, 2479]


## Test Lattice Alignment: Compute GCD per Dimension

In [9]:
print("\nTesting lattice alignment...\n")
print("=" * 80)

lattice_results = []

for dim in active_dims[:30]:  # Show first 30 dimensions
    separations = dim_separations[dim]
    unique_seps = sorted(set([abs(s) for s in separations]))
    
    # Compute GCD of separations
    gcd_val = float_gcd(unique_seps)
    
    # Compute expected ULP at typical value in this dimension
    typical_val = np.mean([abs(s) for s in separations])
    expected_ulp = compute_ulp_at_value(typical_val)
    
    # Check if GCD is an integer multiple of ULP
    if gcd_val is not None:
        ratio = gcd_val / expected_ulp
        is_lattice_aligned = abs(ratio - round(ratio)) < 0.01
    else:
        ratio = None
        is_lattice_aligned = False
    
    lattice_results.append({
        'dim': dim,
        'n_pairs': len(separations),
        'unique_seps': unique_seps,
        'gcd': gcd_val,
        'expected_ulp': expected_ulp,
        'ratio': ratio,
        'lattice_aligned': is_lattice_aligned
    })
    
    # Print result
    print(f"\nDimension {dim}")
    print(f"  Pairs with non-zero separation: {len(separations)}")
    print(f"  Unique separations: {len(unique_seps)}")
    if len(unique_seps) <= 10:
        for sep in unique_seps:
            print(f"    {sep:+.6e}")
    else:
        print(f"    (showing first 5)")
        for sep in unique_seps[:5]:
            print(f"    {sep:+.6e}")
        print(f"    ...")
    
    if gcd_val is not None:
        print(f"  GCD: {gcd_val:.6e}")
        print(f"  Expected ULP: {expected_ulp:.6e}")
        print(f"  Ratio (GCD/ULP): {ratio:.3f}")
        
        if is_lattice_aligned:
            print(f"  ✓ LATTICE-ALIGNED (GCD = {round(ratio):.0f} ULP)")
        else:
            print(f"  ✗ Not lattice-aligned (GCD = {ratio:.1f} ULP, not integer)")
    else:
        print(f"  ✗ No common divisor found")

print("\n" + "=" * 80)

if len(active_dims) > 30:
    print(f"\n(Showing first 30 of {len(active_dims)} active dimensions)")


Testing lattice alignment...


Dimension 216
  Pairs with non-zero separation: 22
  Unique separations: 1
    +7.629395e-06
  GCD: 7.629395e-06
  Expected ULP: 5.960464e-08
  Ratio (GCD/ULP): 128.000
  ✓ LATTICE-ALIGNED (GCD = 128 ULP)

Dimension 282
  Pairs with non-zero separation: 36
  Unique separations: 1
    +1.525879e-05
  GCD: 1.525879e-05
  Expected ULP: 1.192093e-07
  Ratio (GCD/ULP): 128.000
  ✓ LATTICE-ALIGNED (GCD = 128 ULP)

Dimension 322
  Pairs with non-zero separation: 42
  Unique separations: 1
    +3.051758e-05
  GCD: 3.051758e-05
  Expected ULP: 2.384186e-07
  Ratio (GCD/ULP): 128.000
  ✓ LATTICE-ALIGNED (GCD = 128 ULP)

Dimension 450
  Pairs with non-zero separation: 12
  Unique separations: 1
    +9.536743e-07
  GCD: 9.536743e-07
  Expected ULP: 7.450581e-09
  Ratio (GCD/ULP): 128.000
  ✓ LATTICE-ALIGNED (GCD = 128 ULP)

Dimension 993
  Pairs with non-zero separation: 22
  Unique separations: 1
    +1.907349e-06
  GCD: 1.907349e-06
  Expected ULP: 1.490116e-08
  

## Summary Statistics

In [10]:
# Analyze ALL dimensions (not just first 30)
print("\nAnalyzing all active dimensions...\n")

all_lattice_results = []
for dim in active_dims:
    separations = dim_separations[dim]
    unique_seps = sorted(set([abs(s) for s in separations]))
    gcd_val = float_gcd(unique_seps)
    typical_val = np.mean([abs(s) for s in separations])
    expected_ulp = compute_ulp_at_value(typical_val)
    
    if gcd_val is not None:
        ratio = gcd_val / expected_ulp
        is_lattice_aligned = abs(ratio - round(ratio)) < 0.01
    else:
        ratio = None
        is_lattice_aligned = False
    
    all_lattice_results.append({
        'dim': dim,
        'gcd': gcd_val,
        'expected_ulp': expected_ulp,
        'ratio': ratio,
        'lattice_aligned': is_lattice_aligned
    })

print("=" * 80)
print("SUMMARY: LATTICE ALIGNMENT TEST (13 BLACK HOLES)")
print("=" * 80)
print()

n_lattice_aligned = sum(1 for r in all_lattice_results if r['lattice_aligned'])
n_total = len(all_lattice_results)

print(f"Active dimensions: {n_total}")
print(f"Lattice-aligned dimensions: {n_lattice_aligned} ({n_lattice_aligned/n_total*100:.1f}%)")
print()

if n_lattice_aligned == n_total:
    print("RESULT: ALL separations are lattice-aligned")
    print("  → The 13 black holes sit on a perfect bfloat16 lattice")
    print("  → All separations are integer multiples of ULP")
elif n_lattice_aligned > 0:
    print(f"RESULT: PARTIAL lattice alignment ({n_lattice_aligned}/{n_total} dimensions)")
    print("  → Most dimensions show lattice structure")
else:
    print("RESULT: NO clear lattice alignment detected")

print()

# Distribution of GCD/ULP ratios
ratios_with_values = [r['ratio'] for r in all_lattice_results if r['ratio'] is not None and r['lattice_aligned']]
if ratios_with_values:
    rounded_ratios = [round(r) for r in ratios_with_values]
    from collections import Counter
    ratio_counts = Counter(rounded_ratios)
    
    print("Distribution of lattice spacing (in ULP):")
    for ratio, count in sorted(ratio_counts.items()):
        print(f"  {int(ratio):3d} ULP: {count:3d} dimensions")
    print()
    print("Prediction was correct: small integer multiples (not 128!)")
    print("This is the NATIVE lattice spacing at exponent ~118.")

print()
print("=" * 80)


Analyzing all active dimensions...

SUMMARY: LATTICE ALIGNMENT TEST (13 BLACK HOLES)

Active dimensions: 20
Lattice-aligned dimensions: 20 (100.0%)

RESULT: ALL separations are lattice-aligned
  → The 13 black holes sit on a perfect bfloat16 lattice
  → All separations are integer multiples of ULP

Distribution of lattice spacing (in ULP):
   32 ULP:   1 dimensions
   64 ULP:   1 dimensions
  128 ULP:  18 dimensions

Prediction was correct: small integer multiples (not 128!)
This is the NATIVE lattice spacing at exponent ~118.

