# 13.4a: Gatsby Float32 Edition

**The Critical Experiment**

## The Question

Does **float32 initialization + bfloat16 training + gradient escape** naturally produce:
1. The topology we see in Qwen (complete graph, L∞ ≤ 2ε)?
2. The demographics we see in Qwen (13 black holes with specific populations)?

## The Setup

Train a toy GPT-2 on The Great Gatsby corpus:
- 128 ASCII tokens
- **~50 dead tokens** (control chars, math symbols never in Gatsby)
- Initialize embeddings in **float32**: random_unit_vector + Gaussian(0, σ)
- Train in **bfloat16** for 10,000 steps
- Live tokens escape via gradients
- Dead tokens stay frozen at f32→bf16 quantization boundaries

## What We'll Measure

At t=10,000:
- How many black holes?
- Population demographics [n₁, n₂, n₃, ...]
- L∞ topology (adjacency density)
- Stripe structure (bf16 quantization)

If this matches Qwen's structure → **we've found the mechanism**.

## Parameters

In [51]:
# Model architecture (small for speed, matches 11.2a)
VOCAB_SIZE = 128
HIDDEN_DIM = 64
N_LAYER = 2
N_HEAD = 2
MAX_SEQ_LEN = 128

# Training
BATCH_SIZE = 32
GRADIENT_ACCUMULATION = 1
NUM_TRAIN_STEPS = 1000
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.01

# **CRITICAL: Float32 initialization**
INIT_SIGMA = 1e-5  # Tunable! Start conservative

# Checkpointing
SAVE_EVERY_N_STEPS = 1000

# Data
CORPUS_PATH = "../data/training_corpus.txt"
OUTPUT_DIR = f"../data/embeddings_gatsby_f32_sigma{INIT_SIGMA:.0e}"
OUTPUT_FILE = "embedding_evolution.safetensors"

RANDOM_SEED = 42

## Imports

In [52]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, TrainerCallback
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file
import time

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
print("✓ Imports complete")

Using device: mps
✓ Imports complete


## Load Gatsby Corpus

In [53]:
print(f"Loading corpus: {CORPUS_PATH}")

with open(CORPUS_PATH, 'r', encoding='ascii') as f:
    corpus_text = f.read()

corpus_bytes = [b for b in corpus_text.encode('ascii') if b < VOCAB_SIZE]

unique_bytes = set(corpus_bytes)
dead_tokens = sorted(set(range(VOCAB_SIZE)) - unique_bytes)

print(f"  Total bytes: {len(corpus_bytes):,}")
print(f"  Unique tokens in corpus: {len(unique_bytes)} / {VOCAB_SIZE}")
print(f"  Dead tokens: {len(dead_tokens)} ({100 * len(dead_tokens) / VOCAB_SIZE:.1f}%)")
print(f"\nDead token IDs: {dead_tokens[:20]}{'...' if len(dead_tokens) > 20 else ''}")

# Pre-load to device
corpus_tensor = torch.tensor(corpus_bytes, dtype=torch.long, device=device)
print(f"\n✓ Corpus on device")

Loading corpus: ../data/training_corpus.txt
  Total bytes: 265,905
  Unique tokens in corpus: 77 / 128
  Dead tokens: 51 (39.8%)

Dead token IDs: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]...

✓ Corpus on device


## Dataset

In [54]:
class ByteDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = ByteDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"✓ Dataset: {len(dataset):,} examples")

✓ Dataset: 265,777 examples


## Model

In [55]:
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

model = GPT2LMHeadModel(config)
model = model.to(torch.bfloat16).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model created: {total_params:,} parameters")

✓ Model created: 116,480 parameters


## **CRITICAL: Float32 Initialization**

Initialize embeddings in **float32**, then convert to bfloat16 for training.

In [56]:
print(f"\nFloat32 initialization (σ = {INIT_SIGMA:.2e})\n")

with torch.no_grad():
    # Generate random unit vector in float32
    random_vector = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device)
    random_vector = random_vector / random_vector.norm()
    
    # Add Gaussian noise in float32
    noise = torch.randn(VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32, device=device) * INIT_SIGMA
    init_f32 = random_vector + noise
    
    # Convert to bfloat16 for training
    init_bf16 = init_f32.to(torch.bfloat16)
    
    # Assign to model
    model.transformer.wte.weight[:] = init_bf16
    
    # Stats
    f32_norms = torch.norm(init_f32, p=2, dim=1)
    bf16_norms = torch.norm(init_bf16.float(), p=2, dim=1)
    f32_centroid = init_f32.mean(dim=0)
    bf16_centroid = init_bf16.float().mean(dim=0)
    
    print(f"Float32 init:")
    print(f"  Base vector norm: {random_vector.norm().item():.6f}")
    print(f"  Token norms: {f32_norms.min().item():.6f} to {f32_norms.max().item():.6f}")
    print(f"  Centroid norm: {f32_centroid.norm().item():.6f}")
    
    print(f"\nAfter bf16 conversion:")
    print(f"  Token norms: {bf16_norms.min().item():.6f} to {bf16_norms.max().item():.6f}")
    print(f"  Centroid norm: {bf16_centroid.norm().item():.6f}")
    print(f"  Centroid shift: {(bf16_centroid - f32_centroid).norm().item():.6e}")
    
    # Store initial state for analysis
    initial_embeddings_f32 = init_f32.cpu()
    initial_embeddings_bf16 = init_bf16.cpu()

print(f"\n✓ Embeddings initialized in float32, converted to bfloat16")


Float32 initialization (σ = 1.00e-05)

Float32 init:
  Base vector norm: 1.000000
  Token norms: 0.999968 to 1.000023
  Centroid norm: 0.999999

After bf16 conversion:
  Token norms: 1.000109 to 1.000240
  Centroid norm: 1.000153
  Centroid shift: 1.203360e-03

✓ Embeddings initialized in float32, converted to bfloat16


## Pre-allocate Embedding History

In [57]:
embedding_history = torch.zeros(
    (NUM_TRAIN_STEPS + 1, VOCAB_SIZE, HIDDEN_DIM),
    dtype=torch.bfloat16
)

embedding_history[0] = model.transformer.wte.weight.data.clone().cpu()

history_size_mb = embedding_history.element_size() * embedding_history.numel() / 1e6
print(f"✓ Embedding history allocated: {history_size_mb:.1f} MB")

✓ Embedding history allocated: 16.4 MB


## Snapshot Callback

In [58]:
class EmbeddingSnapshotCallback(TrainerCallback):
    def __init__(self, embedding_history, output_dir, output_file, save_every_n):
        self.embedding_history = embedding_history
        self.output_path = Path(output_dir) / output_file
        self.save_every_n = save_every_n
        self.last_time = time.time()
        self.steps_since_print = 0
        
        Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        step = state.global_step
        
        # Store in memory
        self.embedding_history[step] = model.transformer.wte.weight.data.clone().cpu()
        
        # Save to disk periodically
        should_save = (step % self.save_every_n == 0) or (step == args.max_steps)
        if should_save:
            save_file(
                {'embedding_history': self.embedding_history[:step+1]},
                self.output_path
            )
        
        # Print every 1000 steps
        self.steps_since_print += 1
        if step % 1000 == 0 and step > 0:
            elapsed = time.time() - self.last_time
            throughput = self.steps_since_print / elapsed
            
            embeddings = self.embedding_history[step]
            centroid_norm = embeddings.float().mean(dim=0).norm().item()
            
            marker = "[SAVED]" if should_save else ""
            print(f"[{step:5d}] {throughput:6.1f} it/s | centroid: {centroid_norm:.6f} {marker}")
            
            self.last_time = time.time()
            self.steps_since_print = 0
        
        return control

print(f"✓ Callback ready")

✓ Callback ready


## Training Configuration

In [59]:
training_args = TrainingArguments(
    output_dir="./training_output",
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    logging_steps=100,
    save_steps=NUM_TRAIN_STEPS + 1,
    save_total_limit=0,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    bf16=True,
    seed=RANDOM_SEED,
    report_to="none",
    disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    callbacks=[EmbeddingSnapshotCallback(
        embedding_history,
        OUTPUT_DIR,
        OUTPUT_FILE,
        SAVE_EVERY_N_STEPS
    )],
)

print("✓ Trainer ready")

✓ Trainer ready


## Train

In [60]:
print(f"\n{'='*80}")
print(f"Starting training (σ = {INIT_SIGMA:.2e})")
print(f"{'='*80}\n")

start_time = time.time()
trainer.train()
elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"✓ Training complete ({elapsed/60:.1f} min)")
print(f"{'='*80}")


Starting training (σ = 1.00e-05)



Step,Training Loss
100,3.3425
200,3.0786
300,3.011
400,2.9101
500,2.8663
600,2.844
700,2.8238
800,2.8061
900,2.8041
1000,2.8012


[ 1000]  106.0 it/s | centroid: 0.828891 [SAVED]

✓ Training complete (0.2 min)


## Final Save

In [61]:
output_path = Path(OUTPUT_DIR) / OUTPUT_FILE
save_file({
    'embedding_history': embedding_history,
    'dead_token_ids': torch.tensor(dead_tokens, dtype=torch.long),
    'init_sigma': torch.tensor(INIT_SIGMA, dtype=torch.float32),
}, output_path)

print(f"✓ Saved to: {output_path}")
print(f"  File size: {output_path.stat().st_size / 1e6:.1f} MB")

✓ Saved to: ../data/embeddings_gatsby_f32_sigma1e-05/embedding_evolution.safetensors
  File size: 16.4 MB


## Analysis: Dead Token Structure at t=10,000

In [62]:
print(f"\n{'='*80}")
print(f"DEAD TOKEN ANALYSIS (t = {NUM_TRAIN_STEPS:,})")
print(f"{'='*80}\n")

# Extract dead token embeddings
final_embeddings = embedding_history[-1].float()
dead_embeddings = final_embeddings[dead_tokens]

print(f"Dead tokens: {len(dead_tokens)}")
print(f"Dead embedding shape: {dead_embeddings.shape}")

# Compute unique vectors (using exact equality for bf16)
unique_vectors = []
populations = []

for vec in dead_embeddings:
    found = False
    for i, unique_vec in enumerate(unique_vectors):
        if torch.equal(vec, unique_vec):
            populations[i] += 1
            found = True
            break
    if not found:
        unique_vectors.append(vec)
        populations.append(1)

# Sort by population (descending)
sorted_pops = sorted(populations, reverse=True)
black_holes = [p for p in sorted_pops if p >= 2]
singletons = [p for p in sorted_pops if p == 1]

print(f"\nBlack holes: {len(black_holes)}")
print(f"Singletons: {len(singletons)}")
print(f"Unique vectors: {len(unique_vectors)}")
print(f"\nDemographics: {sorted_pops[:20]}{'...' if len(sorted_pops) > 20 else ''}")


DEAD TOKEN ANALYSIS (t = 1,000)

Dead tokens: 51
Dead embedding shape: torch.Size([51, 64])

Black holes: 6
Singletons: 7
Unique vectors: 13

Demographics: [19, 14, 3, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1]


## Topology: Adjacency Graph

In [63]:
print(f"\n{'='*80}")
print(f"TOPOLOGY ANALYSIS")
print(f"{'='*80}\n")

# Stack unique vectors
n_unique = len(unique_vectors)

if n_unique == 0:
    print("No unique vectors found (all dead tokens identical or none exist)")
elif n_unique == 1:
    print("Only 1 unique vector (complete singularity - all dead tokens collapsed to one point)")
    print(f"  Population: {populations[0]}")
else:
    unique_stack = torch.stack(unique_vectors)
    
    # Compute pairwise L∞ distances
    print(f"Computing pairwise L∞ distances for {n_unique} unique vectors...")
    linf_matrix = torch.zeros(n_unique, n_unique)
    
    for i in range(n_unique):
        diff = torch.abs(unique_stack[i].unsqueeze(0) - unique_stack)
        linf_matrix[i] = torch.max(diff, dim=1)[0]
    
    # Define adjacency threshold (2× ULP at scale ~1)
    centroid = dead_embeddings.mean(dim=0)
    centroid_norm = centroid.norm().item()
    epsilon = centroid_norm * 2**(-7)  # bfloat16 ULP
    adjacency_threshold = 2 * epsilon
    
    print(f"\nCentroid norm: {centroid_norm:.6f}")
    print(f"ULP (ε): {epsilon:.6e}")
    print(f"Adjacency threshold (2ε): {adjacency_threshold:.6e}")
    
    # Adjacency matrix
    adjacency = (linf_matrix <= adjacency_threshold)
    n_edges = (adjacency.sum() - n_unique).item() / 2  # exclude diagonal, count each edge once
    max_edges = n_unique * (n_unique - 1) / 2
    density = n_edges / max_edges if max_edges > 0 else 0.0
    
    print(f"\nAdjacency graph:")
    print(f"  Nodes: {n_unique}")
    print(f"  Edges: {int(n_edges)} / {int(max_edges)}")
    print(f"  Density: {density:.6f}")
    
    if density >= 0.999:
        print(f"\n  ✓ FULLY CONNECTED (complete graph!)")
    elif density >= 0.5:
        print(f"\n  ~ DENSELY CONNECTED")
    else:
        print(f"\n  ✗ SPARSE")
    
    # L∞ statistics
    upper_tri_indices = torch.triu_indices(n_unique, n_unique, offset=1)
    pairwise_linf = linf_matrix[upper_tri_indices[0], upper_tri_indices[1]]
    
    print(f"\nPairwise L∞ distances:")
    print(f"  Min: {pairwise_linf.min().item():.6e}")
    print(f"  Max: {pairwise_linf.max().item():.6e}")
    print(f"  Mean: {pairwise_linf.mean().item():.6e}")
    print(f"  Max / ε: {(pairwise_linf.max().item() / epsilon):.2f} ULP")


TOPOLOGY ANALYSIS

Computing pairwise L∞ distances for 13 unique vectors...

Centroid norm: 0.842696
ULP (ε): 6.583559e-03
Adjacency threshold (2ε): 1.316712e-02

Adjacency graph:
  Nodes: 13
  Edges: 78 / 78
  Density: 1.000000

  ✓ FULLY CONNECTED (complete graph!)

Pairwise L∞ distances:
  Min: 2.441406e-04
  Max: 9.765625e-04
  Mean: 8.670122e-04
  Max / ε: 0.15 ULP


## Quantization Check: Are Dead Tokens on BF16 Lattice?

In [64]:
print(f"\n{'='*80}")
print(f"QUANTIZATION VERIFICATION")
print(f"{'='*80}\n")

if len(dead_embeddings) == 0:
    print("No dead token embeddings to verify")
else:
    # Round-trip test: bf16 → f32 → bf16
    dead_bf16 = dead_embeddings.to(torch.bfloat16)
    roundtrip = dead_bf16.float()
    
    max_diff = (dead_embeddings - roundtrip).abs().max().item()
    
    if max_diff == 0.0:
        print("✓ CONFIRMED: Dead tokens are bfloat16-quantized (bit-for-bit match)")
    else:
        print(f"✗ NOT quantized (max diff: {max_diff:.6e})")
    
    # Component value analysis
    all_components = dead_embeddings.flatten()
    unique_components = torch.unique(all_components)
    
    print(f"\nComponent statistics:")
    print(f"  Total components: {len(all_components):,}")
    print(f"  Unique values: {len(unique_components):,}")
    print(f"  Range: [{all_components.min().item():.6e}, {all_components.max().item():.6e}]")


QUANTIZATION VERIFICATION

✓ CONFIRMED: Dead tokens are bfloat16-quantized (bit-for-bit match)

Component statistics:
  Total components: 3,264
  Unique values: 64
  Range: [-4.179688e-01, 2.080078e-01]


## Summary

In [65]:
print(f"\n{'='*80}")
print(f"SUMMARY")
print(f"{'='*80}\n")

print(f"Initialization: σ = {INIT_SIGMA:.2e} (float32 → bfloat16)")
print(f"Training: {NUM_TRAIN_STEPS:,} steps")
print(f"\nDead tokens ({len(dead_tokens)}):")
print(f"  Black holes: {len(black_holes)}")
print(f"  Singletons: {len(singletons)}")
print(f"  Demographics: {sorted_pops[:15]}{'...' if len(sorted_pops) > 15 else ''}")

if n_unique > 1:
    print(f"\nTopology:")
    print(f"  Unique vectors: {n_unique}")
    print(f"  Adjacency density: {density:.6f}")
    print(f"  Max L∞: {pairwise_linf.max().item():.6e} ({pairwise_linf.max().item() / epsilon:.2f} ULP)")
elif n_unique == 1:
    print(f"\nTopology: COMPLETE SINGULARITY (all {populations[0]} dead tokens at one point)")

if len(dead_embeddings) > 0:
    print(f"\nQuantization: {'✓ bfloat16' if max_diff == 0.0 else '✗ NOT bf16'}")

print(f"\n{'='*80}")


SUMMARY

Initialization: σ = 1.00e-05 (float32 → bfloat16)
Training: 1,000 steps

Dead tokens (51):
  Black holes: 6
  Singletons: 7
  Demographics: [19, 14, 3, 3, 3, 2, 1, 1, 1, 1, 1, 1, 1]

Topology:
  Unique vectors: 13
  Adjacency density: 1.000000
  Max L∞: 9.765625e-04 (0.15 ULP)

Quantization: ✓ bfloat16



## Interpretation

**If this matches Qwen:**
- Black hole count ~13
- Demographics similar to [814, 704, 306, 228, ...]
- Adjacency density = 1.0 (complete graph)
- Max L∞ ≤ 2ε
- bfloat16-quantized

**Then we've found the mechanism:**
- Float32 initialization creates slight variation
- bfloat16 training quantizes to lattice
- Dead tokens stay frozen at f32→bf16 boundaries
- Demographics reflect token frequency in corpus

**Next steps if successful:**
- Sweep σ to find best match
- Scale up to larger vocab (2,221 tokens like Qwen)
- Test on other corpora