# 1.20c: Flannel 3 - Ten Big Bangs (Full Dataset)

**Experiment:** Exact replica of Flannel 1 (1.20a) run 10 times sequentially with different seeds.

## Motivation

Test reproducibility of the five epochs observed in Flannel 1:
1. **The Inhale** (t=0–6): Slight contraction
2. **The Sneeze** (t=6–100): Explosive outward expansion
3. **Deceleration** (t=100–300): Rapid slowing
4. **Re-expansion** (t=300–400): Linear second growth
5. **Fimbulwinter** (t=400+): Quantization freeze

## Design

**Exact copy of 1.20a, but:**
- Run 10 times sequentially
- Seeds: 42, 43, 44, ..., 51
- Full ComprehensiveRecorder for each run (embeddings, grads, momentum, variance, logits, losses)
- Save all runs in one file

**Expected output:**
- File size: ~12 GB (10 runs × 1.2 GB per run)
- Data structure: Separate tensors for each run, combined in one safetensors file
- Total runtime: ~5 minutes (10 runs × 30s each)

## Parameters

In [1]:
# Model architecture (identical to Flannel 1)
VOCAB_SIZE = 10000
HIDDEN_DIM = 64
N_LAYER = 2
N_HEAD = 2
MAX_SEQ_LEN = 128

# Training (identical to Flannel 1)
BATCH_SIZE = 32
NUM_TRAIN_STEPS = 1000
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.0

# Optimizer: Adam
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.999
ADAM_EPSILON = 1e-8

# Initialization
INIT_SCALE = 0.02  # N(0, 0.02)

# Batch experiment
NUM_RUNS = 10
BASE_SEED = 42  # Seeds: 42, 43, 44, ..., 51

# Data
TOKENIZER_PATH = "../data/flannel_tokenizer_chars.json"
CORPUS_PATH = "../data/flannel_model_corpus.txt"
TOKEN_MASK_PATH = "../tensors/Flannel/live_dead_tokens.safetensors"
OUTPUT_DIR = "../tensors/Flannel"
OUTPUT_FILE = "1.20c_flannel_3.safetensors"

# Recording
RECORD_EVERY_N_STEPS = 1  # Every step

print("✓ Parameters set")

✓ Parameters set


## Imports

In [2]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from tokenizers import Tokenizer
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file, load_file
import time

print("✓ Imports complete")

✓ Imports complete


## Device Detection

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Tokenizer

In [4]:
print(f"Loading Flannel tokenizer: {TOKENIZER_PATH}\n")

tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))
vocab = tokenizer.get_vocab()

print(f"✓ Loaded Flannel tokenizer")
print(f"  Vocabulary size: {len(vocab):,} tokens")

Loading Flannel tokenizer: ../data/flannel_tokenizer_chars.json

✓ Loaded Flannel tokenizer
  Vocabulary size: 10,000 tokens


## Load Corpus and Tokenize

In [5]:
print(f"\nLoading corpus: {CORPUS_PATH}\n")

with open(CORPUS_PATH, 'r', encoding='utf-8') as f:
    corpus_text = f.read()

corpus_bytes = len(corpus_text.encode('utf-8'))
corpus_mb = corpus_bytes / (1024 * 1024)

print(f"✓ Loaded corpus")
print(f"  Size: {corpus_mb:.2f} MB")
print(f"  Characters: {len(corpus_text):,}")
print()

# Tokenize
print("Tokenizing corpus...\n")
encoding = tokenizer.encode(corpus_text)
tokens = encoding.ids

print(f"✓ Tokenized")
print(f"  Tokens: {len(tokens):,}")
print()

# Pre-load to device
corpus_tensor = torch.tensor(tokens, dtype=torch.long, device=device)
print(f"✓ Corpus on device: {device}")


Loading corpus: ../data/flannel_model_corpus.txt

✓ Loaded corpus
  Size: 5.01 MB
  Characters: 5,225,690

Tokenizing corpus...

✓ Tokenized
  Tokens: 1,371,328

✓ Corpus on device: mps


## Load Token Masks

In [6]:
print(f"\nLoading token masks: {TOKEN_MASK_PATH}\n")

mask_data = load_file(TOKEN_MASK_PATH)
live_mask = mask_data['live_mask']
dead_mask = mask_data['dead_mask']
live_indices = mask_data['live_indices']
dead_indices = mask_data['dead_indices']

n_live = live_mask.sum().item()
n_dead = dead_mask.sum().item()

print(f"✓ Loaded token masks")
print(f"  Live tokens: {n_live:,} ({100*n_live/VOCAB_SIZE:.1f}%)")
print(f"  Dead tokens: {n_dead:,} ({100*n_dead/VOCAB_SIZE:.1f}%)")


Loading token masks: ../tensors/Flannel/live_dead_tokens.safetensors

✓ Loaded token masks
  Live tokens: 6,301 (63.0%)
  Dead tokens: 3,699 (37.0%)


## Dataset

In [7]:
class TokenDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = TokenDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"\n✓ Dataset: {len(dataset):,} examples")


✓ Dataset: 1,371,200 examples


## Comprehensive Recorder (from 1.20a)

In [8]:
class ComprehensiveRecorder:
    """Records embeddings, gradients, optimizer state, logits, loss at every step in bfloat16."""
    
    def __init__(self, vocab_size, hidden_dim, record_every_n):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.record_every_n = record_every_n
        
        # Storage (lists of tensors, keep in RAM)
        self.recorded_steps = []
        self.embeddings = []      # [n_recorded, vocab_size, hidden_dim]
        self.grads = []           # [n_recorded, vocab_size, hidden_dim]
        self.momentum = []        # [n_recorded, vocab_size, hidden_dim]
        self.variance = []        # [n_recorded, vocab_size, hidden_dim]
        self.logits = []          # [n_recorded, vocab_size]
        self.losses = []          # [n_recorded]
        
        # Temporary storage
        self.current_step = 0
        self.recorded_initial = False
        self.grad_before = None
        self.loss_value = None
        self.logits_sample = None
    
    def record_initial_state(self, model, optimizer):
        """Record step 0: initial state before training."""
        if not self.recorded_initial:
            W = model.transformer.wte.weight.data.clone().cpu().bfloat16()
            
            # Step 0: no gradients, no optimizer state yet (zeros)
            self.recorded_steps.append(0)
            self.embeddings.append(W)
            self.grads.append(torch.zeros_like(W))
            self.momentum.append(torch.zeros_like(W))
            self.variance.append(torch.zeros_like(W))
            self.logits.append(torch.zeros(self.vocab_size, dtype=torch.bfloat16))
            self.losses.append(torch.tensor(float('nan'), dtype=torch.bfloat16))  # No loss yet
            
            self.recorded_initial = True
            self.current_step = 1
            
            print(f"    ✓ Recorded initial state (step 0)")
    
    def record_before_step(self, model, loss, logits):
        """Call after forward/backward, before optimizer step."""
        if self.current_step % self.record_every_n == 0:
            # Capture gradients in bfloat16
            if model.transformer.wte.weight.grad is not None:
                self.grad_before = model.transformer.wte.weight.grad.clone().cpu().bfloat16()
            else:
                self.grad_before = torch.zeros(self.vocab_size, self.hidden_dim, dtype=torch.bfloat16)
            
            # Capture loss
            self.loss_value = loss.item()
            
            # Capture logits from first sequence, last position in bfloat16
            self.logits_sample = logits[0, -1, :].detach().cpu().bfloat16()
    
    def record_after_step(self, model, optimizer):
        """Call after optimizer step."""
        if self.current_step % self.record_every_n == 0:
            if self.grad_before is not None and self.loss_value is not None:
                # Capture embeddings in bfloat16
                W = model.transformer.wte.weight.data.clone().cpu().bfloat16()

                # Capture optimizer state (Adam momentum and variance)
                param = model.transformer.wte.weight
                if param in optimizer.state:
                    state = optimizer.state[param]
                    # Get state tensors if they exist, convert to bfloat16
                    mom_src = state.get('exp_avg', None)
                    var_src = state.get('exp_avg_sq', None)
                    mom = mom_src.clone().cpu().bfloat16() if mom_src is not None else torch.zeros_like(W)
                    var = var_src.clone().cpu().bfloat16() if var_src is not None else torch.zeros_like(W)
                else:
                    mom = torch.zeros_like(W)
                    var = torch.zeros_like(W)

                # Store everything
                self.recorded_steps.append(self.current_step)
                self.embeddings.append(W)
                self.grads.append(self.grad_before)
                self.momentum.append(mom)
                self.variance.append(var)
                self.logits.append(self.logits_sample)
                self.losses.append(torch.tensor(self.loss_value, dtype=torch.bfloat16))

                # Clear temp storage
                self.grad_before = None
                self.loss_value = None
                self.logits_sample = None
                
                # Progress indicator every 100 steps
                if self.current_step % 100 == 0:
                    print(f"    Step {self.current_step}")

        self.current_step += 1
    
    def get_data(self):
        """Return recorded data as stacked tensors."""
        print(f"    Stacking {len(self.embeddings)} recorded states...")
        
        return {
            'recorded_steps': torch.tensor(self.recorded_steps, dtype=torch.long),
            'embeddings': torch.stack(self.embeddings) if self.embeddings else torch.tensor([]),
            'grads': torch.stack(self.grads) if self.grads else torch.tensor([]),
            'momentum': torch.stack(self.momentum) if self.momentum else torch.tensor([]),
            'variance': torch.stack(self.variance) if self.variance else torch.tensor([]),
            'logits': torch.stack(self.logits) if self.logits else torch.tensor([]),
            'losses': torch.stack(self.losses) if self.losses else torch.tensor([]),
        }

print("✓ Recorder class defined")

✓ Recorder class defined


## Custom Trainer with Instrumentation

In [9]:
class InstrumentedTrainer(Trainer):
    def __init__(self, recorder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.recorder = recorder
        self.last_logits = None

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Override to capture logits."""
        outputs = model(**inputs)
        loss = outputs.loss
        
        # Store logits for recorder
        self.last_logits = outputs.logits
        
        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        """Override to inject recording."""
        # Standard forward + backward
        loss = super().training_step(model, inputs, num_items_in_batch)
        
        # Record BEFORE optimizer step
        self.recorder.record_before_step(model, loss, self.last_logits)
        
        return loss

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, **kwargs):
        """Override to record AFTER optimizer step."""
        # Record AFTER optimizer updates parameters
        self.recorder.record_after_step(model, self.optimizer)
        
        # Call parent
        super()._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, **kwargs)

print("✓ InstrumentedTrainer defined")

✓ InstrumentedTrainer defined


## Batch Training Loop

In [10]:
print(f"\n{'='*80}")
print(f"FLANNEL 3: TEN BIG BANGS (FULL DATASET)")
print(f"{'='*80}\n")

print(f"Configuration:")
print(f"  Runs: {NUM_RUNS}")
print(f"  Steps per run: {NUM_TRAIN_STEPS:,}")
print(f"  Seeds: {BASE_SEED}–{BASE_SEED + NUM_RUNS - 1}")
print(f"  Recording: full ComprehensiveRecorder (embeddings, grads, momentum, variance, logits, losses)")
print(f"  Expected file size: ~12 GB")
print(f"\n{'='*80}\n")

# Storage for all runs
all_runs_data = []

experiment_start = time.time()

for run_idx in range(NUM_RUNS):
    seed = BASE_SEED + run_idx
    
    print(f"\n{'='*80}")
    print(f"RUN {run_idx + 1}/{NUM_RUNS} (seed={seed})")
    print(f"{'='*80}\n")
    
    # Set seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Create fresh model
    config = GPT2Config(
        vocab_size=VOCAB_SIZE,
        n_positions=MAX_SEQ_LEN,
        n_embd=HIDDEN_DIM,
        n_layer=N_LAYER,
        n_head=N_HEAD,
        resid_pdrop=0.0,
        embd_pdrop=0.0,
        attn_pdrop=0.0,
        tie_word_embeddings=True,
    )
    
    model = GPT2LMHeadModel(config)
    model = model.to(torch.bfloat16).to(device)
    
    # Initialize embeddings
    init_f32 = torch.randn(VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32, device=device) * INIT_SCALE
    init_bf16 = init_f32.to(torch.bfloat16)
    
    with torch.no_grad():
        model.transformer.wte.weight[:] = init_bf16
    
    print(f"  ✓ Model initialized (seed={seed})")
    
    # Create recorder
    recorder = ComprehensiveRecorder(VOCAB_SIZE, HIDDEN_DIM, RECORD_EVERY_N_STEPS)
    
    # Training args
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        max_steps=NUM_TRAIN_STEPS,
        per_device_train_batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        adam_beta1=ADAM_BETA1,
        adam_beta2=ADAM_BETA2,
        adam_epsilon=ADAM_EPSILON,
        optim="adamw_torch",
        logging_steps=1000,  # Minimal logging
        save_steps=NUM_TRAIN_STEPS + 1,
        save_total_limit=0,
        dataloader_num_workers=0,
        dataloader_pin_memory=False,
        bf16=True,
        seed=seed,
        report_to="none",
        disable_tqdm=True,  # Quieter output
    )
    
    trainer = InstrumentedTrainer(
        recorder=recorder,
        model=model,
        args=training_args,
        train_dataset=dataset,
    )
    
    # Record initial state
    recorder.record_initial_state(model, trainer.optimizer)
    
    # Train
    print(f"  Training...")
    run_start = time.time()
    trainer.train()
    run_elapsed = time.time() - run_start
    
    print(f"\n  ✓ Run {run_idx + 1} complete ({run_elapsed:.1f}s)")
    
    # Collect data
    run_data = recorder.get_data()
    all_runs_data.append(run_data)
    
    # Clean up
    del model
    del trainer
    del recorder
    
    if device == 'mps':
        torch.mps.empty_cache()
    elif device == 'cuda':
        torch.cuda.empty_cache()

experiment_elapsed = time.time() - experiment_start

print(f"\n{'='*80}")
print(f"✓ All {NUM_RUNS} runs complete")
print(f"  Total time: {experiment_elapsed:.1f}s ({experiment_elapsed/60:.1f} minutes)")
print(f"  Average per run: {experiment_elapsed/NUM_RUNS:.1f}s")
print(f"{'='*80}")


FLANNEL 3: TEN BIG BANGS (FULL DATASET)

Configuration:
  Runs: 10
  Steps per run: 1,000
  Seeds: 42–51
  Recording: full ComprehensiveRecorder (embeddings, grads, momentum, variance, logits, losses)
  Expected file size: ~12 GB



RUN 1/10 (seed=42)

  ✓ Model initialized (seed=42)
    ✓ Recorded initial state (step 0)
  Training...


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


    Step 100
    Step 200
    Step 300
    Step 400
    Step 500
    Step 600
    Step 700
    Step 800
    Step 900
    Step 1000
{'loss': 7.1299, 'grad_norm': 0.2177734375, 'learning_rate': 1e-06, 'epoch': 0.023337222870478413}
{'train_runtime': 25.9053, 'train_samples_per_second': 1235.269, 'train_steps_per_second': 38.602, 'train_loss': 7.12989404296875, 'epoch': 0.023337222870478413}

  ✓ Run 1 complete (26.0s)
    Stacking 1001 recorded states...

RUN 2/10 (seed=43)

  ✓ Model initialized (seed=43)
    ✓ Recorded initial state (step 0)
  Training...
    Step 100
    Step 200
    Step 300
    Step 400
    Step 500
    Step 600
    Step 700
    Step 800
    Step 900
    Step 1000
{'loss': 7.1002, 'grad_norm': 0.275390625, 'learning_rate': 1e-06, 'epoch': 0.023337222870478413}
{'train_runtime': 25.2983, 'train_samples_per_second': 1264.909, 'train_steps_per_second': 39.528, 'train_loss': 7.10022265625, 'epoch': 0.023337222870478413}

  ✓ Run 2 complete (25.4s)
    Stacking 1001 reco

## Combine and Save All Runs

In [11]:
print(f"\nCombining {NUM_RUNS} runs into single file...\n")

# Build save dictionary with run-indexed keys
save_dict = {
    'n_runs': torch.tensor(NUM_RUNS, dtype=torch.long),
    'base_seed': torch.tensor(BASE_SEED, dtype=torch.long),
    'n_steps': torch.tensor(NUM_TRAIN_STEPS, dtype=torch.long),
    'n_live': torch.tensor(n_live, dtype=torch.long),
    'n_dead': torch.tensor(n_dead, dtype=torch.long),
    'init_scale': torch.tensor(INIT_SCALE, dtype=torch.float32),
    'learning_rate': torch.tensor(LEARNING_RATE, dtype=torch.float32),
    'weight_decay': torch.tensor(WEIGHT_DECAY, dtype=torch.float32),
    'adam_beta1': torch.tensor(ADAM_BETA1, dtype=torch.float32),
    'adam_beta2': torch.tensor(ADAM_BETA2, dtype=torch.float32),
}

# Add each run's data with indexed keys
for run_idx, run_data in enumerate(all_runs_data):
    save_dict[f'run_{run_idx}_recorded_steps'] = run_data['recorded_steps']
    save_dict[f'run_{run_idx}_embeddings'] = run_data['embeddings']
    save_dict[f'run_{run_idx}_grads'] = run_data['grads']
    save_dict[f'run_{run_idx}_momentum'] = run_data['momentum']
    save_dict[f'run_{run_idx}_variance'] = run_data['variance']
    save_dict[f'run_{run_idx}_logits'] = run_data['logits']
    save_dict[f'run_{run_idx}_losses'] = run_data['losses']
    print(f"  Run {run_idx}: {run_data['embeddings'].shape}")

print()

# Save
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
output_path = Path(OUTPUT_DIR) / OUTPUT_FILE

print(f"Saving to: {output_path}\n")

save_start = time.time()
save_file(save_dict, str(output_path))
save_elapsed = time.time() - save_start

file_size_mb = output_path.stat().st_size / 1e6
file_size_gb = file_size_mb / 1000

print(f"✓ Saved successfully")
print(f"  File: {output_path}")
print(f"  Size: {file_size_mb:.1f} MB ({file_size_gb:.2f} GB)")
print(f"  Save time: {save_elapsed:.1f}s")


Combining 10 runs into single file...

  Run 0: torch.Size([1001, 10000, 64])
  Run 1: torch.Size([1001, 10000, 64])
  Run 2: torch.Size([1001, 10000, 64])
  Run 3: torch.Size([1001, 10000, 64])
  Run 4: torch.Size([1001, 10000, 64])
  Run 5: torch.Size([1001, 10000, 64])
  Run 6: torch.Size([1001, 10000, 64])
  Run 7: torch.Size([1001, 10000, 64])
  Run 8: torch.Size([1001, 10000, 64])
  Run 9: torch.Size([1001, 10000, 64])

Saving to: ../tensors/Flannel/1.20c_flannel_3.safetensors



KeyboardInterrupt: 

## Summary

In [None]:
print(f"\n{'='*80}")
print(f"FLANNEL 3 COMPLETE")
print(f"{'='*80}\n")

print(f"Experiment: Ten independent training runs (full dataset)")
print(f"  Runs: {NUM_RUNS}")
print(f"  Steps per run: {NUM_TRAIN_STEPS:,}")
print(f"  Seeds: {BASE_SEED}–{BASE_SEED + NUM_RUNS - 1}")
print()
print(f"Data saved: {output_path}")
print(f"  Size: {file_size_gb:.2f} GB")
print(f"  Total experiment time: {experiment_elapsed/60:.1f} minutes")
print()
print(f"Next steps:")
print(f"  1. Run 1.22d_flannel_3_prelim (updated to load this new format)")
print(f"  2. Verify run 0 (seed=42) matches Flannel 1 exactly")
print(f"  3. Analyze epoch reproducibility across runs")
print(f"  4. Deeper statistical mechanics analysis")

print(f"\n{'='*80}")