# 08.4d: Black Hole Counts (Pre-computed)

**Compute black hole counts from distance matrices and save for reuse**

This notebook processes pre-computed Chebyshev distance matrices to extract black hole statistics across all training steps. By computing once and saving, we enable fast iteration on fission detection analysis.

## What We Compute

For each run, at each step:
- **Number of black holes** (connected components with population ≥2)
- **Total population** in black holes
- **Largest black hole size**

## Strategy

1. **Load** pre-computed Chebyshev distance matrices
2. **Build** adjacency graphs (GPU-accelerated)
3. **Run** Union-Find to extract connected components (CPU, sequential)
4. **Save** results as small safetensors files (~40 KB each)

## Performance

- Expected: ~7 minutes for 16 runs (one-time cost)
- Bottleneck: Union-Find algorithm (inherently sequential, can't be parallelized)

## Output

For each run, saves `black_hole_counts.safetensors` containing:
- `counts`: (10001,) int32 - number of black holes at each step
- `populations`: (10001,) int32 - total population in black holes
- `largest_bh_size`: (10001,) int32 - size of largest black hole

## Parameters

In [7]:
# Data directories
DATA_DIR = "../data"
RUN_PATTERN = "embeddings_128vocab_qweninit_run_*"
DISTANCE_FILE = "pairwise_distances.safetensors"
DISTANCE_KEY = "chebyshev_distances"
OUTPUT_FILE = "black_hole_counts.safetensors"

# Expected dimensions
EXPECTED_RUNS = 16
EXPECTED_STEPS = 10001
VOCAB_SIZE = 128

# Black hole threshold: bit-identical (use very small epsilon)
BLACK_HOLE_THRESHOLD = 1e-10

RANDOM_SEED = 42

## Imports

In [8]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file
from pathlib import Path
from tqdm.auto import tqdm
from collections import defaultdict

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

## Find All Runs

In [9]:
data_dir = Path(DATA_DIR)
run_dirs = sorted(data_dir.glob(RUN_PATTERN))

print(f"Found {len(run_dirs)} runs:")
for run_dir in run_dirs:
    print(f"  {run_dir.name}")

if len(run_dirs) != EXPECTED_RUNS:
    print(f"\n⚠ WARNING: Expected {EXPECTED_RUNS} runs, found {len(run_dirs)}")
else:
    print(f"\n✓ Found all {EXPECTED_RUNS} runs")

Found 16 runs:
  embeddings_128vocab_qweninit_run_001
  embeddings_128vocab_qweninit_run_002
  embeddings_128vocab_qweninit_run_003
  embeddings_128vocab_qweninit_run_004
  embeddings_128vocab_qweninit_run_005
  embeddings_128vocab_qweninit_run_006
  embeddings_128vocab_qweninit_run_007
  embeddings_128vocab_qweninit_run_008
  embeddings_128vocab_qweninit_run_009
  embeddings_128vocab_qweninit_run_010
  embeddings_128vocab_qweninit_run_011
  embeddings_128vocab_qweninit_run_012
  embeddings_128vocab_qweninit_run_013
  embeddings_128vocab_qweninit_run_014
  embeddings_128vocab_qweninit_run_015
  embeddings_128vocab_qweninit_run_016

✓ Found all 16 runs


## Union-Find for Connected Components

In [10]:
class UnionFind:
    """Union-Find data structure for finding connected components."""
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])  # Path compression
        return self.parent[x]
    
    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        
        if root_x == root_y:
            return
        
        # Union by rank
        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            self.parent[root_y] = root_x
            self.rank[root_x] += 1
    
    def get_components(self):
        """Return list of connected components (each component is a list of indices)."""
        components = defaultdict(list)
        for i in range(len(self.parent)):
            root = self.find(i)
            components[root].append(i)
        return list(components.values())


def find_black_holes(distance_matrix, threshold):
    """
    Find black holes in a distance matrix using Union-Find.
    
    Args:
        distance_matrix: (n, n) tensor of pairwise Chebyshev distances
        threshold: distance threshold for considering tokens identical
    
    Returns:
        black_holes: list of lists, each sublist contains token IDs in a black hole
    """
    n = distance_matrix.shape[0]
    uf = UnionFind(n)
    
    # Build adjacency graph: tokens within threshold are connected
    adjacency = (distance_matrix < threshold)
    
    # Extract upper triangle pairs that are adjacent (avoid redundant checks)
    triu_indices = torch.triu_indices(n, n, offset=1)
    adjacent_pairs = triu_indices[:, adjacency[triu_indices[0], triu_indices[1]]]
    
    # Union adjacent pairs
    for k in range(adjacent_pairs.shape[1]):
        i, j = adjacent_pairs[0, k].item(), adjacent_pairs[1, k].item()
        uf.union(i, j)
    
    # Get connected components with population ≥ 2
    components = uf.get_components()
    black_holes = [comp for comp in components if len(comp) >= 2]
    
    return black_holes


print("✓ Union-Find and black hole detection functions defined")

✓ Union-Find and black hole detection functions defined


## Process Each Run

In [11]:
print(f"\nProcessing runs...\n")

for run_dir in tqdm(run_dirs, desc="Runs"):
    run_name = run_dir.name.split('_')[-1]
    distance_path = run_dir / DISTANCE_FILE
    output_path = run_dir / OUTPUT_FILE
    
    # Skip if already computed
    if output_path.exists():
        print(f"  {run_name}: already exists, skipping")
        continue
    
    if not distance_path.exists():
        print(f"  {run_name}: distance file not found, skipping")
        continue
    
    # Load distance matrices
    data = load_file(distance_path)
    chebyshev_distances = data[DISTANCE_KEY]
    n_steps = chebyshev_distances.shape[0]
    
    # Validate dimensions
    expected_shape = (EXPECTED_STEPS, VOCAB_SIZE, VOCAB_SIZE)
    if chebyshev_distances.shape != expected_shape:
        print(f"  {run_name}: unexpected shape {chebyshev_distances.shape}, skipping")
        continue
    
    # Allocate result arrays
    counts = np.zeros(n_steps, dtype=np.int32)
    populations = np.zeros(n_steps, dtype=np.int32)
    largest_bh_size = np.zeros(n_steps, dtype=np.int32)
    
    # Process each step
    for step in tqdm(range(n_steps), desc=f"  {run_name}", leave=False):
        distance_matrix = chebyshev_distances[step]
        black_holes = find_black_holes(distance_matrix, BLACK_HOLE_THRESHOLD)
        
        counts[step] = len(black_holes)
        populations[step] = sum(len(bh) for bh in black_holes)
        largest_bh_size[step] = max((len(bh) for bh in black_holes), default=0)
    
    # Save results
    save_dict = {
        'counts': torch.from_numpy(counts),
        'populations': torch.from_numpy(populations),
        'largest_bh_size': torch.from_numpy(largest_bh_size),
    }
    
    save_file(save_dict, output_path)
    
    file_size_kb = output_path.stat().st_size / 1e3
    print(f"  {run_name}: saved {file_size_kb:.1f} KB")

print(f"\n✓ All runs processed")


Processing runs...



Runs:   0%|          | 0/16 [00:00<?, ?it/s]

  001: already exists, skipping
  002: already exists, skipping
  003: already exists, skipping
  004: already exists, skipping


  005:   0%|          | 0/10001 [00:00<?, ?it/s]

  005: saved 120.2 KB


  006:   0%|          | 0/10001 [00:00<?, ?it/s]

  006: saved 120.2 KB


  007:   0%|          | 0/10001 [00:00<?, ?it/s]

  007: saved 120.2 KB


  008:   0%|          | 0/10001 [00:00<?, ?it/s]

  008: saved 120.2 KB


  009:   0%|          | 0/10001 [00:00<?, ?it/s]

  009: saved 120.2 KB


  010:   0%|          | 0/10001 [00:00<?, ?it/s]

  010: saved 120.2 KB


  011:   0%|          | 0/10001 [00:00<?, ?it/s]

  011: saved 120.2 KB


  012:   0%|          | 0/10001 [00:00<?, ?it/s]

  012: saved 120.2 KB


  013:   0%|          | 0/10001 [00:00<?, ?it/s]

  013: saved 120.2 KB


  014:   0%|          | 0/10001 [00:00<?, ?it/s]

  014: saved 120.2 KB


  015:   0%|          | 0/10001 [00:00<?, ?it/s]

  015: saved 120.2 KB


  016:   0%|          | 0/10001 [00:00<?, ?it/s]

  016: saved 120.2 KB

✓ All runs processed


## Summary

In [12]:
# Check what we created
print(f"\n{'='*80}")
print("SUMMARY")
print(f"{'='*80}\n")

total_size = 0
for run_dir in sorted(run_dirs):
    output_path = run_dir / OUTPUT_FILE
    if output_path.exists():
        size_kb = output_path.stat().st_size / 1e3
        total_size += size_kb
        run_name = run_dir.name.split('_')[-1]
        print(f"  {run_name}: {size_kb:.1f} KB")

print(f"\nTotal storage: {total_size:.1f} KB ({total_size / 1e3:.2f} MB)")
print(f"\nEach file contains:")
print(f"  counts: (10001,) int32 - number of black holes per step")
print(f"  populations: (10001,) int32 - total population in black holes per step")
print(f"  largest_bh_size: (10001,) int32 - size of largest black hole per step")
print(f"\n{'='*80}")


SUMMARY

  001: 120.2 KB
  002: 120.2 KB
  003: 120.2 KB
  004: 120.2 KB
  005: 120.2 KB
  006: 120.2 KB
  007: 120.2 KB
  008: 120.2 KB
  009: 120.2 KB
  010: 120.2 KB
  011: 120.2 KB
  012: 120.2 KB
  013: 120.2 KB
  014: 120.2 KB
  015: 120.2 KB
  016: 120.2 KB

Total storage: 1923.9 KB (1.92 MB)

Each file contains:
  counts: (10001,) int32 - number of black holes per step
  populations: (10001,) int32 - total population in black holes per step
  largest_bh_size: (10001,) int32 - size of largest black hole per step

