# Thimble 1: Adam Accounting Validation

**Question:** Can we reconstruct observed weight changes from recorded gradients and optimizer states?

**Method:** Train a small language model for 1,000 steps using bare PyTorch (no Trainer framework). Record at every step:
- W: embedding weights
- grad_W: gradients
- momentum_W: Adam momentum (exp_avg)
- variance_W: Adam variance (exp_avg_sq)

Then test: does the AdamW update formula reproduce the observed ΔW?

$$\Delta W(t) = -\eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} - \lambda \cdot W(t-1)$$

where $\hat{m}_t = m_t / (1 - \beta_1^t)$ and $\hat{v}_t = v_t / (1 - \beta_2^t)$ are bias-corrected.

**Why this matters:** If we can't validate the accounting, we can't trust any downstream analysis of forces and dynamics. This is the foundation for understanding dead token motion.

---

## Model Architecture

Same as Flannel experiments:
- 10,000 token vocabulary (6,301 live, 3,699 dead)
- 64-dimensional embeddings
- 2-layer transformer
- 2 attention heads
- Tied embeddings (input embedding = output unembedding transpose)

## Parameters

In [None]:
# Model architecture
VOCAB_SIZE = 10000
HIDDEN_DIM = 64
N_LAYERS = 2
N_HEADS = 2
MAX_SEQ_LEN = 128

# Training
NUM_STEPS = 1000
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.0  # Studying dynamics in "flat spacetime"

# Optimizer (AdamW)
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.999
ADAM_EPSILON = 1e-8

# Initialization
INIT_SCALE = 0.02  # N(0, 0.02)
SEED = 42

# Paths
TOKENIZER_PATH = "../data/flannel_tokenizer_chars.json"
CORPUS_PATH = "../data/flannel_model_corpus.txt"
TOKEN_MASK_PATH = "../tensors/Flannel/live_dead_tokens.safetensors"
OUTPUT_PATH = "../tensors/Thimble/thimble_1.safetensors"

print("✓ Parameters set")

## Imports

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Config, GPT2LMHeadModel
from tokenizers import Tokenizer
import numpy as np
from pathlib import Path
from safetensors.torch import save_file, load_file
from tqdm.auto import tqdm
import time

print("✓ Imports complete")

## Safety Check

In [None]:
print(f"\n{'='*80}")
print(f"MEMORY & DISK SAFETY CHECK")
print(f"{'='*80}\n")

# Recording tensors (stored on CPU during training)
bytes_bf16 = 2  # bfloat16
bytes_f32 = 4   # float32

recording_w = (NUM_STEPS+1) * VOCAB_SIZE * HIDDEN_DIM * bytes_bf16
recording_grad = (NUM_STEPS+1) * VOCAB_SIZE * HIDDEN_DIM * bytes_bf16
recording_momentum = (NUM_STEPS+1) * VOCAB_SIZE * HIDDEN_DIM * bytes_f32
recording_variance = (NUM_STEPS+1) * VOCAB_SIZE * HIDDEN_DIM * bytes_f32
recording_losses = (NUM_STEPS+1) * bytes_f32

total_recording = recording_w + recording_grad + recording_momentum + recording_variance + recording_losses

print(f"Recording tensors (CPU memory):")
print(f"  W:         {recording_w/1e9:.2f} GB")
print(f"  grad_W:    {recording_grad/1e9:.2f} GB")
print(f"  momentum:  {recording_momentum/1e9:.2f} GB")
print(f"  variance:  {recording_variance/1e9:.2f} GB")
print(f"  losses:    {recording_losses/1e9:.4f} GB")
print(f"  {'─'*40}")
print(f"  Total:     {total_recording/1e9:.2f} GB")
print()

# Model memory (on device during training)
embedding_params = VOCAB_SIZE * HIDDEN_DIM
params_per_layer = 12 * HIDDEN_DIM**2  # Rough estimate for transformer layer
transformer_params = N_LAYERS * params_per_layer
total_model_params = embedding_params + transformer_params

model_memory = total_model_params * bytes_bf16
optimizer_memory = 2 * total_model_params * bytes_f32  # Adam has 2 states (m, v) in fp32
activation_memory = BATCH_SIZE * MAX_SEQ_LEN * HIDDEN_DIM * N_LAYERS * 2 * bytes_bf16

print(f"Model memory (device memory):")
print(f"  Model weights: {model_memory/1e9:.2f} GB ({total_model_params:,} params)")
print(f"  Optimizer:     {optimizer_memory/1e9:.2f} GB (Adam states)")
print(f"  Activations:   {activation_memory/1e9:.2f} GB (batch={BATCH_SIZE})")
print(f"  {'─'*40}")
print(f"  Total:         {(model_memory + optimizer_memory + activation_memory)/1e9:.2f} GB")
print()

# Peak RAM: recording tensors + model + optimizer + activations + corpus + misc
corpus_memory = 1371328 * 8  # Approximate token count * bytes per long
misc_overhead = 1e9  # 1 GB for Python, libraries, etc.
peak_ram = total_recording + model_memory + optimizer_memory + activation_memory + corpus_memory + misc_overhead

print(f"Peak RAM estimate:")
print(f"  Recording:     {total_recording/1e9:.2f} GB")
print(f"  Model+opt+act: {(model_memory + optimizer_memory + activation_memory)/1e9:.2f} GB")
print(f"  Corpus+misc:   {(corpus_memory + misc_overhead)/1e9:.2f} GB")
print(f"  {'─'*40}")
print(f"  Total:         {peak_ram/1e9:.2f} GB")
print()

# Disk space: same as recording tensors (plus small overhead for metadata)
metadata_overhead = 1e6  # ~1 MB for scalar tensors
disk_needed = total_recording + metadata_overhead

print(f"Disk space needed:")
print(f"  Safetensors:   {disk_needed/1e9:.2f} GB")
print()

# Safety verdict
print(f"{'='*80}")
if peak_ram <= 24e9:
    print(f"✓ SAFE: Peak RAM ({peak_ram/1e9:.1f} GB) within 24 GB budget")
else:
    print(f"⚠️  WARNING: Peak RAM ({peak_ram/1e9:.1f} GB) exceeds 24 GB budget!")
print(f"{'='*80}\n")

## Device Detection

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

## Set Random Seeds

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"✓ Random seed set to {SEED}")

## Load Data

In [None]:
# Tokenizer
print(f"Loading tokenizer: {TOKENIZER_PATH}")
tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))
print(f"  ✓ Vocabulary: {tokenizer.get_vocab_size():,} tokens\n")

# Corpus
print(f"Loading corpus: {CORPUS_PATH}")
with open(CORPUS_PATH, 'r', encoding='utf-8') as f:
    corpus_text = f.read()
encoding = tokenizer.encode(corpus_text)
tokens = encoding.ids
corpus_tensor = torch.tensor(tokens, dtype=torch.long)
print(f"  ✓ Tokens: {len(tokens):,}\n")

# Token masks (for analysis later)
print(f"Loading token masks: {TOKEN_MASK_PATH}")
mask_data = load_file(TOKEN_MASK_PATH)
live_mask = mask_data['live_mask'].bool()
dead_mask = mask_data['dead_mask'].bool()
n_live = live_mask.sum().item()
n_dead = dead_mask.sum().item()
print(f"  ✓ Live: {n_live:,} | Dead: {n_dead:,}")

## Dataset and DataLoader

In [None]:
class TokenDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = TokenDataset(corpus_tensor, MAX_SEQ_LEN)

# DataLoader with deterministic sampling
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    generator=g,
    worker_init_fn=seed_worker,
    num_workers=0,  # Single-threaded for reproducibility
)

print(f"\n✓ Dataset: {len(dataset):,} examples")
print(f"✓ DataLoader: {len(dataloader):,} batches per epoch")

## Create Model

In [None]:
print("Creating model...\n")

config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYERS,
    n_head=N_HEADS,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

model = GPT2LMHeadModel(config)

# Initialize embedding weights with N(0, 0.02)
with torch.no_grad():
    nn.init.normal_(model.transformer.wte.weight, mean=0.0, std=INIT_SCALE)

# Move to device and convert to bfloat16
model = model.to(torch.bfloat16).to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())

print(f"  Architecture: {N_LAYERS} layers, {N_HEADS} heads, {HIDDEN_DIM}d embeddings")
print(f"  Parameters: {n_params:,}")
print(f"  Device: {device}")
print(f"  Dtype: {model.dtype}")
print(f"\n✓ Model created")

## Create Optimizer

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(ADAM_BETA1, ADAM_BETA2),
    eps=ADAM_EPSILON,
    weight_decay=WEIGHT_DECAY,
)

print(f"✓ Optimizer: AdamW")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Betas: ({ADAM_BETA1}, {ADAM_BETA2})")
print(f"  Epsilon: {ADAM_EPSILON}")
print(f"  Weight decay: {WEIGHT_DECAY}")

## Pre-allocate Recording Tensors

In [None]:
print("\nPre-allocating recording tensors...\n")

# Shape: (num_steps+1, vocab_size, hidden_dim)
W_history = torch.zeros(NUM_STEPS+1, VOCAB_SIZE, HIDDEN_DIM, dtype=torch.bfloat16)
grad_history = torch.zeros(NUM_STEPS+1, VOCAB_SIZE, HIDDEN_DIM, dtype=torch.bfloat16)

# Optimizer states are kept in float32 (PyTorch's internal format)
momentum_history = torch.zeros(NUM_STEPS+1, VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32)
variance_history = torch.zeros(NUM_STEPS+1, VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32)

# Losses
loss_history = torch.zeros(NUM_STEPS+1, dtype=torch.float32)

# Memory calculation
memory_w = W_history.numel() * 2  # bfloat16 = 2 bytes
memory_grad = grad_history.numel() * 2
memory_momentum = momentum_history.numel() * 4  # float32 = 4 bytes
memory_variance = variance_history.numel() * 4
memory_loss = loss_history.numel() * 4
total_memory = memory_w + memory_grad + memory_momentum + memory_variance + memory_loss

print(f"  W:         {tuple(W_history.shape)} (bfloat16) = {memory_w/1e9:.2f} GB")
print(f"  grad_W:    {tuple(grad_history.shape)} (bfloat16) = {memory_grad/1e9:.2f} GB")
print(f"  momentum:  {tuple(momentum_history.shape)} (float32) = {memory_momentum/1e9:.2f} GB")
print(f"  variance:  {tuple(variance_history.shape)} (float32) = {memory_variance/1e9:.2f} GB")
print(f"  losses:    {tuple(loss_history.shape)} (float32) = {memory_loss/1e9:.4f} GB")
print(f"\n  Total: {total_memory/1e9:.2f} GB")
print(f"\n✓ Tensors allocated")

## Training Loop

In [None]:
print(f"\n{'='*80}")
print(f"THIMBLE 1: TRAINING")
print(f"{'='*80}\n")

# Record initial state (step 0)
W_history[0] = model.transformer.wte.weight.data.clone().cpu().bfloat16()
loss_history[0] = float('nan')  # No loss before first step
print("✓ Recorded initial state (t=0)\n")

# Create infinite iterator over dataloader
data_iter = iter(dataloader)

# Training loop
model.train()
start_time = time.time()

for step in tqdm(range(1, NUM_STEPS+1), desc="Training"):
    # Get next batch (cycle through dataset if needed)
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(dataloader)
        batch = next(data_iter)
    
    # Move batch to device
    input_ids = batch['input_ids'].to(device)
    labels = batch['labels'].to(device)
    
    # Forward pass
    outputs = model(input_ids=input_ids, labels=labels)
    loss = outputs.loss
    
    # Backward pass
    loss.backward()
    
    # === RECORD GRADIENTS (before optimizer.step) ===
    grad_history[step] = model.transformer.wte.weight.grad.clone().cpu().bfloat16()
    
    # Optimizer step
    optimizer.step()
    optimizer.zero_grad()
    
    # === RECORD WEIGHTS & OPTIMIZER STATE (after optimizer.step) ===
    W_history[step] = model.transformer.wte.weight.data.clone().cpu().bfloat16()
    
    # Get optimizer state for embedding weights
    wte_param = model.transformer.wte.weight
    if wte_param in optimizer.state:
        opt_state = optimizer.state[wte_param]
        momentum_history[step] = opt_state['exp_avg'].clone().cpu().float()
        variance_history[step] = opt_state['exp_avg_sq'].clone().cpu().float()
    
    loss_history[step] = loss.item()

elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"✓ Training complete")
print(f"  Time: {elapsed:.1f}s ({elapsed/60:.1f} minutes)")
print(f"  Final loss: {loss_history[-1]:.4f}")
print(f"{'='*80}")

## Quick Validation: Test Adam Accounting

Before saving, let's verify that we can reconstruct ΔW from our recorded data.

In [21]:
print("\nValidating Adam accounting...\n")

# Test at several timesteps
test_steps = [1, 10, 50, 100, 200, 500, 800]

for t in test_steps:
    if t > NUM_STEPS:
        continue
    
    # Measured ΔW
    delta_W_measured = W_history[t] - W_history[t-1]
    measured_norm = torch.norm(delta_W_measured.float())
    
    # Compute ΔW from AdamW formula
    m_t = momentum_history[t]
    v_t = variance_history[t]
    
    # Bias correction
    m_hat = m_t / (1 - ADAM_BETA1**t)
    v_hat = v_t / (1 - ADAM_BETA2**t)
    
    # AdamW update
    adam_term = LEARNING_RATE * m_hat / (torch.sqrt(v_hat) + ADAM_EPSILON)
    decay_term = WEIGHT_DECAY * W_history[t-1].float()
    
    delta_W_computed = -adam_term - decay_term
    computed_norm = torch.norm(delta_W_computed)
    
    # Compare
    ratio = computed_norm / measured_norm
    
    # Cosine similarity
    cosine = (delta_W_computed.flatten() @ delta_W_measured.float().flatten()) / (computed_norm * measured_norm)
    
    print(f"t={t:3d}: ratio={ratio:.3f}, cosine={cosine:+.3f}")

print("\n(Ratio should be ~1.0, cosine should be ~1.0 for perfect match)")
print("\n✓ Validation complete")


Validating Adam accounting...

t=  1: ratio=1.012, cosine=+1.000
t= 10: ratio=1.003, cosine=+0.999
t= 50: ratio=0.987, cosine=+0.991
t=100: ratio=0.957, cosine=+0.918
t=200: ratio=1.048, cosine=+0.882
t=500: ratio=0.917, cosine=+0.879
t=800: ratio=1.086, cosine=+0.821

(Ratio should be ~1.0, cosine should be ~1.0 for perfect match)

✓ Validation complete


## Save Data

In [None]:
print(f"\nSaving data to {OUTPUT_PATH}...\n")

# Create output directory if needed
Path(OUTPUT_PATH).parent.mkdir(parents=True, exist_ok=True)

# Build save dictionary
save_dict = {
    # Training trajectories
    'W': W_history,
    'grad_W': grad_history,
    'momentum_W': momentum_history,
    'variance_W': variance_history,
    'losses': loss_history,
    
    # Model hyperparameters
    'vocab_size': torch.tensor(VOCAB_SIZE, dtype=torch.long),
    'hidden_dim': torch.tensor(HIDDEN_DIM, dtype=torch.long),
    'n_layers': torch.tensor(N_LAYERS, dtype=torch.long),
    'n_heads': torch.tensor(N_HEADS, dtype=torch.long),
    
    # Training hyperparameters
    'num_steps': torch.tensor(NUM_STEPS, dtype=torch.long),
    'batch_size': torch.tensor(BATCH_SIZE, dtype=torch.long),
    'learning_rate': torch.tensor(LEARNING_RATE, dtype=torch.float32),
    'weight_decay': torch.tensor(WEIGHT_DECAY, dtype=torch.float32),
    'adam_beta1': torch.tensor(ADAM_BETA1, dtype=torch.float32),
    'adam_beta2': torch.tensor(ADAM_BETA2, dtype=torch.float32),
    'adam_epsilon': torch.tensor(ADAM_EPSILON, dtype=torch.float32),
    'init_scale': torch.tensor(INIT_SCALE, dtype=torch.float32),
    'seed': torch.tensor(SEED, dtype=torch.long),
    
    # Token counts
    'n_live': torch.tensor(n_live, dtype=torch.long),
    'n_dead': torch.tensor(n_dead, dtype=torch.long),
}

# Save
save_start = time.time()
save_file(save_dict, str(OUTPUT_PATH))
save_elapsed = time.time() - save_start

# File size
file_size_bytes = Path(OUTPUT_PATH).stat().st_size
file_size_gb = file_size_bytes / 1e9

print(f"✓ Saved successfully")
print(f"  File: {Path(OUTPUT_PATH).name}")
print(f"  Size: {file_size_gb:.2f} GB")
print(f"  Save time: {save_elapsed:.1f}s")

## Summary

In [None]:
print(f"\n{'='*80}")
print(f"THIMBLE 1 COMPLETE")
print(f"{'='*80}\n")

print(f"Trained small language model for {NUM_STEPS:,} steps")
print(f"  Seed: {SEED}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print()
print(f"Recorded at every step:")
print(f"  • W: embedding weights (bfloat16)")
print(f"  • grad_W: gradients (bfloat16)")
print(f"  • momentum_W: Adam exp_avg (float32)")
print(f"  • variance_W: Adam exp_avg_sq (float32)")
print(f"  • losses: training loss")
print()
print(f"Data saved: {OUTPUT_PATH}")
print(f"  Size: {file_size_gb:.2f} GB")
print(f"  Training time: {elapsed/60:.1f} minutes")
print()
print(f"Next step: Analyze in separate notebook to fully validate Adam accounting.")
print(f"\n{'='*80}")