# 1.20h: Flannel 7 - Single Deterministic Run with Gradients

**Purpose:** Create canonical reference trajectory for Flannel analysis.

## What This Records

Single training run (seed 42) for 500 steps, recording:
- **W**: Embedding matrix at every step
- **∇W**: Gradients at every step (the force field driving token motion)
- **Losses**: Training loss at every step

This gives us both the trajectory (W) and the forces (∇W) that cause it.

## Parameters

In [1]:
# === BATCH EXPERIMENT CONFIG ===
NUM_RUNS = 1           # Single canonical run
INIT_SEED = 42         # Seed for creating initial W
BASE_TRAIN_SEED = 42   # Training seed

# === RECORDING CONFIG ===
RECORD_CONFIG = {
    'W': True,
    'grads': True,      # NEW: Record gradients
    'momentum': False,
    'variance': False,
    'logits': False,
    'losses': True,
}

# === MODEL ARCHITECTURE ===
VOCAB_SIZE = 10000
HIDDEN_DIM = 64
N_LAYER = 2
N_HEAD = 2
MAX_SEQ_LEN = 128

# === TRAINING CONFIG ===
BATCH_SIZE = 32
NUM_TRAIN_STEPS = 500  # NEW: Reduced from 1000 to 500
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.0

# Optimizer: Adam
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.999
ADAM_EPSILON = 1e-8

# Initialization
INIT_SCALE = 0.02  # N(0, 0.02)

# === DATA PATHS ===
TOKENIZER_PATH = "../data/flannel_tokenizer_chars.json"
CORPUS_PATH = "../data/flannel_model_corpus.txt"
TOKEN_MASK_PATH = "../tensors/Flannel/live_dead_tokens.safetensors"
OUTPUT_DIR = "../tensors/Flannel"
OUTPUT_FILE = "1.20h_flannel_7.safetensors"

print("✓ Parameters set")

✓ Parameters set


## Imports

In [2]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from tokenizers import Tokenizer
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file, load_file
import time

print("✓ Imports complete")

✓ Imports complete


## Memory & Disk Requirements

In [3]:
print(f"\n{'='*80}")
print(f"MEMORY & DISK REQUIREMENTS")
print(f"{'='*80}\n")

bytes_per_element = 2  # bfloat16

# Calculate recording size
tensor_sizes = {}
if RECORD_CONFIG['W']:
    tensor_sizes['W'] = NUM_RUNS * (NUM_TRAIN_STEPS+1) * VOCAB_SIZE * HIDDEN_DIM * bytes_per_element
if RECORD_CONFIG['grads']:
    tensor_sizes['grads'] = NUM_RUNS * (NUM_TRAIN_STEPS+1) * VOCAB_SIZE * HIDDEN_DIM * bytes_per_element
if RECORD_CONFIG['losses']:
    tensor_sizes['losses'] = NUM_RUNS * (NUM_TRAIN_STEPS+1) * bytes_per_element

total_recorded = sum(tensor_sizes.values())

# Model memory
embedding_params = VOCAB_SIZE * HIDDEN_DIM
params_per_layer = 12 * HIDDEN_DIM**2
transformer_params = N_LAYER * params_per_layer
total_model_params = embedding_params + transformer_params
model_memory = total_model_params * bytes_per_element
optimizer_memory = 2 * total_model_params * 4

peak_ram = total_recorded + model_memory + optimizer_memory

print(f"Experiment: Flannel 7 - Single canonical run")
print(f"  Initialization seed: {INIT_SEED}")
print(f"  Training seed:       {BASE_TRAIN_SEED}")
print(f"  Steps:               {NUM_TRAIN_STEPS:,}")
print()
print(f"Recording: {', '.join([k for k, v in RECORD_CONFIG.items() if v])}")
for name, size in tensor_sizes.items():
    print(f"  {name:12} {size/1e9:.2f} GB")
print(f"  {'Total data:':12} {total_recorded/1e9:.2f} GB")
print()
print(f"Model parameters: {total_model_params:,}")
print(f"  Model (bf16):     {model_memory/1e9:.2f} GB")
print(f"  Optimizer (fp32): {optimizer_memory/1e9:.2f} GB")
print()
print(f"{'─'*80}")
print(f"PEAK RAM:     {peak_ram/1e9:.2f} GB")
print(f"DISK NEEDED:  {total_recorded/1e9:.2f} GB")
print(f"{'─'*80}")

if peak_ram <= 24e9:
    print(f"\n✓ Resources within budget\n")
else:
    print(f"\n⚠️  WARNING: Exceeds 24 GB RAM budget!\n")

print(f"{'='*80}\n")


MEMORY & DISK REQUIREMENTS

Experiment: Flannel 7 - Single canonical run
  Initialization seed: 42
  Training seed:       42
  Steps:               500

Recording: W, grads, losses
  W            0.64 GB
  grads        0.64 GB
  losses       0.00 GB
  Total data:  1.28 GB

Model parameters: 738,304
  Model (bf16):     0.00 GB
  Optimizer (fp32): 0.01 GB

────────────────────────────────────────────────────────────────────────────────
PEAK RAM:     1.29 GB
DISK NEEDED:  1.28 GB
────────────────────────────────────────────────────────────────────────────────

✓ Resources within budget




## Device Detection

In [4]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Data

In [5]:
# Tokenizer
print(f"Loading tokenizer: {TOKENIZER_PATH}")
tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))
print(f"  ✓ Vocabulary: {tokenizer.get_vocab_size():,} tokens\n")

# Corpus
print(f"Loading corpus: {CORPUS_PATH}")
with open(CORPUS_PATH, 'r', encoding='utf-8') as f:
    corpus_text = f.read()
encoding = tokenizer.encode(corpus_text)
tokens = encoding.ids
corpus_tensor = torch.tensor(tokens, dtype=torch.long, device=device)
print(f"  ✓ Tokens: {len(tokens):,}\n")

# Token masks
print(f"Loading token masks: {TOKEN_MASK_PATH}")
mask_data = load_file(TOKEN_MASK_PATH)
dead_indices = mask_data['dead_indices']
n_dead = mask_data['dead_mask'].sum().item()
n_live = mask_data['live_mask'].sum().item()
print(f"  ✓ Live: {n_live:,} | Dead: {n_dead:,}")

Loading tokenizer: ../data/flannel_tokenizer_chars.json
  ✓ Vocabulary: 10,000 tokens

Loading corpus: ../data/flannel_model_corpus.txt
  ✓ Tokens: 1,371,328

Loading token masks: ../tensors/Flannel/live_dead_tokens.safetensors
  ✓ Live: 6,301 | Dead: 3,699


## Dataset

In [6]:
class TokenDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = TokenDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"\n✓ Dataset: {len(dataset):,} examples")


✓ Dataset: 1,371,200 examples


## Create Initial Embedding Matrix

In [7]:
print(f"\nCreating initial embedding matrix (seed={INIT_SEED})...\n")

# Set seed for initialization
torch.manual_seed(INIT_SEED)
np.random.seed(INIT_SEED)

# Create initial W in float32, then convert to bfloat16
W_initial_f32 = torch.randn(VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32) * INIT_SCALE
W_initial = W_initial_f32.to(torch.bfloat16)

print(f"  Shape: {tuple(W_initial.shape)}")
print(f"  Dtype: {W_initial.dtype}")
print(f"  Mean:  {W_initial.float().mean():.6f}")
print(f"  Std:   {W_initial.float().std():.6f}")
print(f"\n✓ Initial W created")


Creating initial embedding matrix (seed=42)...

  Shape: (10000, 64)
  Dtype: torch.bfloat16
  Mean:  -0.000041
  Std:   0.020019

✓ Initial W created


## Pre-allocate Recording Tensors

In [8]:
print("\nPre-allocating recording tensors...\n")

tensors = {}

if RECORD_CONFIG['W']:
    shape = (NUM_RUNS, NUM_TRAIN_STEPS+1, VOCAB_SIZE, HIDDEN_DIM)
    tensors['W'] = torch.zeros(shape, dtype=torch.bfloat16)
    print(f"  W:        {shape}")

if RECORD_CONFIG['grads']:
    shape = (NUM_RUNS, NUM_TRAIN_STEPS+1, VOCAB_SIZE, HIDDEN_DIM)
    tensors['grad_W'] = torch.zeros(shape, dtype=torch.bfloat16)
    print(f"  grad_W:   {shape}")

if RECORD_CONFIG['losses']:
    shape = (NUM_RUNS, NUM_TRAIN_STEPS+1)
    tensors['losses'] = torch.full(shape, float('nan'), dtype=torch.bfloat16)
    print(f"  losses:   {shape}")

print(f"\n✓ All tensors allocated on CPU")


Pre-allocating recording tensors...

  W:        (1, 501, 10000, 64)
  grad_W:   (1, 501, 10000, 64)
  losses:   (1, 501)

✓ All tensors allocated on CPU


## Batch Recorder

In [9]:
class BatchRecorder:
    """Records data directly into pre-allocated tensors."""
    
    def __init__(self, tensors, record_config, run_idx):
        self.tensors = tensors
        self.config = record_config
        self.run_idx = run_idx
        self.current_step = 0
        self.recorded_initial = False
        self.loss_value = None
    
    def record_initial_state(self, model, optimizer):
        """Record step 0."""
        if not self.recorded_initial:
            t = 0
            if self.config['W']:
                self.tensors['W'][self.run_idx, t] = model.transformer.wte.weight.data.clone().cpu().bfloat16()
            if self.config['grads']:
                # No gradient at t=0
                self.tensors['grad_W'][self.run_idx, t] = 0
            self.recorded_initial = True
            self.current_step = 1
            print(f"    ✓ Recorded initial state (t=0)")
    
    def record_before_step(self, model, loss, logits):
        """Capture data after backward, before optimizer step."""
        if self.config['losses']:
            self.loss_value = loss.item()
        
        # Record gradient BEFORE optimizer step
        if self.config['grads']:
            t = self.current_step
            if t <= self.tensors['grad_W'].shape[1] - 1:
                if model.transformer.wte.weight.grad is not None:
                    self.tensors['grad_W'][self.run_idx, t] = model.transformer.wte.weight.grad.clone().cpu().bfloat16()
    
    def record_after_step(self, model, optimizer):
        """Record data after optimizer step."""
        t = self.current_step
        
        if t > self.tensors['W'].shape[1] - 1 if 'W' in self.tensors else float('inf'):
            return
        
        if self.config['W']:
            self.tensors['W'][self.run_idx, t] = model.transformer.wte.weight.data.clone().cpu().bfloat16()
        
        if self.config['losses'] and self.loss_value is not None:
            self.tensors['losses'][self.run_idx, t] = self.loss_value
            self.loss_value = None
        
        if t % 100 == 0:
            print(f"    Step {t}")
        
        self.current_step += 1

print("✓ Recorder class defined")

✓ Recorder class defined


## Instrumented Trainer

In [10]:
class InstrumentedTrainer(Trainer):
    def __init__(self, recorder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.recorder = recorder
        self.last_logits = None

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(**inputs)
        loss = outputs.loss
        self.last_logits = outputs.logits
        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        loss = super().training_step(model, inputs, num_items_in_batch)
        self.recorder.record_before_step(model, loss, self.last_logits)
        return loss

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, **kwargs):
        self.recorder.record_after_step(model, self.optimizer)
        super()._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, **kwargs)

print("✓ InstrumentedTrainer defined")

✓ InstrumentedTrainer defined


## Training Loop

In [11]:
print(f"\n{'='*80}")
print(f"FLANNEL 7: CANONICAL TRAINING RUN")
print(f"{'='*80}\n")

print(f"Configuration:")
print(f"  Steps:               {NUM_TRAIN_STEPS:,}")
print(f"  Initialization seed: {INIT_SEED}")
print(f"  Training seed:       {BASE_TRAIN_SEED}")
print(f"  Recording:           {', '.join([k for k, v in RECORD_CONFIG.items() if v])}")
print(f"\n{'='*80}\n")

experiment_start = time.time()

run_idx = 0
train_seed = BASE_TRAIN_SEED

# Set training seed
torch.manual_seed(train_seed)
np.random.seed(train_seed)

# Create model
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

model = GPT2LMHeadModel(config).to(torch.bfloat16).to(device)

# Initialize with W_initial
with torch.no_grad():
    model.transformer.wte.weight[:] = W_initial.to(device)

print(f"  ✓ Model initialized (seed={INIT_SEED})")
print(f"  ✓ Training seed set to {train_seed}")

# Create recorder
recorder = BatchRecorder(tensors, RECORD_CONFIG, run_idx)

# Training args
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    adam_beta1=ADAM_BETA1,
    adam_beta2=ADAM_BETA2,
    adam_epsilon=ADAM_EPSILON,
    optim="adamw_torch",
    logging_steps=1000,
    save_steps=NUM_TRAIN_STEPS + 1,
    save_total_limit=0,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    bf16=True,
    seed=train_seed,
    report_to="none",
    disable_tqdm=True,
)

trainer = InstrumentedTrainer(
    recorder=recorder,
    model=model,
    args=training_args,
    train_dataset=dataset,
)

# Record initial state
recorder.record_initial_state(model, trainer.optimizer)

# Train
print(f"  Training...")
run_start = time.time()
trainer.train()
run_elapsed = time.time() - run_start

print(f"\n  ✓ Training complete ({run_elapsed:.1f}s)")

experiment_elapsed = time.time() - experiment_start

print(f"\n{'='*80}")
print(f"✓ Flannel 7 complete")
print(f"  Total time: {experiment_elapsed:.1f}s ({experiment_elapsed/60:.1f} minutes)")
print(f"{'='*80}")


FLANNEL 7: CANONICAL TRAINING RUN

Configuration:
  Steps:               500
  Initialization seed: 42
  Training seed:       42
  Recording:           W, grads, losses


  ✓ Model initialized (seed=42)
  ✓ Training seed set to 42
    ✓ Recorded initial state (t=0)
  Training...


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


    Step 100
    Step 200
    Step 300
    Step 400
    Step 500
{'train_runtime': 13.4044, 'train_samples_per_second': 1193.641, 'train_steps_per_second': 37.301, 'train_loss': 7.2028173828125, 'epoch': 0.011668611435239206}

  ✓ Training complete (13.5s)

✓ Flannel 7 complete
  Total time: 13.5s (0.2 minutes)


## Save Data

In [12]:
print(f"\nSaving data...\n")

# Build save dictionary
save_dict = {
    # Metadata
    'n_runs': torch.tensor(NUM_RUNS, dtype=torch.long),
    'init_seed': torch.tensor(INIT_SEED, dtype=torch.long),
    'train_seed': torch.tensor(BASE_TRAIN_SEED, dtype=torch.long),
    'n_steps': torch.tensor(NUM_TRAIN_STEPS, dtype=torch.long),
    'n_live': torch.tensor(n_live, dtype=torch.long),
    'n_dead': torch.tensor(n_dead, dtype=torch.long),
    'vocab_size': torch.tensor(VOCAB_SIZE, dtype=torch.long),
    'hidden_dim': torch.tensor(HIDDEN_DIM, dtype=torch.long),
    'init_scale': torch.tensor(INIT_SCALE, dtype=torch.float32),
    'learning_rate': torch.tensor(LEARNING_RATE, dtype=torch.float32),
    'weight_decay': torch.tensor(WEIGHT_DECAY, dtype=torch.float32),
    'adam_beta1': torch.tensor(ADAM_BETA1, dtype=torch.float32),
    'adam_beta2': torch.tensor(ADAM_BETA2, dtype=torch.float32),
    # Initial W
    'W_initial': W_initial,
    # Record config
    'recorded_W': torch.tensor(RECORD_CONFIG['W'], dtype=torch.bool),
    'recorded_grads': torch.tensor(RECORD_CONFIG['grads'], dtype=torch.bool),
    'recorded_losses': torch.tensor(RECORD_CONFIG['losses'], dtype=torch.bool),
}

# Add recorded tensors
for name, tensor in tensors.items():
    save_dict[name] = tensor
    print(f"  {name:12} {str(tuple(tensor.shape)):30}")

# Save
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
output_path = Path(OUTPUT_DIR) / OUTPUT_FILE

print(f"\nSaving to: {output_path}\n")

save_start = time.time()
save_file(save_dict, str(output_path))
save_elapsed = time.time() - save_start

file_size_mb = output_path.stat().st_size / 1e6
file_size_gb = file_size_mb / 1000

print(f"✓ Saved successfully")
print(f"  File: {output_path.name}")
print(f"  Size: {file_size_mb:.1f} MB ({file_size_gb:.2f} GB)")
print(f"  Save time: {save_elapsed:.1f}s")


Saving data...

  W            (1, 501, 10000, 64)           
  grad_W       (1, 501, 10000, 64)           
  losses       (1, 501)                      

Saving to: ../tensors/Flannel/1.20h_flannel_7.safetensors

✓ Saved successfully
  File: 1.20h_flannel_7.safetensors
  Size: 1283.8 MB (1.28 GB)
  Save time: 0.3s


## Summary

In [13]:
print(f"\n{'='*80}")
print(f"FLANNEL 7 COMPLETE")
print(f"{'='*80}\n")

print(f"Single deterministic training run:")
print(f"  Seed:       {INIT_SEED} (both init and training)")
print(f"  Steps:      {NUM_TRAIN_STEPS:,}")
print(f"  Recorded:   {', '.join([k for k, v in RECORD_CONFIG.items() if v])}")
print()
print(f"Data saved: {output_path}")
print(f"  Size: {file_size_gb:.2f} GB")
print(f"  Training time: {experiment_elapsed/60:.1f} minutes")
print()
print(f"Dead tokens: {n_dead:,} / {VOCAB_SIZE:,} ({100*n_dead/VOCAB_SIZE:.1f}%)")
print()
print(f"This is now the canonical Flannel reference trajectory.")
print(f"Use this for all single-run analysis going forward.")

print(f"\n{'='*80}")


FLANNEL 7 COMPLETE

Single deterministic training run:
  Seed:       42 (both init and training)
  Steps:      500
  Recorded:   W, grads, losses

Data saved: ../tensors/Flannel/1.20h_flannel_7.safetensors
  Size: 1.28 GB
  Training time: 0.2 minutes

Dead tokens: 3,699 / 10,000 (37.0%)

This is now the canonical Flannel reference trajectory.
Use this for all single-run analysis going forward.

