# 14.1a: Lil Gatsby with Instrumentation

**Do dead tokens receive gradients from the loss function?**

## The Question

We observed in 13.4b that dead tokens MOVE during training. Our hypothesis:
- Early in training, all tokens are stacked together
- Model predicts dead tokens "by accident" (uniform softmax probabilities)
- Cross-entropy loss penalizes these predictions → gradients push dead tokens down
- Even tiny gradients (~0.008 per position) accumulate across thousands of positions

## The Experiment

Train the same Gatsby model as 13.4a, but this time record:
1. **Initial state** (step 0, before any training)
2. **Raw gradients** after backward pass (before optimizer)
3. **Parameter updates** (delta from optimizer step)
4. Separate analysis for dead vs live tokens

If dead tokens have **non-zero gradients** → hypothesis confirmed!  
If gradients are **zero** → movement is purely from weight decay.

## Parameters

In [13]:
# Model architecture (tiny for speed)
VOCAB_SIZE = 128
HIDDEN_DIM = 64
N_LAYER = 2
N_HEAD = 2
MAX_SEQ_LEN = 128

# Training
BATCH_SIZE = 32
GRADIENT_ACCUMULATION = 1
NUM_TRAIN_STEPS = 10000
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.01

# Initialization
INIT_SIGMA = 1e-5  # float32 → bfloat16

# Data
CORPUS_PATH = "../data/training_corpus.txt"
OUTPUT_DIR = "../data/instrumented_run"

# Instrumentation
RECORD_EVERY_N_STEPS = 1  # Record every step

RANDOM_SEED = 42

## Imports

In [14]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, TrainerCallback
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file
import time

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")
print("✓ Imports complete")

Using device: mps
✓ Imports complete


## Load Gatsby Corpus

In [15]:
print(f"Loading corpus: {CORPUS_PATH}")

with open(CORPUS_PATH, 'r', encoding='ascii') as f:
    corpus_text = f.read()

corpus_bytes = [b for b in corpus_text.encode('ascii') if b < VOCAB_SIZE]

unique_bytes = set(corpus_bytes)
dead_token_ids = sorted(set(range(VOCAB_SIZE)) - unique_bytes)
live_token_ids = sorted(unique_bytes)

print(f"  Total bytes: {len(corpus_bytes):,}")
print(f"  Live tokens: {len(live_token_ids)} / {VOCAB_SIZE}")
print(f"  Dead tokens: {len(dead_token_ids)} / {VOCAB_SIZE}")

print(f"\nTracking ALL {VOCAB_SIZE} tokens")

# Pre-load to device
corpus_tensor = torch.tensor(corpus_bytes, dtype=torch.long, device=device)
print(f"\n✓ Corpus on device")

Loading corpus: ../data/training_corpus.txt
  Total bytes: 265,905
  Live tokens: 77 / 128
  Dead tokens: 51 / 128

Tracking ALL 128 tokens

✓ Corpus on device


## Dataset

In [16]:
class ByteDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = ByteDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"✓ Dataset: {len(dataset):,} examples")

✓ Dataset: 265,777 examples


## Model

In [17]:
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

model = GPT2LMHeadModel(config)
model = model.to(torch.bfloat16).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model created: {total_params:,} parameters")

✓ Model created: 116,480 parameters


## Float32 Initialization

In [18]:
print(f"\nFloat32 initialization (σ = {INIT_SIGMA:.2e})\n")

with torch.no_grad():
    # Generate random unit vector in float32
    random_vector = torch.randn(HIDDEN_DIM, dtype=torch.float32, device=device)
    random_vector = random_vector / random_vector.norm()
    
    # Add Gaussian noise in float32
    noise = torch.randn(VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32, device=device) * INIT_SIGMA
    init_f32 = random_vector + noise
    
    # Convert to bfloat16 for training
    init_bf16 = init_f32.to(torch.bfloat16)
    
    # Assign to model
    model.transformer.wte.weight[:] = init_bf16
    
    print(f"Initialized: {VOCAB_SIZE} tokens at norm ~{init_f32.norm(dim=1).mean().item():.6f}")

print(f"\n✓ Embeddings initialized")


Float32 initialization (σ = 1.00e-05)

Initialized: 128 tokens at norm ~0.999999

✓ Embeddings initialized


## Instrumentation Callback

The challenge: we need to capture:
1. **Step 0**: Initial embeddings before any training
2. **Steps 1-N**: Gradients BEFORE optimizer step, and deltas AFTER

Strategy: Use a custom Trainer that exposes the right hooks.

In [19]:
class GradientDeltaRecorder:
    """Records initial state, gradients, and parameter updates for ALL tokens."""
    
    def __init__(self, vocab_size, record_every_n):
        self.vocab_size = vocab_size
        self.record_every_n = record_every_n
        
        # Storage
        self.recorded_steps = []
        self.grads = []      # [n_recorded, vocab_size, hidden_dim]
        self.deltas = []     # [n_recorded, vocab_size, hidden_dim]
        
        # Temporary storage for current step
        self.W_before = None
        self.grad_before = None
        self.current_step = 0
        self.recorded_initial = False
    
    def record_initial_state(self, model):
        """Record step 0: initial embeddings before any training."""
        if not self.recorded_initial:
            W_init = model.transformer.wte.weight.data.clone().cpu().float()
            
            # Step 0: no gradients yet, no deltas yet (use zeros)
            self.recorded_steps.append(0)
            self.grads.append(torch.zeros_like(W_init))
            self.deltas.append(torch.zeros_like(W_init))
            
            self.recorded_initial = True
            self.current_step = 1  # Next step will be step 1
    
    def record_before_step(self, model):
        """Call this after backward, before optimizer step."""
        if self.current_step % self.record_every_n == 0:
            self.W_before = model.transformer.wte.weight.data.clone().cpu().float()
            
            if model.transformer.wte.weight.grad is not None:
                self.grad_before = model.transformer.wte.weight.grad.clone().cpu().float()
            else:
                self.grad_before = torch.zeros_like(self.W_before)
    
    def record_after_step(self, model):
        """Call this after optimizer step."""
        if self.current_step % self.record_every_n == 0:
            # Guard: only record if we actually have W_before from training_step
            if self.W_before is not None:
                W_after = model.transformer.wte.weight.data.clone().cpu().float()
                delta = W_after - self.W_before
                
                # Record ALL tokens
                self.recorded_steps.append(self.current_step)
                self.grads.append(self.grad_before)
                self.deltas.append(delta)
                
                # Clear temp storage
                self.W_before = None
                self.grad_before = None
        
        self.current_step += 1
    
    def get_data(self):
        """Return recorded data as tensors."""
        return {
            'recorded_steps': torch.tensor(self.recorded_steps, dtype=torch.long),
            'grads': torch.stack(self.grads) if self.grads else torch.tensor([]),
            'deltas': torch.stack(self.deltas) if self.deltas else torch.tensor([]),
        }

print("✓ Recorder class defined")

✓ Recorder class defined


## Custom Trainer with Gradient Recording

In [20]:
class InstrumentedTrainer(Trainer):
    def __init__(self, recorder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.recorder = recorder

    def training_step(self, model, inputs, num_items_in_batch=None):
        """Override training_step to inject gradient recording."""
        # Standard forward + backward
        loss = super().training_step(model, inputs, num_items_in_batch)

        # Record BEFORE optimizer step (gradients are fresh)
        self.recorder.record_before_step(model)

        return loss

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, **kwargs):
        """Override to record AFTER optimizer step."""
        # Record AFTER optimizer has updated parameters
        self.recorder.record_after_step(model)

        # Call parent with all arguments
        super()._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, **kwargs)

print("✓ InstrumentedTrainer defined")

✓ InstrumentedTrainer defined


## Training Configuration

In [21]:
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

recorder = GradientDeltaRecorder(VOCAB_SIZE, RECORD_EVERY_N_STEPS)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    logging_steps=1000,
    save_steps=NUM_TRAIN_STEPS + 1,  # Don't save checkpoints
    save_total_limit=0,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    bf16=True,
    seed=RANDOM_SEED,
    report_to="none",
    disable_tqdm=False,
)

trainer = InstrumentedTrainer(
    recorder=recorder,
    model=model,
    args=training_args,
    train_dataset=dataset,
)

print("✓ Trainer ready")

✓ Trainer ready


## Record Initial State

In [22]:
print("Recording initial state (step 0)...")
recorder.record_initial_state(model)
print(f"✓ Step 0 recorded")

Recording initial state (step 0)...
✓ Step 0 recorded


## Train

In [23]:
print(f"\n{'='*80}")
print(f"Starting instrumented training")
print(f"  Steps: {NUM_TRAIN_STEPS} (plus initial step 0)")
print(f"  Recording every {RECORD_EVERY_N_STEPS} step(s)")
print(f"  Expected recordings: {NUM_TRAIN_STEPS // RECORD_EVERY_N_STEPS + 1}")
print(f"  Expected file size: ~{(NUM_TRAIN_STEPS // RECORD_EVERY_N_STEPS + 1) * VOCAB_SIZE * HIDDEN_DIM * 2 * 2 / 1e9:.2f} GB")
print(f"{'='*80}\n")

start_time = time.time()
trainer.train()
elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"✓ Training complete ({elapsed/60:.1f} min)")
print(f"{'='*80}")


Starting instrumented training
  Steps: 10000 (plus initial step 0)
  Recording every 1 step(s)
  Expected recordings: 10001
  Expected file size: ~0.33 GB



Step,Training Loss
1000,2.9279
2000,2.5873
3000,2.4081
4000,2.3234
5000,2.2755
6000,2.2436
7000,2.2277
8000,2.2164
9000,2.2092
10000,2.2068



✓ Training complete (1.7 min)


## Save Recorded Data

In [24]:
recorded_data = recorder.get_data()

save_dict = {
    'recorded_steps': recorded_data['recorded_steps'],
    'dead_token_ids': torch.tensor(dead_token_ids, dtype=torch.long),
    'live_token_ids': torch.tensor(live_token_ids, dtype=torch.long),
    'grads': recorded_data['grads'],
    'deltas': recorded_data['deltas'],
    'init_sigma': torch.tensor(INIT_SIGMA, dtype=torch.float32),
    'learning_rate': torch.tensor(LEARNING_RATE, dtype=torch.float32),
    'weight_decay': torch.tensor(WEIGHT_DECAY, dtype=torch.float32),
}

output_path = Path(OUTPUT_DIR) / "gradient_delta_history.safetensors"
save_file(save_dict, output_path)

file_size_mb = output_path.stat().st_size / 1e6

print(f"✓ Saved to: {output_path}")
print(f"  File size: {file_size_mb:.1f} MB")
print(f"  Recorded steps: {len(recorded_data['recorded_steps'])}")
print(f"  Step range: {recorded_data['recorded_steps'][0]} to {recorded_data['recorded_steps'][-1]}")
print(f"  Shape - grads: {recorded_data['grads'].shape}")
print(f"  Shape - deltas: {recorded_data['deltas'].shape}")

✓ Saved to: ../data/instrumented_run/gradient_delta_history.safetensors
  File size: 655.5 MB
  Recorded steps: 10001
  Step range: 0 to 10001
  Shape - grads: torch.Size([10001, 128, 64])
  Shape - deltas: torch.Size([10001, 128, 64])


## Quick Analysis: Do Dead Tokens Have Gradients?

In [25]:
print(f"\n{'='*80}")
print(f"GRADIENT ANALYSIS")
print(f"{'='*80}\n")

grads = recorded_data['grads']  # [n_recorded, vocab_size, hidden_dim]

if len(grads) > 1:  # Need at least step 0 and step 1
    # Split into dead and live
    dead_grads = grads[:, dead_token_ids, :]
    live_grads = grads[:, live_token_ids, :]
    
    # Compute L2 norm of gradients for each token at each step
    dead_grad_norms = torch.norm(dead_grads, p=2, dim=2)  # [n_recorded, n_dead]
    live_grad_norms = torch.norm(live_grads, p=2, dim=2)  # [n_recorded, n_live]
    
    # Statistics (skip step 0 which has zero grads by construction)
    print(f"Dead token gradients (L2 norm, n={len(dead_token_ids)}):")
    print(f"  First training step (step 1):")
    print(f"    Mean: {dead_grad_norms[1].mean().item():.6e}")
    print(f"    Max: {dead_grad_norms[1].max().item():.6e}")
    print(f"    Min: {dead_grad_norms[1].min().item():.6e}")
    
    print(f"\n  Last recorded step (step {recorded_data['recorded_steps'][-1].item()}):")
    print(f"    Mean: {dead_grad_norms[-1].mean().item():.6e}")
    print(f"    Max: {dead_grad_norms[-1].max().item():.6e}")
    print(f"    Min: {dead_grad_norms[-1].min().item():.6e}")
    
    print(f"\nLive token gradients (L2 norm, n={len(live_token_ids)}):")
    print(f"  First training step (step 1):")
    print(f"    Mean: {live_grad_norms[1].mean().item():.6e}")
    print(f"    Max: {live_grad_norms[1].max().item():.6e}")
    
    print(f"\n  Last recorded step (step {recorded_data['recorded_steps'][-1].item()}):")
    print(f"    Mean: {live_grad_norms[-1].mean().item():.6e}")
    print(f"    Max: {live_grad_norms[-1].max().item():.6e}")
    
    # Key question: Are dead grads exactly zero? (exclude step 0)
    dead_max_grad = dead_grad_norms[1:].max().item()
    
    print(f"\n{'='*80}")
    if dead_max_grad == 0.0:
        print(f"RESULT: Dead token gradients are EXACTLY ZERO")
        print(f"  → Movement is purely from weight decay")
    else:
        print(f"RESULT: Dead token gradients are NON-ZERO!")
        print(f"  → Maximum gradient norm: {dead_max_grad:.6e}")
        print(f"  → Hypothesis CONFIRMED: Dead tokens receive loss gradients")
    print(f"{'='*80}")
else:
    print("Not enough data recorded (need at least 2 steps)")


GRADIENT ANALYSIS

Dead token gradients (L2 norm, n=51):
  First training step (step 1):
    Mean: 6.121443e-02
    Max: 6.125058e-02
    Min: 6.121239e-02

  Last recorded step (step 10001):
    Mean: 9.438379e-04
    Max: 9.453714e-04
    Min: 9.427978e-04

Live token gradients (L2 norm, n=77):
  First training step (step 1):
    Mean: 1.160393e-01
    Max: 1.241577e+00

  Last recorded step (step 10001):
    Mean: 6.798930e-02
    Max: 8.619771e-01

RESULT: Dead token gradients are NON-ZERO!
  → Maximum gradient norm: 6.125058e-02
  → Hypothesis CONFIRMED: Dead tokens receive loss gradients


## Summary

This notebook trains the Gatsby model and records:
- **Step 0**: Initial embeddings (before training)
- **Steps 1-N**: Raw gradients (∂loss/∂W) after backward pass
- **Steps 1-N**: Parameter updates (ΔW) after optimizer step

The critical test: **Do dead tokens have non-zero gradients?**

If YES → They receive updates from cross-entropy loss (model predicting them by accident)  
If NO → They only move due to weight decay (regularization)

**Next:** Analyze the recorded data in detail to understand gradient magnitudes and dynamics over time.