# Crucible 1: The Full Evolution

A clean-slate experiment to capture the complete lifecycle of dead token dynamics, from supernova through Fimbulwinter.

**What's new:**
- Hyperparameters tuned for faster freezing (lr=1e-3, no weight decay)
- Lattice displacement ΔW′ computed live, not post-hoc
- Last-motion tracker to pinpoint when each token freezes
- Centroid tracking for cloud drift analysis

**Data captured:**
- W[t]: dead token embeddings (bfloat16)
- ΔW′[t]: displacement in lattice-cell units (float32)
- centroid[t]: cloud center of mass (float32)
- loss[t]: training loss
- last_motion_step: per-token freeze time (computed in-place)

**Goal:** Validate that we capture the full evolution—supernova, cooling, stumbling, Fimbulwinter—with statistical confidence that we've reached the true end state.

## Parameters

In [1]:
# Training parameters
TOTAL_STEPS = 5000
BATCH_SIZE = 128
SEQ_LEN = 128
LEARNING_RATE = 1e-3      # Faster than Thimble 8/9, matches Thimble 7
WEIGHT_DECAY = 0.0        # No weight decay—removes confounding variable
BETA1 = 0.9
BETA2 = 0.999
EPSILON = 1e-8

# Model parameters
VOCAB_SIZE = 10000
HIDDEN_DIM = 64
NUM_LAYERS = 2
NUM_HEADS = 2

# Reproducibility
RANDOM_SEED = 42

# Motion detection threshold (for last_motion_step tracking)
MOTION_THRESHOLD = 0.1  # |ΔW′| below this counts as "frozen"

# Paths
CORPUS_PATH = '../../data/flannel_model_corpus.txt'
TOKENIZER_PATH = '../../data/flannel_tokenizer_chars.json'
DEAD_MASK_PATH = '../../tensors/Flannel/live_dead_tokens.safetensors'
OUTPUT_DIR = '../../tensors/Crucible-1'

print(f"Crucible 1: {TOTAL_STEPS} steps, lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}")

Crucible 1: 5000 steps, lr=0.001, weight_decay=0.0


## Imports

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from safetensors.torch import save_file, load_file
from pathlib import Path
from tqdm import tqdm
import json
import random
import numpy as np

## Device Detection

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Set Random Seeds

In [4]:
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

if device == 'cuda':
    torch.cuda.manual_seed_all(RANDOM_SEED)

print(f"Random seed set to {RANDOM_SEED}")

Random seed set to 42


## Load Tokenizer and Dead Token Mask

In [5]:
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
print(f"Loaded tokenizer with vocab size {tokenizer.get_vocab_size()}")

masks = load_file(DEAD_MASK_PATH)
dead_mask = masks['dead_mask'].bool()
dead_indices = masks['dead_indices'].long()
n_dead = dead_mask.sum().item()

print(f"Dead tokens: {n_dead}")
print(f"Live tokens: {VOCAB_SIZE - n_dead}")

Loaded tokenizer with vocab size 10000
Dead tokens: 3699
Live tokens: 6301


## Dataset

In [6]:
class TextDataset(Dataset):
    def __init__(self, corpus_path, tokenizer, seq_len):
        with open(corpus_path, 'r', encoding='utf-8') as f:
            text = f.read()
        encoding = tokenizer.encode(text)
        self.tokens = encoding.ids
        self.seq_len = seq_len
        self.n_sequences = len(self.tokens) // seq_len
        
    def __len__(self):
        return self.n_sequences
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        chunk = self.tokens[start:start + self.seq_len]
        return torch.tensor(chunk, dtype=torch.long)


dataset = TextDataset(CORPUS_PATH, tokenizer, SEQ_LEN)
print(f"Dataset: {len(dataset)} sequences of length {SEQ_LEN}")

Dataset: 10713 sequences of length 128


## Model Definition

In [7]:
class TinyLM(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, seq_len):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(seq_len, hidden_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.0,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln_f = nn.LayerNorm(hidden_dim)
        
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.pos_embedding.weight, mean=0.0, std=0.02)
        
    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        
        tok_emb = self.embedding(input_ids)
        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        
        hidden = tok_emb + pos_emb
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=input_ids.device)
        hidden = self.transformer(hidden, mask=causal_mask, is_causal=True)
        hidden = self.ln_f(hidden)
        
        logits = hidden @ self.embedding.weight.T
        return logits


model = TinyLM(VOCAB_SIZE, HIDDEN_DIM, NUM_LAYERS, NUM_HEADS, SEQ_LEN)
model = model.to(device).to(torch.bfloat16)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Model parameters: 748,288




## Optimizer

In [8]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(BETA1, BETA2),
    eps=EPSILON,
    weight_decay=WEIGHT_DECAY
)

print(f"Optimizer: AdamW (lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY})")

Optimizer: AdamW (lr=0.001, weight_decay=0.0)


## ULP Computation

In [9]:
def compute_ulp_bf16(tensor_bf16: torch.Tensor) -> torch.Tensor:
    """Compute ULP for each element of a bfloat16 tensor."""
    bits = tensor_bf16.view(torch.uint16).to(torch.int32)
    exponent = (bits >> 7) & 0xFF
    effective_exp = torch.where(exponent == 0, torch.ones_like(exponent), exponent)
    ulp = torch.pow(2.0, (effective_exp - 134).float())
    return ulp

## Data Collection Setup

In [10]:
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)

# Pre-allocate tensors
W_history = torch.zeros(TOTAL_STEPS + 1, n_dead, HIDDEN_DIM, dtype=torch.bfloat16)
delta_W_prime_history = torch.zeros(TOTAL_STEPS, n_dead, HIDDEN_DIM, dtype=torch.float32)
centroid_history = torch.zeros(TOTAL_STEPS + 1, HIDDEN_DIM, dtype=torch.float32)
loss_history = torch.zeros(TOTAL_STEPS + 1, dtype=torch.float32)

# Last motion tracker: -1 means "never moved" (shouldn't happen), updated each step
last_motion_step = torch.full((n_dead,), -1, dtype=torch.int32)

# Memory estimate
total_bytes = (
    W_history.numel() * 2 +
    delta_W_prime_history.numel() * 4 +
    centroid_history.numel() * 4 +
    loss_history.numel() * 4 +
    last_motion_step.numel() * 4
)
print(f"Pre-allocated {total_bytes / 1e9:.2f} GB for data collection")

Pre-allocated 7.10 GB for data collection


## Training Loop

In [11]:
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

loss_fn = nn.CrossEntropyLoss()

# Capture initial state (t=0)
W_prev = model.embedding.weight.detach().cpu()[dead_mask].to(torch.bfloat16)
W_history[0] = W_prev
centroid_history[0] = W_prev.float().mean(dim=0)
loss_history[0] = float('nan')

print(f"Starting training for {TOTAL_STEPS} steps...")

model.train()
step = 0
epoch = 0

pbar = tqdm(total=TOTAL_STEPS, desc="Training")

while step < TOTAL_STEPS:
    epoch += 1
    
    for batch in dataloader:
        if step >= TOTAL_STEPS:
            break
            
        input_ids = batch.to(device)
        
        with torch.autocast(device_type=device if device != 'mps' else 'cpu', dtype=torch.bfloat16):
            logits = model(input_ids)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            loss = loss_fn(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        step += 1
        
        # === Capture state ===
        W_curr = model.embedding.weight.detach().cpu()[dead_mask].to(torch.bfloat16)
        
        # W[t]
        W_history[step] = W_curr
        
        # ΔW = W[t] - W[t-1]
        delta_W = W_curr.float() - W_prev.float()
        
        # ULP at W[t-1]
        ulp = compute_ulp_bf16(W_prev)
        
        # ΔW′ = ΔW / ULP (lattice displacement)
        delta_W_prime = delta_W / ulp
        delta_W_prime_history[step - 1] = delta_W_prime  # step-1 because ΔW′ is indexed from 0
        
        # Centroid
        centroid_history[step] = W_curr.float().mean(dim=0)
        
        # Loss
        loss_history[step] = loss.item()
        
        # Update last_motion_step for tokens that moved
        displacement_magnitude = torch.norm(delta_W_prime, dim=1)  # (n_dead,)
        moved = displacement_magnitude >= MOTION_THRESHOLD
        last_motion_step[moved] = step
        
        # Update W_prev for next iteration
        W_prev = W_curr
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'epoch': epoch})
        pbar.update(1)

pbar.close()
print(f"\nTraining complete. Final loss: {loss.item():.4f}")
print(f"Completed {epoch} epochs")

Starting training for 5000 steps...


Training: 100%|██████████| 5000/5000 [05:17<00:00, 15.73it/s, loss=5.2500, epoch=61]



Training complete. Final loss: 5.2500
Completed 61 epochs


## Quick Fimbulwinter Check

In [12]:
print("=" * 60)
print("FIMBULWINTER CHECK")
print("=" * 60)

# Last motion across all tokens
global_last_motion = last_motion_step.max().item()
fimbulwinter_duration = TOTAL_STEPS - global_last_motion

print(f"Last motion (any token): step {global_last_motion}")
print(f"Fimbulwinter duration: {fimbulwinter_duration} steps")
print(f"Fimbulwinter fraction: {fimbulwinter_duration / TOTAL_STEPS:.1%}")

# Distribution of freeze times
print(f"\nFreeze time distribution:")
print(f"  Earliest freeze: step {last_motion_step.min().item()}")
print(f"  Median freeze: step {last_motion_step.float().median().item():.0f}")
print(f"  Latest freeze: step {last_motion_step.max().item()}")

# Check for any tokens that never moved (shouldn't happen)
never_moved = (last_motion_step == -1).sum().item()
if never_moved > 0:
    print(f"\n⚠️  {never_moved} tokens never moved (unexpected!)")
else:
    print(f"\n✓ All {n_dead} tokens moved at least once")

FIMBULWINTER CHECK
Last motion (any token): step 2509
Fimbulwinter duration: 2491 steps
Fimbulwinter fraction: 49.8%

Freeze time distribution:
  Earliest freeze: step 560
  Median freeze: step 1543
  Latest freeze: step 2509

✓ All 3699 tokens moved at least once


## Stillness Period Analysis

In [13]:
# Compute displacement magnitude per step
displacement_per_step = torch.norm(delta_W_prime_history, dim=2)  # (TOTAL_STEPS, n_dead)
any_motion = (displacement_per_step >= MOTION_THRESHOLD).any(dim=1)  # (TOTAL_STEPS,) bool

# Find runs of stillness
stillness_runs = []
current_run = 0

for moved in any_motion:
    if not moved:
        current_run += 1
    else:
        if current_run > 0:
            stillness_runs.append(current_run)
        current_run = 0

if current_run > 0:
    stillness_runs.append(current_run)

stillness_runs = sorted(stillness_runs, reverse=True)

print("=" * 60)
print("STILLNESS PERIOD ANALYSIS")
print("=" * 60)
print(f"Total stillness periods: {len(stillness_runs)}")
print(f"\nLongest 10:")
for i, run in enumerate(stillness_runs[:10]):
    print(f"  {i+1}. {run} steps")

if len(stillness_runs) >= 2:
    print(f"\nFinal vs second-longest: {stillness_runs[0]} / {stillness_runs[1]} = {stillness_runs[0]/stillness_runs[1]:.1f}x")

print(f"\nMean stillness: {sum(stillness_runs)/len(stillness_runs):.1f} steps")
print(f"Median stillness: {sorted(stillness_runs)[len(stillness_runs)//2]} steps")

STILLNESS PERIOD ANALYSIS
Total stillness periods: 58

Longest 10:
  1. 2491 steps
  2. 63 steps
  3. 43 steps
  4. 37 steps
  5. 34 steps
  6. 33 steps
  7. 25 steps
  8. 22 steps
  9. 20 steps
  10. 19 steps

Final vs second-longest: 2491 / 63 = 39.5x

Mean stillness: 51.8 steps
Median stillness: 4 steps


## Save Data

In [14]:
data_to_save = {
    'W': W_history.view(torch.uint16),  # (5001, 3699, 64) uint16
    'delta_W_prime': delta_W_prime_history,  # (5000, 3699, 64) float32
    'centroid': centroid_history,  # (5001, 64) float32
    'loss': loss_history,  # (5001,) float32
    'last_motion_step': last_motion_step,  # (3699,) int32
    'dead_mask': dead_mask,
    'dead_indices': dead_indices,
}

save_path = output_path / 'crucible_1_trajectory.safetensors'
save_file(data_to_save, str(save_path))

print(f"Saved to {save_path}")
print(f"File size: {save_path.stat().st_size / 1e9:.2f} GB")

Saved to ../../tensors/Crucible-1/crucible_1_trajectory.safetensors
File size: 7.10 GB


## Save Metadata

In [15]:
metadata = {
    'experiment': 'Crucible 1',
    'series': 'Crucible',
    'date': '2025-11-25',
    'total_steps': TOTAL_STEPS,
    'batch_size': BATCH_SIZE,
    'seq_len': SEQ_LEN,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'beta1': BETA1,
    'beta2': BETA2,
    'epsilon': EPSILON,
    'vocab_size': VOCAB_SIZE,
    'hidden_dim': HIDDEN_DIM,
    'num_layers': NUM_LAYERS,
    'num_heads': NUM_HEADS,
    'random_seed': RANDOM_SEED,
    'motion_threshold': MOTION_THRESHOLD,
    'n_dead_tokens': n_dead,
    'final_loss': loss_history[-1].item(),
    'total_epochs': epoch,
    'device': device,
    'global_last_motion': int(global_last_motion),
    'fimbulwinter_duration': int(fimbulwinter_duration),
    'data_shapes': {
        'W': list(W_history.shape),
        'delta_W_prime': list(delta_W_prime_history.shape),
        'centroid': list(centroid_history.shape),
        'loss': list(loss_history.shape),
        'last_motion_step': list(last_motion_step.shape),
    },
    'notes': 'First Crucible experiment. Fast lr, no weight decay. Full lifecycle capture.'
}

with open(output_path / 'metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print("Saved metadata.json")
print(f"\n{'='*60}")
print("CRUCIBLE 1 COMPLETE")
print(f"{'='*60}")

Saved metadata.json

CRUCIBLE 1 COMPLETE
