# 08.3b: Singularity Survey for Trained Embeddings

**Search embedding matrix for singularities (bit-for-bit identical vectors)**

After training, we expect:
- Tokens that appeared in training: unique vectors (moved by gradients)
- Dead tokens (never appeared): collapsed to 1-5 unique vectors (frozen at initialization ± quantization noise)

This notebook uses `torch.unique()` for O(N log N) hash-based deduplication to find groups of tokens sharing identical vectors.

## Parameters

In [1]:
# Input: embedding matrix to analyze
TENSOR_DIR = "../data/embeddings_128vocab_qweninit"
EMBEDDING_FILE = "step_0005000.safetensors"
EMBEDDING_KEY = "embeddings"

# Display options
MAX_TOKENS_PER_GROUP = 20  # Limit display for very large singularity groups

RANDOM_SEED = 42

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file
from pathlib import Path
from collections import defaultdict

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

## Load Embedding Matrix

In [3]:
data_dir = Path(TENSOR_DIR)
embedding_path = data_dir / EMBEDDING_FILE

print(f"Loading embeddings from: {embedding_path}\n")

data = load_file(embedding_path)
gamma = data[EMBEDDING_KEY]
vocab_size, hidden_dim = gamma.shape

print(f"✓ Embeddings loaded")
print(f"Shape: {gamma.shape}")
print(f"Vocabulary size: {vocab_size:,}")
print(f"Hidden dimension: {hidden_dim:,}")
print(f"Total parameters: {vocab_size * hidden_dim:,}")
print(f"Memory footprint: {gamma.element_size() * gamma.numel() / 1e6:.2f} MB")

Loading embeddings from: ../data/embeddings_128vocab_qweninit/step_0005000.safetensors

✓ Embeddings loaded
Shape: torch.Size([128, 64])
Vocabulary size: 128
Hidden dimension: 64
Total parameters: 8,192
Memory footprint: 0.03 MB


## Find Singularities

Use `torch.unique()` to find all unique vectors and identify groups of tokens that share identical vectors.

In [4]:
print(f"\nSearching for singularities...\n")

# Find unique vectors
unique_vectors, inverse_indices, counts = torch.unique(
    gamma,
    dim=0,
    return_inverse=True,
    return_counts=True
)

n_unique = len(unique_vectors)
n_total = vocab_size
n_duplicate = n_total - n_unique

print(f"Total tokens: {n_total:,}")
print(f"Unique vectors: {n_unique:,}")
print(f"Duplicate tokens: {n_duplicate:,}")
print(f"Uniqueness: {100 * n_unique / n_total:.2f}%\n")

if n_duplicate == 0:
    print("✓ No singularities found. Every token has a unique vector.")
else:
    print(f"⚠ Found {n_duplicate:,} duplicate tokens")


Searching for singularities...

Total tokens: 128
Unique vectors: 78
Duplicate tokens: 50
Uniqueness: 60.94%

⚠ Found 50 duplicate tokens


## Analyze Singularity Groups

Group tokens by their shared vector and report statistics.

In [5]:
if n_duplicate > 0:
    # Build map from unique vector index to list of token IDs
    singularity_groups = defaultdict(list)
    
    for token_id, unique_idx in enumerate(inverse_indices.tolist()):
        if counts[unique_idx] > 1:  # Only include vectors shared by 2+ tokens
            singularity_groups[unique_idx].append(token_id)
    
    n_groups = len(singularity_groups)
    group_sizes = [len(tokens) for tokens in singularity_groups.values()]
    
    print(f"\nSingularity groups: {n_groups:,}")
    print(f"Largest group: {max(group_sizes):,} tokens")
    print(f"Smallest group: {min(group_sizes):,} tokens")
    print(f"Mean group size: {np.mean(group_sizes):.1f} tokens")
    print(f"Median group size: {np.median(group_sizes):.1f} tokens")
    
    # Histogram of group sizes
    size_counts = defaultdict(int)
    for size in group_sizes:
        size_counts[size] += 1
    
    print("\nGroup size distribution:")
    for size in sorted(size_counts.keys()):
        count = size_counts[size]
        print(f"  {size:4d} tokens: {count:4d} groups")


Singularity groups: 1
Largest group: 51 tokens
Smallest group: 51 tokens
Mean group size: 51.0 tokens
Median group size: 51.0 tokens

Group size distribution:
    51 tokens:    1 groups


## Display Singularity Token IDs

Show which token IDs are in each singularity group.

In [6]:
if n_duplicate > 0:
    # Sort groups by size (largest first)
    sorted_groups = sorted(
        singularity_groups.items(),
        key=lambda x: len(x[1]),
        reverse=True
    )
    
    print(f"\nDisplaying up to {MAX_TOKENS_PER_GROUP} tokens per group\n")
    print("=" * 80)
    
    for group_idx, (unique_idx, token_ids) in enumerate(sorted_groups, 1):
        n_tokens = len(token_ids)
        print(f"\nGroup {group_idx}/{n_groups}: {n_tokens} tokens sharing vector #{unique_idx}")
        print("-" * 80)
        
        # Show first N tokens
        display_tokens = token_ids[:MAX_TOKENS_PER_GROUP]
        
        # Show as byte values and ASCII chars where printable
        print(f"  Token IDs (byte values): {display_tokens}")
        
        # Show ASCII characters
        chars = []
        for token_id in display_tokens:
            if 32 <= token_id < 127:  # Printable ASCII
                chars.append(chr(token_id))
            else:
                chars.append(f"\\x{token_id:02x}")
        
        print(f"  Characters: {chars}")
        
        if n_tokens > MAX_TOKENS_PER_GROUP:
            print(f"  ... and {n_tokens - MAX_TOKENS_PER_GROUP} more tokens")
        
        print("=" * 80)


Displaying up to 20 tokens per group


Group 1/1: 51 tokens sharing vector #67
--------------------------------------------------------------------------------
  Token IDs (byte values): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
  Characters: ['\\x00', '\\x01', '\\x02', '\\x03', '\\x04', '\\x05', '\\x06', '\\x07', '\\x08', '\\x09', '\\x0b', '\\x0c', '\\x0d', '\\x0e', '\\x0f', '\\x10', '\\x11', '\\x12', '\\x13', '\\x14']
  ... and 31 more tokens


## Compute Black Hole Geometry

If singularities exist, compute their geometric properties:
- L2 norms (distance from origin)
- Pairwise distances (how spread out are they?)

In [7]:
if n_duplicate > 0:
    print(f"\n{'='*80}")
    print("BLACK HOLE GEOMETRY")
    print(f"{'='*80}\n")
    
    # Get one representative from each singularity group
    black_hole_vectors = unique_vectors[counts > 1]
    
    # Compute L2 norms
    norms = torch.norm(black_hole_vectors, p=2, dim=1)
    
    print(f"Black hole L2 norms:")
    print(f"  Min: {norms.min().item():.6f}")
    print(f"  Max: {norms.max().item():.6f}")
    print(f"  Mean: {norms.mean().item():.6f}")
    print(f"  Std: {norms.std().item():.6f}")
    
    # Compute pairwise distances
    if len(black_hole_vectors) > 1:
        v1 = black_hole_vectors.unsqueeze(1)  # (n, 1, d)
        v2 = black_hole_vectors.unsqueeze(0)  # (1, n, d)
        diffs = v1 - v2  # (n, n, d)
        
        # L2 distances
        l2_distances = torch.norm(diffs, p=2, dim=2)
        
        # Mask out diagonal
        mask = ~torch.eye(len(black_hole_vectors), dtype=torch.bool)
        l2_nonzero = l2_distances[mask]
        
        print(f"\nPairwise L2 distances between black holes:")
        print(f"  Min: {l2_nonzero.min().item():.6e}")
        print(f"  Max: {l2_nonzero.max().item():.6e}")
        print(f"  Mean: {l2_nonzero.mean().item():.6e}")
        print(f"  Median: {l2_nonzero.median().item():.6e}")
    
    # Compare to full cloud
    full_norms = torch.norm(gamma, p=2, dim=1)
    centroid = gamma.mean(dim=0)
    centroid_norm = centroid.norm().item()
    
    print(f"\nComparison to full cloud:")
    print(f"  Full cloud L2 norms: mean={full_norms.mean().item():.6f}, std={full_norms.std().item():.6f}")
    print(f"  Centroid L2 norm: {centroid_norm:.6f}")
    print(f"  Black holes vs full cloud: {norms.mean().item() / full_norms.mean().item():.4f}× mean norm")


BLACK HOLE GEOMETRY

Black hole L2 norms:
  Min: 7.529632
  Max: 7.529632
  Mean: 7.529632
  Std: nan

Comparison to full cloud:
  Full cloud L2 norms: mean=7.749794, std=0.274956
  Centroid L2 norm: 7.731306
  Black holes vs full cloud: 0.9716× mean norm


  print(f"  Std: {norms.std().item():.6f}")


## Summary

In [8]:
print(f"\n{'='*80}")
print("SUMMARY")
print(f"{'='*80}")
print(f"Embedding file: {EMBEDDING_FILE}")
print(f"Vocabulary size: {vocab_size:,}")
print(f"Unique vectors: {n_unique:,}")
print(f"Duplicate tokens: {n_duplicate:,} ({100 * n_duplicate / n_total:.2f}%)")

if n_duplicate > 0:
    print(f"Singularity groups: {n_groups:,}")
    print(f"Largest singularity: {max(group_sizes):,} tokens")
    print(f"Deduplication ratio: {n_total / n_unique:.2f}x")
else:
    print("No singularities detected.")

print(f"{'='*80}")


SUMMARY
Embedding file: step_0005000.safetensors
Vocabulary size: 128
Unique vectors: 78
Duplicate tokens: 50 (39.06%)
Singularity groups: 1
Largest singularity: 51 tokens
Deduplication ratio: 1.64x
