# Thimble 8: Dead Token Dynamics Training Run

Train a minimal language model and capture full optimizer state for dead tokens at every step.

**Focus:** First 3,500 training steps—the supernova, cooling, and fimbulwinter onset.

**Data captured per step (dead tokens only):**
- W: embedding weights (bfloat16)
- m: Adam momentum / first moment (float32)
- v: Adam variance / second moment (float32)  
- g: gradients (float32)
- loss: training loss (float32 scalar)

**Checkpoints:** Full model + optimizer state at t = 0, 1, 10, 100, 500, 1000, 2000, 3000, 3500

**Expected output:** ~11.6 GB in `box_4/tensors/Thimble-8/`

## Parameters

In [1]:
# Training parameters
TOTAL_STEPS = 4000
BATCH_SIZE = 128
SEQ_LEN = 128
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.1
BETA1 = 0.9
BETA2 = 0.999
EPSILON = 1e-8

# Model parameters
VOCAB_SIZE = 10000
HIDDEN_DIM = 64
NUM_LAYERS = 2
NUM_HEADS = 2

# Reproducibility
RANDOM_SEED = 42

# Checkpointing
CHECKPOINT_STEPS = [0, 1, 10, 100, 500, 1000, 2000, 3000, 3500, 4000]

# Paths (relative to notebook location: box_4/notebooks/training/)
CORPUS_PATH = '../../data/flannel_model_corpus.txt'
TOKENIZER_PATH = '../../data/flannel_tokenizer_chars.json'
DEAD_MASK_PATH = '../../tensors/Flannel/live_dead_tokens.safetensors'
OUTPUT_DIR = '../../tensors/Thimble-8'

## Imports

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from safetensors.torch import save_file, load_file
from pathlib import Path
from tqdm import tqdm
import json
import random
import numpy as np

## Device Detection

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Set Random Seeds

In [4]:
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

if device == 'cuda':
    torch.cuda.manual_seed_all(RANDOM_SEED)

print(f"Random seed set to {RANDOM_SEED}")

Random seed set to 42


## Load Tokenizer and Dead Token Mask

In [5]:
# Load tokenizer (using tokenizers library, same as Thimble 7)
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
print(f"Loaded tokenizer with vocab size {tokenizer.get_vocab_size()}")

# Load dead token mask
masks = load_file(DEAD_MASK_PATH)
dead_mask = masks['dead_mask'].bool()
dead_indices = masks['dead_indices'].long()
n_dead = dead_mask.sum().item()

print(f"Dead tokens: {n_dead}")
print(f"Live tokens: {VOCAB_SIZE - n_dead}")

Loaded tokenizer with vocab size 10000
Dead tokens: 3699
Live tokens: 6301


## Dataset

In [6]:
class TextDataset(Dataset):
    """Simple dataset that returns chunks of tokenized text."""
    
    def __init__(self, corpus_path, tokenizer, seq_len):
        with open(corpus_path, 'r', encoding='utf-8') as f:
            text = f.read()
        
        # Tokenize entire corpus (tokenizers library returns Encoding object)
        encoding = tokenizer.encode(text)
        self.tokens = encoding.ids
        self.seq_len = seq_len
        
        # Number of complete sequences we can make
        self.n_sequences = len(self.tokens) // seq_len
        
    def __len__(self):
        return self.n_sequences
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        chunk = self.tokens[start:start + self.seq_len]
        return torch.tensor(chunk, dtype=torch.long)


dataset = TextDataset(CORPUS_PATH, tokenizer, SEQ_LEN)
print(f"Dataset: {len(dataset)} sequences of length {SEQ_LEN}")
print(f"Total tokens: {len(dataset) * SEQ_LEN:,}")

Dataset: 10713 sequences of length 128
Total tokens: 1,371,264


## Model Definition

Minimal transformer with tied embeddings (E = W^T), matching Qwen architecture choice.

In [7]:
class TinyLM(nn.Module):
    """Minimal language model with tied embeddings."""
    
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads):
        super().__init__()
        
        # Token embeddings - this is W^T, the transpose of unembedding
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        
        # Positional embeddings
        self.pos_embedding = nn.Embedding(SEQ_LEN, hidden_dim)
        
        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.0,  # No dropout for reproducibility
            activation='gelu',
            batch_first=True,
            norm_first=True  # Pre-norm like modern architectures
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Final layer norm
        self.ln_f = nn.LayerNorm(hidden_dim)
        
        # Output projection (tied to embedding)
        # We'll compute logits as: hidden @ embedding.weight.T
        
        # Initialize embeddings with N(0, 0.02)
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.pos_embedding.weight, mean=0.0, std=0.02)
        
    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        
        # Get embeddings
        tok_emb = self.embedding(input_ids)
        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        
        hidden = tok_emb + pos_emb
        
        # Causal mask
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=input_ids.device)
        
        # Transform
        hidden = self.transformer(hidden, mask=causal_mask, is_causal=True)
        hidden = self.ln_f(hidden)
        
        # Tied output projection: logits = hidden @ W where W = embedding.weight
        logits = hidden @ self.embedding.weight.T
        
        return logits


# Create model
model = TinyLM(VOCAB_SIZE, HIDDEN_DIM, NUM_LAYERS, NUM_HEADS)
model = model.to(device).to(torch.bfloat16)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print(f"Embedding shape: {model.embedding.weight.shape}")

Model parameters: 748,288
Embedding shape: torch.Size([10000, 64])




## Optimizer

In [8]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(BETA1, BETA2),
    eps=EPSILON,
    weight_decay=WEIGHT_DECAY
)

print(f"Optimizer: AdamW")
print(f"  lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}")
print(f"  betas=({BETA1}, {BETA2}), eps={EPSILON}")

Optimizer: AdamW
  lr=0.0003, weight_decay=0.1
  betas=(0.9, 0.999), eps=1e-08


## Data Collection Setup

In [9]:
# Create output directory
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)
checkpoint_path = output_path / 'checkpoints'
checkpoint_path.mkdir(exist_ok=True)

# Pre-allocate tensors for data collection (on CPU to save GPU memory)
# Shape: (steps+1, n_dead_tokens, hidden_dim) - +1 because we capture t=0
W_history = torch.zeros(TOTAL_STEPS + 1, n_dead, HIDDEN_DIM, dtype=torch.bfloat16)
m_history = torch.zeros(TOTAL_STEPS + 1, n_dead, HIDDEN_DIM, dtype=torch.float32)
v_history = torch.zeros(TOTAL_STEPS + 1, n_dead, HIDDEN_DIM, dtype=torch.float32)
g_history = torch.zeros(TOTAL_STEPS + 1, n_dead, HIDDEN_DIM, dtype=torch.float32)
loss_history = torch.zeros(TOTAL_STEPS + 1, dtype=torch.float32)

# Memory estimate
total_bytes = (
    W_history.numel() * 2 +  # bfloat16
    m_history.numel() * 4 +  # float32
    v_history.numel() * 4 +
    g_history.numel() * 4 +
    loss_history.numel() * 4
)
print(f"Pre-allocated {total_bytes / 1e9:.2f} GB for data collection")

Pre-allocated 13.26 GB for data collection


## Helper Functions

In [10]:
def get_embedding_param_id(model):
    """Get the parameter ID for the embedding weights in optimizer state."""
    for i, (name, param) in enumerate(model.named_parameters()):
        if name == 'embedding.weight':
            return i
    raise ValueError("Could not find embedding.weight parameter")


def capture_state(model, optimizer, step, dead_mask, loss_val=0.0):
    """Capture current state for dead tokens."""
    
    # Get embedding weights (dead tokens only)
    W = model.embedding.weight.detach().cpu()[dead_mask]
    W_history[step] = W.to(torch.bfloat16)
    
    # Get optimizer state for embeddings
    emb_param_id = get_embedding_param_id(model)
    emb_param = list(model.parameters())[emb_param_id]
    
    if emb_param in optimizer.state:
        state = optimizer.state[emb_param]
        # Momentum (exp_avg)
        m_history[step] = state['exp_avg'].detach().cpu()[dead_mask].float()
        # Variance (exp_avg_sq)
        v_history[step] = state['exp_avg_sq'].detach().cpu()[dead_mask].float()
    # else: stays zero (step 0, before any optimizer step)
    
    # Gradients (if they exist)
    if emb_param.grad is not None:
        g_history[step] = emb_param.grad.detach().cpu()[dead_mask].float()
    # else: stays zero
    
    # Loss
    loss_history[step] = loss_val


def save_checkpoint(model, optimizer, step, path):
    """Save full model and optimizer checkpoint."""
    checkpoint = {
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'rng_state': torch.get_rng_state(),
    }
    if device == 'cuda':
        checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state_all()
    
    torch.save(checkpoint, path / f'checkpoint_step_{step:05d}.pt')


print(f"Embedding parameter ID: {get_embedding_param_id(model)}")

Embedding parameter ID: 0


## Training Loop

In [11]:
# DataLoader with shuffling
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True  # Ensure consistent batch sizes
)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Capture initial state (t=0, before any training)
capture_state(model, optimizer, step=0, dead_mask=dead_mask, loss_val=0.0)
if 0 in CHECKPOINT_STEPS:
    save_checkpoint(model, optimizer, 0, checkpoint_path)
    print("Saved checkpoint at step 0")

print(f"\nStarting training for {TOTAL_STEPS} steps...")
print(f"Checkpoints at: {CHECKPOINT_STEPS}")

# Training
model.train()
step = 0
epoch = 0

pbar = tqdm(total=TOTAL_STEPS, desc="Training")

while step < TOTAL_STEPS:
    epoch += 1
    
    for batch in dataloader:
        if step >= TOTAL_STEPS:
            break
            
        # Move to device
        input_ids = batch.to(device)
        
        # Forward pass
        with torch.autocast(device_type=device if device != 'mps' else 'cpu', dtype=torch.bfloat16):
            logits = model(input_ids)
            
            # Shift for next-token prediction
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            
            # Compute loss
            loss = loss_fn(
                shift_logits.view(-1, VOCAB_SIZE),
                shift_labels.view(-1)
            )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Optimizer step
        optimizer.step()
        
        step += 1
        
        # Capture state (after optimizer step, so this is state at end of step t)
        capture_state(model, optimizer, step=step, dead_mask=dead_mask, loss_val=loss.item())
        
        # Checkpoint if needed
        if step in CHECKPOINT_STEPS:
            save_checkpoint(model, optimizer, step, checkpoint_path)
            tqdm.write(f"Saved checkpoint at step {step}")
        
        # Update progress bar
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'epoch': epoch})
        pbar.update(1)

pbar.close()
print(f"\nTraining complete. Final loss: {loss.item():.4f}")
print(f"Completed {epoch} epochs")

Saved checkpoint at step 0

Starting training for 4000 steps...
Checkpoints at: [0, 1, 10, 100, 500, 1000, 2000, 3000, 3500, 4000]


Training:   0%|          | 3/4000 [00:01<24:20,  2.74it/s, loss=9.1250, epoch=1]  

Saved checkpoint at step 1


Training:   0%|          | 13/4000 [00:01<05:49, 11.41it/s, loss=8.9375, epoch=1]

Saved checkpoint at step 10


Training:   3%|▎         | 101/4000 [00:07<04:31, 14.35it/s, loss=7.6250, epoch=2]

Saved checkpoint at step 100


Training:  13%|█▎        | 503/4000 [00:33<03:51, 15.10it/s, loss=6.9688, epoch=7]

Saved checkpoint at step 500


Training:  25%|██▌       | 1003/4000 [01:05<03:16, 15.25it/s, loss=6.6562, epoch=13]

Saved checkpoint at step 1000


Training:  50%|█████     | 2003/4000 [02:07<02:08, 15.53it/s, loss=6.3750, epoch=25]

Saved checkpoint at step 2000


Training:  75%|███████▌  | 3003/4000 [03:10<01:04, 15.40it/s, loss=6.3438, epoch=37]

Saved checkpoint at step 3000


Training:  88%|████████▊ | 3503/4000 [03:42<00:32, 15.46it/s, loss=6.2812, epoch=43]

Saved checkpoint at step 3500


Training: 100%|██████████| 4000/4000 [04:13<00:00, 15.76it/s, loss=6.2812, epoch=49]

Saved checkpoint at step 4000

Training complete. Final loss: 6.2812
Completed 49 epochs





## Save Collected Data

In [12]:
# Save as safetensors
data_to_save = {
    'W': W_history,           # (3501, n_dead, 64) bfloat16
    'm': m_history,           # (3501, n_dead, 64) float32
    'v': v_history,           # (3501, n_dead, 64) float32
    'g': g_history,           # (3501, n_dead, 64) float32
    'loss': loss_history,     # (3501,) float32
    'dead_mask': dead_mask,   # (10000,) bool -> for reference
    'dead_indices': dead_indices,  # (n_dead,) long -> for reference
}

# Note: safetensors doesn't support bfloat16 directly in all versions
# Convert to uint16 view to preserve exact bits
data_to_save['W'] = W_history.view(torch.uint16)

save_path = output_path / 'thimble_8_trajectory.safetensors'
save_file(data_to_save, str(save_path))

print(f"Saved trajectory data to {save_path}")
print(f"File size: {save_path.stat().st_size / 1e9:.2f} GB")

Saved trajectory data to ../../tensors/Thimble-8/thimble_8_trajectory.safetensors
File size: 13.26 GB


## Save Metadata

In [13]:
metadata = {
    'experiment': 'Thimble 8',
    'date': '2025-11-25',
    'total_steps': TOTAL_STEPS,
    'batch_size': BATCH_SIZE,
    'seq_len': SEQ_LEN,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'beta1': BETA1,
    'beta2': BETA2,
    'epsilon': EPSILON,
    'vocab_size': VOCAB_SIZE,
    'hidden_dim': HIDDEN_DIM,
    'num_layers': NUM_LAYERS,
    'num_heads': NUM_HEADS,
    'random_seed': RANDOM_SEED,
    'n_dead_tokens': n_dead,
    'checkpoint_steps': CHECKPOINT_STEPS,
    'final_loss': loss_history[-1].item(),
    'total_epochs': epoch,
    'device': device,
    'data_shapes': {
        'W': list(W_history.shape),
        'm': list(m_history.shape),
        'v': list(v_history.shape),
        'g': list(g_history.shape),
        'loss': list(loss_history.shape),
    },
    'notes': 'Dead tokens only. W stored as uint16 (bfloat16 bit pattern).'
}

with open(output_path / 'metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print("Saved metadata.json")
print(f"\nExperiment complete!")
print(f"Data: {save_path}")
print(f"Checkpoints: {checkpoint_path}")
print(f"Metadata: {output_path / 'metadata.json'}")

Saved metadata.json

Experiment complete!
Data: ../../tensors/Thimble-8/thimble_8_trajectory.safetensors
Checkpoints: ../../tensors/Thimble-8/checkpoints
Metadata: ../../tensors/Thimble-8/metadata.json


## Quick Sanity Check

In [14]:
# Verify we captured data correctly
print("=" * 50)
print("SANITY CHECK")
print("=" * 50)

# W should change over time
W_bf16 = W_history.view(torch.bfloat16)  # Convert back
W_delta = (W_bf16[-1] - W_bf16[0]).float().abs().mean()
print(f"Mean |W[3500] - W[0]|: {W_delta:.6f}")

# m and v should be non-zero at end
print(f"Mean |m[3500]|: {m_history[-1].abs().mean():.6f}")
print(f"Mean |v[3500]|: {v_history[-1].abs().mean():.6f}")

# Loss should decrease
print(f"Loss[1]: {loss_history[1]:.4f}")
print(f"Loss[3500]: {loss_history[-1]:.4f}")

# Check m[0] and v[0] are zero (before any optimizer step)
print(f"m[0] all zero: {(m_history[0] == 0).all().item()}")
print(f"v[0] all zero: {(v_history[0] == 0).all().item()}")

print("\nSanity check complete!")

SANITY CHECK
Mean |W[3500] - W[0]|: 0.054386
Mean |m[3500]|: 0.000008
Mean |v[3500]|: 0.000000
Loss[1]: 9.1875
Loss[3500]: 6.2812
m[0] all zero: True
v[0] all zero: True

Sanity check complete!
