# 1.13a: Lattice Structure Search

**Goal:** Find black holes and orthogonally-adjacent lattice neighbors.

## Method

### Stage 1: Black Hole Detection
Use `torch.unique()` to find duplicate vectors.

### Stage 2: Deduplication
Keep one representative per black hole centroid to reduce search space.

### Stage 3: Isolated Token Exclusion (ε-sphere filter)
For each token, find neighbors within ε = ULP × √D (worst-case diagonal distance).
Exclude tokens with no neighbors.

### Stage 4: Orthogonal Neighbor Detection
For candidate pairs within ε:
1. **Geometric filter**: L∞ = L1 (exactly one dimension differs)
2. **ULP distance check**: Verify distance ≈ 1 ULP
3. **Bit-level verification**: Same sign, same exponent, mantissa differs by 1

## Design Goals

- **Memory-efficient**: No massive allocations, batch processing
- **Hardware accelerated**: Explicit device management for GPU/MPS
- **Scalable**: Works for Qwen (151k tokens × 2560D)
- **Correct**: Finds nothing in random Gaussian, detects spongecrystal in Qwen

## Parameters

In [24]:
# Tensor to analyze
TENSOR_FILE = "../tensors/Qwen3-4B-Instruct-2507/W.safetensors"
TENSOR_KEY = "W"
TENSOR_INDEX = None  # None = load full tensor

# Stage 3 parameters
BATCH_SIZE = 100  # For distance computations (100 × 150k × 4 bytes ≈ 60 MB per batch)

# Stage 4 parameters
ULP_TOLERANCE = 0.01  # Allow 1% tolerance when checking ULP distances
MAX_EXAMPLES = 10  # How many example tokens to show in output

## Imports

In [33]:
import torch
import numpy as np
import math
import ml_dtypes
from safetensors.torch import load_file
from pathlib import Path
from collections import Counter
from tqdm import tqdm

## Device Detection

In [26]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Data

In [27]:
# Load tensor (safetensors loads to CPU by default)
data = load_file(TENSOR_FILE)
W = data[TENSOR_KEY]

# Apply indexing if specified
if TENSOR_INDEX is not None:
    W = W[TENSOR_INDEX]

# Move to device for hardware acceleration
W = W.to(device)

n_vectors, n_dims = W.shape

print(f"✓ Loaded W from {Path(TENSOR_FILE).name}")
print(f"  Shape: {W.shape}")
print(f"  Dtype: {W.dtype}")
print(f"  Device: {W.device}")
print(f"  Memory: ~{W.element_size() * W.numel() / 1024**3:.2f} GB")
print()
print(f"Analyzing {n_vectors:,} vectors in {n_dims:,} dimensions")

✓ Loaded W from W.safetensors
  Shape: torch.Size([151936, 2560])
  Dtype: torch.bfloat16
  Device: mps:0
  Memory: ~0.72 GB

Analyzing 151,936 vectors in 2,560 dimensions


## Helper Functions

In [35]:
def get_max_exponent(vector_bf16):
    """
    Get the maximum exponent across all dimensions of a bfloat16 vector.
    
    Args:
        vector_bf16: (D,) tensor of bfloat16 values (on any device)
    
    Returns:
        int: maximum exponent value
    """
    values_uint16 = vector_bf16.view(torch.int16).to(torch.int64) & 0xFFFF
    exponents = (values_uint16 >> 7) & 0xFF
    return exponents.max().item()

def decode_bfloat16_bits(value_bf16):
    """
    Decode a single bfloat16 value into its bit components.
    
    Args:
        value_bf16: scalar torch.bfloat16 value (on any device)
    
    Returns:
        dict with 'sign', 'exponent', 'mantissa' as integers
    """
    # Convert to bytes and interpret as uint16
    bits_uint16 = value_bf16.view(torch.int16).cpu().numpy().astype(np.uint16).item()
    bits_binary = format(bits_uint16, '016b')
    
    sign_bit = bits_binary[0]
    exponent_bits = bits_binary[1:9]
    mantissa_bits = bits_binary[9:16]
    
    sign = int(sign_bit)
    exponent = int(exponent_bits, 2)
    mantissa = int(mantissa_bits, 2)
    
    return {
        'bits_uint16': bits_uint16,
        'bits_binary': bits_binary,
        'sign': sign,
        'exponent': exponent,
        'mantissa': mantissa,
        'sign_bit': sign_bit,
        'exponent_bits': exponent_bits,
        'mantissa_bits': mantissa_bits
    }

def compute_ulp_at_exponent(exponent):
    """
    Compute ULP (unit in last place) for bfloat16 at given exponent.
    
    bfloat16 has 7 mantissa bits, so ULP = 2^(exponent - 127 - 7) = 2^(exponent - 134)
    """
    return 2.0 ** (exponent - 134)

print("✓ Helper functions defined")

✓ Helper functions defined


## Stage 1: Black Hole Detection

In [29]:
print("\n" + "=" * 80)
print("STAGE 1: BLACK HOLE DETECTION")
print("=" * 80)
print()

# torch.unique not implemented on MPS in Torch 2.8, use CPU
print("Finding unique vectors...")
W_cpu = W.cpu()
W_unique, inverse_indices, counts = torch.unique(W_cpu, dim=0, return_inverse=True, return_counts=True)

n_unique = len(W_unique)
n_duplicates = n_vectors - n_unique

print(f"  ✓ Found {n_unique:,} unique vectors")
print(f"  ✓ {n_duplicates:,} vectors are duplicates")
print()

# Count tokens participating in black holes
duplicate_mask = counts > 1
n_black_hole_centroids = duplicate_mask.sum().item()

black_hole_tokens = []
if n_black_hole_centroids > 0:
    print(f"Found {n_black_hole_centroids} black hole centroids")
    print("Counting tokens...")
    
    black_hole_unique_ids = duplicate_mask.nonzero(as_tuple=True)[0]
    
    for unique_id in tqdm(black_hole_unique_ids, desc="Processing"):
        # Find all tokens that map to this unique vector
        tokens = (inverse_indices == unique_id).nonzero(as_tuple=True)[0].tolist()
        black_hole_tokens.extend(tokens)
    
    print()

n_black_hole_tokens = len(black_hole_tokens)

print(f"Black hole tokens: {n_black_hole_tokens:,} ({100 * n_black_hole_tokens / n_vectors:.2f}%)")
if n_black_hole_tokens > 0:
    print(f"  Organized into {n_black_hole_centroids} centroids")
print()


STAGE 1: BLACK HOLE DETECTION

Finding unique vectors...
  ✓ Found 149,849 unique vectors
  ✓ 2,087 vectors are duplicates

Found 13 black hole centroids
Counting tokens...


Processing: 100%|██████████| 13/13 [00:00<00:00, 5757.15it/s]


Black hole tokens: 2,100 (1.38%)
  Organized into 13 centroids






## Stage 2: Deduplication

In [30]:
print("=" * 80)
print("STAGE 2: DEDUPLICATION")
print("=" * 80)
print()

# Keep one representative per unique vector
print("Creating deduplicated token set...")
representative_tokens = []
for unique_id in range(n_unique):
    # Get first token that maps to this unique vector
    token_id = (inverse_indices == unique_id).nonzero(as_tuple=True)[0][0].item()
    representative_tokens.append(token_id)

# Index W (which is on device) and keep on device
W_dedup = W[representative_tokens]
n_dedup = len(representative_tokens)

print(f"  ✓ Deduplicated: {n_vectors:,} → {n_dedup:,} tokens")
print(f"  ✓ Reduced search space by {n_vectors - n_dedup:,} tokens ({100 * (n_vectors - n_dedup) / n_vectors:.2f}%)")
print(f"  ✓ W_dedup on device: {W_dedup.device}")
print()

STAGE 2: DEDUPLICATION

Creating deduplicated token set...
  ✓ Deduplicated: 151,936 → 149,849 tokens
  ✓ Reduced search space by 2,087 tokens (1.37%)
  ✓ W_dedup on device: mps:0



## Stage 3: Isolated Token Exclusion (ε-sphere filter)

In [31]:
print("=" * 80)
print("STAGE 3: ISOLATED TOKEN EXCLUSION (ε-sphere filter)")
print("=" * 80)
print()

# Compute epsilon multiplier (worst-case diagonal distance)
EPSILON_MULTIPLIER = math.ceil(math.sqrt(n_dims))
print(f"Using ε = {EPSILON_MULTIPLIER} × ULP (worst-case {n_dims}D diagonal)")
print()

# Estimate memory usage
batch_memory_gb = (BATCH_SIZE * n_dedup * 4) / 1024**3  # float32
print(f"Memory estimate: {batch_memory_gb:.2f} GB per batch (batch size = {BATCH_SIZE})")

if batch_memory_gb > 0.5:
    print(f"  ⚠️  Large memory usage! Consider reducing BATCH_SIZE if this crashes.")

print()
print("Finding candidate pairs within ε...")

candidate_pairs = []
isolated_tokens = set(range(n_dedup))

# Pre-convert W_dedup to float32 once (stays on device)
W_dedup_float = W_dedup.float()

for i in tqdm(range(0, n_dedup, BATCH_SIZE), desc="Processing batches"):
    batch_end = min(i + BATCH_SIZE, n_dedup)
    batch = W_dedup_float[i:batch_end]  # Already float32, on device
    
    # Compute distances to all deduplicated tokens (stays on device)
    distances = torch.cdist(batch, W_dedup_float)  # (B, N) on device
    
    for b in range(batch.shape[0]):
        token_i = i + b
        
        # Compute epsilon for this token
        max_exp = get_max_exponent(W_dedup[token_i])
        ulp = compute_ulp_at_exponent(max_exp)
        epsilon = ulp * EPSILON_MULTIPLIER
        
        # Find neighbors within epsilon
        neighbors = (distances[b] < epsilon).nonzero(as_tuple=True)[0]
        neighbors = neighbors[neighbors != token_i]  # Exclude self
        
        if len(neighbors) > 0:
            # This token has neighbors - not isolated
            isolated_tokens.discard(token_i)
            
            # Add pairs (only j > i to avoid duplicates)
            for j in neighbors.tolist():
                if j > token_i:
                    candidate_pairs.append((token_i, j))

n_isolated = len(isolated_tokens)
n_candidates = len(candidate_pairs)

print()
print(f"  ✓ Found {n_candidates:,} candidate pairs within ε")
print(f"  ✓ Excluded {n_isolated:,} isolated tokens ({100 * n_isolated / n_dedup:.2f}%)")
print()

STAGE 3: ISOLATED TOKEN EXCLUSION (ε-sphere filter)

Using ε = 51 × ULP (worst-case 2560D diagonal)

Memory estimate: 0.06 GB per batch (batch size = 100)

Finding candidate pairs within ε...


Processing batches: 100%|██████████| 1499/1499 [04:01<00:00,  6.20it/s]


  ✓ Found 11,212 candidate pairs within ε
  ✓ Excluded 149,698 isolated tokens (99.90%)






## Stage 4: Orthogonal Neighbor Detection

Three-step verification for candidate pairs:
1. Geometric filter (L∞ = L1)
2. ULP distance check
3. Bit-level verification

In [36]:
print("=" * 80)
print("STAGE 4: ORTHOGONAL NEIGHBOR DETECTION")
print("=" * 80)
print()

if n_candidates == 0:
    print("No candidate pairs to check. Skipping Stage 4.\n")
    orthogonal_pairs = []
else:
    # Step 4a: Geometric filter (L∞ = L1)
    print("Step 4a: Geometric filter (L∞ = L1)...")
    orthogonal_candidates = []
    
    for i, j in tqdm(candidate_pairs, desc="  Checking"):
        # Operations on device
        diff = (W_dedup_float[i] - W_dedup_float[j]).abs()
        l_inf = diff.max().item()
        l_1 = diff.sum().item()
        
        if abs(l_inf - l_1) < 1e-7:  # Exactly one dimension differs
            orthogonal_candidates.append((i, j))
    
    print(f"  ✓ {len(orthogonal_candidates):,} pairs pass L∞ = L1 test")
    print()
    
    # Step 4b: ULP distance check
    if len(orthogonal_candidates) == 0:
        print("No candidates passed geometric filter. Skipping steps 4b-4c.\n")
        orthogonal_pairs = []
    else:
        print("Step 4b: ULP distance check...")
        ulp_candidates = []
        
        for i, j in tqdm(orthogonal_candidates, desc="  Checking"):
            diff = (W_dedup_float[i] - W_dedup_float[j]).abs()
            dim = diff.argmax().item()  # The one dimension that differs
            actual_dist = diff[dim].item()
            
            # Decode exponent at this dimension
            val_i = W_dedup[i, dim]
            decoded = decode_bfloat16_bits(val_i)
            ulp = compute_ulp_at_exponent(decoded['exponent'])
            
            if abs(actual_dist - ulp) < ulp * ULP_TOLERANCE:
                ulp_candidates.append((i, j, dim))
        
        print(f"  ✓ {len(ulp_candidates):,} pairs pass ULP distance test")
        print()
        
        # Step 4c: Bit-level verification
        if len(ulp_candidates) == 0:
            print("No candidates passed ULP check. Skipping step 4c.\n")
            orthogonal_pairs = []
        else:
            print("Step 4c: Bit-level verification...")
            orthogonal_pairs = []
            
            for i, j, dim in tqdm(ulp_candidates, desc="  Checking"):
                val_i = W_dedup[i, dim]
                val_j = W_dedup[j, dim]
                
                decoded_i = decode_bfloat16_bits(val_i)
                decoded_j = decode_bfloat16_bits(val_j)
                
                same_sign = decoded_i['sign'] == decoded_j['sign']
                same_exp = decoded_i['exponent'] == decoded_j['exponent']
                mant_diff = abs(decoded_i['mantissa'] - decoded_j['mantissa'])
                
                if same_sign and same_exp and mant_diff == 1:
                    orthogonal_pairs.append((i, j, dim))
            
            print(f"  ✓ {len(orthogonal_pairs):,} pairs verified as orthogonal neighbors")
            print()

# Extract unique tokens that participate in orthogonal structure
orthogonal_tokens = set()
for i, j, dim in orthogonal_pairs:
    orthogonal_tokens.add(i)
    orthogonal_tokens.add(j)

n_orthogonal_tokens = len(orthogonal_tokens)

STAGE 4: ORTHOGONAL NEIGHBOR DETECTION

Step 4a: Geometric filter (L∞ = L1)...


  Checking: 100%|██████████| 11212/11212 [00:03<00:00, 3442.61it/s]


  ✓ 240 pairs pass L∞ = L1 test

Step 4b: ULP distance check...


  Checking: 100%|██████████| 240/240 [00:00<00:00, 2500.93it/s]


  ✓ 159 pairs pass ULP distance test

Step 4c: Bit-level verification...


  Checking: 100%|██████████| 159/159 [00:00<00:00, 4337.64it/s]

  ✓ 159 pairs verified as orthogonal neighbors






## Summary

In [37]:
print("=" * 80)
print("SUMMARY")
print("=" * 80)
print()

print(f"Input: {n_vectors:,} vectors × {n_dims:,} dimensions")
print()

print(f"Black holes:")
print(f"  {n_black_hole_tokens:,} tokens ({100 * n_black_hole_tokens / n_vectors:.2f}%)")
if n_black_hole_tokens > 0:
    print(f"  {n_black_hole_centroids} centroids")
print()

print(f"Orthogonal neighbors:")
print(f"  {len(orthogonal_pairs):,} pairs found")
print(f"  {n_orthogonal_tokens:,} unique tokens participate ({100 * n_orthogonal_tokens / n_dedup:.2f}% of deduplicated)")
print()

if len(orthogonal_pairs) > 0:
    # Dimension distribution
    dim_counts = Counter([dim for _, _, dim in orthogonal_pairs])
    print(f"Most common dimensions (top 10):")
    for dim, count in dim_counts.most_common(10):
        print(f"  Dimension {dim:4d}: {count:4d} pairs")
    print()

print("=" * 80)

SUMMARY

Input: 151,936 vectors × 2,560 dimensions

Black holes:
  2,100 tokens (1.38%)
  13 centroids

Orthogonal neighbors:
  159 pairs found
  39 unique tokens participate (0.03% of deduplicated)

Most common dimensions (top 10):
  Dimension 1564:   67 pairs
  Dimension 1435:   27 pairs
  Dimension 1718:   21 pairs
  Dimension  216:   13 pairs
  Dimension 1362:   13 pairs
  Dimension 1008:   10 pairs
  Dimension 1382:    8 pairs

