# 1.20a: Flannel 1 - Baseline 1,000-Step Run

**Experiment:** Train tiny GPT-2 model with Flannel tokenizer (10k vocab, ~3.7k dead tokens) on English-only corpus.

## The Flannel Experiment

**Goal:** Study token dynamics under controlled conditions with engineered dead tokens.

**Design:**
1. **Tokenizer:** Trained on 80% English + 20% Thai mixed corpus → 10,000 tokens
2. **Model training:** Pure English corpus (5MB FineWeb) → Thai tokens never appear
3. **Result:** ~1,272 Thai tokens are dead by construction + ~2,384 other dead tokens (CJK, symbols, etc.)

This gives us:
- **6,301 live tokens** (63%) that receive gradient updates
- **3,699 dead tokens** (37%) that should stay frozen near initialization

## Why Flannel?

We want to understand how the "spongecrystal" formed in Qwen 3 4B—a dense cluster of 2,100 Thai tokens, many collapsed to identical vectors. 

**Hypothesis:** Dead tokens (never trained) experience only thermal jitter and may cluster at lattice scale due to bfloat16 quantization.

**Flannel tests this** by creating dead tokens we can track from initialization through training.

## Flannel 1 Parameters

**Baseline run:**
- **1,000 training steps** (~3 epochs through corpus)
- **Batch size: 32** (test M4 Pro performance)
- **Record every step** (full trajectory data)
- Same architecture as Wordybird/Lil Gatsby: 2 layers, 2 heads, 64D
- bfloat16-native training

## Parameters

In [1]:
# Model architecture
VOCAB_SIZE = 10000  # Flannel
HIDDEN_DIM = 64
N_LAYER = 2
N_HEAD = 2
MAX_SEQ_LEN = 128

# Training
BATCH_SIZE = 32
NUM_TRAIN_STEPS = 1000
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.0

# Optimizer: Adam
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.999
ADAM_EPSILON = 1e-8

# Initialization (bfloat16 native)
INIT_SCALE = 0.02  # N(0, 0.02)

# Data
TOKENIZER_PATH = "../data/flannel_tokenizer_chars.json"
CORPUS_PATH = "../data/flannel_model_corpus.txt"
TOKEN_MASK_PATH = "../tensors/Flannel/live_dead_tokens.safetensors"
OUTPUT_DIR = "../tensors/Flannel"
OUTPUT_FILE = "1.20a_flannel_1.safetensors"

# Instrumentation
RECORD_EVERY_N_STEPS = 1  # Record every step

RANDOM_SEED = 42

print("✓ Parameters set")

✓ Parameters set


## Imports

In [2]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from tokenizers import Tokenizer
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file, load_file
import time

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print("✓ Imports complete")

✓ Imports complete


## Device Detection

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Tokenizer

In [4]:
print(f"Loading Flannel tokenizer: {TOKENIZER_PATH}\n")

tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))
vocab = tokenizer.get_vocab()

print(f"✓ Loaded Flannel tokenizer")
print(f"  Vocabulary size: {len(vocab):,} tokens")

Loading Flannel tokenizer: ../data/flannel_tokenizer_chars.json

✓ Loaded Flannel tokenizer
  Vocabulary size: 10,000 tokens


## Load Corpus and Tokenize

In [5]:
print(f"\nLoading corpus: {CORPUS_PATH}\n")

with open(CORPUS_PATH, 'r', encoding='utf-8') as f:
    corpus_text = f.read()

corpus_bytes = len(corpus_text.encode('utf-8'))
corpus_mb = corpus_bytes / (1024 * 1024)

print(f"✓ Loaded corpus")
print(f"  Size: {corpus_mb:.2f} MB")
print(f"  Characters: {len(corpus_text):,}")
print()

# Tokenize
print("Tokenizing corpus...\n")
encoding = tokenizer.encode(corpus_text)
tokens = encoding.ids

print(f"✓ Tokenized")
print(f"  Tokens: {len(tokens):,}")
print()

# Calculate expected epochs
tokens_per_step = BATCH_SIZE * MAX_SEQ_LEN
steps_per_epoch = len(tokens) / tokens_per_step
expected_epochs = NUM_TRAIN_STEPS / steps_per_epoch

print(f"Training coverage:")
print(f"  Tokens per step: {tokens_per_step:,}")
print(f"  Steps per epoch: {steps_per_epoch:.1f}")
print(f"  Expected epochs in {NUM_TRAIN_STEPS:,} steps: {expected_epochs:.2f}")
print()

# Pre-load to device
corpus_tensor = torch.tensor(tokens, dtype=torch.long, device=device)
print(f"✓ Corpus on device: {device}")


Loading corpus: ../data/flannel_model_corpus.txt

✓ Loaded corpus
  Size: 5.01 MB
  Characters: 5,225,690

Tokenizing corpus...

✓ Tokenized
  Tokens: 1,371,328

Training coverage:
  Tokens per step: 4,096
  Steps per epoch: 334.8
  Expected epochs in 1,000 steps: 2.99

✓ Corpus on device: mps


## Load Token Masks

In [6]:
print(f"\nLoading token masks: {TOKEN_MASK_PATH}\n")

mask_data = load_file(TOKEN_MASK_PATH)
live_mask = mask_data['live_mask']
dead_mask = mask_data['dead_mask']
live_indices = mask_data['live_indices']
dead_indices = mask_data['dead_indices']
token_occurrence_counts = mask_data['token_occurrence_counts']

n_live = live_mask.sum().item()
n_dead = dead_mask.sum().item()

print(f"✓ Loaded token masks")
print(f"  Live tokens: {n_live:,} ({100*n_live/VOCAB_SIZE:.1f}%)")
print(f"  Dead tokens: {n_dead:,} ({100*n_dead/VOCAB_SIZE:.1f}%)")
print()

# Token frequency stats for live tokens
live_counts = token_occurrence_counts[live_indices]
print(f"Live token frequency:")
print(f"  Min occurrences: {live_counts.min().item():,}")
print(f"  Max occurrences: {live_counts.max().item():,}")
print(f"  Mean occurrences: {live_counts.float().mean().item():.1f}")
print(f"  Median occurrences: {live_counts.float().median().item():.0f}")


Loading token masks: ../tensors/Flannel/live_dead_tokens.safetensors

✓ Loaded token masks
  Live tokens: 6,301 (63.0%)
  Dead tokens: 3,699 (37.0%)

Live token frequency:
  Min occurrences: 1
  Max occurrences: 45,630
  Mean occurrences: 217.6
  Median occurrences: 68


## Dataset

In [7]:
class TokenDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = TokenDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"\n✓ Dataset: {len(dataset):,} examples")


✓ Dataset: 1,371,200 examples


## Model

In [8]:
print(f"\nCreating model...\n")

config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

model = GPT2LMHeadModel(config)
model = model.to(torch.bfloat16).to(device)

total_params = sum(p.numel() for p in model.parameters())
embedding_params = model.transformer.wte.weight.numel()

print(f"✓ Model created")
print(f"  Total parameters: {total_params:,}")
print(f"  Embedding parameters (E+W): {embedding_params:,}")
print(f"  Other parameters: {total_params - embedding_params:,}")


Creating model...

✓ Model created
  Total parameters: 748,288
  Embedding parameters (E+W): 640,000
  Other parameters: 108,288


## Initialization (bfloat16 Native)

In [9]:
print(f"\n{'='*80}")
print(f"INITIALIZING: N(0, {INIT_SCALE}) bfloat16-native")
print(f"{'='*80}\n")

torch.manual_seed(RANDOM_SEED)

# Standard GPT-2 init: each embedding drawn independently from N(0, 0.02)
# Generate in float32, immediately convert to bfloat16
init_f32 = torch.randn(VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32, device=device) * INIT_SCALE
init_bf16 = init_f32.to(torch.bfloat16)

print(f"Initialization: N(0, {INIT_SCALE})")
print(f"  Shape: {init_bf16.shape}")
print(f"  Each token initialized independently")
print()

# Assign to model
with torch.no_grad():
    model.transformer.wte.weight[:] = init_bf16

print(f"✓ Initialized embeddings (pure bfloat16)")
print(f"  Shape: {model.transformer.wte.weight.shape}")
print(f"  Dtype: {model.transformer.wte.weight.dtype}")
print()

# Verify initialization stats for dead tokens
W_check = model.transformer.wte.weight.cpu().float()
W_dead = W_check[dead_indices]

centroid = W_dead.mean(dim=0)
centroid_norm = torch.norm(centroid).item()
radii = torch.norm(W_dead - centroid, dim=1)
mean_radius = radii.mean().item()
max_radius = radii.max().item()

print(f"Initial dead token statistics ({n_dead:,} tokens):")
print(f"  Centroid norm: {centroid_norm:.6f}")
print(f"  Mean radius from centroid: {mean_radius:.6f}")
print(f"  Max radius from centroid: {max_radius:.6f}")
print(f"  Bounding hypersphere volume ∝ R^{HIDDEN_DIM} = {max_radius**HIDDEN_DIM:.2e}")
print()
print(f"  ✓ Dead tokens distributed in hypersphere (standard init)")
print(f"\n{'='*80}\n")


INITIALIZING: N(0, 0.02) bfloat16-native

Initialization: N(0, 0.02)
  Shape: torch.Size([10000, 64])
  Each token initialized independently

✓ Initialized embeddings (pure bfloat16)
  Shape: torch.Size([10000, 64])
  Dtype: torch.bfloat16

Initial dead token statistics (3,699 tokens):
  Centroid norm: 0.002689
  Mean radius from centroid: 0.159301
  Max radius from centroid: 0.206368
  Bounding hypersphere volume ∝ R^64 = 1.37e-44

  ✓ Dead tokens distributed in hypersphere (standard init)




## Comprehensive Recorder

In [10]:
class ComprehensiveRecorder:
    """Records embeddings, gradients, optimizer state, logits, loss at every step in bfloat16."""
    
    def __init__(self, vocab_size, hidden_dim, record_every_n):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.record_every_n = record_every_n
        
        # Storage (lists of tensors, keep in RAM)
        self.recorded_steps = []
        self.embeddings = []      # [n_recorded, vocab_size, hidden_dim]
        self.grads = []           # [n_recorded, vocab_size, hidden_dim]
        self.momentum = []        # [n_recorded, vocab_size, hidden_dim]
        self.variance = []        # [n_recorded, vocab_size, hidden_dim]
        self.logits = []          # [n_recorded, vocab_size]
        self.losses = []          # [n_recorded]
        
        # Temporary storage
        self.current_step = 0
        self.recorded_initial = False
        self.grad_before = None
        self.loss_value = None
        self.logits_sample = None
    
    def record_initial_state(self, model, optimizer):
        """Record step 0: initial state before training."""
        if not self.recorded_initial:
            W = model.transformer.wte.weight.data.clone().cpu().bfloat16()
            
            # Step 0: no gradients, no optimizer state yet (zeros)
            self.recorded_steps.append(0)
            self.embeddings.append(W)
            self.grads.append(torch.zeros_like(W))
            self.momentum.append(torch.zeros_like(W))
            self.variance.append(torch.zeros_like(W))
            self.logits.append(torch.zeros(self.vocab_size, dtype=torch.bfloat16))
            self.losses.append(torch.tensor(float('nan'), dtype=torch.bfloat16))  # No loss yet
            
            self.recorded_initial = True
            self.current_step = 1
            
            print(f"✓ Recorded initial state (step 0)")
    
    def record_before_step(self, model, loss, logits):
        """Call after forward/backward, before optimizer step."""
        if self.current_step % self.record_every_n == 0:
            # Capture gradients in bfloat16
            if model.transformer.wte.weight.grad is not None:
                self.grad_before = model.transformer.wte.weight.grad.clone().cpu().bfloat16()
            else:
                self.grad_before = torch.zeros(self.vocab_size, self.hidden_dim, dtype=torch.bfloat16)
            
            # Capture loss
            self.loss_value = loss.item()
            
            # Capture logits from first sequence, last position in bfloat16
            self.logits_sample = logits[0, -1, :].detach().cpu().bfloat16()
    
    def record_after_step(self, model, optimizer):
        """Call after optimizer step."""
        if self.current_step % self.record_every_n == 0:
            if self.grad_before is not None and self.loss_value is not None:
                # Capture embeddings in bfloat16
                W = model.transformer.wte.weight.data.clone().cpu().bfloat16()

                # Capture optimizer state (Adam momentum and variance)
                param = model.transformer.wte.weight
                if param in optimizer.state:
                    state = optimizer.state[param]
                    # Get state tensors if they exist, convert to bfloat16
                    mom_src = state.get('exp_avg', None)
                    var_src = state.get('exp_avg_sq', None)
                    mom = mom_src.clone().cpu().bfloat16() if mom_src is not None else torch.zeros_like(W)
                    var = var_src.clone().cpu().bfloat16() if var_src is not None else torch.zeros_like(W)
                else:
                    mom = torch.zeros_like(W)
                    var = torch.zeros_like(W)

                # Store everything
                self.recorded_steps.append(self.current_step)
                self.embeddings.append(W)
                self.grads.append(self.grad_before)
                self.momentum.append(mom)
                self.variance.append(var)
                self.logits.append(self.logits_sample)
                self.losses.append(torch.tensor(self.loss_value, dtype=torch.bfloat16))

                # Clear temp storage
                self.grad_before = None
                self.loss_value = None
                self.logits_sample = None
                
                # Progress indicator every 50 steps
                if self.current_step % 50 == 0:
                    print(f"  Recorded step {self.current_step}")

        self.current_step += 1
    
    def get_data(self):
        """Return recorded data as stacked tensors."""
        print(f"\nStacking {len(self.embeddings)} recorded states...")
        
        return {
            'recorded_steps': torch.tensor(self.recorded_steps, dtype=torch.long),
            'embeddings': torch.stack(self.embeddings) if self.embeddings else torch.tensor([]),
            'grads': torch.stack(self.grads) if self.grads else torch.tensor([]),
            'momentum': torch.stack(self.momentum) if self.momentum else torch.tensor([]),
            'variance': torch.stack(self.variance) if self.variance else torch.tensor([]),
            'logits': torch.stack(self.logits) if self.logits else torch.tensor([]),
            'losses': torch.stack(self.losses) if self.losses else torch.tensor([]),
        }

print("✓ Recorder class defined")

✓ Recorder class defined


## Custom Trainer with Instrumentation

In [11]:
class InstrumentedTrainer(Trainer):
    def __init__(self, recorder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.recorder = recorder
        self.last_logits = None

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Override to capture logits."""
        outputs = model(**inputs)
        loss = outputs.loss
        
        # Store logits for recorder
        self.last_logits = outputs.logits
        
        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        """Override to inject recording."""
        # Standard forward + backward
        loss = super().training_step(model, inputs, num_items_in_batch)
        
        # Record BEFORE optimizer step
        self.recorder.record_before_step(model, loss, self.last_logits)
        
        return loss

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, **kwargs):
        """Override to record AFTER optimizer step."""
        # Record AFTER optimizer updates parameters
        self.recorder.record_after_step(model, self.optimizer)
        
        # Call parent
        super()._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, **kwargs)

print("✓ InstrumentedTrainer defined")

✓ InstrumentedTrainer defined


## Training Configuration

In [12]:
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

recorder = ComprehensiveRecorder(VOCAB_SIZE, HIDDEN_DIM, RECORD_EVERY_N_STEPS)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    adam_beta1=ADAM_BETA1,
    adam_beta2=ADAM_BETA2,
    adam_epsilon=ADAM_EPSILON,
    optim="adamw_torch",
    logging_steps=50,
    save_steps=NUM_TRAIN_STEPS + 1,  # Don't save checkpoints
    save_total_limit=0,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    bf16=True,  # Native bfloat16 training
    seed=RANDOM_SEED,
    report_to="none",
    disable_tqdm=False,
)

trainer = InstrumentedTrainer(
    recorder=recorder,
    model=model,
    args=training_args,
    train_dataset=dataset,
)

print(f"\n✓ Trainer ready (Adam, bf16=True, batch_size={BATCH_SIZE})")


✓ Trainer ready (Adam, bf16=True, batch_size=32)


## Record Initial State

In [13]:
print()
recorder.record_initial_state(model, trainer.optimizer)


✓ Recorded initial state (step 0)


## Train

**1,000 steps should take ~2-3 minutes on M4 Pro.**

In [14]:
print(f"\n{'='*80}")
print(f"STARTING FLANNEL 1 TRAINING")
print(f"{'='*80}")
print(f"\nConfiguration:")
print(f"  Vocabulary: {VOCAB_SIZE:,} tokens (Flannel)")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Live tokens: {n_live:,} ({100*n_live/VOCAB_SIZE:.1f}%)")
print(f"  Dead tokens: {n_dead:,} ({100*n_dead/VOCAB_SIZE:.1f}%)")
print()
print(f"  Initialization: N(0, {INIT_SCALE}) bfloat16-native")
print(f"  Optimizer: Adam (lr={LEARNING_RATE})")
print(f"  Precision: bfloat16 (native)")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Steps: {NUM_TRAIN_STEPS:,} (~{expected_epochs:.1f} epochs)")
print(f"  Recording: every step")
print(f"\n{'='*80}\n")

start_time = time.time()
trainer.train()
elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"✓ Training complete")
print(f"  Elapsed time: {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)")
print(f"  Throughput: {NUM_TRAIN_STEPS / elapsed:.1f} steps/second")
print(f"{'='*80}")


STARTING FLANNEL 1 TRAINING

Configuration:
  Vocabulary: 10,000 tokens (Flannel)
  Hidden dim: 64
  Live tokens: 6,301 (63.0%)
  Dead tokens: 3,699 (37.0%)

  Initialization: N(0, 0.02) bfloat16-native
  Optimizer: Adam (lr=0.001)
  Precision: bfloat16 (native)
  Batch size: 32
  Steps: 1,000 (~3.0 epochs)
  Recording: every step




`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
50,8.0405
100,7.1565
150,7.1153
200,7.1052
250,7.1085
300,7.0918
350,7.0964
400,7.0901
450,7.0873
500,7.0755


  Recorded step 50
  Recorded step 100
  Recorded step 150
  Recorded step 200
  Recorded step 250
  Recorded step 300
  Recorded step 350
  Recorded step 400
  Recorded step 450
  Recorded step 500
  Recorded step 550
  Recorded step 600
  Recorded step 650
  Recorded step 700
  Recorded step 750
  Recorded step 800
  Recorded step 850
  Recorded step 900
  Recorded step 950
  Recorded step 1000

✓ Training complete
  Elapsed time: 26.7 seconds (0.4 minutes)
  Throughput: 37.4 steps/second


## Save Recorded Data

In [15]:
print(f"\nPreparing data for save...\n")

recorded_data = recorder.get_data()

save_dict = {
    'recorded_steps': recorded_data['recorded_steps'],
    'embeddings': recorded_data['embeddings'],
    'grads': recorded_data['grads'],
    'momentum': recorded_data['momentum'],
    'variance': recorded_data['variance'],
    'logits': recorded_data['logits'],
    'losses': recorded_data['losses'],
    # Metadata
    'init_scale': torch.tensor(INIT_SCALE, dtype=torch.float32),
    'learning_rate': torch.tensor(LEARNING_RATE, dtype=torch.float32),
    'weight_decay': torch.tensor(WEIGHT_DECAY, dtype=torch.float32),
    'adam_beta1': torch.tensor(ADAM_BETA1, dtype=torch.float32),
    'adam_beta2': torch.tensor(ADAM_BETA2, dtype=torch.float32),
    'n_live': torch.tensor(n_live, dtype=torch.long),
    'n_dead': torch.tensor(n_dead, dtype=torch.long),
}

output_path = Path(OUTPUT_DIR) / OUTPUT_FILE

print(f"Saving to: {output_path}")

save_start = time.time()
save_file(save_dict, str(output_path))
save_elapsed = time.time() - save_start

file_size_mb = output_path.stat().st_size / 1e6

print(f"\n✓ Saved successfully")
print(f"  File: {output_path}")
print(f"  Size: {file_size_mb:.1f} MB")
print(f"  Save time: {save_elapsed:.1f} seconds")
print(f"  Recorded steps: {len(recorded_data['recorded_steps'])}")
print(f"  Step range: {recorded_data['recorded_steps'][0]} to {recorded_data['recorded_steps'][-1]}")


Preparing data for save...


Stacking 1001 recorded states...
Saving to: ../tensors/Flannel/1.20a_flannel_1.safetensors

✓ Saved successfully
  File: ../tensors/Flannel/1.20a_flannel_1.safetensors
  Size: 5145.2 MB
  Save time: 18.4 seconds
  Recorded steps: 1001
  Step range: 0 to 1000


## Quick Verification

In [16]:
print(f"\n{'='*80}")
print(f"QUICK VERIFICATION")
print(f"{'='*80}\n")

embeddings = recorded_data['embeddings']
losses = recorded_data['losses']

print(f"Data shapes:")
print(f"  embeddings: {embeddings.shape}")
print(f"  losses: {losses.shape}")
print()

# Analyze dead token movement
W_step0 = embeddings[0, dead_indices].float()
W_step1000 = embeddings[-1, dead_indices].float()

# Compute displacements
displacements = torch.norm(W_step1000 - W_step0, dim=1)
max_displacement = displacements.max().item()
mean_displacement = displacements.mean().item()
median_displacement = displacements.median().item()

print(f"Dead token displacement (steps 0 → {NUM_TRAIN_STEPS}):")
print(f"  Max: {max_displacement:.2e}")
print(f"  Mean: {mean_displacement:.2e}")
print(f"  Median: {median_displacement:.2e}")
print()

# Compare to live tokens
W_live_step0 = embeddings[0, live_indices].float()
W_live_step1000 = embeddings[-1, live_indices].float()
live_displacements = torch.norm(W_live_step1000 - W_live_step0, dim=1)

print(f"Live token displacement (steps 0 → {NUM_TRAIN_STEPS}):")
print(f"  Max: {live_displacements.max().item():.2e}")
print(f"  Mean: {live_displacements.mean().item():.2e}")
print(f"  Median: {live_displacements.median().item():.2e}")
print()

print(f"Loss trajectory:")
print(f"  Step 1: {losses[1].float().item():.4f}")
print(f"  Step {NUM_TRAIN_STEPS}: {losses[-1].float().item():.4f}")
print(f"  Reduction: {(losses[1].float() - losses[-1].float()).item():.4f}")

print(f"\n{'='*80}")


QUICK VERIFICATION

Data shapes:
  embeddings: torch.Size([1001, 10000, 64])
  losses: torch.Size([1001])

Dead token displacement (steps 0 → 1000):
  Max: 5.66e-01
  Mean: 5.05e-01
  Median: 5.04e-01

Live token displacement (steps 0 → 1000):
  Max: 6.22e-01
  Mean: 1.82e-01
  Median: 1.70e-01

Loss trajectory:
  Step 1: 9.2500
  Step 1000: 6.9688
  Reduction: 2.2812



## Summary

In [17]:
print(f"\n{'='*80}")
print(f"FLANNEL 1 COMPLETE")
print(f"{'='*80}\n")

print(f"Experiment: Baseline 1,000-step run with engineered dead tokens")
print(f"  Tokenizer: Flannel (10k vocab, char-level BPE)")
print(f"  Corpus: 5MB English-only FineWeb")
print(f"  Training: {NUM_TRAIN_STEPS:,} steps (~{expected_epochs:.1f} epochs)")
print(f"  Architecture: 2 layers, 2 heads, 64D, bfloat16-native")
print()
print(f"Token demographics:")
print(f"  Live: {n_live:,} tokens ({100*n_live/VOCAB_SIZE:.1f}%)")
print(f"  Dead: {n_dead:,} tokens ({100*n_dead/VOCAB_SIZE:.1f}%)")
print()
print(f"Data saved: {output_path}")
print(f"Size: {file_size_mb:.1f} MB")
print()
print(f"Next steps:")
print(f"  1. Run lattice hop detection (1.17b pattern)")
print(f"  2. Analyze dead token clustering")
print(f"  3. Look for spongecrystal formation")
print(f"  4. Compare to Wordybird dynamics")

print(f"\n{'='*80}")


FLANNEL 1 COMPLETE

Experiment: Baseline 1,000-step run with engineered dead tokens
  Tokenizer: Flannel (10k vocab, char-level BPE)
  Corpus: 5MB English-only FineWeb
  Training: 1,000 steps (~3.0 epochs)
  Architecture: 2 layers, 2 heads, 64D, bfloat16-native

Token demographics:
  Live: 6,301 tokens (63.0%)
  Dead: 3,699 tokens (37.0%)

Data saved: ../tensors/Flannel/1.20a_flannel_1.safetensors
Size: 5145.2 MB

Next steps:
  1. Run lattice hop detection (1.17b pattern)
  2. Analyze dead token clustering
  3. Look for spongecrystal formation
  4. Compare to Wordybird dynamics

