# 1.12f: Wordybird 1b - Clean 200-Step Run (No Checkpoint Artifact)

**Experiment:** Exact copy of Wordybird 1, but run for 200 steps straight through without checkpointing.

## The Checkpoint Artifact Problem

**Wordybird 1 (0-100):** Sharp freeze around step 42, nearly complete by step 70.

**Wordybird 3 (100-200, resumed from checkpoint):** Massive spike at step 100 when training resumes, then rapid decay and freeze by ~step 125.

**The Artifact:** WB1+WB3 stitched together shows an unnatural "kick" at the seam (step 100) that doesn't match continuous training dynamics.

## Test

Run Wordybird 1b: **same seed, same hyperparameters, same everything** as WB1, but:
- Train for 200 steps continuously (no checkpointing)
- Record every step throughout

This will show us the TRUE dynamics of steps 0-200 without checkpoint artifacts.

## Wordybird 1b Parameters

**Identical to Wordybird 1 except:**
- **Steps: 200** (was 100)
- **Output: `1.12f_wordybird_1b.safetensors`**

Everything else unchanged:
- Same random seed (42)
- Same model architecture (2 layers, 2 heads, 64D)
- Same training config (batch=32, lr=0.001, Adam)
- Same corpus (FineWeb 2MB)
- Same initialization (N(0, 0.02) bfloat16)

## Parameters

In [1]:
# Model architecture
VOCAB_SIZE = 50257  # GPT-2
HIDDEN_DIM = 64
N_LAYER = 2
N_HEAD = 2
MAX_SEQ_LEN = 128

# Training
BATCH_SIZE = 32
NUM_TRAIN_STEPS = 200  # ← CHANGED from 100
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.0

# Optimizer: Adam
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.999
ADAM_EPSILON = 1e-8

# Initialization (bfloat16 native)
INIT_SCALE = 0.02  # N(0, 0.02)

# Data
CORPUS_PATH = "../data/fineweb_2mb_unicode.txt"
TOKEN_MASK_PATH = "../tensors/Wordybird/fineweb_token_masks.safetensors"
OUTPUT_DIR = "../tensors/Wordybird"
OUTPUT_FILE = "1.12f_wordybird_1b.safetensors"  # ← CHANGED

# Instrumentation
RECORD_EVERY_N_STEPS = 1  # Record every step

RANDOM_SEED = 42

## Imports

In [2]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file, load_file
import time

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print("✓ Imports complete")

✓ Imports complete


## Device Detection

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Tokenizer

In [4]:
print("Loading GPT-2 tokenizer...\n")

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

print(f"✓ Loaded GPT-2 tokenizer")
print(f"  Vocabulary size: {len(tokenizer):,} tokens")

Loading GPT-2 tokenizer...

✓ Loaded GPT-2 tokenizer
  Vocabulary size: 50,257 tokens


## Load Corpus and Tokenize

In [5]:
print(f"\nLoading corpus: {CORPUS_PATH}\n")

with open(CORPUS_PATH, 'r', encoding='utf-8') as f:
    corpus_text = f.read()

corpus_bytes = len(corpus_text.encode('utf-8'))
corpus_mb = corpus_bytes / (1024 * 1024)

print(f"✓ Loaded corpus")
print(f"  Size: {corpus_mb:.2f} MB")
print(f"  Characters: {len(corpus_text):,}")
print()

# Tokenize
print("Tokenizing corpus...\n")
tokens = tokenizer.encode(corpus_text)

print(f"✓ Tokenized")
print(f"  Tokens: {len(tokens):,}")
print()

# Pre-load to device
corpus_tensor = torch.tensor(tokens, dtype=torch.long, device=device)
print(f"✓ Corpus on device: {device}")


Loading corpus: ../data/fineweb_2mb_unicode.txt

✓ Loaded corpus
  Size: 2.00 MB
  Characters: 2,089,201

Tokenizing corpus...



Token indices sequence length is longer than the specified maximum sequence length for this model (475160 > 1024). Running this sequence through the model will result in indexing errors


✓ Tokenized
  Tokens: 475,160

✓ Corpus on device: mps


## Load Token Masks

In [6]:
print(f"\nLoading token masks: {TOKEN_MASK_PATH}\n")

mask_data = load_file(TOKEN_MASK_PATH)
trained_mask = mask_data['trained_mask']
untrained_mask = mask_data['untrained_mask']
trained_indices = mask_data['trained_indices']
untrained_indices = mask_data['untrained_indices']

n_trained = trained_mask.sum().item()
n_untrained = untrained_mask.sum().item()

print(f"✓ Loaded token masks")
print(f"  Trained tokens: {n_trained:,} ({100*n_trained/VOCAB_SIZE:.1f}%)")
print(f"  Untrained tokens: {n_untrained:,} ({100*n_untrained/VOCAB_SIZE:.1f}%)")


Loading token masks: ../tensors/Wordybird/fineweb_token_masks.safetensors

✓ Loaded token masks
  Trained tokens: 30,590 (60.9%)
  Untrained tokens: 19,667 (39.1%)


## Dataset

In [7]:
class TokenDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = TokenDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"\n✓ Dataset: {len(dataset):,} examples")


✓ Dataset: 475,032 examples


## Model

In [8]:
print(f"\nCreating model...\n")

config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

model = GPT2LMHeadModel(config)
model = model.to(torch.bfloat16).to(device)

total_params = sum(p.numel() for p in model.parameters())
embedding_params = model.transformer.wte.weight.numel()

print(f"✓ Model created")
print(f"  Total parameters: {total_params:,}")
print(f"  Embedding parameters (E+W): {embedding_params:,}")
print(f"  Other parameters: {total_params - embedding_params:,}")


Creating model...

✓ Model created
  Total parameters: 3,324,736
  Embedding parameters (E+W): 3,216,448
  Other parameters: 108,288


## Initialization (bfloat16 Native)

In [9]:
print(f"\n{'='*80}")
print(f"INITIALIZING: N(0, {INIT_SCALE}) bfloat16-native")
print(f"{'='*80}\n")

torch.manual_seed(RANDOM_SEED)

# Standard GPT-2 init: each embedding drawn independently from N(0, 0.02)
# Generate in float32, immediately convert to bfloat16
init_f32 = torch.randn(VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32, device=device) * INIT_SCALE
init_bf16 = init_f32.to(torch.bfloat16)

print(f"Initialization: N(0, {INIT_SCALE})")
print(f"  Shape: {init_bf16.shape}")
print(f"  Each token initialized independently")
print()

# Assign to model
with torch.no_grad():
    model.transformer.wte.weight[:] = init_bf16

print(f"✓ Initialized embeddings (pure bfloat16)")
print(f"  Shape: {model.transformer.wte.weight.shape}")
print(f"  Dtype: {model.transformer.wte.weight.dtype}")
print()

# Verify initialization stats
W_check = model.transformer.wte.weight.cpu().float()
W_untrained = W_check[untrained_indices]

centroid = W_untrained.mean(dim=0)
centroid_norm = torch.norm(centroid).item()
radii = torch.norm(W_untrained - centroid, dim=1)
mean_radius = radii.mean().item()
max_radius = radii.max().item()

print(f"Initial untrained token statistics ({n_untrained:,} tokens):")
print(f"  Centroid norm: {centroid_norm:.6f}")
print(f"  Mean radius from centroid: {mean_radius:.6f}")
print(f"  Max radius from centroid: {max_radius:.6f}")
print(f"  Bounding hypersphere volume ∝ R^{HIDDEN_DIM} = {max_radius**HIDDEN_DIM:.2e}")
print()
print(f"  ✓ Tokens distributed in hypersphere (standard init)")
print(f"\n{'='*80}\n")


INITIALIZING: N(0, 0.02) bfloat16-native

Initialization: N(0, 0.02)
  Shape: torch.Size([50257, 64])
  Each token initialized independently

✓ Initialized embeddings (pure bfloat16)
  Shape: torch.Size([50257, 64])
  Dtype: torch.bfloat16

Initial untrained token statistics (19,667 tokens):
  Centroid norm: 0.001090
  Mean radius from centroid: 0.159273
  Max radius from centroid: 0.219594
  Bounding hypersphere volume ∝ R^64 = 7.31e-43

  ✓ Tokens distributed in hypersphere (standard init)




## Comprehensive Recorder

In [10]:
class ComprehensiveRecorder:
    """Records embeddings, gradients, optimizer state, logits, loss at every step in bfloat16."""
    
    def __init__(self, vocab_size, hidden_dim, record_every_n):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.record_every_n = record_every_n
        
        # Storage (lists of tensors, keep in RAM)
        self.recorded_steps = []
        self.embeddings = []      # [n_recorded, vocab_size, hidden_dim]
        self.grads = []           # [n_recorded, vocab_size, hidden_dim]
        self.momentum = []        # [n_recorded, vocab_size, hidden_dim]
        self.variance = []        # [n_recorded, vocab_size, hidden_dim]
        self.logits = []          # [n_recorded, vocab_size]
        self.losses = []          # [n_recorded]
        
        # Temporary storage
        self.current_step = 0
        self.recorded_initial = False
        self.grad_before = None
        self.loss_value = None
        self.logits_sample = None
    
    def record_initial_state(self, model, optimizer):
        """Record step 0: initial state before training."""
        if not self.recorded_initial:
            W = model.transformer.wte.weight.data.clone().cpu().bfloat16()
            
            # Step 0: no gradients, no optimizer state yet (zeros)
            self.recorded_steps.append(0)
            self.embeddings.append(W)
            self.grads.append(torch.zeros_like(W))
            self.momentum.append(torch.zeros_like(W))
            self.variance.append(torch.zeros_like(W))
            self.logits.append(torch.zeros(self.vocab_size, dtype=torch.bfloat16))
            self.losses.append(torch.tensor(float('nan'), dtype=torch.bfloat16))  # No loss yet
            
            self.recorded_initial = True
            self.current_step = 1
            
            print(f"✓ Recorded initial state (step 0)")
    
    def record_before_step(self, model, loss, logits):
        """Call after forward/backward, before optimizer step."""
        if self.current_step % self.record_every_n == 0:
            # Capture gradients in bfloat16
            if model.transformer.wte.weight.grad is not None:
                self.grad_before = model.transformer.wte.weight.grad.clone().cpu().bfloat16()
            else:
                self.grad_before = torch.zeros(self.vocab_size, self.hidden_dim, dtype=torch.bfloat16)
            
            # Capture loss
            self.loss_value = loss.item()
            
            # Capture logits from first sequence, last position in bfloat16
            self.logits_sample = logits[0, -1, :].detach().cpu().bfloat16()
    
    def record_after_step(self, model, optimizer):
        """Call after optimizer step."""
        if self.current_step % self.record_every_n == 0:
            if self.grad_before is not None and self.loss_value is not None:
                # Capture embeddings in bfloat16
                W = model.transformer.wte.weight.data.clone().cpu().bfloat16()

                # Capture optimizer state (Adam momentum and variance)
                param = model.transformer.wte.weight
                if param in optimizer.state:
                    state = optimizer.state[param]
                    # Get state tensors if they exist, convert to bfloat16
                    mom_src = state.get('exp_avg', None)
                    var_src = state.get('exp_avg_sq', None)
                    mom = mom_src.clone().cpu().bfloat16() if mom_src is not None else torch.zeros_like(W)
                    var = var_src.clone().cpu().bfloat16() if var_src is not None else torch.zeros_like(W)
                else:
                    mom = torch.zeros_like(W)
                    var = torch.zeros_like(W)

                # Store everything
                self.recorded_steps.append(self.current_step)
                self.embeddings.append(W)
                self.grads.append(self.grad_before)
                self.momentum.append(mom)
                self.variance.append(var)
                self.logits.append(self.logits_sample)
                self.losses.append(torch.tensor(self.loss_value, dtype=torch.bfloat16))

                # Clear temp storage
                self.grad_before = None
                self.loss_value = None
                self.logits_sample = None
                
                # Progress indicator every 20 steps
                if self.current_step % 20 == 0:
                    print(f"  Recorded step {self.current_step}")

        self.current_step += 1
    
    def get_data(self):
        """Return recorded data as stacked tensors."""
        print(f"\nStacking {len(self.embeddings)} recorded states...")
        
        return {
            'recorded_steps': torch.tensor(self.recorded_steps, dtype=torch.long),
            'embeddings': torch.stack(self.embeddings) if self.embeddings else torch.tensor([]),
            'grads': torch.stack(self.grads) if self.grads else torch.tensor([]),
            'momentum': torch.stack(self.momentum) if self.momentum else torch.tensor([]),
            'variance': torch.stack(self.variance) if self.variance else torch.tensor([]),
            'logits': torch.stack(self.logits) if self.logits else torch.tensor([]),
            'losses': torch.stack(self.losses) if self.losses else torch.tensor([]),
        }

print("✓ Recorder class defined")

✓ Recorder class defined


## Custom Trainer with Instrumentation

In [11]:
class InstrumentedTrainer(Trainer):
    def __init__(self, recorder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.recorder = recorder
        self.last_logits = None

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Override to capture logits."""
        outputs = model(**inputs)
        loss = outputs.loss
        
        # Store logits for recorder
        self.last_logits = outputs.logits
        
        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        """Override to inject recording."""
        # Standard forward + backward
        loss = super().training_step(model, inputs, num_items_in_batch)
        
        # Record BEFORE optimizer step
        self.recorder.record_before_step(model, loss, self.last_logits)
        
        return loss

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, **kwargs):
        """Override to record AFTER optimizer step."""
        # Record AFTER optimizer updates parameters
        self.recorder.record_after_step(model, self.optimizer)
        
        # Call parent
        super()._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, **kwargs)

print("✓ InstrumentedTrainer defined")

✓ InstrumentedTrainer defined


## Training Configuration

In [12]:
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

recorder = ComprehensiveRecorder(VOCAB_SIZE, HIDDEN_DIM, RECORD_EVERY_N_STEPS)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    adam_beta1=ADAM_BETA1,
    adam_beta2=ADAM_BETA2,
    adam_epsilon=ADAM_EPSILON,
    optim="adamw_torch",
    logging_steps=20,
    save_steps=NUM_TRAIN_STEPS + 1,  # Don't save checkpoints
    save_total_limit=0,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    bf16=True,  # Native bfloat16 training
    seed=RANDOM_SEED,
    report_to="none",
    disable_tqdm=False,
)

trainer = InstrumentedTrainer(
    recorder=recorder,
    model=model,
    args=training_args,
    train_dataset=dataset,
)

print("\n✓ Trainer ready (Adam, bf16=True, batch_size=32)")


✓ Trainer ready (Adam, bf16=True, batch_size=32)


## Record Initial State

In [13]:
print()
recorder.record_initial_state(model, trainer.optimizer)


✓ Recorded initial state (step 0)


## Train

**200 steps should take ~20-30 seconds.**

In [14]:
print(f"\n{'='*80}")
print(f"STARTING WORDYBIRD 1B TRAINING (CLEAN 200-STEP RUN)")
print(f"{'='*80}")
print(f"\nConfiguration:")
print(f"  Vocabulary: {VOCAB_SIZE:,} tokens (GPT-2)")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Trained tokens: {n_trained:,} ({100*n_trained/VOCAB_SIZE:.1f}%)")
print(f"  Untrained tokens: {n_untrained:,} ({100*n_untrained/VOCAB_SIZE:.1f}%)")
print()
print(f"  Initialization: N(0, {INIT_SCALE}) bfloat16-native")
print(f"  Optimizer: Adam (lr={LEARNING_RATE})")
print(f"  Precision: bfloat16 (native)")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Steps: {NUM_TRAIN_STEPS} (continuous, no checkpointing)")
print(f"  Recording: every step")
print(f"\n{'='*80}\n")

start_time = time.time()
trainer.train()
elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"✓ Training complete")
print(f"  Elapsed time: {elapsed:.1f} seconds")
print(f"  Throughput: {NUM_TRAIN_STEPS / elapsed:.1f} steps/second")
print(f"{'='*80}")


STARTING WORDYBIRD 1B TRAINING (CLEAN 200-STEP RUN)

Configuration:
  Vocabulary: 50,257 tokens (GPT-2)
  Hidden dim: 64
  Trained tokens: 30,590 (60.9%)
  Untrained tokens: 19,667 (39.1%)

  Initialization: N(0, 0.02) bfloat16-native
  Optimizer: Adam (lr=0.001)
  Precision: bfloat16 (native)
  Batch size: 32
  Steps: 200 (continuous, no checkpointing)
  Recording: every step




`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
20,10.2193
40,9.0734
60,8.3722
80,7.9931
100,7.7838
120,7.6844
140,7.6415
160,7.6503
180,7.6561
200,7.6354


  Recorded step 20
  Recorded step 40
  Recorded step 60
  Recorded step 80
  Recorded step 100
  Recorded step 120
  Recorded step 140
  Recorded step 160
  Recorded step 180
  Recorded step 200

✓ Training complete
  Elapsed time: 23.9 seconds
  Throughput: 8.4 steps/second


## Save Recorded Data

In [15]:
print(f"\nPreparing data for save...\n")

recorded_data = recorder.get_data()

save_dict = {
    'recorded_steps': recorded_data['recorded_steps'],
    'embeddings': recorded_data['embeddings'],
    'grads': recorded_data['grads'],
    'momentum': recorded_data['momentum'],
    'variance': recorded_data['variance'],
    'logits': recorded_data['logits'],
    'losses': recorded_data['losses'],
    # Metadata
    'init_scale': torch.tensor(INIT_SCALE, dtype=torch.float32),
    'learning_rate': torch.tensor(LEARNING_RATE, dtype=torch.float32),
    'weight_decay': torch.tensor(WEIGHT_DECAY, dtype=torch.float32),
    'adam_beta1': torch.tensor(ADAM_BETA1, dtype=torch.float32),
    'adam_beta2': torch.tensor(ADAM_BETA2, dtype=torch.float32),
    'n_trained': torch.tensor(n_trained, dtype=torch.long),
    'n_untrained': torch.tensor(n_untrained, dtype=torch.long),
}

output_path = Path(OUTPUT_DIR) / OUTPUT_FILE

print(f"Saving to: {output_path}")

save_start = time.time()
save_file(save_dict, str(output_path))
save_elapsed = time.time() - save_start

file_size_mb = output_path.stat().st_size / 1e6

print(f"\n✓ Saved successfully")
print(f"  File: {output_path}")
print(f"  Size: {file_size_mb:.1f} MB")
print(f"  Save time: {save_elapsed:.1f} seconds")
print(f"  Recorded steps: {len(recorded_data['recorded_steps'])}")
print(f"  Step range: {recorded_data['recorded_steps'][0]} to {recorded_data['recorded_steps'][-1]}")


Preparing data for save...


Stacking 201 recorded states...
Saving to: ../tensors/Wordybird/1.12f_wordybird_1b.safetensors

✓ Saved successfully
  File: ../tensors/Wordybird/1.12f_wordybird_1b.safetensors
  Size: 5192.3 MB
  Save time: 4.4 seconds
  Recorded steps: 201
  Step range: 0 to 200


## Quick Verification

In [16]:
print(f"\n{'='*80}")
print(f"QUICK VERIFICATION")
print(f"{'='*80}\n")

embeddings = recorded_data['embeddings']
losses = recorded_data['losses']

print(f"Data shapes:")
print(f"  embeddings: {embeddings.shape}")
print(f"  losses: {losses.shape}")
print()

# Analyze untrained token movement
W_step0 = embeddings[0, untrained_indices].float()
W_step200 = embeddings[-1, untrained_indices].float()

# Compute displacements
displacements = torch.norm(W_step200 - W_step0, dim=1)
max_displacement = displacements.max().item()
mean_displacement = displacements.mean().item()
median_displacement = displacements.median().item()

print(f"Untrained token displacement (steps 0 → 200):")
print(f"  Max: {max_displacement:.2e}")
print(f"  Mean: {mean_displacement:.2e}")
print(f"  Median: {median_displacement:.2e}")
print()

print(f"Loss trajectory:")
print(f"  Step 1: {losses[1].float().item():.4f}")
print(f"  Step 200: {losses[-1].float().item():.4f}")
print(f"  Reduction: {(losses[1].float() - losses[-1].float()).item():.4f}")

print(f"\n{'='*80}")


QUICK VERIFICATION

Data shapes:
  embeddings: torch.Size([201, 50257, 64])
  losses: torch.Size([201])

Untrained token displacement (steps 0 → 200):
  Max: 4.55e-01
  Mean: 4.34e-01
  Median: 4.34e-01

Loss trajectory:
  Step 1: 10.8125
  Step 200: 7.7188
  Reduction: 3.0938



## Summary

In [17]:
print(f"\n{'='*80}")
print(f"WORDYBIRD 1B COMPLETE")
print(f"{'='*80}\n")

print(f"Experiment: Clean 200-step run (no checkpoint artifact)")
print(f"  Steps: 0-200 continuous")
print(f"  Data saved: {output_path}")
print(f"  Size: {file_size_mb:.1f} MB")
print()
print(f"Next steps:")
print(f"  1. Run 1.17b lattice hop analysis on WB1b")
print(f"  2. Compare to WB1+WB3 stitched dynamics")
print(f"  3. Determine if step-100 'kick' was a checkpoint artifact")

print(f"\n{'='*80}")


WORDYBIRD 1B COMPLETE

Experiment: Clean 200-step run (no checkpoint artifact)
  Steps: 0-200 continuous
  Data saved: ../tensors/Wordybird/1.12f_wordybird_1b.safetensors
  Size: 5192.3 MB

Next steps:
  1. Run 1.17b lattice hop analysis on WB1b
  2. Compare to WB1+WB3 stitched dynamics
  3. Determine if step-100 'kick' was a checkpoint artifact

