# Crucible 3: h_mean Persistence and Live/Dead Separation

**Hypothesis:** h_mean maintains a consistent direction over training, causing dead tokens to drift in a straight line (antiparallel to h) while live tokens drift in the opposite direction (parallel to h). This creates systematic separation between the two populations.

**What's new vs Crucible 2:**
- Records **all tokens** (live + dead), not just dead tokens
- Extracts **h_mean[t]** at each training step
- Simplified data collection (no m, v, g)
- Focused on h_mean autocorrelation and centroid dynamics

**Data captured:**
- W[t]: full embedding matrix (all 10k tokens, bfloat16 as uint16)
- h_mean[t]: average hidden state per step (float32)
- loss[t]: training loss

**Questions to answer:**
1. Does h_mean autocorrelation stay high (~0.95) over 500 steps, or decay?
2. Do live and dead token centroids separate over time?
3. Does h_mean point toward the live centroid?

**Storage:** ~0.64 GB total (much smaller than Crucible 2's 2.1 GB)

## Parameters

In [None]:
# Training parameters
TOTAL_STEPS = 500
BATCH_SIZE = 128
SEQ_LEN = 128
LEARNING_RATE = 1e-3      # Same as Crucible 1 & 2
WEIGHT_DECAY = 0.0        # Same as Crucible 1 & 2
BETA1 = 0.9
BETA2 = 0.999
EPSILON = 1e-8

# Model parameters
VOCAB_SIZE = 10000
HIDDEN_DIM = 64
NUM_LAYERS = 2
NUM_HEADS = 2

# Reproducibility — SAME SEED as Crucible 1 & 2 for comparability
RANDOM_SEED = 42

# Paths
CORPUS_PATH = '../../data/flannel_model_corpus.txt'
TOKENIZER_PATH = '../../data/flannel_tokenizer_chars.json'
DEAD_MASK_PATH = '../../tensors/Flannel/live_dead_tokens.safetensors'
OUTPUT_DIR = '../../tensors/Crucible-3'

print(f"Crucible 3: {TOTAL_STEPS} steps, lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}")
print(f"Recording full W (all {VOCAB_SIZE} tokens) + h_mean[t]")

## Imports

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from safetensors.torch import save_file, load_file
from pathlib import Path
from tqdm import tqdm
import json
import random
import numpy as np

## Device Detection

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

## Set Random Seeds

In [None]:
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

if device == 'cuda':
    torch.cuda.manual_seed_all(RANDOM_SEED)

print(f"Random seed set to {RANDOM_SEED}")

## Load Tokenizer and Dead Token Mask

In [None]:
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
print(f"Loaded tokenizer with vocab size {tokenizer.get_vocab_size()}")

masks = load_file(DEAD_MASK_PATH)
dead_mask = masks['dead_mask'].bool()
dead_indices = masks['dead_indices'].long()
live_mask = ~dead_mask
n_dead = dead_mask.sum().item()
n_live = live_mask.sum().item()

print(f"Dead tokens: {n_dead}")
print(f"Live tokens: {n_live}")

## Dataset

In [None]:
class TextDataset(Dataset):
    def __init__(self, corpus_path, tokenizer, seq_len):
        with open(corpus_path, 'r', encoding='utf-8') as f:
            text = f.read()
        encoding = tokenizer.encode(text)
        self.tokens = encoding.ids
        self.seq_len = seq_len
        self.n_sequences = len(self.tokens) // seq_len
        
    def __len__(self):
        return self.n_sequences
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        chunk = self.tokens[start:start + self.seq_len]
        return torch.tensor(chunk, dtype=torch.long)


dataset = TextDataset(CORPUS_PATH, tokenizer, SEQ_LEN)
print(f"Dataset: {len(dataset)} sequences of length {SEQ_LEN}")

## Model Definition

Modified to optionally return hidden state h.

In [None]:
class TinyLM(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, seq_len):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(seq_len, hidden_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.0,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln_f = nn.LayerNorm(hidden_dim)
        
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.pos_embedding.weight, mean=0.0, std=0.02)
        
    def forward(self, input_ids, return_hidden=False):
        """If return_hidden=True, returns (logits, h) where h is the final hidden state."""
        batch_size, seq_len = input_ids.shape
        
        tok_emb = self.embedding(input_ids)
        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        
        hidden = tok_emb + pos_emb
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=input_ids.device)
        hidden = self.transformer(hidden, mask=causal_mask, is_causal=True)
        hidden = self.ln_f(hidden)  # This is h
        
        logits = hidden @ self.embedding.weight.T
        
        if return_hidden:
            return logits, hidden
        return logits


model = TinyLM(VOCAB_SIZE, HIDDEN_DIM, NUM_LAYERS, NUM_HEADS, SEQ_LEN)
model = model.to(device).to(torch.bfloat16)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Optimizer

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(BETA1, BETA2),
    eps=EPSILON,
    weight_decay=WEIGHT_DECAY
)

print(f"Optimizer: AdamW (lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY})")

## Data Collection Setup

Memory budget for 500 steps:
- W: 501 × 10000 × 64 × 2 bytes ≈ 641 MB
- h_mean: 500 × 64 × 4 bytes ≈ 0.13 MB
- loss: 501 × 4 bytes ≈ 0.002 MB

Total: ~641 MB

In [None]:
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)

# Pre-allocate tensors
W_history = torch.zeros(TOTAL_STEPS + 1, VOCAB_SIZE, HIDDEN_DIM, dtype=torch.bfloat16)
h_mean_history = torch.zeros(TOTAL_STEPS, HIDDEN_DIM, dtype=torch.float32)
loss_history = torch.zeros(TOTAL_STEPS + 1, dtype=torch.float32)

# Memory estimate
total_bytes = (
    W_history.numel() * 2 +
    h_mean_history.numel() * 4 +
    loss_history.numel() * 4
)
print(f"Pre-allocated {total_bytes / 1e9:.3f} GB for data collection")

## Training Loop

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

loss_fn = nn.CrossEntropyLoss()

# Capture initial state (t=0)
W_history[0] = model.embedding.weight.detach().cpu().to(torch.bfloat16)
loss_history[0] = float('nan')

print(f"Starting training for {TOTAL_STEPS} steps...")

model.train()
step = 0
epoch = 0

pbar = tqdm(total=TOTAL_STEPS, desc="Training")

while step < TOTAL_STEPS:
    epoch += 1
    
    for batch in dataloader:
        if step >= TOTAL_STEPS:
            break
            
        input_ids = batch.to(device)
        
        # Forward pass with hidden state extraction
        with torch.autocast(device_type=device if device != 'mps' else 'cpu', dtype=torch.bfloat16):
            logits, h = model(input_ids, return_hidden=True)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            loss = loss_fn(shift_logits.view(-1, VOCAB_SIZE), shift_labels.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        step += 1
        
        # === Capture state AFTER optimizer step ===
        
        # W[t] (all tokens)
        W_history[step] = model.embedding.weight.detach().cpu().to(torch.bfloat16)
        
        # h_mean[t] (averaged over batch and sequence)
        h_mean = h.mean(dim=(0, 1)).cpu().float()  # [D]
        h_mean_history[step - 1] = h_mean
        
        # Loss[t]
        loss_history[step] = loss.item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'epoch': epoch})
        pbar.update(1)

pbar.close()
print(f"\nTraining complete. Final loss: {loss.item():.4f}")
print(f"Completed {epoch} epochs")

## Quick Stats

In [None]:
print("=" * 60)
print("CRUCIBLE 3 STATS")
print("=" * 60)

# Centroid stats
W_dead = W_history[:, dead_mask, :].float()
W_live = W_history[:, live_mask, :].float()

centroid_dead = W_dead.mean(dim=1)  # (501, 64)
centroid_live = W_live.mean(dim=1)  # (501, 64)

centroid_dead_norm = torch.norm(centroid_dead, dim=1)
centroid_live_norm = torch.norm(centroid_live, dim=1)

print(f"\nDead token centroid:")
print(f"  Norm at t=0: {centroid_dead_norm[0].item():.6f}")
print(f"  Norm at t=500: {centroid_dead_norm[-1].item():.6f}")
print(f"  Total displacement: {torch.norm(centroid_dead[-1] - centroid_dead[0]).item():.6f}")

print(f"\nLive token centroid:")
print(f"  Norm at t=0: {centroid_live_norm[0].item():.6f}")
print(f"  Norm at t=500: {centroid_live_norm[-1].item():.6f}")
print(f"  Total displacement: {torch.norm(centroid_live[-1] - centroid_live[0]).item():.6f}")

# Separation
separation = torch.norm(centroid_dead - centroid_live, dim=1)
print(f"\nLive/Dead separation:")
print(f"  Distance at t=0: {separation[0].item():.6f}")
print(f"  Distance at t=500: {separation[-1].item():.6f}")

# h_mean stats
h_mean_norm = torch.norm(h_mean_history, dim=1)
print(f"\nh_mean:")
print(f"  L2 norm range: [{h_mean_norm.min().item():.6f}, {h_mean_norm.max().item():.6f}]")
print(f"  Mean L2 norm: {h_mean_norm.mean().item():.6f}")

## Save Data

In [None]:
data_to_save = {
    'W': W_history.view(torch.uint16),  # (501, 10000, 64) uint16
    'h_mean': h_mean_history,            # (500, 64) float32
    'loss': loss_history,                # (501,) float32
    'dead_mask': dead_mask,
    'dead_indices': dead_indices,
}

save_path = output_path / 'crucible_3_trajectory.safetensors'
save_file(data_to_save, str(save_path))

print(f"Saved to {save_path}")
print(f"File size: {save_path.stat().st_size / 1e9:.2f} GB")

## Save Metadata

In [None]:
metadata = {
    'experiment': 'Crucible 3',
    'series': 'Crucible',
    'date': '2025-11-28',
    'total_steps': TOTAL_STEPS,
    'batch_size': BATCH_SIZE,
    'seq_len': SEQ_LEN,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'beta1': BETA1,
    'beta2': BETA2,
    'epsilon': EPSILON,
    'vocab_size': VOCAB_SIZE,
    'hidden_dim': HIDDEN_DIM,
    'num_layers': NUM_LAYERS,
    'num_heads': NUM_HEADS,
    'random_seed': RANDOM_SEED,
    'n_dead_tokens': n_dead,
    'n_live_tokens': n_live,
    'final_loss': loss_history[-1].item(),
    'total_epochs': epoch,
    'device': device,
    'data_shapes': {
        'W': list(W_history.shape),
        'h_mean': list(h_mean_history.shape),
        'loss': list(loss_history.shape),
    },
    'notes': 'h_mean persistence and live/dead separation study. Records all tokens + h_mean[t]. Same seed as Crucible 1 & 2.'
}

with open(output_path / 'metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print("Saved metadata.json")
print(f"\n{'='*60}")
print("CRUCIBLE 3 COMPLETE")
print(f"{'='*60}")