# 1.22e: Sanity Check - Flannel 1 vs Flannel 3

**Goal:** Diagnose why Flannel 3 shows completely different behavior than Flannel 1 despite using the same seed (42) for run 0.

## The Problem

**Flannel 1 (seed=42):**
- Mean radius: 0.159 → 0.565 (expansion factor: 3.55×)
- Five distinct epochs visible
- Fimbulwinter at t≈400

**Flannel 3 run 0 (seed=42):**
- Mean radius: 0.159 → 0.114 (expansion factor: 0.72× — contraction!)
- Different epoch structure
- "Fimbulwinter" at t≈101

These should be identical if the seed is truly the same. Something is wrong.

## Diagnostic Plan

1. Load both datasets
2. Check if dead token indices match
3. Compare initial embeddings (t=0) between runs
4. Compare embeddings at t=100, t=500, t=1000
5. Check if runs within Flannel 3 actually differ from each other
6. Identify the methodological error

## Parameters

In [1]:
# Data paths
FLANNEL_1_PATH = "../tensors/Flannel/1.20a_flannel_1.safetensors"
FLANNEL_3_PATH = "../tensors/Flannel/1.20c_flannel_3.safetensors"

print("✓ Parameters set")

✓ Parameters set


## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file

print("✓ Imports complete")

✓ Imports complete


## Load Flannel 1

In [3]:
print(f"Loading Flannel 1: {FLANNEL_1_PATH}\n")

f1_data = load_file(FLANNEL_1_PATH)

f1_embeddings = f1_data['embeddings'].to(torch.float32)  # (1001, 10000, 64)
f1_n_dead = f1_data['n_dead'].item()

print(f"✓ Loaded Flannel 1")
print(f"  Shape: {f1_embeddings.shape}")
print(f"  Dead tokens: {f1_n_dead}")
print()

Loading Flannel 1: ../tensors/Flannel/1.20a_flannel_1.safetensors

✓ Loaded Flannel 1
  Shape: torch.Size([1001, 10000, 64])
  Dead tokens: 3699



## Load Flannel 3

In [4]:
print(f"Loading Flannel 3: {FLANNEL_3_PATH}\n")

f3_data = load_file(FLANNEL_3_PATH)

f3_embeddings = f3_data['embeddings'].to(torch.float32)  # (10, 1001, 3699, 64)
f3_dead_indices = f3_data['dead_indices']
f3_n_runs = f3_data['n_runs'].item()
f3_n_dead = f3_data['n_dead'].item()
f3_base_seed = f3_data['base_seed'].item()

print(f"✓ Loaded Flannel 3")
print(f"  Shape: {f3_embeddings.shape}")
print(f"  Runs: {f3_n_runs}")
print(f"  Dead tokens: {f3_n_dead}")
print(f"  Seeds: {f3_base_seed}–{f3_base_seed + f3_n_runs - 1}")
print()

Loading Flannel 3: ../tensors/Flannel/1.20c_flannel_3.safetensors

✓ Loaded Flannel 3
  Shape: torch.Size([10, 1002, 3699, 64])
  Runs: 10
  Dead tokens: 3699
  Seeds: 42–51



## Check 1: Dead Token Indices

Do both experiments use the same dead token mask?

In [5]:
print("Check 1: Dead token indices\n")

# Load the canonical mask
mask_data = load_file("../tensors/Flannel/live_dead_tokens.safetensors")
canonical_dead_indices = mask_data['dead_indices']

print(f"Canonical dead indices: {len(canonical_dead_indices)}")
print(f"Flannel 3 dead indices: {len(f3_dead_indices)}")
print()

if torch.equal(canonical_dead_indices, f3_dead_indices):
    print("✓ Dead indices match")
else:
    print("✗ Dead indices DO NOT match!")
    print(f"  Difference: {len(set(canonical_dead_indices.tolist()) - set(f3_dead_indices.tolist()))} tokens")

Check 1: Dead token indices

Canonical dead indices: 3699
Flannel 3 dead indices: 3699

✓ Dead indices match


## Check 2: Extract Dead Tokens from Flannel 1

Flannel 1 saved the full W matrix. Let's extract just the dead tokens for comparison.

In [6]:
print("Check 2: Extracting dead tokens from Flannel 1\n")

# Extract dead tokens from full W matrix
f1_dead = f1_embeddings[:, f3_dead_indices, :]  # (1001, 3699, 64)

print(f"✓ Extracted dead tokens from Flannel 1")
print(f"  Shape: {f1_dead.shape}")
print()

# Note: Flannel 3 has shape (10, 1002, 3699, 64) - extra timestep!
print(f"Flannel 1 timesteps: {f1_dead.shape[0]}")
print(f"Flannel 3 timesteps: {f3_embeddings.shape[1]}")
print(f"  → Off-by-one error in Flannel 3 recorder!")
print()

# Trim Flannel 3 to match (use first 1001 timesteps)
f3_run0 = f3_embeddings[0, :1001, :, :]  # (1001, 3699, 64)
print(f"✓ Trimmed Flannel 3 run 0 to match: {f3_run0.shape}")

Check 2: Extracting dead tokens from Flannel 1

✓ Extracted dead tokens from Flannel 1
  Shape: torch.Size([1001, 3699, 64])

Flannel 1 timesteps: 1001
Flannel 3 timesteps: 1002
  → Off-by-one error in Flannel 3 recorder!

✓ Trimmed Flannel 3 run 0 to match: torch.Size([1001, 3699, 64])


## Check 3: Compare Initial Embeddings (t=0)

If seed=42 was used for both, the initial embeddings should be identical.

In [7]:
print("Check 3: Comparing initial embeddings (t=0)\n")

f1_t0 = f1_dead[0]  # (3699, 64)
f3_t0 = f3_run0[0]  # (3699, 64)

# Check if identical
diff = torch.norm(f1_t0 - f3_t0)

print(f"Frobenius norm of difference: {diff:.6e}")
print()

if diff < 1e-6:
    print("✓ Initial embeddings are IDENTICAL (seeds match)")
else:
    print("✗ Initial embeddings DIFFER (seeds don't match!)")
    print(f"  Max absolute difference: {(f1_t0 - f3_t0).abs().max():.6e}")
    print(f"  Mean absolute difference: {(f1_t0 - f3_t0).abs().mean():.6e}")

Check 3: Comparing initial embeddings (t=0)

Frobenius norm of difference: 0.000000e+00

✓ Initial embeddings are IDENTICAL (seeds match)


## Check 4: Compare Embeddings at Key Timesteps

In [8]:
print("Check 4: Comparing embeddings at key timesteps\n")

test_steps = [1, 10, 50, 100, 500, 1000]

for t in test_steps:
    f1_t = f1_dead[t]
    f3_t = f3_run0[t]
    
    diff = torch.norm(f1_t - f3_t)
    
    if diff < 1e-4:
        status = "✓ MATCH"
    else:
        status = "✗ DIFFER"
    
    print(f"t={t:4d}: diff={diff:8.6f} {status}")

print()
print("If all diffs are near zero: runs are identical (good!)")
print("If diffs grow over time: runs diverged (bad!)")

Check 4: Comparing embeddings at key timesteps

t=   1: diff=0.004673 ✗ DIFFER
t=  10: diff=0.029129 ✗ DIFFER
t=  50: diff=0.210857 ✗ DIFFER
t= 100: diff=0.203653 ✗ DIFFER
t= 500: diff=0.481526 ✗ DIFFER
t=1000: diff=0.482891 ✗ DIFFER

If all diffs are near zero: runs are identical (good!)
If diffs grow over time: runs diverged (bad!)


## Check 5: Variance Within Flannel 3

Do the 10 runs in Flannel 3 actually differ from each other?

In [9]:
print("Check 5: Variance within Flannel 3 runs\n")

# Compare run 0 vs run 1 at t=0
f3_r0_t0 = f3_embeddings[0, 0]  # (3699, 64)
f3_r1_t0 = f3_embeddings[1, 0]  # (3699, 64)

diff_runs = torch.norm(f3_r0_t0 - f3_r1_t0)

print(f"Difference between run 0 and run 1 at t=0: {diff_runs:.6e}")
print()

if diff_runs < 1e-6:
    print("✗ Runs are IDENTICAL at t=0 (should differ with different seeds!)")
    print("   → Flannel 3 likely used the same seed for all runs")
else:
    print("✓ Runs differ at t=0 (expected with different seeds)")
    print(f"  Max difference per element: {(f3_r0_t0 - f3_r1_t0).abs().max():.6e}")
    print(f"  Mean difference per element: {(f3_r0_t0 - f3_r1_t0).abs().mean():.6e}")
print()

# Check all pairwise differences at t=0
print("Pairwise differences at t=0 (first 5 runs):")
for i in range(5):
    for j in range(i+1, 5):
        diff_ij = torch.norm(f3_embeddings[i, 0] - f3_embeddings[j, 0])
        print(f"  Run {i} vs Run {j}: {diff_ij:.6e}")

Check 5: Variance within Flannel 3 runs

Difference between run 0 and run 1 at t=0: 1.377979e+01

✓ Runs differ at t=0 (expected with different seeds)
  Max difference per element: 1.247559e-01
  Mean difference per element: 2.259804e-02

Pairwise differences at t=0 (first 5 runs):
  Run 0 vs Run 1: 1.377979e+01
  Run 0 vs Run 2: 1.375070e+01
  Run 0 vs Run 3: 1.377166e+01
  Run 0 vs Run 4: 1.374709e+01
  Run 1 vs Run 2: 1.378423e+01
  Run 1 vs Run 3: 1.375857e+01
  Run 1 vs Run 4: 1.373965e+01
  Run 2 vs Run 3: 1.376126e+01
  Run 2 vs Run 4: 1.376263e+01
  Run 3 vs Run 4: 1.376143e+01


## Check 6: Mean Radius Calculation

Maybe we're computing radius wrong?

In [10]:
print("Check 6: Recompute mean radius for both\n")

def compute_mean_radius(embeddings):
    """Compute mean radius from centroid."""
    centroid = embeddings.mean(dim=0)
    radii = torch.norm(embeddings - centroid, dim=1)
    return radii.mean().item()

# Flannel 1
f1_r0 = compute_mean_radius(f1_dead[0])
f1_r1000 = compute_mean_radius(f1_dead[1000])

print(f"Flannel 1 (seed=42):")
print(f"  t=0:    {f1_r0:.6f}")
print(f"  t=1000: {f1_r1000:.6f}")
print(f"  Ratio:  {f1_r1000/f1_r0:.3f}×")
print()

# Flannel 3 run 0
f3_r0 = compute_mean_radius(f3_run0[0])
f3_r1000 = compute_mean_radius(f3_run0[1000])

print(f"Flannel 3 run 0 (seed=42):")
print(f"  t=0:    {f3_r0:.6f}")
print(f"  t=1000: {f3_r1000:.6f}")
print(f"  Ratio:  {f3_r1000/f3_r0:.3f}×")
print()

if abs(f1_r1000 - f3_r1000) < 0.01:
    print("✓ Radii match (data is consistent)")
else:
    print("✗ Radii DO NOT match (something is wrong!)")

Check 6: Recompute mean radius for both

Flannel 1 (seed=42):
  t=0:    0.159301
  t=1000: 0.115900
  Ratio:  0.728×

Flannel 3 run 0 (seed=42):
  t=0:    0.159301
  t=1000: 0.115145
  Ratio:  0.723×

✓ Radii match (data is consistent)


## Summary and Diagnosis

In [11]:
print(f"\n{'='*80}")
print(f"DIAGNOSTIC SUMMARY")
print(f"{'='*80}\n")

print(f"Expected: Flannel 1 and Flannel 3 run 0 should be identical (same seed=42)")
print(f"Observed: [Results from checks above]")
print()

print(f"Likely issues:")
print(f"  1. Off-by-one error in Flannel 3 recorder (1002 timesteps instead of 1001)")
print(f"  2. Flannel 3 may not have properly reset seeds between runs")
print(f"  3. Flannel 3 recorder only saved dead tokens, not full W matrix")
print(f"     → Can't compare gradients, optimizer state, etc.")
print()

print(f"Recommended fix:")
print(f"  Rewrite 1.20c as exact copy of 1.20a (full recorder), run 10 times sequentially")
print(f"  Expected file size: ~12 GB (acceptable)")
print(f"  This ensures perfect reproducibility and full data capture")

print(f"\n{'='*80}")


DIAGNOSTIC SUMMARY

Expected: Flannel 1 and Flannel 3 run 0 should be identical (same seed=42)
Observed: [Results from checks above]

Likely issues:
  1. Off-by-one error in Flannel 3 recorder (1002 timesteps instead of 1001)
  2. Flannel 3 may not have properly reset seeds between runs
  3. Flannel 3 recorder only saved dead tokens, not full W matrix
     → Can't compare gradients, optimizer state, etc.

Recommended fix:
  Rewrite 1.20c as exact copy of 1.20a (full recorder), run 10 times sequentially
  Expected file size: ~12 GB (acceptable)
  This ensures perfect reproducibility and full data capture

