# 1.20b: Flannel 2 - Untied Weights Control (1,000 Steps)

**Experiment:** Same as Flannel 1, but with **untied embedding weights** (E ≠ W).

## Why Flannel 2?

**Goal:** Isolate gradient contributions from embedding vs unembedding layers.

**In Flannel 1 (tied weights):**
- E and W are the same tensor
- Gradient = ∂L/∂W (unembedding, affects all tokens) + ∂L/∂E (embedding, affects only input tokens)
- Both contributions accumulate into the same parameter

**In Flannel 2 (untied weights):**
- E and W are independent tensors
- ∂L/∂W affects W only (unembedding backscatter)
- ∂L/∂E affects E only (embedding supervision)
- Dead tokens in W get gradients, dead tokens in E get zero gradient

## The Question We're Testing

Does the "inflationary expansion" at step 0→1 come from:
1. **Unembedding dynamics** (would persist in W even when untied)
2. **Tied weight interaction** (would disappear or change when untied)
3. **Embedding dynamics** (would show up in E but not W)

**Bonus question:** Does W₁[t] ≈ W₂[t] + E₂[t]? (Linear decomposition hypothesis)

## Flannel 2 Parameters

**Identical to Flannel 1 except:**
- `tie_word_embeddings=False`
- Record both E and W matrices at every step
- Same random seed (42) for direct comparison

## Parameters

In [1]:
# Model architecture
VOCAB_SIZE = 10000  # Flannel
HIDDEN_DIM = 64
N_LAYER = 2
N_HEAD = 2
MAX_SEQ_LEN = 128

# Training
BATCH_SIZE = 32
NUM_TRAIN_STEPS = 1000
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.0

# Optimizer: Adam
ADAM_BETA1 = 0.9
ADAM_BETA2 = 0.999
ADAM_EPSILON = 1e-8

# Initialization (bfloat16 native)
INIT_SCALE = 0.02  # N(0, 0.02)

# Data
TOKENIZER_PATH = "../data/flannel_tokenizer_chars.json"
CORPUS_PATH = "../data/flannel_model_corpus.txt"
TOKEN_MASK_PATH = "../tensors/Flannel/live_dead_tokens.safetensors"
OUTPUT_DIR = "../tensors/Flannel"
OUTPUT_FILE = "1.20b_flannel_2.safetensors"

# Instrumentation
RECORD_EVERY_N_STEPS = 1  # Record every step

RANDOM_SEED = 42  # SAME AS FLANNEL 1

print("✓ Parameters set")

✓ Parameters set


## Imports

In [2]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from tokenizers import Tokenizer
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file, load_file
import time

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print("✓ Imports complete")

✓ Imports complete


## Device Detection

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

Using device: mps


## Load Tokenizer

In [4]:
print(f"Loading Flannel tokenizer: {TOKENIZER_PATH}\n")

tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))
vocab = tokenizer.get_vocab()

print(f"✓ Loaded Flannel tokenizer")
print(f"  Vocabulary size: {len(vocab):,} tokens")

Loading Flannel tokenizer: ../data/flannel_tokenizer_chars.json

✓ Loaded Flannel tokenizer
  Vocabulary size: 10,000 tokens


## Load Corpus and Tokenize

In [5]:
print(f"\nLoading corpus: {CORPUS_PATH}\n")

with open(CORPUS_PATH, 'r', encoding='utf-8') as f:
    corpus_text = f.read()

corpus_bytes = len(corpus_text.encode('utf-8'))
corpus_mb = corpus_bytes / (1024 * 1024)

print(f"✓ Loaded corpus")
print(f"  Size: {corpus_mb:.2f} MB")
print(f"  Characters: {len(corpus_text):,}")
print()

# Tokenize
print("Tokenizing corpus...\n")
encoding = tokenizer.encode(corpus_text)
tokens = encoding.ids

print(f"✓ Tokenized")
print(f"  Tokens: {len(tokens):,}")
print()

# Calculate expected epochs
tokens_per_step = BATCH_SIZE * MAX_SEQ_LEN
steps_per_epoch = len(tokens) / tokens_per_step
expected_epochs = NUM_TRAIN_STEPS / steps_per_epoch

print(f"Training coverage:")
print(f"  Tokens per step: {tokens_per_step:,}")
print(f"  Steps per epoch: {steps_per_epoch:.1f}")
print(f"  Expected epochs in {NUM_TRAIN_STEPS:,} steps: {expected_epochs:.2f}")
print()

# Pre-load to device
corpus_tensor = torch.tensor(tokens, dtype=torch.long, device=device)
print(f"✓ Corpus on device: {device}")


Loading corpus: ../data/flannel_model_corpus.txt

✓ Loaded corpus
  Size: 5.01 MB
  Characters: 5,225,690

Tokenizing corpus...

✓ Tokenized
  Tokens: 1,371,328

Training coverage:
  Tokens per step: 4,096
  Steps per epoch: 334.8
  Expected epochs in 1,000 steps: 2.99

✓ Corpus on device: mps


## Load Token Masks

In [6]:
print(f"\nLoading token masks: {TOKEN_MASK_PATH}\n")

mask_data = load_file(TOKEN_MASK_PATH)
live_mask = mask_data['live_mask']
dead_mask = mask_data['dead_mask']
live_indices = mask_data['live_indices']
dead_indices = mask_data['dead_indices']
token_occurrence_counts = mask_data['token_occurrence_counts']

n_live = live_mask.sum().item()
n_dead = dead_mask.sum().item()

print(f"✓ Loaded token masks")
print(f"  Live tokens: {n_live:,} ({100*n_live/VOCAB_SIZE:.1f}%)")
print(f"  Dead tokens: {n_dead:,} ({100*n_dead/VOCAB_SIZE:.1f}%)")
print()

# Token frequency stats for live tokens
live_counts = token_occurrence_counts[live_indices]
print(f"Live token frequency:")
print(f"  Min occurrences: {live_counts.min().item():,}")
print(f"  Max occurrences: {live_counts.max().item():,}")
print(f"  Mean occurrences: {live_counts.float().mean().item():.1f}")
print(f"  Median occurrences: {live_counts.float().median().item():.0f}")


Loading token masks: ../tensors/Flannel/live_dead_tokens.safetensors

✓ Loaded token masks
  Live tokens: 6,301 (63.0%)
  Dead tokens: 3,699 (37.0%)

Live token frequency:
  Min occurrences: 1
  Max occurrences: 45,630
  Mean occurrences: 217.6
  Median occurrences: 68


## Dataset

In [7]:
class TokenDataset(Dataset):
    def __init__(self, corpus_tensor, max_seq_len):
        self.corpus = corpus_tensor
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.corpus) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.corpus[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': chunk[:-1],
            'labels': chunk[1:]
        }

dataset = TokenDataset(corpus_tensor, MAX_SEQ_LEN)
print(f"\n✓ Dataset: {len(dataset):,} examples")


✓ Dataset: 1,371,200 examples


## Model

In [8]:
print(f"\nCreating model...\n")

config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=False,  # ← KEY DIFFERENCE: UNTIED WEIGHTS
)

model = GPT2LMHeadModel(config)
model = model.to(torch.bfloat16).to(device)

total_params = sum(p.numel() for p in model.parameters())
embedding_params = model.transformer.wte.weight.numel()
unembedding_params = model.lm_head.weight.numel()

print(f"✓ Model created (UNTIED WEIGHTS)")
print(f"  Total parameters: {total_params:,}")
print(f"  Embedding parameters (E): {embedding_params:,}")
print(f"  Unembedding parameters (W): {unembedding_params:,}")
print(f"  Other parameters: {total_params - embedding_params - unembedding_params:,}")
print(f"  E and W are independent tensors: {model.transformer.wte.weight is not model.lm_head.weight}")


Creating model...

✓ Model created (UNTIED WEIGHTS)
  Total parameters: 1,388,288
  Embedding parameters (E): 640,000
  Unembedding parameters (W): 640,000
  Other parameters: 108,288
  E and W are independent tensors: True


## Initialization (bfloat16 Native)

**Initialize E and W independently with same seed for direct comparison to Flannel 1.**

In [9]:
print(f"\n{'='*80}")
print(f"INITIALIZING: N(0, {INIT_SCALE}) bfloat16-native (UNTIED)")
print(f"{'='*80}\n")

torch.manual_seed(RANDOM_SEED)

# Generate SAME initialization as Flannel 1 (same seed, same distribution)
# But apply to both E and W independently
init_f32 = torch.randn(VOCAB_SIZE, HIDDEN_DIM, dtype=torch.float32, device=device) * INIT_SCALE
init_bf16 = init_f32.to(torch.bfloat16)

print(f"Initialization: N(0, {INIT_SCALE})")
print(f"  Shape: {init_bf16.shape}")
print(f"  Each token initialized independently")
print()

# Assign to BOTH embedding and unembedding
with torch.no_grad():
    model.transformer.wte.weight[:] = init_bf16  # E
    model.lm_head.weight[:] = init_bf16  # W (same starting point)

print(f"✓ Initialized both E and W (pure bfloat16)")
print(f"  E shape: {model.transformer.wte.weight.shape}")
print(f"  W shape: {model.lm_head.weight.shape}")
print(f"  Dtype: {model.transformer.wte.weight.dtype}")
print(f"  E == W initially: {torch.allclose(model.transformer.wte.weight, model.lm_head.weight)}")
print()

# Verify initialization stats for dead tokens
E_check = model.transformer.wte.weight.cpu().float()
E_dead = E_check[dead_indices]

centroid = E_dead.mean(dim=0)
centroid_norm = torch.norm(centroid).item()
radii = torch.norm(E_dead - centroid, dim=1)
mean_radius = radii.mean().item()
max_radius = radii.max().item()

print(f"Initial dead token statistics ({n_dead:,} tokens):")
print(f"  Centroid norm: {centroid_norm:.6f}")
print(f"  Mean radius from centroid: {mean_radius:.6f}")
print(f"  Max radius from centroid: {max_radius:.6f}")
print(f"  Bounding hypersphere volume ∝ R^{HIDDEN_DIM} = {max_radius**HIDDEN_DIM:.2e}")
print()
print(f"  ✓ Dead tokens distributed in hypersphere (standard init)")
print(f"\n{'='*80}\n")


INITIALIZING: N(0, 0.02) bfloat16-native (UNTIED)

Initialization: N(0, 0.02)
  Shape: torch.Size([10000, 64])
  Each token initialized independently

✓ Initialized both E and W (pure bfloat16)
  E shape: torch.Size([10000, 64])
  W shape: torch.Size([10000, 64])
  Dtype: torch.bfloat16
  E == W initially: True

Initial dead token statistics (3,699 tokens):
  Centroid norm: 0.002689
  Mean radius from centroid: 0.159301
  Max radius from centroid: 0.206368
  Bounding hypersphere volume ∝ R^64 = 1.37e-44

  ✓ Dead tokens distributed in hypersphere (standard init)




## Comprehensive Recorder (Modified for Untied Weights)

**Records both E and W matrices at every step.**

In [10]:
class UntiedRecorder:
    """Records E, W, gradients, optimizer state, logits, loss at every step in bfloat16."""
    
    def __init__(self, vocab_size, hidden_dim, record_every_n):
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.record_every_n = record_every_n
        
        # Storage (lists of tensors, keep in RAM)
        self.recorded_steps = []
        self.embeddings_E = []    # Embedding matrix E
        self.embeddings_W = []    # Unembedding matrix W
        self.grads_E = []
        self.grads_W = []
        self.momentum_E = []
        self.momentum_W = []
        self.variance_E = []
        self.variance_W = []
        self.logits = []
        self.losses = []
        
        # Temporary storage
        self.current_step = 0
        self.recorded_initial = False
        self.grad_E_before = None
        self.grad_W_before = None
        self.loss_value = None
        self.logits_sample = None
    
    def record_initial_state(self, model, optimizer):
        """Record step 0: initial state before training."""
        if not self.recorded_initial:
            E = model.transformer.wte.weight.data.clone().cpu().bfloat16()
            W = model.lm_head.weight.data.clone().cpu().bfloat16()
            
            # Step 0: no gradients, no optimizer state yet (zeros)
            self.recorded_steps.append(0)
            self.embeddings_E.append(E)
            self.embeddings_W.append(W)
            self.grads_E.append(torch.zeros_like(E))
            self.grads_W.append(torch.zeros_like(W))
            self.momentum_E.append(torch.zeros_like(E))
            self.momentum_W.append(torch.zeros_like(W))
            self.variance_E.append(torch.zeros_like(E))
            self.variance_W.append(torch.zeros_like(W))
            self.logits.append(torch.zeros(self.vocab_size, dtype=torch.bfloat16))
            self.losses.append(torch.tensor(float('nan'), dtype=torch.bfloat16))
            
            self.recorded_initial = True
            self.current_step = 1
            
            print(f"✓ Recorded initial state (step 0)")
            print(f"  E and W are {'IDENTICAL' if torch.allclose(E.float(), W.float()) else 'DIFFERENT'}")
    
    def record_before_step(self, model, loss, logits):
        """Call after forward/backward, before optimizer step."""
        if self.current_step % self.record_every_n == 0:
            # Capture gradients for both E and W in bfloat16
            if model.transformer.wte.weight.grad is not None:
                self.grad_E_before = model.transformer.wte.weight.grad.clone().cpu().bfloat16()
            else:
                self.grad_E_before = torch.zeros(self.vocab_size, self.hidden_dim, dtype=torch.bfloat16)
            
            if model.lm_head.weight.grad is not None:
                self.grad_W_before = model.lm_head.weight.grad.clone().cpu().bfloat16()
            else:
                self.grad_W_before = torch.zeros(self.vocab_size, self.hidden_dim, dtype=torch.bfloat16)
            
            # Capture loss
            self.loss_value = loss.item()
            
            # Capture logits from first sequence, last position in bfloat16
            self.logits_sample = logits[0, -1, :].detach().cpu().bfloat16()
    
    def record_after_step(self, model, optimizer):
        """Call after optimizer step."""
        if self.current_step % self.record_every_n == 0:
            if self.grad_E_before is not None and self.grad_W_before is not None and self.loss_value is not None:
                # Capture E and W in bfloat16
                E = model.transformer.wte.weight.data.clone().cpu().bfloat16()
                W = model.lm_head.weight.data.clone().cpu().bfloat16()

                # Capture optimizer state for E
                param_E = model.transformer.wte.weight
                if param_E in optimizer.state:
                    state = optimizer.state[param_E]
                    mom_E = state.get('exp_avg', torch.zeros_like(E)).clone().cpu().bfloat16()
                    var_E = state.get('exp_avg_sq', torch.zeros_like(E)).clone().cpu().bfloat16()
                else:
                    mom_E = torch.zeros_like(E)
                    var_E = torch.zeros_like(E)

                # Capture optimizer state for W
                param_W = model.lm_head.weight
                if param_W in optimizer.state:
                    state = optimizer.state[param_W]
                    mom_W = state.get('exp_avg', torch.zeros_like(W)).clone().cpu().bfloat16()
                    var_W = state.get('exp_avg_sq', torch.zeros_like(W)).clone().cpu().bfloat16()
                else:
                    mom_W = torch.zeros_like(W)
                    var_W = torch.zeros_like(W)

                # Store everything
                self.recorded_steps.append(self.current_step)
                self.embeddings_E.append(E)
                self.embeddings_W.append(W)
                self.grads_E.append(self.grad_E_before)
                self.grads_W.append(self.grad_W_before)
                self.momentum_E.append(mom_E)
                self.momentum_W.append(mom_W)
                self.variance_E.append(var_E)
                self.variance_W.append(var_W)
                self.logits.append(self.logits_sample)
                self.losses.append(torch.tensor(self.loss_value, dtype=torch.bfloat16))

                # Clear temp storage
                self.grad_E_before = None
                self.grad_W_before = None
                self.loss_value = None
                self.logits_sample = None
                
                # Progress indicator every 50 steps
                if self.current_step % 50 == 0:
                    print(f"  Recorded step {self.current_step}")

        self.current_step += 1
    
    def get_data(self):
        """Return recorded data as stacked tensors."""
        print(f"\nStacking {len(self.embeddings_E)} recorded states...")
        
        return {
            'recorded_steps': torch.tensor(self.recorded_steps, dtype=torch.long),
            'embeddings_E': torch.stack(self.embeddings_E) if self.embeddings_E else torch.tensor([]),
            'embeddings_W': torch.stack(self.embeddings_W) if self.embeddings_W else torch.tensor([]),
            'grads_E': torch.stack(self.grads_E) if self.grads_E else torch.tensor([]),
            'grads_W': torch.stack(self.grads_W) if self.grads_W else torch.tensor([]),
            'momentum_E': torch.stack(self.momentum_E) if self.momentum_E else torch.tensor([]),
            'momentum_W': torch.stack(self.momentum_W) if self.momentum_W else torch.tensor([]),
            'variance_E': torch.stack(self.variance_E) if self.variance_E else torch.tensor([]),
            'variance_W': torch.stack(self.variance_W) if self.variance_W else torch.tensor([]),
            'logits': torch.stack(self.logits) if self.logits else torch.tensor([]),
            'losses': torch.stack(self.losses) if self.losses else torch.tensor([]),
        }

print("✓ UntiedRecorder class defined")

✓ UntiedRecorder class defined


## Custom Trainer with Instrumentation

In [11]:
class InstrumentedTrainer(Trainer):
    def __init__(self, recorder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.recorder = recorder
        self.last_logits = None

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """Override to capture logits."""
        outputs = model(**inputs)
        loss = outputs.loss
        
        # Store logits for recorder
        self.last_logits = outputs.logits
        
        return (loss, outputs) if return_outputs else loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        """Override to inject recording."""
        # Standard forward + backward
        loss = super().training_step(model, inputs, num_items_in_batch)
        
        # Record BEFORE optimizer step
        self.recorder.record_before_step(model, loss, self.last_logits)
        
        return loss

    def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, **kwargs):
        """Override to record AFTER optimizer step."""
        # Record AFTER optimizer updates parameters
        self.recorder.record_after_step(model, self.optimizer)
        
        # Call parent
        super()._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, **kwargs)

print("✓ InstrumentedTrainer defined")

✓ InstrumentedTrainer defined


## Training Configuration

In [12]:
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

recorder = UntiedRecorder(VOCAB_SIZE, HIDDEN_DIM, RECORD_EVERY_N_STEPS)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    adam_beta1=ADAM_BETA1,
    adam_beta2=ADAM_BETA2,
    adam_epsilon=ADAM_EPSILON,
    optim="adamw_torch",
    logging_steps=50,
    save_steps=NUM_TRAIN_STEPS + 1,  # Don't save checkpoints
    save_total_limit=0,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    bf16=True,  # Native bfloat16 training
    seed=RANDOM_SEED,
    report_to="none",
    disable_tqdm=False,
)

trainer = InstrumentedTrainer(
    recorder=recorder,
    model=model,
    args=training_args,
    train_dataset=dataset,
)

print(f"\n✓ Trainer ready (Adam, bf16=True, batch_size={BATCH_SIZE}, UNTIED WEIGHTS)")


✓ Trainer ready (Adam, bf16=True, batch_size=32, UNTIED WEIGHTS)


## Record Initial State

In [13]:
print()
recorder.record_initial_state(model, trainer.optimizer)


✓ Recorded initial state (step 0)
  E and W are IDENTICAL


## Train

**1,000 steps should take ~30 seconds on M4 Pro (slightly slower than tied due to independent gradient computation).**

In [14]:
print(f"\n{'='*80}")
print(f"STARTING FLANNEL 2 TRAINING (UNTIED WEIGHTS)")
print(f"{'='*80}")
print(f"\nConfiguration:")
print(f"  Vocabulary: {VOCAB_SIZE:,} tokens (Flannel)")
print(f"  Hidden dim: {HIDDEN_DIM}")
print(f"  Live tokens: {n_live:,} ({100*n_live/VOCAB_SIZE:.1f}%)")
print(f"  Dead tokens: {n_dead:,} ({100*n_dead/VOCAB_SIZE:.1f}%)")
print()
print(f"  Initialization: N(0, {INIT_SCALE}) bfloat16-native")
print(f"  Optimizer: Adam (lr={LEARNING_RATE})")
print(f"  Precision: bfloat16 (native)")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Steps: {NUM_TRAIN_STEPS:,} (~{expected_epochs:.1f} epochs)")
print(f"  Recording: every step (BOTH E AND W)")
print(f"  Weight tying: DISABLED (E ≠ W)")
print(f"\n{'='*80}\n")

start_time = time.time()
trainer.train()
elapsed = time.time() - start_time

print(f"\n{'='*80}")
print(f"✓ Training complete")
print(f"  Elapsed time: {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)")
print(f"  Throughput: {NUM_TRAIN_STEPS / elapsed:.1f} steps/second")
print(f"{'='*80}")


STARTING FLANNEL 2 TRAINING (UNTIED WEIGHTS)

Configuration:
  Vocabulary: 10,000 tokens (Flannel)
  Hidden dim: 64
  Live tokens: 6,301 (63.0%)
  Dead tokens: 3,699 (37.0%)

  Initialization: N(0, 0.02) bfloat16-native
  Optimizer: Adam (lr=0.001)
  Precision: bfloat16 (native)
  Batch size: 32
  Steps: 1,000 (~3.0 epochs)
  Recording: every step (BOTH E AND W)
  Weight tying: DISABLED (E ≠ W)




`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
50,8.0354
100,7.1546
150,7.1153
200,7.103
250,7.0892
300,7.0379
350,7.0056
400,6.9657
450,6.9383
500,6.9144


  Recorded step 50
  Recorded step 100
  Recorded step 150
  Recorded step 200
  Recorded step 250
  Recorded step 300
  Recorded step 350
  Recorded step 400
  Recorded step 450
  Recorded step 500
  Recorded step 550
  Recorded step 600
  Recorded step 650
  Recorded step 700
  Recorded step 750
  Recorded step 800
  Recorded step 850
  Recorded step 900
  Recorded step 950
  Recorded step 1000

✓ Training complete
  Elapsed time: 28.5 seconds (0.5 minutes)
  Throughput: 35.1 steps/second


## Save Recorded Data

In [15]:
print(f"\nPreparing data for save...\n")

recorded_data = recorder.get_data()

save_dict = {
    'recorded_steps': recorded_data['recorded_steps'],
    'embeddings_E': recorded_data['embeddings_E'],
    'embeddings_W': recorded_data['embeddings_W'],
    'grads_E': recorded_data['grads_E'],
    'grads_W': recorded_data['grads_W'],
    'momentum_E': recorded_data['momentum_E'],
    'momentum_W': recorded_data['momentum_W'],
    'variance_E': recorded_data['variance_E'],
    'variance_W': recorded_data['variance_W'],
    'logits': recorded_data['logits'],
    'losses': recorded_data['losses'],
    # Metadata
    'init_scale': torch.tensor(INIT_SCALE, dtype=torch.float32),
    'learning_rate': torch.tensor(LEARNING_RATE, dtype=torch.float32),
    'weight_decay': torch.tensor(WEIGHT_DECAY, dtype=torch.float32),
    'adam_beta1': torch.tensor(ADAM_BETA1, dtype=torch.float32),
    'adam_beta2': torch.tensor(ADAM_BETA2, dtype=torch.float32),
    'n_live': torch.tensor(n_live, dtype=torch.long),
    'n_dead': torch.tensor(n_dead, dtype=torch.long),
    'tied_weights': torch.tensor(False, dtype=torch.bool),  # Flag for analysis
}

output_path = Path(OUTPUT_DIR) / OUTPUT_FILE

print(f"Saving to: {output_path}")

save_start = time.time()
save_file(save_dict, str(output_path))
save_elapsed = time.time() - save_start

file_size_mb = output_path.stat().st_size / 1e6

print(f"\n✓ Saved successfully")
print(f"  File: {output_path}")
print(f"  Size: {file_size_mb:.1f} MB")
print(f"  Save time: {save_elapsed:.1f} seconds")
print(f"  Recorded steps: {len(recorded_data['recorded_steps'])}")
print(f"  Step range: {recorded_data['recorded_steps'][0]} to {recorded_data['recorded_steps'][-1]}")


Preparing data for save...


Stacking 1001 recorded states...
Saving to: ../tensors/Flannel/1.20b_flannel_2.safetensors

✓ Saved successfully
  File: ../tensors/Flannel/1.20b_flannel_2.safetensors
  Size: 10270.3 MB
  Save time: 14.1 seconds
  Recorded steps: 1001
  Step range: 0 to 1000


## Quick Verification

In [16]:
print(f"\n{'='*80}")
print(f"QUICK VERIFICATION")
print(f"{'='*80}\n")

E = recorded_data['embeddings_E']
W = recorded_data['embeddings_W']
losses = recorded_data['losses']

print(f"Data shapes:")
print(f"  E (embedding): {E.shape}")
print(f"  W (unembedding): {W.shape}")
print(f"  losses: {losses.shape}")
print()

# Check E vs W divergence
print(f"E vs W divergence over time:")
for t in [0, 1, 10, 100, 1000]:
    if t < len(E):
        diff = torch.norm(E[t].float() - W[t].float(), p='fro').item()
        print(f"  Step {t:4d}: ||E - W||_F = {diff:.4f}")
print()

# Analyze dead token movement IN W (unembedding)
W_step0 = W[0, dead_indices].float()
W_step1000 = W[-1, dead_indices].float()

displacements_W = torch.norm(W_step1000 - W_step0, dim=1)

print(f"Dead token displacement in W (steps 0 → {NUM_TRAIN_STEPS}):")
print(f"  Max: {displacements_W.max().item():.2e}")
print(f"  Mean: {displacements_W.mean().item():.2e}")
print(f"  Median: {displacements_W.median().item():.2e}")
print()

# Analyze dead token movement IN E (embedding)
E_step0 = E[0, dead_indices].float()
E_step1000 = E[-1, dead_indices].float()

displacements_E = torch.norm(E_step1000 - E_step0, dim=1)

print(f"Dead token displacement in E (steps 0 → {NUM_TRAIN_STEPS}):")
print(f"  Max: {displacements_E.max().item():.2e}")
print(f"  Mean: {displacements_E.mean().item():.2e}")
print(f"  Median: {displacements_E.median().item():.2e}")
print()

# Compare to live tokens
W_live_step0 = W[0, live_indices].float()
W_live_step1000 = W[-1, live_indices].float()
live_displacements_W = torch.norm(W_live_step1000 - W_live_step0, dim=1)

E_live_step0 = E[0, live_indices].float()
E_live_step1000 = E[-1, live_indices].float()
live_displacements_E = torch.norm(E_live_step1000 - E_live_step0, dim=1)

print(f"Live token displacement in W (steps 0 → {NUM_TRAIN_STEPS}):")
print(f"  Max: {live_displacements_W.max().item():.2e}")
print(f"  Mean: {live_displacements_W.mean().item():.2e}")
print(f"  Median: {live_displacements_W.median().item():.2e}")
print()

print(f"Live token displacement in E (steps 0 → {NUM_TRAIN_STEPS}):")
print(f"  Max: {live_displacements_E.max().item():.2e}")
print(f"  Mean: {live_displacements_E.mean().item():.2e}")
print(f"  Median: {live_displacements_E.median().item():.2e}")
print()

print(f"Loss trajectory:")
print(f"  Step 1: {losses[1].float().item():.4f}")
print(f"  Step {NUM_TRAIN_STEPS}: {losses[-1].float().item():.4f}")
print(f"  Reduction: {(losses[1].float() - losses[-1].float()).item():.4f}")

print(f"\n{'='*80}")


QUICK VERIFICATION

Data shapes:
  E (embedding): torch.Size([1001, 10000, 64])
  W (unembedding): torch.Size([1001, 10000, 64])
  losses: torch.Size([1001])

E vs W divergence over time:
  Step    0: ||E - W||_F = 0.0000
  Step    1: ||E - W||_F = 0.8465
  Step   10: ||E - W||_F = 5.8140
  Step  100: ||E - W||_F = 30.5975
  Step 1000: ||E - W||_F = 42.8370

Dead token displacement in W (steps 0 → 1000):
  Max: 6.12e-01
  Mean: 5.60e-01
  Median: 5.60e-01

Dead token displacement in E (steps 0 → 1000):
  Max: 0.00e+00
  Mean: 0.00e+00
  Median: 0.00e+00

Live token displacement in W (steps 0 → 1000):
  Max: 7.75e-01
  Mean: 2.40e-01
  Median: 2.30e-01

Live token displacement in E (steps 0 → 1000):
  Max: 7.27e-01
  Mean: 2.33e-01
  Median: 2.27e-01

Loss trajectory:
  Step 1: 9.2500
  Step 1000: 6.7188
  Reduction: 2.5312



## Summary

In [17]:
print(f"\n{'='*80}")
print(f"FLANNEL 2 COMPLETE")
print(f"{'='*80}\n")

print(f"Experiment: Untied weights control (1,000 steps)")
print(f"  Tokenizer: Flannel (10k vocab, char-level BPE)")
print(f"  Corpus: 5MB English-only FineWeb")
print(f"  Training: {NUM_TRAIN_STEPS:,} steps (~{expected_epochs:.1f} epochs)")
print(f"  Architecture: 2 layers, 2 heads, 64D, bfloat16-native")
print(f"  Weight tying: DISABLED (E ≠ W)")
print()
print(f"Token demographics:")
print(f"  Live: {n_live:,} tokens ({100*n_live/VOCAB_SIZE:.1f}%)")
print(f"  Dead: {n_dead:,} tokens ({100*n_dead/VOCAB_SIZE:.1f}%)")
print()
print(f"Data saved: {output_path}")
print(f"Size: {file_size_mb:.1f} MB")
print()
print(f"Key differences from Flannel 1:")
print(f"  - E and W evolve independently")
print(f"  - Dead tokens in E should get zero gradient (never in input)")
print(f"  - Dead tokens in W should get gradients (unembedding backscatter)")
print(f"  - Allows testing: does W₁[t] ≈ W₂[t] + E₂[t]?")
print()
print(f"Next steps:")
print(f"  1. Compare E vs W dynamics (comoving frame for each)")
print(f"  2. Test linear decomposition: W_flannel1 ≈ W_flannel2 + E_flannel2")
print(f"  3. Check if inflationary expansion persists in W when untied")
print(f"  4. Analyze gradient magnitudes for dead tokens in E vs W")

print(f"\n{'='*80}")


FLANNEL 2 COMPLETE

Experiment: Untied weights control (1,000 steps)
  Tokenizer: Flannel (10k vocab, char-level BPE)
  Corpus: 5MB English-only FineWeb
  Training: 1,000 steps (~3.0 epochs)
  Architecture: 2 layers, 2 heads, 64D, bfloat16-native
  Weight tying: DISABLED (E ≠ W)

Token demographics:
  Live: 6,301 tokens (63.0%)
  Dead: 3,699 tokens (37.0%)

Data saved: ../tensors/Flannel/1.20b_flannel_2.safetensors
Size: 10270.3 MB

Key differences from Flannel 1:
  - E and W evolve independently
  - Dead tokens in E should get zero gradient (never in input)
  - Dead tokens in W should get gradients (unembedding backscatter)
  - Allows testing: does W₁[t] ≈ W₂[t] + E₂[t]?

Next steps:
  1. Compare E vs W dynamics (comoving frame for each)
  2. Test linear decomposition: W_flannel1 ≈ W_flannel2 + E_flannel2
  3. Check if inflationary expansion persists in W when untied
  4. Analyze gradient magnitudes for dead tokens in E vs W

