# 08.2a: Embedding Evolution Training

**Train a tiny transformer and record embedding matrix evolution**

We're testing two hypotheses about how embedding matrices evolve during training:

## Hypothesis 1: Normal Initialization

Standard random initialization → soft explosion → centroid random-walks during training

**Predictions:**
- Token vectors start at different random locations
- Cloud gradually expands and drifts
- Centroid undergoes random walk (high-dimensional brownian motion)
- Dead tokens drift slightly but stay near origin

## Hypothesis 2: Qwen-Style Singular Initialization

All tokens start at ONE random point → violent explosion → random walk away from origin → black holes left behind

**Predictions:**
- All tokens start at identical location (singularity)
- Training causes rapid "supernova" expansion
- Active tokens random-walk to equilibrium
- Dead tokens stay frozen at initialization point ± quantization noise
- Final structure: tight black hole cluster + dispersed cloud ~0.2-0.4 units away

## Experimental Design

**Model:** Tiny GPT2 (2 layers, 2 heads, 64-dim hidden space)

**Tokenizer:** Byte-level (128 ASCII tokens)

**Training data:** The Great Gatsby (~50k words, pure ASCII)

**Dead tokens:** ~50 ASCII bytes never appear in corpus (~40% of vocab)

**Architecture:** Tied weights (embedding matrix = unembedding matrix, like Qwen)

**Data collection:** Save embedding matrix after EVERY training step

## Parameters

In [31]:
# Model architecture
VOCAB_SIZE = 128      # 128 for ASCII-only, 256 for full byte range
HIDDEN_DIM = 64       # Embedding dimension
N_LAYER = 2           # Number of transformer layers
N_HEAD = 2            # Number of attention heads
MAX_SEQ_LEN = 128     # Context window (tokens per training example)

# Initialization
INIT_MODE = "qwen"  # "normal" or "qwen"

# Training
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.01   # Restoring force toward origin
BATCH_SIZE = 16
NUM_TRAIN_STEPS = 5000

# Data
CORPUS_PATH = "../data/training_corpus.txt"
SNAPSHOT_DIR = f"../data/embeddings_{VOCAB_SIZE}vocab_{INIT_MODE}init"

RANDOM_SEED = 42

## Imports

In [32]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, TrainerCallback
from torch.utils.data import Dataset
import numpy as np
import os
import glob
from safetensors.torch import save_file

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

## Detect Hardware Accelerator

In [33]:
if torch.cuda.is_available():
    DEVICE = "cuda"
    print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    DEVICE = "mps"
    print("Using MPS (Apple Silicon)")
else:
    DEVICE = "cpu"
    print("Using CPU")

Using MPS (Apple Silicon)


## Load Training Corpus

In [34]:
print(f"Loading corpus from: {CORPUS_PATH}\n")

with open(CORPUS_PATH, 'r', encoding='ascii') as f:
    corpus_text = f.read()

# Convert to bytes and filter to vocab size
corpus_bytes = [b for b in corpus_text.encode('ascii') if b < VOCAB_SIZE]

print(f"✓ Corpus loaded")
print(f"Total bytes: {len(corpus_bytes):,}")
print(f"Vocabulary size: {VOCAB_SIZE}")

# Count unique bytes present
unique_bytes = set(corpus_bytes)
dead_tokens = VOCAB_SIZE - len(unique_bytes)

print(f"Unique bytes in corpus: {len(unique_bytes)}")
print(f"Dead tokens (never appear): {dead_tokens} ({100 * dead_tokens / VOCAB_SIZE:.1f}%)")

Loading corpus from: ../data/training_corpus.txt

✓ Corpus loaded
Total bytes: 265,905
Vocabulary size: 128
Unique bytes in corpus: 77
Dead tokens (never appear): 51 (39.8%)


## Create Dataset

PyTorch Dataset that chops the corpus into overlapping windows for next-token prediction training.

In [35]:
class ByteDataset(Dataset):
    """Dataset for byte-level language modeling.
    
    Returns overlapping sequences where:
    - input_ids: tokens [0, 1, 2, ..., N-1]
    - labels: tokens [1, 2, 3, ..., N] (shifted by 1 for next-token prediction)
    """
    
    def __init__(self, byte_sequence, max_seq_len):
        self.byte_sequence = byte_sequence
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        # Number of possible starting positions
        return max(0, len(self.byte_sequence) - self.max_seq_len)
    
    def __getitem__(self, idx):
        # Extract sequence of max_seq_len+1 bytes
        chunk = self.byte_sequence[idx : idx + self.max_seq_len + 1]
        
        # Input: first max_seq_len tokens
        # Target: shifted by 1 (next token prediction)
        input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
        labels = torch.tensor(chunk[1:], dtype=torch.long)
        
        return {
            'input_ids': input_ids,
            'labels': labels
        }

dataset = ByteDataset(corpus_bytes, MAX_SEQ_LEN)
print(f"\n✓ Dataset created")
print(f"Training examples: {len(dataset):,}")
print(f"Tokens per example: {MAX_SEQ_LEN}")


✓ Dataset created
Training examples: 265,777
Tokens per example: 128


## Create Model Configuration

In [36]:
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,  # No dropout (we want pure dynamics)
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,  # Tie embedding and unembedding matrices (like Qwen)
)

print("Model configuration:")
print(f"  Vocabulary: {config.vocab_size} tokens")
print(f"  Hidden dimension: {config.n_embd}")
print(f"  Layers: {config.n_layer}")
print(f"  Attention heads: {config.n_head}")
print(f"  Context window: {config.n_positions} tokens")
print(f"  Weight tying: {config.tie_word_embeddings}")

Model configuration:
  Vocabulary: 128 tokens
  Hidden dimension: 64
  Layers: 2
  Attention heads: 2
  Context window: 128 tokens
  Weight tying: True


## Initialize Model

In [37]:
model = GPT2LMHeadModel(config)
model = model.to(torch.bfloat16)

total_params = sum(p.numel() for p in model.parameters())
embedding_params = model.transformer.wte.weight.numel()

print(f"\n✓ Model initialized")
print(f"Total parameters: {total_params:,}")
print(f"Embedding parameters: {embedding_params:,} ({100 * embedding_params / total_params:.1f}% of total)")


✓ Model initialized
Total parameters: 116,480
Embedding parameters: 8,192 (7.0% of total)


## Apply Custom Initialization

Based on `INIT_MODE`, either use default PyTorch initialization (normal) or apply Qwen-style singular initialization.

In [38]:
if INIT_MODE == "qwen":
    print(f"\nApplying Qwen-style initialization (singular vector)...")
    
    with torch.no_grad():
        # Generate one random vector
        random_vector = torch.randn(HIDDEN_DIM)
        
        # Set ALL embedding vectors to this single vector
        # (lm_head is tied, so it automatically matches)
        model.transformer.wte.weight[:] = random_vector
    
    print(f"✓ All {VOCAB_SIZE} tokens initialized to same random vector")
    print(f"  Vector L2 norm: {random_vector.norm().item():.6f}")
    print(f"  Vector mean component: {random_vector.mean().item():.6e}")
    print(f"  (lm_head tied to wte, automatically matches)")
    
else:
    print(f"\nUsing normal initialization (default PyTorch)")
    
    # Show statistics of default initialization
    with torch.no_grad():
        norms = torch.norm(model.transformer.wte.weight, p=2, dim=1)
        print(f"  Token L2 norms: min={norms.min().item():.6f}, max={norms.max().item():.6f}, mean={norms.mean().item():.6f}")


Applying Qwen-style initialization (singular vector)...
✓ All 128 tokens initialized to same random vector
  Vector L2 norm: 8.201394
  Vector mean component: 2.535469e-01
  (lm_head tied to wte, automatically matches)


## Create Snapshot Directory

In [39]:
os.makedirs(SNAPSHOT_DIR, exist_ok=True)
print(f"\nSnapshots will be saved to: {SNAPSHOT_DIR}")


Snapshots will be saved to: ../data/embeddings_128vocab_qweninit


## Save Initial Embedding (Step 0)

In [40]:
# Save the embedding matrix BEFORE training starts
initial_embeddings = model.transformer.wte.weight.data.clone().cpu()

save_file(
    {'embeddings': initial_embeddings},
    os.path.join(SNAPSHOT_DIR, 'step_0000000.safetensors')
)

initial_norms = torch.norm(initial_embeddings, p=2, dim=1)
initial_centroid = initial_embeddings.mean(dim=0)
initial_centroid_norm = initial_centroid.norm().item()

print(f"✓ Saved initial embeddings (step 0)")
print(f"  Shape: {initial_embeddings.shape}")
print(f"  Token L2 norms: min={initial_norms.min().item():.6f}, max={initial_norms.max().item():.6f}, mean={initial_norms.mean().item():.6f}")
print(f"  Centroid L2 norm: {initial_centroid_norm:.6f}")

✓ Saved initial embeddings (step 0)
  Shape: torch.Size([128, 64])
  Token L2 norms: min=8.187500, max=8.187500, mean=8.187500
  Centroid L2 norm: 8.187500


## Define Snapshot Callback

Custom callback that saves the embedding matrix after every training step.

In [41]:
class EmbeddingSnapshotCallback(TrainerCallback):
    """Save embedding matrix after every training step."""
    
    def __init__(self, snapshot_dir):
        self.snapshot_dir = snapshot_dir
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        step = state.global_step
        
        # Extract embedding matrix and move to CPU for saving
        embeddings = model.transformer.wte.weight.data.clone().cpu()
        
        # Save with zero-padded step number
        filename = f"step_{step:07d}.safetensors"
        save_file(
            {'embeddings': embeddings},
            os.path.join(self.snapshot_dir, filename)
        )
        
        # Print progress every 100 steps
        if step % 100 == 0:
            centroid = embeddings.mean(dim=0)
            centroid_norm = centroid.norm().item()
            print(f"[Step {step:5d}] Saved snapshot | Centroid L2 norm: {centroid_norm:.6f}")
        
        return control

print("✓ Snapshot callback defined")

✓ Snapshot callback defined


## Configure Training

In [42]:
training_args = TrainingArguments(
    output_dir="./training_output",
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    logging_steps=100,
    save_steps=10000,  # We're saving embeddings manually, so don't need frequent full checkpoints
    save_total_limit=1,  # Only keep latest full checkpoint
    seed=RANDOM_SEED,
    dataloader_num_workers=0,
    use_cpu=(DEVICE == "cpu"),  # Only force CPU if we detected CPU
    # Trainer will auto-detect and use MPS or CUDA if available
    bf16=True,  # Use bfloat16 if supported
)

print("Training configuration:")
print(f"  Device: {DEVICE}")
print(f"  Steps: {training_args.max_steps:,}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Weight decay: {training_args.weight_decay}")

Training configuration:
  Device: mps
  Steps: 5,000
  Batch size: 16
  Learning rate: 0.001
  Weight decay: 0.01


## Create Trainer

In [43]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    callbacks=[EmbeddingSnapshotCallback(SNAPSHOT_DIR)],
)

print("✓ Trainer initialized with snapshot callback")

✓ Trainer initialized with snapshot callback


## Train!

This will take several minutes. The callback will save embedding snapshots after every step and print progress every 100 steps.

In [44]:
print(f"\n{'='*80}")
print(f"Starting training...")
print(f"{'='*80}\n")

trainer.train()

print(f"\n{'='*80}")
print(f"✓ Training complete!")
print(f"{'='*80}")


Starting training...





Step,Training Loss
100,3.4545
200,3.0939
300,3.0779
400,3.0836
500,3.0731
600,3.0816
700,3.0735
800,3.0737
900,3.0768
1000,3.0749


[Step   100] Saved snapshot | Centroid L2 norm: 8.187500
[Step   200] Saved snapshot | Centroid L2 norm: 8.187500
[Step   300] Saved snapshot | Centroid L2 norm: 8.187500
[Step   400] Saved snapshot | Centroid L2 norm: 8.187500
[Step   500] Saved snapshot | Centroid L2 norm: 8.187500
[Step   600] Saved snapshot | Centroid L2 norm: 8.187500
[Step   700] Saved snapshot | Centroid L2 norm: 8.187500
[Step   800] Saved snapshot | Centroid L2 norm: 8.187500
[Step   900] Saved snapshot | Centroid L2 norm: 8.187500
[Step  1000] Saved snapshot | Centroid L2 norm: 8.187500
[Step  1100] Saved snapshot | Centroid L2 norm: 8.187500
[Step  1200] Saved snapshot | Centroid L2 norm: 8.187500
[Step  1300] Saved snapshot | Centroid L2 norm: 8.187500
[Step  1400] Saved snapshot | Centroid L2 norm: 8.187500
[Step  1500] Saved snapshot | Centroid L2 norm: 8.187500
[Step  1600] Saved snapshot | Centroid L2 norm: 8.187500
[Step  1700] Saved snapshot | Centroid L2 norm: 8.187500
[Step  1800] Saved snapshot | C

## Verify Snapshots

In [45]:
snapshot_files = sorted(glob.glob(os.path.join(SNAPSHOT_DIR, "*.safetensors")))

print(f"\nSnapshot verification:")
print(f"  Total snapshots saved: {len(snapshot_files):,}")
print(f"  Expected: {NUM_TRAIN_STEPS + 1:,} (including step 0)")

if len(snapshot_files) == NUM_TRAIN_STEPS + 1:
    print(f"  ✓ All snapshots saved successfully")
else:
    print(f"  ⚠ Snapshot count mismatch!")

# Show first and last few filenames
print(f"\nFirst 5 snapshots:")
for f in snapshot_files[:5]:
    print(f"  {os.path.basename(f)}")

print(f"\nLast 5 snapshots:")
for f in snapshot_files[-5:]:
    print(f"  {os.path.basename(f)}")


Snapshot verification:
  Total snapshots saved: 5,001
  Expected: 5,001 (including step 0)
  ✓ All snapshots saved successfully

First 5 snapshots:
  step_0000000.safetensors
  step_0000001.safetensors
  step_0000002.safetensors
  step_0000003.safetensors
  step_0000004.safetensors

Last 5 snapshots:
  step_0004996.safetensors
  step_0004997.safetensors
  step_0004998.safetensors
  step_0004999.safetensors
  step_0005000.safetensors


## Analyze Final Embeddings

In [46]:
from safetensors.torch import load_file

# Load final embeddings
final_embeddings = load_file(snapshot_files[-1])['embeddings']

final_norms = torch.norm(final_embeddings, p=2, dim=1)
final_centroid = final_embeddings.mean(dim=0)
final_centroid_norm = final_centroid.norm().item()

print(f"\nFinal embeddings (step {NUM_TRAIN_STEPS}):")
print(f"  Token L2 norms: min={final_norms.min().item():.6f}, max={final_norms.max().item():.6f}, mean={final_norms.mean().item():.6f}")
print(f"  Centroid L2 norm: {final_centroid_norm:.6f}")

# Compare to initial
centroid_displacement = (final_centroid - initial_centroid).norm().item()
print(f"\nCentroid displacement: {centroid_displacement:.6f}")
print(f"  (Distance between initial and final centroid locations)")


Final embeddings (step 5000):
  Token L2 norms: min=8.187500, max=8.250000, mean=8.187500
  Centroid L2 norm: 8.187500

Centroid displacement: 0.365234
  (Distance between initial and final centroid locations)


## Summary

In [47]:
print(f"\n{'='*80}")
print(f"TRAINING COMPLETE")
print(f"{'='*80}")
print(f"Model: GPT2 ({N_LAYER} layers, {N_HEAD} heads, {HIDDEN_DIM}-dim)")
print(f"Vocabulary: {VOCAB_SIZE} tokens")
print(f"Dead tokens: {dead_tokens} ({100 * dead_tokens / VOCAB_SIZE:.1f}%)")
print(f"Initialization: {INIT_MODE}")
print(f"Training steps: {NUM_TRAIN_STEPS:,}")
print(f"Snapshots saved: {len(snapshot_files):,}")
print(f"\nOutput directory: {SNAPSHOT_DIR}")
print(f"\nInitial centroid norm: {initial_centroid_norm:.6f}")
print(f"Final centroid norm: {final_centroid_norm:.6f}")
print(f"Centroid displacement: {centroid_displacement:.6f}")
print(f"{'='*80}")


TRAINING COMPLETE
Model: GPT2 (2 layers, 2 heads, 64-dim)
Vocabulary: 128 tokens
Dead tokens: 51 (39.8%)
Initialization: qwen
Training steps: 5,000
Snapshots saved: 5,001

Output directory: ../data/embeddings_128vocab_qweninit

Initial centroid norm: 8.187500
Final centroid norm: 8.187500
Centroid displacement: 0.365234
