# Thimble 9: Long-Haul Fimbulwinter Hunt

Train for 10,000 steps to definitively capture the onset of Fimbulwinter.

**Stripped-down data capture:**
- W only (dead tokens, bfloat16 as uint16)
- Loss per step
- Sparse checkpoints

**Verification mode:** First run with 1,000 steps to confirm bitwise equality with Thimble 8.

**Expected output:** ~4.7 GB for full 10k run

## Parameters

In [14]:
# === VERIFICATION MODE ===
# Set to True for 1000-step verification run against Thimble 8
# Set to False for full 10,000-step production run
VERIFICATION_MODE = False

# Training parameters (MUST match Thimble 8 exactly)
TOTAL_STEPS = 1000 if VERIFICATION_MODE else 10000
BATCH_SIZE = 128
SEQ_LEN = 128
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.1
BETA1 = 0.9
BETA2 = 0.999
EPSILON = 1e-8

# Model parameters (MUST match Thimble 8 exactly)
VOCAB_SIZE = 10000
HIDDEN_DIM = 64
NUM_LAYERS = 2
NUM_HEADS = 2

# Reproducibility (MUST match Thimble 8 exactly)
RANDOM_SEED = 42

# Checkpointing (sparse for long run)
if VERIFICATION_MODE:
    CHECKPOINT_STEPS = [0, 100, 500, 1000]
else:
    CHECKPOINT_STEPS = [0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]

# Paths
CORPUS_PATH = '../../data/flannel_model_corpus.txt'
TOKENIZER_PATH = '../../data/flannel_tokenizer_chars.json'
DEAD_MASK_PATH = '../../tensors/Flannel/live_dead_tokens.safetensors'
OUTPUT_DIR = '../../tensors/Thimble-9'

# For verification comparison
THIMBLE_8_PATH = '../../tensors/Thimble-8/thimble_8_trajectory.safetensors'

print(f"MODE: {'VERIFICATION (1000 steps)' if VERIFICATION_MODE else 'PRODUCTION (10000 steps)'}")
print(f"Total steps: {TOTAL_STEPS}")

MODE: PRODUCTION (10000 steps)
Total steps: 10000


## Imports

In [15]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from safetensors.torch import save_file, load_file
from pathlib import Path
from tqdm import tqdm
import json
import random
import numpy as np

## Device Detection

In [16]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Set Random Seeds

In [17]:
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

if device == 'cuda':
    torch.cuda.manual_seed_all(RANDOM_SEED)

print(f"Random seed set to {RANDOM_SEED}")

Random seed set to 42


## Load Tokenizer and Dead Token Mask

In [18]:
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
print(f"Loaded tokenizer with vocab size {tokenizer.get_vocab_size()}")

masks = load_file(DEAD_MASK_PATH)
dead_mask = masks['dead_mask'].bool()
dead_indices = masks['dead_indices'].long()
n_dead = dead_mask.sum().item()

print(f"Dead tokens: {n_dead}")
print(f"Live tokens: {VOCAB_SIZE - n_dead}")

Loaded tokenizer with vocab size 10000
Dead tokens: 3699
Live tokens: 6301


## Dataset

In [19]:
class TextDataset(Dataset):
    """Simple dataset that returns chunks of tokenized text."""
    
    def __init__(self, corpus_path, tokenizer, seq_len):
        with open(corpus_path, 'r', encoding='utf-8') as f:
            text = f.read()
        
        encoding = tokenizer.encode(text)
        self.tokens = encoding.ids
        self.seq_len = seq_len
        self.n_sequences = len(self.tokens) // seq_len
        
    def __len__(self):
        return self.n_sequences
    
    def __getitem__(self, idx):
        start = idx * self.seq_len
        chunk = self.tokens[start:start + self.seq_len]
        return torch.tensor(chunk, dtype=torch.long)


dataset = TextDataset(CORPUS_PATH, tokenizer, SEQ_LEN)
print(f"Dataset: {len(dataset)} sequences of length {SEQ_LEN}")
print(f"Total tokens: {len(dataset) * SEQ_LEN:,}")

Dataset: 10713 sequences of length 128
Total tokens: 1,371,264


## Model Definition

In [20]:
class TinyLM(nn.Module):
    """Minimal language model with tied embeddings."""
    
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(SEQ_LEN, hidden_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.0,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.ln_f = nn.LayerNorm(hidden_dim)
        
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.pos_embedding.weight, mean=0.0, std=0.02)
        
    def forward(self, input_ids):
        batch_size, seq_len = input_ids.shape
        
        tok_emb = self.embedding(input_ids)
        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        
        hidden = tok_emb + pos_emb
        
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=input_ids.device)
        
        hidden = self.transformer(hidden, mask=causal_mask, is_causal=True)
        hidden = self.ln_f(hidden)
        
        logits = hidden @ self.embedding.weight.T
        
        return logits


model = TinyLM(VOCAB_SIZE, HIDDEN_DIM, NUM_LAYERS, NUM_HEADS)
model = model.to(device).to(torch.bfloat16)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print(f"Embedding shape: {model.embedding.weight.shape}")

Model parameters: 748,288
Embedding shape: torch.Size([10000, 64])




## Optimizer

In [21]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(BETA1, BETA2),
    eps=EPSILON,
    weight_decay=WEIGHT_DECAY
)

print(f"Optimizer: AdamW")
print(f"  lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}")
print(f"  betas=({BETA1}, {BETA2}), eps={EPSILON}")

Optimizer: AdamW
  lr=0.0003, weight_decay=0.1
  betas=(0.9, 0.999), eps=1e-08


## Data Collection Setup

In [22]:
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)
checkpoint_path = output_path / 'checkpoints'
checkpoint_path.mkdir(exist_ok=True)

# Stripped down: W only (as bfloat16/uint16) + loss
W_history = torch.zeros(TOTAL_STEPS + 1, n_dead, HIDDEN_DIM, dtype=torch.bfloat16)
loss_history = torch.zeros(TOTAL_STEPS + 1, dtype=torch.float32)

total_bytes = W_history.numel() * 2 + loss_history.numel() * 4
print(f"Pre-allocated {total_bytes / 1e9:.2f} GB for data collection")

Pre-allocated 4.74 GB for data collection


## Helper Functions

In [23]:
def capture_state(model, step, dead_mask, loss_val=0.0):
    """Capture W for dead tokens only."""
    W = model.embedding.weight.detach().cpu()[dead_mask]
    W_history[step] = W.to(torch.bfloat16)
    loss_history[step] = loss_val


def save_checkpoint(model, optimizer, step, path):
    """Save full model and optimizer checkpoint."""
    checkpoint = {
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'rng_state': torch.get_rng_state(),
    }
    if device == 'cuda':
        checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state_all()
    
    torch.save(checkpoint, path / f'checkpoint_step_{step:05d}.pt')

## Training Loop

In [24]:
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

loss_fn = nn.CrossEntropyLoss()

# Capture initial state
capture_state(model, step=0, dead_mask=dead_mask, loss_val=0.0)
if 0 in CHECKPOINT_STEPS:
    save_checkpoint(model, optimizer, 0, checkpoint_path)
    print("Saved checkpoint at step 0")

print(f"\nStarting training for {TOTAL_STEPS} steps...")
print(f"Checkpoints at: {CHECKPOINT_STEPS}")

model.train()
step = 0
epoch = 0

pbar = tqdm(total=TOTAL_STEPS, desc="Training")

while step < TOTAL_STEPS:
    epoch += 1
    
    for batch in dataloader:
        if step >= TOTAL_STEPS:
            break
            
        input_ids = batch.to(device)
        
        with torch.autocast(device_type=device if device != 'mps' else 'cpu', dtype=torch.bfloat16):
            logits = model(input_ids)
            
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            
            loss = loss_fn(
                shift_logits.view(-1, VOCAB_SIZE),
                shift_labels.view(-1)
            )
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        step += 1
        
        capture_state(model, step=step, dead_mask=dead_mask, loss_val=loss.item())
        
        if step in CHECKPOINT_STEPS:
            save_checkpoint(model, optimizer, step, checkpoint_path)
            tqdm.write(f"Saved checkpoint at step {step}")
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'epoch': epoch})
        pbar.update(1)

pbar.close()
print(f"\nTraining complete. Final loss: {loss.item():.4f}")
print(f"Completed {epoch} epochs")

Saved checkpoint at step 0

Starting training for 10000 steps...
Checkpoints at: [0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]


Training:  10%|█         | 1002/10000 [01:00<09:21, 16.04it/s, loss=6.6562, epoch=13]

Saved checkpoint at step 1000


Training:  20%|██        | 2002/10000 [02:00<08:16, 16.10it/s, loss=6.3750, epoch=25]

Saved checkpoint at step 2000


Training:  30%|███       | 3002/10000 [03:02<07:15, 16.09it/s, loss=6.3438, epoch=37]

Saved checkpoint at step 3000


Training:  40%|████      | 4002/10000 [04:03<06:20, 15.76it/s, loss=6.2500, epoch=49]

Saved checkpoint at step 4000


Training:  50%|█████     | 5002/10000 [05:06<05:29, 15.19it/s, loss=6.2188, epoch=61]

Saved checkpoint at step 5000


Training:  60%|██████    | 6002/10000 [06:09<04:21, 15.30it/s, loss=6.2812, epoch=73]

Saved checkpoint at step 6000


Training:  70%|███████   | 7002/10000 [07:12<03:25, 14.59it/s, loss=6.2188, epoch=85]

Saved checkpoint at step 7000


Training:  80%|████████  | 8002/10000 [08:15<02:06, 15.75it/s, loss=6.2188, epoch=97]

Saved checkpoint at step 8000


Training:  90%|█████████ | 9002/10000 [09:15<01:03, 15.74it/s, loss=6.2500, epoch=109]

Saved checkpoint at step 9000


Training: 100%|██████████| 10000/10000 [10:17<00:00, 16.21it/s, loss=6.2188, epoch=121]

Saved checkpoint at step 10000

Training complete. Final loss: 6.2188
Completed 121 epochs





## Verification: Compare to Thimble 8

In [25]:
if VERIFICATION_MODE:
    print("=" * 60)
    print("VERIFICATION: Comparing to Thimble 8")
    print("=" * 60)
    
    # Load Thimble 8 data
    thimble_8_data = load_file(THIMBLE_8_PATH)
    W_thimble_8 = thimble_8_data['W'].view(torch.bfloat16)  # (4001, 3699, 64)
    
    # Compare first 1001 steps (0 through 1000)
    W_thimble_9 = W_history  # (1001, 3699, 64)
    W_thimble_8_subset = W_thimble_8[:1001]  # (1001, 3699, 64)
    
    print(f"Thimble 9 W shape: {W_thimble_9.shape}")
    print(f"Thimble 8 W subset shape: {W_thimble_8_subset.shape}")
    
    # Bitwise comparison (view as uint16 for exact bit comparison)
    W9_bits = W_thimble_9.view(torch.uint16)
    W8_bits = W_thimble_8_subset.view(torch.uint16)
    
    exact_match = (W9_bits == W8_bits).all().item()
    
    if exact_match:
        print("\n✓ BITWISE IDENTICAL: Thimble 9 matches Thimble 8 exactly!")
    else:
        # Find where they differ
        diff_mask = W9_bits != W8_bits
        n_diff = diff_mask.sum().item()
        total_elements = W9_bits.numel()
        
        print(f"\n✗ MISMATCH: {n_diff:,} / {total_elements:,} elements differ ({n_diff/total_elements:.4%})")
        
        # Find first differing timestep
        diff_per_step = diff_mask.any(dim=(1, 2))
        first_diff_step = torch.where(diff_per_step)[0]
        if len(first_diff_step) > 0:
            print(f"First difference at step: {first_diff_step[0].item()}")
        
        # Check step 0 specifically (initialization)
        step0_match = (W9_bits[0] == W8_bits[0]).all().item()
        print(f"Step 0 (initialization) matches: {step0_match}")
        
        # Check loss values
        loss_8 = thimble_8_data['loss'][:1001]
        loss_9 = loss_history
        loss_match = torch.allclose(loss_8, loss_9, rtol=1e-5, atol=1e-5)
        print(f"Loss values match (within tolerance): {loss_match}")
        
        if not loss_match:
            # Find first loss divergence
            loss_diff = (loss_8 - loss_9).abs()
            max_diff_idx = loss_diff.argmax().item()
            print(f"Max loss difference at step {max_diff_idx}: {loss_diff[max_diff_idx]:.6f}")
else:
    print("Skipping verification (production mode)")

Skipping verification (production mode)


## Save Data (if not verification mode, or if verification passed)

In [26]:
if VERIFICATION_MODE:
    print("\nVerification mode: Not saving trajectory data.")
    print("If verification passed, set VERIFICATION_MODE = False and rerun.")
else:
    data_to_save = {
        'W': W_history.view(torch.uint16),  # Store as uint16 to preserve bfloat16 bits
        'loss': loss_history,
        'dead_mask': dead_mask,
        'dead_indices': dead_indices,
    }
    
    save_path = output_path / 'thimble_9_trajectory.safetensors'
    save_file(data_to_save, str(save_path))
    
    print(f"Saved trajectory data to {save_path}")
    print(f"File size: {save_path.stat().st_size / 1e9:.2f} GB")
    
    # Save metadata
    metadata = {
        'experiment': 'Thimble 9',
        'date': '2025-11-25',
        'total_steps': TOTAL_STEPS,
        'batch_size': BATCH_SIZE,
        'seq_len': SEQ_LEN,
        'learning_rate': LEARNING_RATE,
        'weight_decay': WEIGHT_DECAY,
        'beta1': BETA1,
        'beta2': BETA2,
        'epsilon': EPSILON,
        'vocab_size': VOCAB_SIZE,
        'hidden_dim': HIDDEN_DIM,
        'num_layers': NUM_LAYERS,
        'num_heads': NUM_HEADS,
        'random_seed': RANDOM_SEED,
        'n_dead_tokens': n_dead,
        'checkpoint_steps': CHECKPOINT_STEPS,
        'final_loss': loss_history[-1].item(),
        'total_epochs': epoch,
        'device': device,
        'data_shapes': {
            'W': list(W_history.shape),
            'loss': list(loss_history.shape),
        },
        'notes': 'Dead tokens only. W stored as uint16 (bfloat16 bit pattern). Stripped-down for long run.'
    }
    
    with open(output_path / 'metadata.json', 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print("Saved metadata.json")

Saved trajectory data to ../../tensors/Thimble-9/thimble_9_trajectory.safetensors
File size: 4.74 GB
Saved metadata.json
