# 15.1a: Comprehensive Instrumentation

**Record everything: embeddings, gradients, optimizer state, logits, loss**

## Volume 15: The Everything Dataset

Previous instrumented runs recorded deltas (changes). This time we record **absolute state** at each step:

1. **Embedding matrix W**: The actual token positions in space
2. **Gradients**: The instantaneous forces acting on tokens
3. **Adam momentum**: The inertial velocity each token has built up
4. **Adam variance**: The adaptive learning rate scaling per token
5. **Logits**: Model predictions (one position per step)
6. **Loss**: Training loss value

This lets us study:
- Token trajectories (from W directly, no cumsum needed)
- Force decomposition (gradient vs momentum)
- Prediction evolution (logits over time)
- Optimizer dynamics (how Adam modulates the motion)

## File Size

~660 MB for 10,001 steps (0-10,000)

## Parameters

In [1]:
# Model architecture
VOCAB_SIZE = 128
HIDDEN_DIM = 64
N_LAYER = 2
N_HEAD = 2
MAX_SEQ_LEN = 128

# Training
BATCH_SIZE = 32
GRADIENT_ACCUMULATION = 1
NUM_TRAIN_STEPS = 10000
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.01

# Adam parameters (defaults from torch.optim.AdamW)
ADAM_BETA1 = 0.9  # momentum decay
ADAM_BETA2 = 0.999  # variance decay
ADAM_EPSILON = 1e-8

# Initialization
INIT_SIGMA = 1e-5  # float32 → bfloat16

# Data
CORPUS_PATH = "../data/training_corpus.txt"
OUTPUT_DIR = "../data/comprehensive_run"

# Instrumentation
RECORD_EVERY_N_STEPS = 1  # Record every step

RANDOM_SEED = 42

## Imports

In [2]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file
import time

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
print("✓ Imports complete")

Using device: mps
✓ Imports complete


## Load Gatsby Corpus

In [3]:
print(f"Loading corpus: {CORPUS_PATH}")

with open(CORPUS_PATH, 'r', encoding='ascii') as f:
    corpus_text = f.read()

corpus_bytes = [b for b in corpus_text.encode('ascii') if b < VOCAB_SIZE]

unique_bytes = set(corpus_bytes)
dead_token_ids = sorted(set(range(VOCAB_SIZE)) - unique_bytes)
live_token_ids = sorted(unique_bytes)

print(f"  Total bytes: {len(corpus_bytes):,}")
print(f"  Live tokens: {len(live_token_ids)} / {VOCAB_SIZE}")
print(f"  Dead tokens: {len(dead_token_ids)} / {VOCAB_SIZE}")

# Pre-load to device
corpus_tensor = torch.tensor(corpus_bytes, dtype=torch.long, device=device)
print(f"\n✓ Corpus on device")

Loading corpus: ../data/training_corpus.txt
  Total bytes: 265,905
  Live tokens: 77 / 128
  Dead tokens: 51 / 128

✓ Corpus on device


## Dataset

In [4]:
class ByteDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = ByteDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"✓ Dataset: {len(dataset):,} examples")

✓ Dataset: 265,777 examples


## Model

In [5]:
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

model = GPT2LMHeadModel(config)
model = model.to(torch.bfloat16).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model created: {total_params:,} parameters")

✓ Model created: 116,480 parameters


## Float32 Initialization

In [6]:
print(f"\nFloat32 initialization (σ = {INIT_SIGMA:.2e})\n")

with torch.no_grad():
    # Generate random unit vector in float32
    random_vector = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device)
    random_vector = random_vector / random_vector.norm()
    
    # Add Gaussian noise in float32
    noise = torch.randn(VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32, device=device) * INIT_SIGMA
    init_f32 = random_vector + noise
    
    # Convert to bfloat16 for training
    init_bf16 = init_f32.to(torch.bfloat16)
    
    # Assign to model
    model.transformer.wte.weight[:] = init_bf16
    
    print(f"Initialized: {VOCAB_SIZE} tokens at norm ~{init_f32.norm(dim=1).mean().item():.6f}")

print(f"\n✓ Embeddings initialized")


Float32 initialization (σ = 1.00e-05)

Initialized: 128 tokens at norm ~0.999999

✓ Embeddings initialized


## Comprehensive Recorder

In [7]:
class ComprehensiveRecorder:
    """Records everything: embeddings, gradients, Adam state, logits, loss."""
    
    def __init__(self, vocab_size, hidden_dim, record_every_n):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.record_every_n = record_every_n
        
        # Storage
        self.recorded_steps = []
        self.embeddings = []      # [n_recorded, vocab_size, hidden_dim]
        self.grads = []           # [n_recorded, vocab_size, hidden_dim]
        self.momentum = []        # [n_recorded, vocab_size, hidden_dim]
        self.variance = []        # [n_recorded, vocab_size, hidden_dim]
        self.logits = []          # [n_recorded, vocab_size]
        self.losses = []          # [n_recorded]
        
        # Temporary storage
        self.current_step = 0
        self.recorded_initial = False
        self.grad_before = None
        self.loss_value = None
        self.logits_sample = None
    
    def record_initial_state(self, model, optimizer):
        """Record step 0: initial state before training."""
        if not self.recorded_initial:
            W = model.transformer.wte.weight.data.clone().cpu().float()
            
            # Step 0: no gradients, no Adam state yet (zeros)
            self.recorded_steps.append(0)
            self.embeddings.append(W)
            self.grads.append(torch.zeros_like(W))
            self.momentum.append(torch.zeros_like(W))
            self.variance.append(torch.zeros_like(W))
            self.logits.append(torch.zeros(self.vocab_size))
            self.losses.append(torch.tensor(float('nan')))  # No loss yet
            
            self.recorded_initial = True
            self.current_step = 1
    
    def record_before_step(self, model, loss, logits):
        """Call after forward/backward, before optimizer step."""
        if self.current_step % self.record_every_n == 0:
            # Capture gradients
            if model.transformer.wte.weight.grad is not None:
                self.grad_before = model.transformer.wte.weight.grad.clone().cpu().float()
            else:
                self.grad_before = torch.zeros(self.vocab_size, self.hidden_dim)
            
            # Capture loss
            self.loss_value = loss.item()
            
            # Capture logits from first sequence, last position
            # logits shape: [batch_size, seq_len, vocab_size]
            self.logits_sample = logits[0, -1, :].detach().cpu().float()
    
    def record_after_step(self, model, optimizer):
        """Call after optimizer step."""
        if self.current_step % self.record_every_n == 0:
            # Only record if we have data from before_step
            if self.grad_before is not None and self.loss_value is not None:
                # Capture embeddings
                W = model.transformer.wte.weight.data.clone().cpu().float()

                # Capture Adam state
                # Adam stores state in optimizer.state[param]
                # We need to find which param corresponds to wte.weight
                param = model.transformer.wte.weight
                if param in optimizer.state:
                    state = optimizer.state[param]
                    # exp_avg = first moment (momentum)
                    # exp_avg_sq = second moment (variance)
                    mom = state['exp_avg'].clone().cpu().float()
                    var = state['exp_avg_sq'].clone().cpu().float()
                else:
                    # Optimizer hasn't initialized state yet (shouldn't happen after step 1)
                    mom = torch.zeros_like(W)
                    var = torch.zeros_like(W)

                # Store everything
                self.recorded_steps.append(self.current_step)
                self.embeddings.append(W)
                self.grads.append(self.grad_before)
                self.momentum.append(mom)
                self.variance.append(var)
                self.logits.append(self.logits_sample)
                self.losses.append(torch.tensor(self.loss_value))

                # Clear temp storage
                self.grad_before = None
                self.loss_value = None
                self.logits_sample = None

        self.current_step += 1
    
    def get_data(self):
        """Return recorded data as tensors."""
        return {
            'recorded_steps': torch.tensor(self.recorded_steps, dtype=torch.long),
            'embeddings': torch.stack(self.embeddings) if self.embeddings else torch.tensor([]),
            'grads': torch.stack(self.grads) if self.grads else torch.tensor([]),
            'momentum': torch.stack(self.momentum) if self.momentum else torch.tensor([]),
            'variance': torch.stack(self.variance) if self.variance else torch.tensor([]),
            'logits': torch.stack(self.logits) if self.logits else torch.tensor([]),
            'losses': torch.stack(self.losses) if self.losses else torch.tensor([]),
        }

print("✓ Recorder class defined")

✓ Recorder class defined


## Custom Trainer

In [8]:
class InstrumentedTrainer(Trainer):
    def __init__(self, recorder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.recorder = recorder
        self.last_logits = None

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Override to capture logits."""
        outputs = model(**inputs)
        loss = outputs.loss
        
        # Store logits for recorder
        self.last_logits = outputs.logits
        
        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        """Override to inject recording."""
        # Standard forward + backward
        loss = super().training_step(model, inputs, num_items_in_batch)
        
        # Record BEFORE optimizer step (gradients + loss + logits)
        self.recorder.record_before_step(model, loss, self.last_logits)
        
        return loss

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, **kwargs):
        """Override to record AFTER optimizer step."""
        # Record AFTER optimizer updates parameters
        self.recorder.record_after_step(model, self.optimizer)
        
        # Call parent
        super()._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, **kwargs)

print("✓ InstrumentedTrainer defined")

✓ InstrumentedTrainer defined


## Training Configuration

In [9]:
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

recorder = ComprehensiveRecorder(VOCAB_SIZE, HIDDEN_DIM, RECORD_EVERY_N_STEPS)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    adam_beta1=ADAM_BETA1,
    adam_beta2=ADAM_BETA2,
    adam_epsilon=ADAM_EPSILON,
    logging_steps=1000,
    save_steps=NUM_TRAIN_STEPS + 1,  # Don't save checkpoints
    save_total_limit=0,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    bf16=True,
    seed=RANDOM_SEED,
    report_to="none",
    disable_tqdm=False,
)

trainer = InstrumentedTrainer(
    recorder=recorder,
    model=model,
    args=training_args,
    train_dataset=dataset,
)

print("✓ Trainer ready")

✓ Trainer ready


## Record Initial State

In [10]:
print("Recording initial state (step 0)...")
recorder.record_initial_state(model, trainer.optimizer)
print(f"✓ Step 0 recorded")

Recording initial state (step 0)...
✓ Step 0 recorded


## Train

In [11]:
print(f"\n{'='*80}")
print(f"Starting comprehensive instrumented training")
print(f"  Steps: {NUM_TRAIN_STEPS} (plus initial step 0)")
print(f"  Recording: embeddings, gradients, Adam state, logits, loss")
print(f"  Expected file size: ~660 MB")
print(f"{'='*80}\n")

start_time = time.time()
trainer.train()
elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"✓ Training complete ({elapsed/60:.1f} min)")
print(f"{'='*80}")


Starting comprehensive instrumented training
  Steps: 10000 (plus initial step 0)
  Recording: embeddings, gradients, Adam state, logits, loss
  Expected file size: ~660 MB



`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
1000,2.912
2000,2.5288
3000,2.3422
4000,2.2622
5000,2.2184
6000,2.1884
7000,2.1739
8000,2.1635
9000,2.1561
10000,2.1538



✓ Training complete (1.7 min)


## Save Recorded Data

In [12]:
recorded_data = recorder.get_data()

save_dict = {
    'recorded_steps': recorded_data['recorded_steps'],
    'dead_token_ids': torch.tensor(dead_token_ids, dtype=torch.long),
    'live_token_ids': torch.tensor(live_token_ids, dtype=torch.long),
    'embeddings': recorded_data['embeddings'],
    'grads': recorded_data['grads'],
    'momentum': recorded_data['momentum'],
    'variance': recorded_data['variance'],
    'logits': recorded_data['logits'],
    'losses': recorded_data['losses'],
    'init_sigma': torch.tensor(INIT_SIGMA, dtype=torch.float32),
    'learning_rate': torch.tensor(LEARNING_RATE, dtype=torch.float32),
    'weight_decay': torch.tensor(WEIGHT_DECAY, dtype=torch.float32),
    'adam_beta1': torch.tensor(ADAM_BETA1, dtype=torch.float32),
    'adam_beta2': torch.tensor(ADAM_BETA2, dtype=torch.float32),
    'init_method_code': torch.tensor(0, dtype=torch.int32),  # 0=f32→bf16, 1=pure_bf16
}

output_path = Path(OUTPUT_DIR) / "comprehensive_training_data.safetensors"
save_file(save_dict, output_path)

file_size_mb = output_path.stat().st_size / 1e6

print(f"✓ Saved to: {output_path}")
print(f"  File size: {file_size_mb:.1f} MB")
print(f"  Recorded steps: {len(recorded_data['recorded_steps'])}")
print(f"  Step range: {recorded_data['recorded_steps'][0]} to {recorded_data['recorded_steps'][-1]}")
print(f"\nShapes:")
print(f"  embeddings: {recorded_data['embeddings'].shape}")
print(f"  grads: {recorded_data['grads'].shape}")
print(f"  momentum: {recorded_data['momentum'].shape}")
print(f"  variance: {recorded_data['variance'].shape}")
print(f"  logits: {recorded_data['logits'].shape}")
print(f"  losses: {recorded_data['losses'].shape}")

✓ Saved to: ../data/comprehensive_run/comprehensive_training_data.safetensors
  File size: 1316.1 MB
  Recorded steps: 10001
  Step range: 0 to 10001

Shapes:
  embeddings: torch.Size([10001, 128, 64])
  grads: torch.Size([10001, 128, 64])
  momentum: torch.Size([10001, 128, 64])
  variance: torch.Size([10001, 128, 64])
  logits: torch.Size([10001, 128])
  losses: torch.Size([10001])


## Quick Verification

In [13]:
print(f"\n{'='*80}")
print(f"QUICK VERIFICATION")
print(f"{'='*80}\n")

embeddings = recorded_data['embeddings']
grads = recorded_data['grads']
logits_vec = recorded_data['logits']
losses = recorded_data['losses']

if len(embeddings) > 1:
    print(f"Embeddings evolution:")
    print(f"  Step 0 centroid norm: {embeddings[0].mean(dim=0).norm().item():.6f}")
    print(f"  Step {NUM_TRAIN_STEPS} centroid norm: {embeddings[-1].mean(dim=0).norm().item():.6f}")
    
    print(f"\nGradient magnitudes:")
    dead_grads_step1 = grads[1, dead_token_ids, :]
    print(f"  Step 1 dead token grads (mean): {torch.norm(dead_grads_step1, p=2, dim=1).mean().item():.6e}")
    
    print(f"\nLogits at step 1 (should be ~equal for all tokens):")
    logits_step1 = logits_vec[1]
    print(f"  Min: {logits_step1.min().item():.4f}")
    print(f"  Max: {logits_step1.max().item():.4f}")
    print(f"  Range: {(logits_step1.max() - logits_step1.min()).item():.4f}")
    print(f"  Std: {logits_step1.std().item():.4f}")
    
    print(f"\nLoss evolution:")
    print(f"  Step 1: {losses[1].item():.4f}")
    print(f"  Step {NUM_TRAIN_STEPS}: {losses[-1].item():.4f}")

print(f"\n{'='*80}")


QUICK VERIFICATION

Embeddings evolution:
  Step 0 centroid norm: 1.000153
  Step 10000 centroid norm: 0.808508

Gradient magnitudes:
  Step 1 dead token grads (mean): 6.171330e-02

Logits at step 1 (should be ~equal for all tokens):
  Min: 7.6875
  Max: 7.6875
  Range: 0.0000
  Std: 0.0000

Loss evolution:
  Step 1: 4.8520
  Step 10000: 2.1623



## Summary

This notebook records the complete training history:

- **Embeddings**: Actual token positions (no reconstruction needed)
- **Gradients**: Instantaneous forces from loss function
- **Adam momentum**: Inertial velocity built up from past gradients
- **Adam variance**: Adaptive per-parameter learning rate scaling
- **Logits**: Model predictions at one position per step
- **Loss**: Training objective value

This lets us analyze:
1. Token trajectories and velocities
2. Force decomposition (gradient vs momentum)
3. Prediction evolution (logits over time)
4. Whether early logits are uniform (Jeffery's hypothesis)
5. How optimizer momentum affects dead token motion