# 1.13a: Lattice Structure Search

**Goal:** Find black holes and orthogonally-adjacent lattice neighbors.

## Method

### Stage 1: Black Hole Detection
Use `torch.unique()` to find duplicate vectors.

### Stage 2: Deduplication
Keep one representative per black hole centroid to reduce search space.

### Stage 3: Isolated Token Exclusion (ε-sphere filter)
For each token, find neighbors within ε = ULP × √D (worst-case diagonal distance).
Exclude tokens with no neighbors.

### Stage 4: Orthogonal Neighbor Detection
For candidate pairs within ε:
1. **Geometric filter**: L∞ = L1 (exactly one dimension differs)
2. **ULP distance check**: Verify distance ≈ 1 ULP
3. **Bit-level verification**: Same sign, same exponent, mantissa differs by 1

## Design Goals

- **Memory-efficient**: No massive allocations, batch processing
- **Hardware accelerated**: Explicit device management for GPU/MPS
- **Scalable**: Works for Qwen (151k tokens × 2560D)
- **Correct**: Finds nothing in random Gaussian, detects spongecrystal in Qwen

## Parameters

In [38]:
# Tensor to analyze
TENSOR_FILE = "../tensors/Qwen3-4B-Instruct-2507/W.safetensors"
TENSOR_KEY = "W"
TENSOR_INDEX = None  # None = load full tensor

# Stage 3 parameters
BATCH_SIZE = 100  # For distance computations (100 × 150k × 4 bytes ≈ 60 MB per batch)

# Stage 4 parameters
ULP_TOLERANCE = 0.01  # Allow 1% tolerance when checking ULP distances
MAX_EXAMPLES = 10  # How many example tokens to show in output

## Imports

In [39]:
import torch
import numpy as np
import math
import ml_dtypes
from safetensors.torch import load_file
from pathlib import Path
from collections import Counter
from tqdm import tqdm

## Device Detection

In [40]:
# Detect available device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Data

In [41]:
# Load tensor (safetensors loads to CPU by default)
data = load_file(TENSOR_FILE)
W = data[TENSOR_KEY]

# Apply indexing if specified
if TENSOR_INDEX is not None:
    W = W[TENSOR_INDEX]

# Move to device for hardware acceleration
W = W.to(device)

n_vectors, n_dims = W.shape

print(f"✓ Loaded W from {Path(TENSOR_FILE).name}")
print(f"  Shape: {W.shape}")
print(f"  Dtype: {W.dtype}")
print(f"  Device: {W.device}")
print(f"  Memory: ~{W.element_size() * W.numel() / 1024**3:.2f} GB")
print()
print(f"Analyzing {n_vectors:,} vectors in {n_dims:,} dimensions")

✓ Loaded W from W.safetensors
  Shape: torch.Size([151936, 2560])
  Dtype: torch.bfloat16
  Device: mps:0
  Memory: ~0.72 GB

Analyzing 151,936 vectors in 2,560 dimensions


## Helper Functions

In [42]:
def get_max_exponent(vector_bf16):
    """
    Get the maximum exponent across all dimensions of a bfloat16 vector.
    
    Args:
        vector_bf16: (D,) tensor of bfloat16 values (on any device)
    
    Returns:
        int: maximum exponent value
    """
    values_uint16 = vector_bf16.view(torch.int16).to(torch.int64) & 0xFFFF
    exponents = (values_uint16 >> 7) & 0xFF
    return exponents.max().item()

def decode_bfloat16_bits(value_bf16):
    """
    Decode a single bfloat16 value into its bit components.
    
    Args:
        value_bf16: scalar torch.bfloat16 value (on any device)
    
    Returns:
        dict with 'sign', 'exponent', 'mantissa' as integers
    """
    # Convert to bytes and interpret as uint16
    bits_uint16 = value_bf16.view(torch.int16).cpu().numpy().astype(np.uint16).item()
    bits_binary = format(bits_uint16, '016b')
    
    sign_bit = bits_binary[0]
    exponent_bits = bits_binary[1:9]
    mantissa_bits = bits_binary[9:16]
    
    sign = int(sign_bit)
    exponent = int(exponent_bits, 2)
    mantissa = int(mantissa_bits, 2)
    
    return {
        'bits_uint16': bits_uint16,
        'bits_binary': bits_binary,
        'sign': sign,
        'exponent': exponent,
        'mantissa': mantissa,
        'sign_bit': sign_bit,
        'exponent_bits': exponent_bits,
        'mantissa_bits': mantissa_bits
    }

def compute_ulp_at_exponent(exponent):
    """
    Compute ULP (unit in last place) for bfloat16 at given exponent.
    
    bfloat16 has 7 mantissa bits, so ULP = 2^(exponent - 127 - 7) = 2^(exponent - 134)
    """
    return 2.0 ** (exponent - 134)

print("✓ Helper functions defined")

✓ Helper functions defined


## Stage 1: Black Hole Detection

In [43]:
print("\n" + "=" * 80)
print("STAGE 1: BLACK HOLE DETECTION")
print("=" * 80)
print()

# torch.unique not implemented on MPS in Torch 2.8, use CPU
print("Finding unique vectors...")
W_cpu = W.cpu()
W_unique, inverse_indices, counts = torch.unique(W_cpu, dim=0, return_inverse=True, return_counts=True)

n_unique = len(W_unique)
n_duplicates = n_vectors - n_unique

print(f"  ✓ Found {n_unique:,} unique vectors")
print(f"  ✓ {n_duplicates:,} vectors are duplicates")
print()

# Count tokens participating in black holes
duplicate_mask = counts > 1
n_black_hole_centroids = duplicate_mask.sum().item()

black_hole_tokens = []
if n_black_hole_centroids > 0:
    print(f"Found {n_black_hole_centroids} black hole centroids")
    print("Counting tokens...")
    
    black_hole_unique_ids = duplicate_mask.nonzero(as_tuple=True)[0]
    
    for unique_id in tqdm(black_hole_unique_ids, desc="Processing"):
        # Find all tokens that map to this unique vector
        tokens = (inverse_indices == unique_id).nonzero(as_tuple=True)[0].tolist()
        black_hole_tokens.extend(tokens)
    
    print()

n_black_hole_tokens = len(black_hole_tokens)

print(f"Black hole tokens: {n_black_hole_tokens:,} ({100 * n_black_hole_tokens / n_vectors:.2f}%)")
if n_black_hole_tokens > 0:
    print(f"  Organized into {n_black_hole_centroids} centroids")
print()


STAGE 1: BLACK HOLE DETECTION

Finding unique vectors...
  ✓ Found 149,849 unique vectors
  ✓ 2,087 vectors are duplicates

Found 13 black hole centroids
Counting tokens...


Processing: 100%|██████████| 13/13 [00:00<00:00, 5811.76it/s]


Black hole tokens: 2,100 (1.38%)
  Organized into 13 centroids






## Stage 2: Deduplication

In [44]:
print("=" * 80)
print("STAGE 2: DEDUPLICATION")
print("=" * 80)
print()

# Keep one representative per unique vector
print("Creating deduplicated token set...")
representative_tokens = []
for unique_id in range(n_unique):
    # Get first token that maps to this unique vector
    token_id = (inverse_indices == unique_id).nonzero(as_tuple=True)[0][0].item()
    representative_tokens.append(token_id)

# Index W (which is on device) and keep on device
W_dedup = W[representative_tokens]
n_dedup = len(representative_tokens)

print(f"  ✓ Deduplicated: {n_vectors:,} → {n_dedup:,} tokens")
print(f"  ✓ Reduced search space by {n_vectors - n_dedup:,} tokens ({100 * (n_vectors - n_dedup) / n_vectors:.2f}%)")
print(f"  ✓ W_dedup on device: {W_dedup.device}")
print()

STAGE 2: DEDUPLICATION

Creating deduplicated token set...
  ✓ Deduplicated: 151,936 → 149,849 tokens
  ✓ Reduced search space by 2,087 tokens (1.37%)
  ✓ W_dedup on device: mps:0



## Stage 3: Isolated Token Exclusion (ε-sphere filter)

In [45]:
print("=" * 80)
print("STAGE 3: ISOLATED TOKEN EXCLUSION (ε-sphere filter)")
print("=" * 80)
print()

# Compute epsilon multiplier (worst-case diagonal distance)
EPSILON_MULTIPLIER = math.ceil(math.sqrt(n_dims))
print(f"Using ε = {EPSILON_MULTIPLIER} × ULP (worst-case {n_dims}D diagonal)")
print()

# Estimate memory usage
batch_memory_gb = (BATCH_SIZE * n_dedup * 4) / 1024**3  # float32
print(f"Memory estimate: {batch_memory_gb:.2f} GB per batch (batch size = {BATCH_SIZE})")

if batch_memory_gb > 0.5:
    print(f"  ⚠️  Large memory usage! Consider reducing BATCH_SIZE if this crashes.")

print()
print("Finding candidate pairs within ε...")

candidate_pairs = []
isolated_tokens = set(range(n_dedup))

# Pre-convert W_dedup to float32 once (stays on device)
W_dedup_float = W_dedup.float()

for i in tqdm(range(0, n_dedup, BATCH_SIZE), desc="Processing batches"):
    batch_end = min(i + BATCH_SIZE, n_dedup)
    batch = W_dedup_float[i:batch_end]  # Already float32, on device
    
    # Compute distances to all deduplicated tokens (stays on device)
    distances = torch.cdist(batch, W_dedup_float)  # (B, N) on device
    
    for b in range(batch.shape[0]):
        token_i = i + b
        
        # Compute epsilon for this token
        max_exp = get_max_exponent(W_dedup[token_i])
        ulp = compute_ulp_at_exponent(max_exp)
        epsilon = ulp * EPSILON_MULTIPLIER
        
        # Find neighbors within epsilon
        neighbors = (distances[b] < epsilon).nonzero(as_tuple=True)[0]
        neighbors = neighbors[neighbors != token_i]  # Exclude self
        
        if len(neighbors) > 0:
            # This token has neighbors - not isolated
            isolated_tokens.discard(token_i)
            
            # Add pairs (only j > i to avoid duplicates)
            for j in neighbors.tolist():
                if j > token_i:
                    candidate_pairs.append((token_i, j))

n_isolated = len(isolated_tokens)
n_candidates = len(candidate_pairs)

print()
print(f"  ✓ Found {n_candidates:,} candidate pairs within ε")
print(f"  ✓ Excluded {n_isolated:,} isolated tokens ({100 * n_isolated / n_dedup:.2f}%)")
print()

STAGE 3: ISOLATED TOKEN EXCLUSION (ε-sphere filter)

Using ε = 51 × ULP (worst-case 2560D diagonal)

Memory estimate: 0.06 GB per batch (batch size = 100)

Finding candidate pairs within ε...


Processing batches: 100%|██████████| 1499/1499 [03:59<00:00,  6.25it/s]


  ✓ Found 11,212 candidate pairs within ε
  ✓ Excluded 149,698 isolated tokens (99.90%)






## Stage 4: Lattice Neighbor Detection

Three-step verification for candidate pairs:
1. Geometric filter (L∞ = L1 for orthogonal, general distance check for diagonal)
2. ULP distance check
3. Bit-level verification (classify as orthogonal or diagonal based on dimensionality)

In [46]:
print("=" * 80)
print("STAGE 4: LATTICE NEIGHBOR DETECTION")
print("=" * 80)
print()

if n_candidates == 0:
    print("No candidate pairs to check. Skipping Stage 4.\n")
    orthogonal_pairs = []
    diagonal_pairs = []
else:
    # Step 4a: Geometric filter
    print("Step 4a: Geometric filter...")
    print("  (L∞ = L1 suggests orthogonal, but keeping all candidates for full check)")
    lattice_candidates = candidate_pairs  # Keep all ε-sphere candidates
    
    print(f"  ✓ {len(lattice_candidates):,} pairs to verify")
    print()
    
    # Step 4b: Bit-level verification across ALL dimensions
    print("Step 4b: Full bit-level verification...")
    orthogonal_pairs = []
    diagonal_pairs = []
    
    for i, j in tqdm(lattice_candidates, desc="  Checking"):
        # Decode ALL dimensions for both vectors
        vec_i = W_dedup[i]
        vec_j = W_dedup[j]
        
        # Convert to uint16 for bit manipulation
        bits_i = vec_i.view(torch.int16).cpu().numpy().astype(np.uint16)
        bits_j = vec_j.view(torch.int16).cpu().numpy().astype(np.uint16)
        
        # Extract sign, exponent, mantissa for all dimensions
        signs_i = (bits_i >> 15) & 0x1
        signs_j = (bits_j >> 15) & 0x1
        exps_i = (bits_i >> 7) & 0xFF
        exps_j = (bits_j >> 7) & 0xFF
        mants_i = bits_i & 0x7F
        mants_j = bits_j & 0x7F
        
        # Find dimensions where vectors are lattice neighbors
        same_sign = signs_i == signs_j
        same_exp = exps_i == exps_j
        mant_diff = np.abs(mants_i.astype(np.int16) - mants_j.astype(np.int16))
        
        # Lattice neighbor in dimension d: same sign, same exponent, mantissa differs by 1
        lattice_neighbor_dims = same_sign & same_exp & (mant_diff == 1)
        neighbor_dims = np.where(lattice_neighbor_dims)[0]
        
        n_diff = len(neighbor_dims)
        
        if n_diff == 1:
            # Orthogonal neighbor (differs in exactly 1 dimension)
            orthogonal_pairs.append((i, j, neighbor_dims[0]))
        elif n_diff > 1:
            # Diagonal neighbor (differs in multiple dimensions)
            diagonal_pairs.append((i, j, neighbor_dims.tolist(), n_diff))
    
    print(f"  ✓ {len(orthogonal_pairs):,} orthogonal pairs (1D)")
    print(f"  ✓ {len(diagonal_pairs):,} diagonal pairs (2D+)")
    print()

# Extract unique tokens that participate in lattice structure
orthogonal_tokens = set()
for i, j, dim in orthogonal_pairs:
    orthogonal_tokens.add(i)
    orthogonal_tokens.add(j)

diagonal_tokens = set()
for i, j, dims, n_diff in diagonal_pairs:
    diagonal_tokens.add(i)
    diagonal_tokens.add(j)

n_orthogonal_tokens = len(orthogonal_tokens)
n_diagonal_tokens = len(diagonal_tokens)
n_lattice_tokens = len(orthogonal_tokens | diagonal_tokens)

STAGE 4: LATTICE NEIGHBOR DETECTION

Step 4a: Geometric filter...
  (L∞ = L1 suggests orthogonal, but keeping all candidates for full check)
  ✓ 11,212 pairs to verify

Step 4b: Full bit-level verification...


  Checking: 100%|██████████| 11212/11212 [00:02<00:00, 4185.96it/s]

  ✓ 90 orthogonal pairs (1D)
  ✓ 11,093 diagonal pairs (2D+)






## Save Results

In [48]:
from safetensors.torch import save_file

print("=" * 80)
print("SAVING RESULTS")
print("=" * 80)
print()

# Map deduplicated indices back to original token IDs
representative_tokens_tensor = torch.tensor(representative_tokens, dtype=torch.int64)

# Black hole tokens (already in original token ID space)
black_hole_token_ids = torch.tensor(sorted(black_hole_tokens), dtype=torch.int64)
black_hole_mask = torch.zeros(n_vectors, dtype=torch.bool)
black_hole_mask[black_hole_token_ids] = True

# Orthogonal neighbor tokens (map from deduplicated to original)
orthogonal_token_ids_dedup = torch.tensor(sorted(orthogonal_tokens), dtype=torch.int64)
orthogonal_token_ids = representative_tokens_tensor[orthogonal_token_ids_dedup]
orthogonal_mask = torch.zeros(n_vectors, dtype=torch.bool)
orthogonal_mask[orthogonal_token_ids] = True

# Diagonal neighbor tokens (map from deduplicated to original)
diagonal_token_ids_dedup = torch.tensor(sorted(diagonal_tokens), dtype=torch.int64)
diagonal_token_ids = representative_tokens_tensor[diagonal_token_ids_dedup]
diagonal_mask = torch.zeros(n_vectors, dtype=torch.bool)
diagonal_mask[diagonal_token_ids] = True

# All lattice tokens (union)
lattice_token_ids_dedup = torch.tensor(sorted(orthogonal_tokens | diagonal_tokens), dtype=torch.int64)
lattice_token_ids = representative_tokens_tensor[lattice_token_ids_dedup]
lattice_mask = torch.zeros(n_vectors, dtype=torch.bool)
lattice_mask[lattice_token_ids] = True

# Edge lists for graph construction
# Store as pairs of original token IDs
orthogonal_edges = torch.tensor(
    [(representative_tokens[i], representative_tokens[j]) for i, j, _ in orthogonal_pairs],
    dtype=torch.int64
)

diagonal_edges = torch.tensor(
    [(representative_tokens[i], representative_tokens[j]) for i, j, _, _ in diagonal_pairs],
    dtype=torch.int64
)

# Prepare save dictionary
save_dict = {
    # Black holes
    'black_hole_token_ids': black_hole_token_ids,
    'black_hole_mask': black_hole_mask,
    'n_black_hole_tokens': torch.tensor(n_black_hole_tokens, dtype=torch.int64),
    'n_black_hole_centroids': torch.tensor(n_black_hole_centroids, dtype=torch.int64),
    
    # Orthogonal neighbors
    'orthogonal_token_ids': orthogonal_token_ids,
    'orthogonal_mask': orthogonal_mask,
    'orthogonal_edges': orthogonal_edges,
    
    # Diagonal neighbors
    'diagonal_token_ids': diagonal_token_ids,
    'diagonal_mask': diagonal_mask,
    'diagonal_edges': diagonal_edges,
    
    # All lattice structure
    'lattice_token_ids': lattice_token_ids,
    'lattice_mask': lattice_mask,
    
    # Deduplication mapping (for reference)
    'representative_tokens': representative_tokens_tensor,
    'inverse_indices': inverse_indices,
}

# Determine output path based on input
model_name = Path(TENSOR_FILE).parent.name
output_path = Path(f"../tensors/{model_name}/1.13a_lattice_structure.safetensors")

print(f"Saving to: {output_path}")
save_file(save_dict, str(output_path))

print()
print("✓ Saved lattice structure data:")
print(f"  Black holes: {len(black_hole_token_ids)} tokens")
print(f"  Orthogonal: {len(orthogonal_token_ids)} tokens, {len(orthogonal_edges)} edges")
print(f"  Diagonal: {len(diagonal_token_ids)} tokens, {len(diagonal_edges)} edges")
print(f"  Total lattice: {len(lattice_token_ids)} tokens")
print()

SAVING RESULTS

Saving to: ../tensors/Qwen3-4B-Instruct-2507/1.13a_lattice_structure.safetensors

✓ Saved lattice structure data:
  Black holes: 2100 tokens
  Orthogonal: 41 tokens, 90 edges
  Diagonal: 151 tokens, 11093 edges
  Total lattice: 151 tokens



## Summary

In [47]:
print("=" * 80)
print("SUMMARY")
print("=" * 80)
print()

print(f"Input: {n_vectors:,} vectors × {n_dims:,} dimensions")
print()

print(f"Black holes:")
print(f"  {n_black_hole_tokens:,} tokens ({100 * n_black_hole_tokens / n_vectors:.2f}%)")
if n_black_hole_tokens > 0:
    print(f"  {n_black_hole_centroids} centroids")
print()

print(f"Lattice neighbors:")
print(f"  Orthogonal: {len(orthogonal_pairs):,} pairs, {n_orthogonal_tokens:,} tokens")
print(f"  Diagonal:   {len(diagonal_pairs):,} pairs, {n_diagonal_tokens:,} tokens")
print(f"  Total:      {n_lattice_tokens:,} unique tokens in lattice structure ({100 * n_lattice_tokens / n_dedup:.2f}% of deduplicated)")
print()

if len(orthogonal_pairs) > 0:
    # Dimension distribution for orthogonal neighbors
    dim_counts = Counter([dim for _, _, dim in orthogonal_pairs])
    print(f"Orthogonal pairs by dimension (top 10):")
    for dim, count in dim_counts.most_common(10):
        print(f"  Dimension {dim:4d}: {count:4d} pairs")
    print()

if len(diagonal_pairs) > 0:
    # Dimensionality distribution for diagonal neighbors
    diag_dim_counts = Counter([n_diff for _, _, _, n_diff in diagonal_pairs])
    print(f"Diagonal pairs by dimensionality:")
    for n_diff in sorted(diag_dim_counts.keys()):
        count = diag_dim_counts[n_diff]
        print(f"  {n_diff}D diagonal: {count:4d} pairs")
    print()

print("=" * 80)

SUMMARY

Input: 151,936 vectors × 2,560 dimensions

Black holes:
  2,100 tokens (1.38%)
  13 centroids

Lattice neighbors:
  Orthogonal: 90 pairs, 41 tokens
  Diagonal:   11,093 pairs, 151 tokens
  Total:      151 unique tokens in lattice structure (0.10% of deduplicated)

Orthogonal pairs by dimension (top 10):
  Dimension 1435:   20 pairs
  Dimension 1382:   19 pairs
  Dimension 1564:   16 pairs
  Dimension 1008:   13 pairs
  Dimension 1718:    8 pairs
  Dimension  216:    5 pairs
  Dimension 1272:    4 pairs
  Dimension  993:    3 pairs
  Dimension 2012:    2 pairs

Diagonal pairs by dimensionality:
  2D diagonal:  155 pairs
  3D diagonal:  155 pairs
  4D diagonal:  203 pairs
  5D diagonal:  202 pairs
  6D diagonal:  117 pairs
  7D diagonal:   53 pairs
  8D diagonal:   41 pairs
  9D diagonal:   29 pairs
  10D diagonal:   46 pairs
  11D diagonal:   53 pairs
  12D diagonal:   67 pairs
  13D diagonal:   19 pairs
  14D diagonal:    7 pairs
  15D diagonal:    5 pairs
  16D diagonal:   21