# 1.12e: Wordybird 3 - Continuation from Checkpoint 100

**Experiment:** Load Wordybird 1's final state (step 100) and continue training for another 100 steps.

## The Mystery

**Wordybird 1 (steps 0-100):** Lattice hop analysis (1.17b) showed apparent freeze around step 42, with mean displacement magnitude dropping to ~1e-5 by step 100.

**Wordybird 2 (steps 0-1000):** Same analysis showed gradual freeze continuing all the way to step 900!

**The Question:** Was Wordybird 1 truly frozen at step 100, or was it just the beginning of a longer freeze process?

## Test

If WB1 was **truly frozen** at step 100:
- WB3 (steps 101-200) should show zero non-lattice movement
- Tokens remain stuck in lattice cells
- Loss might still improve (attention keeps learning)

If WB1 was **not frozen** (threshold issue or incomplete freeze):
- WB3 should show continued movement
- Gradual freeze continues past step 100
- Matches WB2's gradual decay pattern

## Wordybird 3 Parameters

**Starting from:**
- Checkpoint: `box_3/tensors/Wordybird/checkpoint-100`
- Model state: WB1 step 100 (embeddings, attention layers)
- Optimizer state: Adam momentum & variance preserved

**Training:**
- **Steps: 100** (continuing from 100 → 200)
- All other hyperparameters identical to WB1
- Record every step (101 snapshots total)

**Output:** `1.12e_wordybird_3.safetensors`

## Parameters

In [25]:
# Model architecture (unchanged)
VOCAB_SIZE = 50257  # GPT-2
HIDDEN_DIM = 64
N_LAYER = 2
N_HEAD = 2
MAX_SEQ_LEN = 128

# Training
BATCH_SIZE = 32
NUM_TRAIN_STEPS = 100  # Continue for another 100 steps
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.0

# Optimizer: Adam (state will be loaded from checkpoint)
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.999
ADAM_EPSILON = 1e-8

# Data
CORPUS_PATH = "../data/fineweb_2mb_unicode.txt"
TOKEN_MASK_PATH = "../tensors/Wordybird/fineweb_token_masks.safetensors"
CHECKPOINT_PATH = "../tensors/Wordybird/checkpoint-100"  # ← Load from here
OUTPUT_DIR = "../tensors/Wordybird"
OUTPUT_FILE = "1.12e_wordybird_3.safetensors"  # ← New output

# Instrumentation
RECORD_EVERY_N_STEPS = 1  # Record every step

RANDOM_SEED = 42

## Imports

In [26]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file, load_file
import time

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print("✓ Imports complete")

✓ Imports complete


## Device Detection

In [27]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Tokenizer

In [28]:
print("Loading GPT-2 tokenizer...\n")

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

print(f"✓ Loaded GPT-2 tokenizer")
print(f"  Vocabulary size: {len(tokenizer):,} tokens")

Loading GPT-2 tokenizer...

✓ Loaded GPT-2 tokenizer
  Vocabulary size: 50,257 tokens


## Load Corpus and Tokenize

In [29]:
print(f"\nLoading corpus: {CORPUS_PATH}\n")

with open(CORPUS_PATH, 'r', encoding='utf-8') as f:
    corpus_text = f.read()

corpus_bytes = len(corpus_text.encode('utf-8'))
corpus_mb = corpus_bytes / (1024 * 1024)

print(f"✓ Loaded corpus")
print(f"  Size: {corpus_mb:.2f} MB")
print(f"  Characters: {len(corpus_text):,}")
print()

# Tokenize
print("Tokenizing corpus...\n")
tokens = tokenizer.encode(corpus_text)

print(f"✓ Tokenized")
print(f"  Tokens: {len(tokens):,}")
print()

# Pre-load to device
corpus_tensor = torch.tensor(tokens, dtype=torch.long, device=device)
print(f"✓ Corpus on device: {device}")


Loading corpus: ../data/fineweb_2mb_unicode.txt

✓ Loaded corpus
  Size: 2.00 MB
  Characters: 2,089,201

Tokenizing corpus...



Token indices sequence length is longer than the specified maximum sequence length for this model (475160 > 1024). Running this sequence through the model will result in indexing errors


✓ Tokenized
  Tokens: 475,160

✓ Corpus on device: mps


## Load Token Masks

In [30]:
print(f"\nLoading token masks: {TOKEN_MASK_PATH}\n")

mask_data = load_file(TOKEN_MASK_PATH)
trained_mask = mask_data['trained_mask']
untrained_mask = mask_data['untrained_mask']
trained_indices = mask_data['trained_indices']
untrained_indices = mask_data['untrained_indices']

n_trained = trained_mask.sum().item()
n_untrained = untrained_mask.sum().item()

print(f"✓ Loaded token masks")
print(f"  Trained tokens: {n_trained:,} ({100*n_trained/VOCAB_SIZE:.1f}%)")
print(f"  Untrained tokens: {n_untrained:,} ({100*n_untrained/VOCAB_SIZE:.1f}%)")


Loading token masks: ../tensors/Wordybird/fineweb_token_masks.safetensors

✓ Loaded token masks
  Trained tokens: 30,590 (60.9%)
  Untrained tokens: 19,667 (39.1%)


## Dataset

In [31]:
class TokenDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = TokenDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"\n✓ Dataset: {len(dataset):,} examples")


✓ Dataset: 475,032 examples


## Load Model from Checkpoint

In [32]:
print(f"\n{'='*80}")
print(f"LOADING WORDYBIRD 1 CHECKPOINT (STEP 100)")
print(f"{'='*80}\n")

print(f"Loading from: {CHECKPOINT_PATH}\n")

# Load model (this includes all weights at step 100)
model = GPT2LMHeadModel.from_pretrained(CHECKPOINT_PATH)
model = model.to(torch.bfloat16).to(device)

total_params = sum(p.numel() for p in model.parameters())
embedding_params = model.transformer.wte.weight.numel()

print(f"✓ Model loaded from checkpoint")
print(f"  Total parameters: {total_params:,}")
print(f"  Embedding parameters: {embedding_params:,}")
print(f"  Model dtype: {model.transformer.wte.weight.dtype}")
print()

# Verify initial state
W_initial = model.transformer.wte.weight.cpu().float()
W_untrained_initial = W_initial[untrained_indices]
centroid_initial = W_untrained_initial.mean(dim=0)
radius_initial = torch.norm(W_untrained_initial - centroid_initial, dim=1).max().item()

print(f"Initial state (step 100 from WB1):")
print(f"  Untrained centroid norm: {torch.norm(centroid_initial).item():.6f}")
print(f"  Bounding radius: {radius_initial:.6f}")

print(f"\n{'='*80}\n")


LOADING WORDYBIRD 1 CHECKPOINT (STEP 100)

Loading from: ../tensors/Wordybird/checkpoint-100

✓ Model loaded from checkpoint
  Total parameters: 3,324,736
  Embedding parameters: 3,216,448
  Model dtype: torch.bfloat16

Initial state (step 100 from WB1):
  Untrained centroid norm: 0.325751
  Bounding radius: 0.211612




## Comprehensive Recorder

In [33]:
class ComprehensiveRecorder:
    """Records embeddings, gradients, optimizer state, logits, loss at every step in bfloat16."""
    
    def __init__(self, vocab_size, hidden_dim, record_every_n, starting_step=100):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.record_every_n = record_every_n
        self.starting_step = starting_step  # ← NEW: track that we start at 100
        
        # Storage
        self.recorded_steps = []
        self.embeddings = []
        self.grads = []
        self.momentum = []
        self.variance = []
        self.logits = []
        self.losses = []
        
        # Temporary storage
        self.current_step = starting_step
        self.recorded_initial = False
        self.grad_before = None
        self.loss_value = None
        self.logits_sample = None
    
    def record_initial_state(self, model, optimizer=None):
        """Record step 100: initial state from checkpoint."""
        if not self.recorded_initial:
            W = model.transformer.wte.weight.data.clone().cpu().bfloat16()
            
            # Extract optimizer state (if optimizer exists)
            param = model.transformer.wte.weight
            if optimizer is not None and param in optimizer.state:
                state = optimizer.state[param]
                mom_src = state.get('exp_avg', None)
                var_src = state.get('exp_avg_sq', None)
                mom = mom_src.clone().cpu().bfloat16() if mom_src is not None else torch.zeros_like(W)
                var = var_src.clone().cpu().bfloat16() if var_src is not None else torch.zeros_like(W)
            else:
                # Optimizer not initialized yet, use zeros
                mom = torch.zeros_like(W)
                var = torch.zeros_like(W)
            
            self.recorded_steps.append(self.starting_step)
            self.embeddings.append(W)
            self.grads.append(torch.zeros_like(W))  # No grad at checkpoint load
            self.momentum.append(mom)
            self.variance.append(var)
            self.logits.append(torch.zeros(self.vocab_size, dtype=torch.bfloat16))
            self.losses.append(torch.tensor(float('nan'), dtype=torch.bfloat16))
            
            self.recorded_initial = True
            self.current_step = self.starting_step + 1
            
            opt_status = "with optimizer state" if (optimizer is not None and param in optimizer.state) else "without optimizer state (will load on first step)"
            print(f"✓ Recorded initial state (step {self.starting_step} from checkpoint, {opt_status})")
    
    def record_before_step(self, model, loss, logits):
        """Call after forward/backward, before optimizer step."""
        if self.current_step % self.record_every_n == 0:
            if model.transformer.wte.weight.grad is not None:
                self.grad_before = model.transformer.wte.weight.grad.clone().cpu().bfloat16()
            else:
                self.grad_before = torch.zeros(self.vocab_size, self.hidden_dim, dtype=torch.bfloat16)
            
            self.loss_value = loss.item()
            self.logits_sample = logits[0, -1, :].detach().cpu().bfloat16()
    
    def record_after_step(self, model, optimizer):
        """Call after optimizer step."""
        if self.current_step % self.record_every_n == 0:
            if self.grad_before is not None and self.loss_value is not None:
                W = model.transformer.wte.weight.data.clone().cpu().bfloat16()

                param = model.transformer.wte.weight
                if param in optimizer.state:
                    state = optimizer.state[param]
                    mom_src = state.get('exp_avg', None)
                    var_src = state.get('exp_avg_sq', None)
                    mom = mom_src.clone().cpu().bfloat16() if mom_src is not None else torch.zeros_like(W)
                    var = var_src.clone().cpu().bfloat16() if var_src is not None else torch.zeros_like(W)
                else:
                    mom = torch.zeros_like(W)
                    var = torch.zeros_like(W)

                self.recorded_steps.append(self.current_step)
                self.embeddings.append(W)
                self.grads.append(self.grad_before)
                self.momentum.append(mom)
                self.variance.append(var)
                self.logits.append(self.logits_sample)
                self.losses.append(torch.tensor(self.loss_value, dtype=torch.bfloat16))

                self.grad_before = None
                self.loss_value = None
                self.logits_sample = None
                
                if self.current_step % 10 == 0:
                    print(f"  Recorded step {self.current_step}")

        self.current_step += 1
    
    def get_data(self):
        """Return recorded data as stacked tensors."""
        print(f"\nStacking {len(self.embeddings)} recorded states...")
        
        return {
            'recorded_steps': torch.tensor(self.recorded_steps, dtype=torch.long),
            'embeddings': torch.stack(self.embeddings) if self.embeddings else torch.tensor([]),
            'grads': torch.stack(self.grads) if self.grads else torch.tensor([]),
            'momentum': torch.stack(self.momentum) if self.momentum else torch.tensor([]),
            'variance': torch.stack(self.variance) if self.variance else torch.tensor([]),
            'logits': torch.stack(self.logits) if self.logits else torch.tensor([]),
            'losses': torch.stack(self.losses) if self.losses else torch.tensor([]),
        }

print("✓ Recorder class defined")

✓ Recorder class defined


## Custom Trainer with Instrumentation

In [34]:
class InstrumentedTrainer(Trainer):
    def __init__(self, recorder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.recorder = recorder
        self.last_logits = None

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(**inputs)
        loss = outputs.loss
        self.last_logits = outputs.logits
        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        loss = super().training_step(model, inputs, num_items_in_batch)
        self.recorder.record_before_step(model, loss, self.last_logits)
        return loss

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, **kwargs):
        self.recorder.record_after_step(model, self.optimizer)
        super()._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, **kwargs)

print("✓ InstrumentedTrainer defined")

✓ InstrumentedTrainer defined


## Training Configuration

In [35]:
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

recorder = ComprehensiveRecorder(VOCAB_SIZE, HIDDEN_DIM, RECORD_EVERY_N_STEPS, starting_step=100)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    adam_beta1=ADAM_BETA1,
    adam_beta2=ADAM_BETA2,
    adam_epsilon=ADAM_EPSILON,
    optim="adamw_torch",
    logging_steps=10,
    save_steps=NUM_TRAIN_STEPS + 1,
    save_total_limit=0,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    bf16=True,
    seed=RANDOM_SEED,
    report_to="none",
    disable_tqdm=False,
)

trainer = InstrumentedTrainer(
    recorder=recorder,
    model=model,
    args=training_args,
    train_dataset=dataset,
)

print("\n✓ Trainer ready (continuing from checkpoint)")


✓ Trainer ready (continuing from checkpoint)


## Record Initial State (Step 100 from Checkpoint)

In [36]:
print()
recorder.record_initial_state(model, trainer.optimizer)


✓ Recorded initial state (step 100 from checkpoint, without optimizer state (will load on first step))


## Train (Continue for 100 More Steps)

In [37]:
print(f"\n{'='*80}")
print(f"STARTING WORDYBIRD 3 TRAINING (STEPS 101-200)")
print(f"{'='*80}")
print(f"\nContinuing from Wordybird 1 step 100...")
print(f"  Steps: 101-200 (100 additional steps)")
print(f"  Recording: every step")
print(f"\n{'='*80}\n")

start_time = time.time()
trainer.train()
elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"✓ Training complete")
print(f"  Elapsed time: {elapsed:.1f} seconds")
print(f"  Throughput: {NUM_TRAIN_STEPS / elapsed:.1f} steps/second")
print(f"{'='*80}")


STARTING WORDYBIRD 3 TRAINING (STEPS 101-200)

Continuing from Wordybird 1 step 100...
  Steps: 101-200 (100 additional steps)
  Recording: every step




`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
10,7.9254
20,7.7454
30,7.6548
40,7.6787
50,7.678
60,7.6898
70,7.6894
80,7.7074
90,7.6607
100,7.703


  Recorded step 110
  Recorded step 120
  Recorded step 130
  Recorded step 140
  Recorded step 150
  Recorded step 160
  Recorded step 170
  Recorded step 180
  Recorded step 190
  Recorded step 200

✓ Training complete
  Elapsed time: 14.0 seconds
  Throughput: 7.1 steps/second


## Save Recorded Data

In [38]:
print(f"\nPreparing data for save...\n")

recorded_data = recorder.get_data()

save_dict = {
    'recorded_steps': recorded_data['recorded_steps'],
    'embeddings': recorded_data['embeddings'],
    'grads': recorded_data['grads'],
    'momentum': recorded_data['momentum'],
    'variance': recorded_data['variance'],
    'logits': recorded_data['logits'],
    'losses': recorded_data['losses'],
    # Metadata
    'learning_rate': torch.tensor(LEARNING_RATE, dtype=torch.float32),
    'weight_decay': torch.tensor(WEIGHT_DECAY, dtype=torch.float32),
    'adam_beta1': torch.tensor(ADAM_BETA1, dtype=torch.float32),
    'adam_beta2': torch.tensor(ADAM_BETA2, dtype=torch.float32),
    'n_trained': torch.tensor(n_trained, dtype=torch.long),
    'n_untrained': torch.tensor(n_untrained, dtype=torch.long),
    'starting_step': torch.tensor(100, dtype=torch.long),  # Mark that we continued from 100
}

output_path = Path(OUTPUT_DIR) / OUTPUT_FILE

print(f"Saving to: {output_path}")

save_start = time.time()
save_file(save_dict, str(output_path))
save_elapsed = time.time() - save_start

file_size_mb = output_path.stat().st_size / 1e6

print(f"\n✓ Saved successfully")
print(f"  File: {output_path}")
print(f"  Size: {file_size_mb:.1f} MB")
print(f"  Save time: {save_elapsed:.1f} seconds")
print(f"  Recorded steps: {len(recorded_data['recorded_steps'])}")
print(f"  Step range: {recorded_data['recorded_steps'][0]} to {recorded_data['recorded_steps'][-1]}")


Preparing data for save...


Stacking 101 recorded states...
Saving to: ../tensors/Wordybird/1.12e_wordybird_3.safetensors

✓ Saved successfully
  File: ../tensors/Wordybird/1.12e_wordybird_3.safetensors
  Size: 2609.0 MB
  Save time: 0.8 seconds
  Recorded steps: 101
  Step range: 100 to 200


## Quick Verification

In [39]:
print(f"\n{'='*80}")
print(f"QUICK VERIFICATION")
print(f"{'='*80}\n")

embeddings = recorded_data['embeddings']
losses = recorded_data['losses']

print(f"Data shapes:")
print(f"  embeddings: {embeddings.shape}")
print(f"  losses: {losses.shape}")
print()

# Analyze untrained token movement
W_step100 = embeddings[0, untrained_indices].float()  # Initial (from checkpoint)
W_step200 = embeddings[-1, untrained_indices].float()  # Final

# Compute displacements
displacements = torch.norm(W_step200 - W_step100, dim=1)
max_displacement = displacements.max().item()
mean_displacement = displacements.mean().item()
median_displacement = displacements.median().item()

print(f"Untrained token displacement (steps 100 → 200):")
print(f"  Max: {max_displacement:.2e}")
print(f"  Mean: {mean_displacement:.2e}")
print(f"  Median: {median_displacement:.2e}")
print()

# Compare to WB1 step 0 → 100 for context
if mean_displacement < 1e-4:
    print(f"  → FROZEN (mean displacement < 1e-4)")
    print(f"  → Supports WB1 freeze hypothesis!")
else:
    print(f"  → STILL MOVING (mean displacement ≥ 1e-4)")
    print(f"  → Freeze continues past step 100")

print()
print(f"Loss trajectory:")
print(f"  Step 101: {losses[1].float().item():.4f}")
print(f"  Step 200: {losses[-1].float().item():.4f}")
print(f"  Change: {(losses[-1].float() - losses[1].float()).item():+.4f}")

print(f"\n{'='*80}")


QUICK VERIFICATION

Data shapes:
  embeddings: torch.Size([101, 50257, 64])
  losses: torch.Size([101])

Untrained token displacement (steps 100 → 200):
  Max: 2.55e-01
  Mean: 2.47e-01
  Median: 2.47e-01

  → STILL MOVING (mean displacement ≥ 1e-4)
  → Freeze continues past step 100

Loss trajectory:
  Step 101: 8.0625
  Step 200: 7.7500
  Change: -0.3125



## Summary

In [40]:
print(f"\n{'='*80}")
print(f"WORDYBIRD 3 COMPLETE")
print(f"{'='*80}\n")

print(f"Experiment: Continuation from WB1 checkpoint (step 100 → 200)")
print(f"  Continued for: {NUM_TRAIN_STEPS} additional steps")
print(f"  Data saved: {output_path}")
print(f"  Size: {file_size_mb:.1f} MB")
print()
print(f"Next steps:")
print(f"  1. Run 1.17b lattice hop analysis on WB3")
print(f"  2. Compare to WB1 and WB2 freeze patterns")
print(f"  3. Determine if WB1 was truly frozen or still freezing")

print(f"\n{'='*80}")


WORDYBIRD 3 COMPLETE

Experiment: Continuation from WB1 checkpoint (step 100 → 200)
  Continued for: 100 additional steps
  Data saved: ../tensors/Wordybird/1.12e_wordybird_3.safetensors
  Size: 2609.0 MB

Next steps:
  1. Run 1.17b lattice hop analysis on WB3
  2. Compare to WB1 and WB2 freeze patterns
  3. Determine if WB1 was truly frozen or still freezing

