# 08.2a_h100: GPU-Optimized Embedding Evolution Training

**H100-optimized version with maximum throughput**

This notebook is optimized for high-end GPUs (H100, A100, etc.) with aggressive batching, pre-loaded data, and minimal I/O overhead.

## Optimizations Applied

1. **Pre-load corpus to GPU memory**: Entire dataset lives on GPU, zero CPU→GPU transfer per batch
2. **Large batch sizes**: 4096-8192 to saturate GPU compute
3. **Longer sequences**: 512 tokens (or more) for better GPU utilization
4. **Periodic checkpointing**: Save every 100 steps instead of every step
5. **Multi-worker data loading**: Parallel batch preparation
6. **TF32 + bfloat16**: Maximum throughput on Ampere/Hopper GPUs
7. **Gradient accumulation**: Optional for even larger effective batch sizes

## Expected Performance

- **H100**: 2000-5000 it/s
- **A100**: 1000-3000 it/s
- **Consumer GPUs (4090, etc.)**: 500-1500 it/s

## Parameters

In [None]:
# Model architecture
VOCAB_SIZE = 128      # 128 for ASCII-only
HIDDEN_DIM = 64       # Embedding dimension
N_LAYER = 2           # Number of transformer layers
N_HEAD = 2            # Number of attention heads
MAX_SEQ_LEN = 512     # Context window (512 for GPU efficiency)

# Initialization
INIT_MODE = "qwen"  # "normal" or "qwen"

# Training (GPU-optimized)
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.01
BATCH_SIZE = 4096              # Large batch for GPU saturation
GRADIENT_ACCUMULATION = 1      # Set to 2-4 for even larger effective batch
NUM_TRAIN_STEPS = 50000        # Longer training to see more dynamics

# Checkpointing
SAVE_EVERY_N_STEPS = 100       # Save snapshot frequency

# Data loading
NUM_WORKERS = 8                # Parallel data loading
PREFETCH_FACTOR = 4            # Batches to pre-fetch per worker

# Data
CORPUS_PATH = "../data/training_corpus.txt"
OUTPUT_DIR = f"../data/embeddings_{VOCAB_SIZE}vocab_{INIT_MODE}init_h100"
OUTPUT_FILE = f"embedding_evolution.safetensors"

RANDOM_SEED = 42

## Imports

In [None]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, TrainerCallback
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file
import time

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# Enable TF32 for maximum throughput on Ampere/Hopper
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

## Detect Hardware

In [None]:
if torch.cuda.is_available():
    DEVICE = "cuda"
    print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"CUDA Capability: {torch.cuda.get_device_capability(0)}")
    print(f"TF32 enabled: {torch.backends.cuda.matmul.allow_tf32}")
else:
    raise RuntimeError("This notebook requires CUDA. Use 08.2a for CPU/MPS training.")

## Load and Pre-Process Corpus

In [None]:
print(f"Loading corpus from: {CORPUS_PATH}\n")

with open(CORPUS_PATH, 'r', encoding='ascii') as f:
    corpus_text = f.read()

# Convert to bytes and filter to vocab size
corpus_bytes = [b for b in corpus_text.encode('ascii') if b < VOCAB_SIZE]

print(f"✓ Corpus loaded")
print(f"Total bytes: {len(corpus_bytes):,}")
print(f"Vocabulary size: {VOCAB_SIZE}")

# Count unique bytes present
unique_bytes = set(corpus_bytes)
dead_tokens = VOCAB_SIZE - len(unique_bytes)

print(f"Unique bytes in corpus: {len(unique_bytes)}")
print(f"Dead tokens (never appear): {dead_tokens} ({100 * dead_tokens / VOCAB_SIZE:.1f}%)")

# Pre-load to GPU as contiguous tensor
corpus_tensor = torch.tensor(corpus_bytes, dtype=torch.long, device=DEVICE)
print(f"\n✓ Corpus loaded to GPU memory")
print(f"  Memory usage: {corpus_tensor.numel() * corpus_tensor.element_size() / 1e6:.2f} MB")

## GPU-Optimized Dataset

Samples directly from GPU tensor—zero CPU→GPU transfer during training.

In [None]:
class GPUByteDataset(Dataset):
    """Dataset that samples from GPU tensor for zero-copy training.
    
    Returns overlapping sequences where:
    - input_ids: tokens [0, 1, 2, ..., N-1]
    - labels: tokens [1, 2, 3, ..., N] (shifted by 1 for next-token prediction)
    """
    
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        # Extract sequence directly from GPU tensor
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        
        # Input: first max_seq_len tokens
        # Target: shifted by 1 (next token prediction)
        input_ids = chunk[:-1]
        labels = chunk[1:]
        
        return {
            'input_ids': input_ids,
            'labels': labels
        }

dataset = GPUByteDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"✓ GPU dataset created")
print(f"Training examples: {len(dataset):,}")
print(f"Tokens per example: {MAX_SEQ_LEN}")
print(f"Total tokens per epoch: {len(dataset) * MAX_SEQ_LEN:,}")

## Create Model Configuration

In [None]:
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

print("Model configuration:")
print(f"  Vocabulary: {config.vocab_size} tokens")
print(f"  Hidden dimension: {config.n_embd}")
print(f"  Layers: {config.n_layer}")
print(f"  Attention heads: {config.n_head}")
print(f"  Context window: {config.n_positions} tokens")
print(f"  Weight tying: {config.tie_word_embeddings}")

## Initialize Model

In [None]:
model = GPT2LMHeadModel(config)
model = model.to(torch.bfloat16).to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
embedding_params = model.transformer.wte.weight.numel()

print(f"\n✓ Model initialized (bfloat16, {DEVICE})")
print(f"Total parameters: {total_params:,}")
print(f"Embedding parameters: {embedding_params:,} ({100 * embedding_params / total_params:.1f}% of total)")
print(f"Model memory: {sum(p.numel() * p.element_size() for p in model.parameters()) / 1e6:.2f} MB")

## Apply Custom Initialization

In [None]:
if INIT_MODE == "qwen":
    print(f"\nApplying Qwen-style initialization (singular unit vector)...")
    
    with torch.no_grad():
        # Generate one random unit vector
        random_vector = torch.randn(HIDDEN_DIM, device=DEVICE)
        random_vector = random_vector / random_vector.norm()
        
        # Set ALL embedding vectors to this single vector
        model.transformer.wte.weight[:] = random_vector
    
    print(f"✓ All {VOCAB_SIZE} tokens initialized to same random unit vector")
    print(f"  Vector L2 norm: {random_vector.norm().item():.6f}")
    print(f"  Vector mean component: {random_vector.mean().item():.6e}")
    
else:
    print(f"\nUsing normal initialization (default PyTorch)")
    
    with torch.no_grad():
        norms = torch.norm(model.transformer.wte.weight, p=2, dim=1)
        print(f"  Token L2 norms: min={norms.min().item():.6f}, max={norms.max().item():.6f}, mean={norms.mean().item():.6f}")

## Pre-allocate Embedding History Tensor

In [None]:
embedding_dtype = model.transformer.wte.weight.dtype

# Pre-allocate on CPU
embedding_history = torch.zeros(
    (NUM_TRAIN_STEPS + 1, VOCAB_SIZE, HIDDEN_DIM),
    dtype=embedding_dtype
)

# Save initial state (step 0)
embedding_history[0] = model.transformer.wte.weight.data.clone().cpu()

initial_norms = torch.norm(embedding_history[0], p=2, dim=1)
initial_centroid = embedding_history[0].mean(dim=0)
initial_centroid_norm = initial_centroid.norm().item()

print(f"✓ Pre-allocated embedding history tensor")
print(f"  Shape: {embedding_history.shape}")
print(f"  Dtype: {embedding_history.dtype}")
print(f"  Memory: {embedding_history.element_size() * embedding_history.numel() / 1e6:.1f} MB")
print(f"\nInitial embeddings (step 0):")
print(f"  Token L2 norms: min={initial_norms.min().item():.6f}, max={initial_norms.max().item():.6f}, mean={initial_norms.mean().item():.6f}")
print(f"  Centroid L2 norm: {initial_centroid_norm:.6f}")

## Optimized Snapshot Callback

Saves periodically (every N steps) instead of every step to minimize I/O overhead.

In [None]:
class OptimizedEmbeddingCallback(TrainerCallback):
    """Save embedding matrix periodically with timing metrics."""
    
    def __init__(self, embedding_history, output_dir, output_file, save_every_n):
        self.embedding_history = embedding_history
        self.output_dir = output_dir
        self.output_file = output_file
        self.output_path = Path(output_dir) / output_file
        self.save_every_n = save_every_n
        self.last_print_time = time.time()
        self.steps_since_print = 0
        
        Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        step = state.global_step
        
        # Always store in memory (cheap)
        embeddings = model.transformer.wte.weight.data.clone().cpu()
        self.embedding_history[step] = embeddings
        
        # Save to disk periodically
        should_save = (step % self.save_every_n == 0) or (step == args.max_steps)
        
        if should_save:
            save_file(
                {'embedding_history': self.embedding_history[:step+1]},
                self.output_path
            )
        
        # Print throughput every 100 steps
        self.steps_since_print += 1
        if step % 100 == 0 and step > 0:
            elapsed = time.time() - self.last_print_time
            throughput = self.steps_since_print / elapsed
            
            centroid = embeddings.mean(dim=0)
            centroid_norm = centroid.norm().item()
            
            saved_marker = "[SAVED]" if should_save else ""
            print(f"[Step {step:6d}] {throughput:6.1f} it/s | Centroid: {centroid_norm:.6f} {saved_marker}")
            
            self.last_print_time = time.time()
            self.steps_since_print = 0
        
        return control

print(f"✓ Optimized callback defined (save every {SAVE_EVERY_N_STEPS} steps)")

## Configure Training

In [None]:
training_args = TrainingArguments(
    output_dir="./training_output",
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    logging_steps=100,
    save_steps=NUM_TRAIN_STEPS + 1,  # Don't save model checkpoints
    save_total_limit=0,
    seed=RANDOM_SEED,
    
    # Data loading optimization
    dataloader_num_workers=NUM_WORKERS,
    dataloader_prefetch_factor=PREFETCH_FACTOR,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=(NUM_WORKERS > 0),
    
    # Precision
    bf16=True,
    tf32=True,  # Enable TF32 on Ampere/Hopper
    
    # Disable unnecessary features
    report_to="none",
    disable_tqdm=False,
)

effective_batch_size = BATCH_SIZE * GRADIENT_ACCUMULATION

print("Training configuration:")
print(f"  Device: {DEVICE}")
print(f"  Steps: {training_args.max_steps:,}")
print(f"  Batch size: {BATCH_SIZE} × {GRADIENT_ACCUMULATION} = {effective_batch_size}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Weight decay: {training_args.weight_decay}")
print(f"  Precision: bfloat16 + TF32")
print(f"  Data workers: {NUM_WORKERS}")
print(f"  Prefetch factor: {PREFETCH_FACTOR}")

## Create Trainer

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    callbacks=[OptimizedEmbeddingCallback(
        embedding_history, 
        OUTPUT_DIR, 
        OUTPUT_FILE,
        SAVE_EVERY_N_STEPS
    )],
)

print("✓ Trainer initialized")

## Train!

Should see 2000-5000 it/s on H100.

In [None]:
print(f"\n{'='*80}")
print(f"Starting training...")
print(f"{'='*80}\n")

start_time = time.time()
trainer.train()
elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"✓ Training complete!")
print(f"Total time: {elapsed:.1f}s ({elapsed/60:.1f} min)")
print(f"Average throughput: {NUM_TRAIN_STEPS / elapsed:.1f} it/s")
print(f"{'='*80}")

## Final Save and Analysis

In [None]:
# Ensure final state is saved
output_path = Path(OUTPUT_DIR) / OUTPUT_FILE
save_file(
    {'embedding_history': embedding_history},
    output_path
)

# Analyze final embeddings
final_embeddings = embedding_history[-1]
final_norms = torch.norm(final_embeddings, p=2, dim=1)
final_centroid = final_embeddings.mean(dim=0)
final_centroid_norm = final_centroid.norm().item()
centroid_displacement = (final_centroid - initial_centroid).norm().item()

print(f"\n{'='*80}")
print(f"TRAINING SUMMARY")
print(f"{'='*80}")
print(f"Model: GPT2 ({N_LAYER} layers, {N_HEAD} heads, {HIDDEN_DIM}-dim)")
print(f"Vocabulary: {VOCAB_SIZE} tokens")
print(f"Dead tokens: {dead_tokens} ({100 * dead_tokens / VOCAB_SIZE:.1f}%)")
print(f"Initialization: {INIT_MODE}")
print(f"Training steps: {NUM_TRAIN_STEPS:,}")
print(f"Sequence length: {MAX_SEQ_LEN}")
print(f"Batch size: {effective_batch_size}")
print(f"Precision: bfloat16")
print(f"\nOutput file: {output_path}")
print(f"  Shape: {embedding_history.shape}")
print(f"  Size: {output_path.stat().st_size / 1e6:.1f} MB")
print(f"\nInitial centroid norm: {initial_centroid_norm:.6f}")
print(f"Final centroid norm: {final_centroid_norm:.6f}")
print(f"Centroid displacement: {centroid_displacement:.6f}")
print(f"\nFinal token norms:")
print(f"  Min: {final_norms.min().item():.6f}")
print(f"  Max: {final_norms.max().item():.6f}")
print(f"  Mean: {final_norms.mean().item():.6f}")
print(f"  Std: {final_norms.std().item():.6f}")
print(f"{'='*80}")