# Crucible 2: High-Resolution Phase Transition

A 500-step experiment with full instrumentation to study the hot-to-cold transition.

**What's new vs Crucible 1:**
- Shorter run (500 steps) focused on the transition period
- Records optimizer state: m[t] (momentum) and v[t] (variance)
- Records raw gradients g[t] before Adam processing
- Enables analysis of phase space (m, v) dynamics and gradient structure

**Data captured:**
- W[t]: dead token embeddings (bfloat16 as uint16)
- m[t]: Adam momentum (float32)
- v[t]: Adam variance (float32)
- g[t]: raw gradients (float32)
- ΔW′[t]: displacement in lattice-cell units (float32)
- loss[t]: training loss

**Known issue:** ΔW′ contains `inf` values where W[t-1] ≈ 0. The ULP calculation gives ULP = 2^(-134) for zero weights, so ΔW/ULP explodes. These are edge cases where "displacement in lattice units" is undefined. Filter with `torch.isfinite()` in downstream analysis.

**Goal:** Understand *why* tokens transition in cohorts. Are they clustered in (m, v) phase space near the critical boundary? Are their gradients aligned?

## Parameters

In [1]:
# Training parameters
TOTAL_STEPS = 500
BATCH_SIZE = 128
SEQ_LEN = 128
LEARNING_RATE = 1e-3      # Same as Crucible 1
WEIGHT_DECAY = 0.0        # Same as Crucible 1
BETA1 = 0.9
BETA2 = 0.999
EPSILON = 1e-8

# Model parameters
VOCAB_SIZE = 10000
HIDDEN_DIM = 64
NUM_LAYERS = 2
NUM_HEADS = 2

# Reproducibility — SAME SEED as Crucible 1 for comparability
RANDOM_SEED = 42

# Motion detection threshold
MOTION_THRESHOLD = 0.1

# Paths
CORPUS_PATH = '../../data/flannel_model_corpus.txt'
TOKENIZER_PATH = '../../data/flannel_tokenizer_chars.json'
DEAD_MASK_PATH = '../../tensors/Flannel/live_dead_tokens.safetensors'
OUTPUT_DIR = '../../tensors/Crucible-2'

print(f"Crucible 2: {TOTAL_STEPS} steps, lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}")
print(f"High-resolution recording of W, m, v, g")

Crucible 2: 500 steps, lr=0.001, weight_decay=0.0
High-resolution recording of W, m, v, g


## Imports

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from safetensors.torch import save_file, load_file
from pathlib import Path
from tqdm import tqdm
import json
import random
import numpy as np

## Device Detection

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Set Random Seeds

In [4]:
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

if device == 'cuda':
    torch.cuda.manual_seed_all(RANDOM_SEED)

print(f"Random seed set to {RANDOM_SEED}")

Random seed set to 42


## Load Tokenizer and Dead Token Mask

In [5]:
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
print(f"Loaded tokenizer with vocab size {tokenizer.get_vocab_size()}")

masks = load_file(DEAD_MASK_PATH)
dead_mask = masks['dead_mask'].bool()
dead_indices = masks['dead_indices'].long()
n_dead = dead_mask.sum().item()

print(f"Dead tokens: {n_dead}")
print(f"Live tokens: {VOCAB_SIZE - n_dead}")

Loaded tokenizer with vocab size 10000
Dead tokens: 3699
Live tokens: 6301


## Dataset

In [6]:
class TextDataset(Dataset):
    def __init__(self, corpus_path, tokenizer, seq_len):
        with open(corpus_path, 'r', encoding='utf-8') as f:
            text = f.read()
        encoding = tokenizer.encode(text)
        self.tokens = encoding.ids
        self.seq_len = seq_len
        self.n_sequences = len(self.tokens) // seq_len
        
    def __len__(self):
        return self.n_sequences
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        chunk = self.tokens[start:start + self.seq_len]
        return torch.tensor(chunk, dtype=torch.long)


dataset = TextDataset(CORPUS_PATH, tokenizer, SEQ_LEN)
print(f"Dataset: {len(dataset)} sequences of length {SEQ_LEN}")

Dataset: 10713 sequences of length 128


## Model Definition

In [7]:
class TinyLM(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, seq_len):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(seq_len, hidden_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.0,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln_f = nn.LayerNorm(hidden_dim)
        
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.pos_embedding.weight, mean=0.0, std=0.02)
        
    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        
        tok_emb = self.embedding(input_ids)
        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        
        hidden = tok_emb + pos_emb
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=input_ids.device)
        hidden = self.transformer(hidden, mask=causal_mask, is_causal=True)
        hidden = self.ln_f(hidden)
        
        logits = hidden @ self.embedding.weight.T
        return logits


model = TinyLM(VOCAB_SIZE, HIDDEN_DIM, NUM_LAYERS, NUM_HEADS, SEQ_LEN)
model = model.to(device).to(torch.bfloat16)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Model parameters: 748,288




## Optimizer

In [8]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(BETA1, BETA2),
    eps=EPSILON,
    weight_decay=WEIGHT_DECAY
)

print(f"Optimizer: AdamW (lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY})")

Optimizer: AdamW (lr=0.001, weight_decay=0.0)


## ULP Computation

In [9]:
def compute_ulp_bf16(tensor_bf16: torch.Tensor) -> torch.Tensor:
    """Compute ULP for each element of a bfloat16 tensor."""
    bits = tensor_bf16.view(torch.uint16).to(torch.int32)
    exponent = (bits >> 7) & 0xFF
    effective_exp = torch.where(exponent == 0, torch.ones_like(exponent), exponent)
    ulp = torch.pow(2.0, (effective_exp - 134).float())
    return ulp

## Data Collection Setup

Memory budget for 500 steps:
- W: 501 × 3699 × 64 × 2 bytes ≈ 237 MB
- m: 501 × 3699 × 64 × 4 bytes ≈ 475 MB
- v: 501 × 3699 × 64 × 4 bytes ≈ 475 MB
- g: 500 × 3699 × 64 × 4 bytes ≈ 474 MB
- ΔW′: 500 × 3699 × 64 × 4 bytes ≈ 474 MB

Total: ~2.1 GB

In [10]:
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)

# Pre-allocate tensors
W_history = torch.zeros(TOTAL_STEPS + 1, n_dead, HIDDEN_DIM, dtype=torch.bfloat16)
m_history = torch.zeros(TOTAL_STEPS + 1, n_dead, HIDDEN_DIM, dtype=torch.float32)
v_history = torch.zeros(TOTAL_STEPS + 1, n_dead, HIDDEN_DIM, dtype=torch.float32)
g_history = torch.zeros(TOTAL_STEPS, n_dead, HIDDEN_DIM, dtype=torch.float32)
delta_W_prime_history = torch.zeros(TOTAL_STEPS, n_dead, HIDDEN_DIM, dtype=torch.float32)
loss_history = torch.zeros(TOTAL_STEPS + 1, dtype=torch.float32)

# Last motion tracker
last_motion_step = torch.full((n_dead,), -1, dtype=torch.int32)

# Memory estimate
total_bytes = (
    W_history.numel() * 2 +
    m_history.numel() * 4 +
    v_history.numel() * 4 +
    g_history.numel() * 4 +
    delta_W_prime_history.numel() * 4 +
    loss_history.numel() * 4 +
    last_motion_step.numel() * 4
)
print(f"Pre-allocated {total_bytes / 1e9:.2f} GB for data collection")

Pre-allocated 2.13 GB for data collection


## Find Embedding Parameter in Optimizer

We need to extract m and v specifically for the embedding layer.

In [11]:
# Find which param group index corresponds to embedding.weight
embedding_param_id = id(model.embedding.weight)
embedding_param_idx = None

for group_idx, group in enumerate(optimizer.param_groups):
    for param_idx, param in enumerate(group['params']):
        if id(param) == embedding_param_id:
            embedding_param_idx = (group_idx, param_idx)
            break
    if embedding_param_idx is not None:
        break

print(f"Embedding parameter location in optimizer: group {embedding_param_idx[0]}, param {embedding_param_idx[1]}")

# Helper to get optimizer state for embedding
def get_embedding_optimizer_state():
    state = optimizer.state[model.embedding.weight]
    if 'exp_avg' in state:
        m = state['exp_avg'].detach().cpu()[dead_mask].float()
        v = state['exp_avg_sq'].detach().cpu()[dead_mask].float()
        return m, v
    else:
        # Before first step, no state yet
        return torch.zeros(n_dead, HIDDEN_DIM), torch.zeros(n_dead, HIDDEN_DIM)

Embedding parameter location in optimizer: group 0, param 0


## Training Loop

In [12]:
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

loss_fn = nn.CrossEntropyLoss()

# Capture initial state (t=0)
W_prev = model.embedding.weight.detach().cpu()[dead_mask].to(torch.bfloat16)
W_history[0] = W_prev
m_history[0] = torch.zeros(n_dead, HIDDEN_DIM)  # No momentum yet
v_history[0] = torch.zeros(n_dead, HIDDEN_DIM)  # No variance yet
loss_history[0] = float('nan')

print(f"Starting training for {TOTAL_STEPS} steps...")

model.train()
step = 0
epoch = 0

pbar = tqdm(total=TOTAL_STEPS, desc="Training")

while step < TOTAL_STEPS:
    epoch += 1
    
    for batch in dataloader:
        if step >= TOTAL_STEPS:
            break
            
        input_ids = batch.to(device)
        
        with torch.autocast(device_type=device if device != 'mps' else 'cpu', dtype=torch.bfloat16):
            logits = model(input_ids)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            loss = loss_fn(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        
        # === Capture gradient BEFORE optimizer step ===
        grad = model.embedding.weight.grad.detach().cpu()[dead_mask].float()
        g_history[step] = grad
        
        optimizer.step()
        
        step += 1
        
        # === Capture state AFTER optimizer step ===
        W_curr = model.embedding.weight.detach().cpu()[dead_mask].to(torch.bfloat16)
        m_curr, v_curr = get_embedding_optimizer_state()
        
        # W[t]
        W_history[step] = W_curr
        
        # m[t] and v[t] (after this step's update)
        m_history[step] = m_curr
        v_history[step] = v_curr
        
        # ΔW = W[t] - W[t-1]
        delta_W = W_curr.float() - W_prev.float()
        
        # ULP at W[t-1]
        ulp = compute_ulp_bf16(W_prev)
        
        # ΔW′ = ΔW / ULP (lattice displacement)
        delta_W_prime = delta_W / ulp
        delta_W_prime_history[step - 1] = delta_W_prime
        
        # Loss
        loss_history[step] = loss.item()
        
        # Update last_motion_step for tokens that moved
        displacement_magnitude = torch.norm(delta_W_prime, dim=1)
        moved = displacement_magnitude >= MOTION_THRESHOLD
        last_motion_step[moved] = step
        
        # Update W_prev for next iteration
        W_prev = W_curr
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'epoch': epoch})
        pbar.update(1)

pbar.close()
print(f"\nTraining complete. Final loss: {loss.item():.4f}")
print(f"Completed {epoch} epochs")

Starting training for 500 steps...


Training: 100%|██████████| 500/500 [00:33<00:00, 14.98it/s, loss=6.1562, epoch=7]


Training complete. Final loss: 6.1562
Completed 7 epochs





## Quick Stats

In [13]:
print("=" * 60)
print("CRUCIBLE 2 STATS")
print("=" * 60)

# Displacement stats
disp_l2 = torch.norm(delta_W_prime_history, dim=2)  # (500, 3699)
print(f"\nDisplacement |ΔW′|₂:")
print(f"  Range: [{disp_l2.min():.2f}, {disp_l2.max():.2f}]")
print(f"  Mean at step 1: {disp_l2[0].mean():.2f}")
print(f"  Mean at step 500: {disp_l2[-1].mean():.2f}")

# Momentum stats
m_norm = torch.norm(m_history[1:], dim=2)  # (500, 3699)
print(f"\nMomentum |m|:")
print(f"  Range: [{m_norm.min():.2e}, {m_norm.max():.2e}]")

# Variance stats
v_mean = v_history[1:].mean(dim=2)  # (500, 3699)
print(f"\nVariance v (mean across dims):")
print(f"  Range: [{v_mean.min():.2e}, {v_mean.max():.2e}]")

# Gradient stats
g_norm = torch.norm(g_history, dim=2)  # (500, 3699)
print(f"\nGradient |g|:")
print(f"  Range: [{g_norm.min():.2e}, {g_norm.max():.2e}]")

# Motion stats
global_last_motion = last_motion_step.max().item()
tokens_still_moving = (last_motion_step == TOTAL_STEPS).sum().item()
print(f"\nMotion:")
print(f"  Last motion (any token): step {global_last_motion}")
print(f"  Tokens still moving at step 500: {tokens_still_moving}")

CRUCIBLE 2 STATS

Displacement |ΔW′|₂:
  Range: [0.00, inf]
  Mean at step 1: 7450.73
  Mean at step 500: 0.70

Momentum |m|:
  Range: [1.32e-05, 8.18e-04]

Variance v (mean across dims):
  Range: [2.19e-12, 4.09e-10]

Gradient |g|:
  Range: [1.14e-05, 1.23e-03]

Motion:
  Last motion (any token): step 500
  Tokens still moving at step 500: 2270


## Save Data

In [14]:
data_to_save = {
    'W': W_history.view(torch.uint16),        # (501, 3699, 64) uint16
    'm': m_history,                            # (501, 3699, 64) float32
    'v': v_history,                            # (501, 3699, 64) float32
    'g': g_history,                            # (500, 3699, 64) float32
    'delta_W_prime': delta_W_prime_history,    # (500, 3699, 64) float32
    'loss': loss_history,                      # (501,) float32
    'last_motion_step': last_motion_step,      # (3699,) int32
    'dead_mask': dead_mask,
    'dead_indices': dead_indices,
}

save_path = output_path / 'crucible_2_trajectory.safetensors'
save_file(data_to_save, str(save_path))

print(f"Saved to {save_path}")
print(f"File size: {save_path.stat().st_size / 1e9:.2f} GB")

Saved to ../../tensors/Crucible-2/crucible_2_trajectory.safetensors
File size: 2.13 GB


## Save Metadata

In [15]:
metadata = {
    'experiment': 'Crucible 2',
    'series': 'Crucible',
    'date': '2025-11-25',
    'total_steps': TOTAL_STEPS,
    'batch_size': BATCH_SIZE,
    'seq_len': SEQ_LEN,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'beta1': BETA1,
    'beta2': BETA2,
    'epsilon': EPSILON,
    'vocab_size': VOCAB_SIZE,
    'hidden_dim': HIDDEN_DIM,
    'num_layers': NUM_LAYERS,
    'num_heads': NUM_HEADS,
    'random_seed': RANDOM_SEED,
    'motion_threshold': MOTION_THRESHOLD,
    'n_dead_tokens': n_dead,
    'final_loss': loss_history[-1].item(),
    'total_epochs': epoch,
    'device': device,
    'global_last_motion': int(global_last_motion),
    'tokens_still_moving': int(tokens_still_moving),
    'data_shapes': {
        'W': list(W_history.shape),
        'm': list(m_history.shape),
        'v': list(v_history.shape),
        'g': list(g_history.shape),
        'delta_W_prime': list(delta_W_prime_history.shape),
        'loss': list(loss_history.shape),
        'last_motion_step': list(last_motion_step.shape),
    },
    'notes': 'High-resolution phase transition study. Same seed as Crucible 1. Records W, m, v, g.'
}

with open(output_path / 'metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print("Saved metadata.json")
print(f"\n{'='*60}")
print("CRUCIBLE 2 COMPLETE")
print(f"{'='*60}")

Saved metadata.json

CRUCIBLE 2 COMPLETE
