# 08.5d: Long Run Diagnostic

**Quick diagnostic check for late-stage black hole fission**

This notebook loads the final snapshot from a long training run and checks:
- How many black holes exist at the end?
- What is the total population in black holes?
- What is the size of the largest black hole?

## Hypothesis

If bfloat16 diffusion causes late-stage fission, we expect:
- **Multiple black holes** at step 100,000 (not just 1)
- **Total population ‚âà 51** (the dead tokens)
- **Fragment sizes < 51** (the original cluster broke apart)

If we see this, we've confirmed that dead tokens can spontaneously fragment over long training times.

## Parameters

In [1]:
# Run to analyze
DATA_DIR = "../data"
RUN_NAME = "embeddings_128vocab_qweninit_run_1001"
EMBEDDING_FILE = "embedding_evolution.safetensors"
EMBEDDING_KEY = "embedding_history"

# Black hole detection
BLACK_HOLE_THRESHOLD = 1e-10  # Chebyshev distance threshold for bit-identical

RANDOM_SEED = 42

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file
from pathlib import Path
from collections import defaultdict

np.random.seed(RANDOM_SEED)

## Union-Find for Black Hole Detection

In [3]:
class UnionFind:
    """Union-Find data structure for finding connected components."""
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])  # Path compression
        return self.parent[x]
    
    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        
        if root_x == root_y:
            return
        
        # Union by rank
        if self.rank[root_x] < self.rank[root_y]:
            self.parent[root_x] = root_y
        elif self.rank[root_x] > self.rank[root_y]:
            self.parent[root_y] = root_x
        else:
            self.parent[root_y] = root_x
            self.rank[root_x] += 1
    
    def get_components(self):
        """Return list of connected components (each component is a list of indices)."""
        components = defaultdict(list)
        for i in range(len(self.parent)):
            root = self.find(i)
            components[root].append(i)
        return list(components.values())


def find_black_holes(embeddings, threshold):
    """
    Find black holes in an embedding matrix.
    
    Args:
        embeddings: (vocab_size, hidden_dim) tensor
        threshold: Chebyshev distance threshold for considering tokens identical
    
    Returns:
        black_holes: list of lists, each sublist contains token IDs in a black hole
    """
    n = embeddings.shape[0]
    
    # Compute pairwise Chebyshev distances
    diff = embeddings.unsqueeze(0) - embeddings.unsqueeze(1)  # (n, n, d)
    distances = torch.abs(diff).max(dim=2)[0]  # (n, n)
    
    # Build adjacency graph
    adjacency = (distances < threshold)
    
    # Union-Find
    uf = UnionFind(n)
    triu_indices = torch.triu_indices(n, n, offset=1)
    adjacent_pairs = triu_indices[:, adjacency[triu_indices[0], triu_indices[1]]]
    
    for k in range(adjacent_pairs.shape[1]):
        i, j = adjacent_pairs[0, k].item(), adjacent_pairs[1, k].item()
        uf.union(i, j)
    
    # Get components with population ‚â• 2
    components = uf.get_components()
    black_holes = [comp for comp in components if len(comp) >= 2]
    
    return black_holes


print("‚úì Union-Find functions defined")

‚úì Union-Find functions defined


## Load Final Snapshot

In [4]:
run_dir = Path(DATA_DIR) / RUN_NAME
embedding_path = run_dir / EMBEDDING_FILE

if not embedding_path.exists():
    raise FileNotFoundError(f"Embedding file not found: {embedding_path}")

print(f"Loading embeddings from: {embedding_path}")
print(f"File size: {embedding_path.stat().st_size / 1e9:.2f} GB\n")

# Load full history
data = load_file(embedding_path)
embedding_history = data[EMBEDDING_KEY]

print(f"Shape: {embedding_history.shape}")
print(f"dtype: {embedding_history.dtype}")
print(f"\n‚úì Loaded {embedding_history.shape[0]} snapshots")

Loading embeddings from: ../data/embeddings_128vocab_qweninit_run_1001/embedding_evolution.safetensors
File size: 1.64 GB

Shape: torch.Size([100001, 128, 64])
dtype: torch.bfloat16

‚úì Loaded 100001 snapshots


## Analyze Final State

In [5]:
# Get the last step
final_step = embedding_history.shape[0] - 1
final_embeddings = embedding_history[final_step].float()  # Convert to float for computation

print(f"Analyzing step {final_step}...\n")

# Find black holes
black_holes = find_black_holes(final_embeddings, BLACK_HOLE_THRESHOLD)

# Compute statistics
num_black_holes = len(black_holes)
total_population = sum(len(bh) for bh in black_holes)
largest_bh_size = max((len(bh) for bh in black_holes), default=0)

print(f"{'='*80}")
print(f"FINAL STATE (Step {final_step})")
print(f"{'='*80}\n")

print(f"Number of black holes: {num_black_holes}")
print(f"Total population in black holes: {total_population}")
print(f"Largest black hole size: {largest_bh_size}")

if num_black_holes > 0:
    sizes = sorted([len(bh) for bh in black_holes], reverse=True)
    print(f"\nBlack hole size distribution: {sizes}")

print(f"\n{'='*80}")

Analyzing step 100000...

FINAL STATE (Step 100000)

Number of black holes: 1
Total population in black holes: 51
Largest black hole size: 51

Black hole size distribution: [51]



## Interpretation

In [6]:
print("\nINTERPRETATION:\n")

if num_black_holes == 0:
    print("‚ùå No black holes detected.")
    print("   All tokens have separated into singletons.")
    print("   Complete evaporation occurred.")

elif num_black_holes == 1:
    print("‚ö†Ô∏è  Single black hole detected.")
    if total_population >= 50:
        print("   Dead token cluster remained intact.")
        print("   NO late-stage fission observed.")
    else:
        print(f"   Population ({total_population}) < 51 suggests partial evaporation.")
        print("   Some dead tokens may have escaped via bfloat16 diffusion.")

else:  # num_black_holes > 1
    print(f"üéâ MULTIPLE BLACK HOLES DETECTED! ({num_black_holes} total)")
    print("   Late-stage fission CONFIRMED.")
    
    if total_population >= 50:
        print(f"   Total population ({total_population}) ‚âà 51 ‚Üí dead token cluster fragmented.")
        print("   This is consistent with bfloat16 diffusion hypothesis!")
    else:
        print(f"   Total population ({total_population}) < 51 suggests mixed dynamics:")
        print("   - Some fission (cluster broke apart)")
        print("   - Some evaporation (tokens escaped to singletons)")
    
    if num_black_holes > 10:
        print(f"\n   ‚ö†Ô∏è  VERY HIGH fragmentation ({num_black_holes} BHs).")
        print("   Possible cascading fission or spontaneous formation events.")


INTERPRETATION:

‚ö†Ô∏è  Single black hole detected.
   Dead token cluster remained intact.
   NO late-stage fission observed.


## Compare to Initial State

In [7]:
# Analyze step 0 for comparison
initial_embeddings = embedding_history[0].float()
initial_black_holes = find_black_holes(initial_embeddings, BLACK_HOLE_THRESHOLD)

initial_num = len(initial_black_holes)
initial_pop = sum(len(bh) for bh in initial_black_holes)

print(f"\n{'='*80}")
print("COMPARISON TO INITIAL STATE")
print(f"{'='*80}\n")

print(f"Step 0:")
print(f"  Black holes: {initial_num}")
print(f"  Population: {initial_pop}")

print(f"\nStep {final_step}:")
print(f"  Black holes: {num_black_holes}")
print(f"  Population: {total_population}")

print(f"\nChange:")
print(f"  Œî Black holes: {num_black_holes - initial_num:+d}")
print(f"  Œî Population: {total_population - initial_pop:+d}")

print(f"\n{'='*80}")


COMPARISON TO INITIAL STATE

Step 0:
  Black holes: 1
  Population: 128

Step 100000:
  Black holes: 1
  Population: 51

Change:
  Œî Black holes: +0
  Œî Population: -77

