# 07.3b: Identify Black Hole Tokens

**Goal:** Find all tokens with degenerate (bit-for-bit identical) embedding vectors and save as a boolean mask.

Black holes are tokens that share the exact same vector with at least one other token. These degeneracies create extreme density hotspots that can overwhelm visualizations.

**Output:** `black_hole_mask.safetensors` with boolean tensor where `True` = black hole token (exclude from plots).

## Parameters

In [1]:
TENSOR_DIR = "../data/tensors"
GAMMA_FILE = "gamma_qwen3_4b_instruct_2507.safetensors"
GAMMA_KEY = "gamma"

OUTPUT_FILE = "black_hole_mask.safetensors"
OUTPUT_KEY = "mask"

## Imports

In [2]:
import torch
from safetensors.torch import load_file, save_file
from pathlib import Path

print("Imports loaded successfully.")

Imports loaded successfully.


## Load Gamma Matrix

In [3]:
data_dir = Path(TENSOR_DIR)

print("Loading gamma matrix...")
gamma_data = load_file(data_dir / GAMMA_FILE)
gamma = gamma_data[GAMMA_KEY]
N, d = gamma.shape

print(f"  Shape: ({N:,}, {d:,})")
print(f"  Dtype: {gamma.dtype}")
print()

Loading gamma matrix...
  Shape: (151,936, 2,560)
  Dtype: torch.float32



## Find Degenerate Tokens

Strategy: Use a hash-based approach to find duplicate vectors efficiently.

For each unique vector, if it appears more than once, all tokens with that vector are black holes.

In [4]:
print("Finding degenerate tokens...\n")

# Create a dictionary mapping vector hashes to token IDs
from collections import defaultdict

vector_to_tokens = defaultdict(list)

print("Hashing vectors...")
for token_id in range(N):
    # Convert vector to bytes for hashing
    vec_bytes = gamma[token_id].cpu().numpy().tobytes()
    vector_to_tokens[vec_bytes].append(token_id)

print(f"  Found {len(vector_to_tokens):,} unique vectors\n")

# Find all tokens that share a vector with at least one other token
black_hole_tokens = []
degenerate_groups = []

for vec_bytes, token_ids in vector_to_tokens.items():
    if len(token_ids) > 1:
        # This is a degenerate group
        black_hole_tokens.extend(token_ids)
        degenerate_groups.append(token_ids)

print(f"Degeneracy analysis:")
print(f"  Unique vectors: {len(vector_to_tokens):,}")
print(f"  Degenerate groups: {len(degenerate_groups):,}")
print(f"  Total black hole tokens: {len(black_hole_tokens):,} ({len(black_hole_tokens)/N*100:.2f}%)")
print()

# Show largest degenerate groups
degenerate_groups_sorted = sorted(degenerate_groups, key=len, reverse=True)
print("Largest degenerate groups:")
for i, group in enumerate(degenerate_groups_sorted[:10]):
    print(f"  {i+1}. {len(group):,} tokens")
print()

Finding degenerate tokens...

Hashing vectors...
  Found 149,849 unique vectors

Degeneracy analysis:
  Unique vectors: 149,849
  Degenerate groups: 13
  Total black hole tokens: 2,100 (1.38%)

Largest degenerate groups:
  1. 814 tokens
  2. 704 tokens
  3. 306 tokens
  4. 228 tokens
  5. 11 tokens
  6. 10 tokens
  7. 6 tokens
  8. 5 tokens
  9. 4 tokens
  10. 4 tokens



## Create Boolean Mask

In [5]:
print("Creating boolean mask...")

# Create mask: True = black hole (exclude), False = normal token (include)
mask = torch.zeros(N, dtype=torch.bool)
mask[black_hole_tokens] = True

print(f"  Mask shape: {mask.shape}")
print(f"  Tokens to exclude: {mask.sum().item():,}")
print(f"  Tokens to include: {(~mask).sum().item():,}")
print()

Creating boolean mask...
  Mask shape: torch.Size([151936])
  Tokens to exclude: 2,100
  Tokens to include: 149,836



## Save Mask

In [6]:
output_path = data_dir / OUTPUT_FILE

print(f"Saving mask to {output_path}...")
save_file({OUTPUT_KEY: mask}, output_path)

print("✓ Black hole mask saved successfully.")
print()
print(f"To use in 07.2a, set:")
print(f"  MASK_FILE = '{OUTPUT_FILE}'")
print(f"  MASK_KEY = '{OUTPUT_KEY}'")

Saving mask to ../data/tensors/black_hole_mask.safetensors...
✓ Black hole mask saved successfully.

To use in 07.2a, set:
  MASK_FILE = 'black_hole_mask.safetensors'
  MASK_KEY = 'mask'
