# 11.2a: RMSNorm Diffusion Test

**Hypothesis:** RMSNorm (used in Qwen) causes black hole diffusion during training

## Background

Our previous experiments (run_1001, 100k steps) showed **zero black hole diffusion** using GPT-2 architecture with LayerNorm.

Qwen uses **RMSNorm** instead of LayerNorm. The difference:
- **LayerNorm**: `(x - mean(x)) / std(x)` — centers and normalizes
- **RMSNorm**: `x / sqrt(mean(x²))` — normalizes only, no centering

RMSNorm may create different gradient dynamics that break black hole symmetry.

## Experimental Design

**Control:** GPT-2 with LayerNorm → C=1, P=50 (no diffusion)

**Treatment:** Same model but with RMSNorm → if C>1 or P<50, RMSNorm causes diffusion

## Implementation

We'll create a custom GPT-2 model that replaces all LayerNorm layers with RMSNorm.

## Success Criteria

After 10,000 training steps:
- **Null result:** C = 1, P = 51 (same as control)
- **Positive result:** C > 1 or P < 51 (RMSNorm causes diffusion)

## Parameters

In [1]:
# Model architecture (same as 08.2a)
VOCAB_SIZE = 128      # ASCII tokens
HIDDEN_DIM = 64       # Embedding dimension
N_LAYER = 2           # Transformer layers
N_HEAD = 2            # Attention heads
MAX_SEQ_LEN = 128     # Context window

# Initialization
INIT_MODE = "qwen"    # All tokens start at same point

# Training
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.01
BATCH_SIZE = 32
NUM_TRAIN_STEPS = 10000  # 2× longer than 08.2a to give diffusion more time

# Data
CORPUS_PATH = "../data/training_corpus.txt"
OUTPUT_DIR = "../data/embeddings_128vocab_rmsnorm_test"
OUTPUT_FILE = "embedding_evolution.safetensors"

RANDOM_SEED = 42

## Imports

In [2]:
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments, TrainerCallback
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from safetensors.torch import save_file
from typing import Optional, Tuple

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

## Detect Hardware

In [3]:
if torch.cuda.is_available():
    DEVICE = "cuda"
    print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    DEVICE = "mps"
    print("Using MPS (Apple Silicon)")
else:
    DEVICE = "cpu"
    print("Using CPU")

Using MPS (Apple Silicon)


## Define RMSNorm Layer

RMSNorm implementation following the Qwen paper reference.

In [4]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.
    
    RMSNorm(x) = x / sqrt(mean(x²) + eps) * scale
    
    Unlike LayerNorm, this does NOT subtract the mean (no centering).
    """
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        
        # Compute RMS
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        
        return self.weight * hidden_states.to(input_dtype)

print("✓ RMSNorm defined")

✓ RMSNorm defined


## Replace LayerNorm with RMSNorm

Monkey-patch the GPT-2 model to use RMSNorm instead of LayerNorm.

In [5]:
def replace_layernorm_with_rmsnorm(model):
    """Recursively replace all LayerNorm layers with RMSNorm."""
    for name, module in model.named_children():
        if isinstance(module, nn.LayerNorm):
            # Replace with RMSNorm
            rmsnorm = RMSNorm(
                hidden_size=module.normalized_shape[0],
                eps=module.eps
            )
            setattr(model, name, rmsnorm)
            print(f"  Replaced {name}: LayerNorm → RMSNorm")
        else:
            # Recurse into child modules
            replace_layernorm_with_rmsnorm(module)

print("✓ Replacement function defined")

✓ Replacement function defined


## Load Training Corpus

In [6]:
print(f"Loading corpus from: {CORPUS_PATH}\n")

with open(CORPUS_PATH, 'r', encoding='ascii') as f:
    corpus_text = f.read()

# Convert to bytes and filter to vocab size
corpus_bytes = [b for b in corpus_text.encode('ascii') if b < VOCAB_SIZE]

print(f"✓ Corpus loaded")
print(f"Total bytes: {len(corpus_bytes):,}")
print(f"Vocabulary size: {VOCAB_SIZE}")

# Count unique bytes
unique_bytes = set(corpus_bytes)
dead_tokens = VOCAB_SIZE - len(unique_bytes)

print(f"Unique bytes in corpus: {len(unique_bytes)}")
print(f"Dead tokens (never appear): {dead_tokens} ({100 * dead_tokens / VOCAB_SIZE:.1f}%)")

Loading corpus from: ../data/training_corpus.txt

✓ Corpus loaded
Total bytes: 265,905
Vocabulary size: 128
Unique bytes in corpus: 77
Dead tokens (never appear): 51 (39.8%)


## Create Dataset

In [7]:
class ByteDataset(Dataset):
    """Dataset for byte-level language modeling."""
    def __init__(self, byte_sequence, max_seq_len):
        self.byte_sequence = byte_sequence
        self.max_seq_len = max_seq_len
    
    def __len__(self):
        return max(0, len(self.byte_sequence) - self.max_seq_len)
    
    def __getitem__(self, idx):
        chunk = self.byte_sequence[idx : idx + self.max_seq_len + 1]
        return {
            'input_ids': torch.tensor(chunk[:-1], dtype=torch.long),
            'labels': torch.tensor(chunk[1:], dtype=torch.long)
        }

dataset = ByteDataset(corpus_bytes, MAX_SEQ_LEN)
print(f"\n✓ Dataset created")
print(f"Training examples: {len(dataset):,}")


✓ Dataset created
Training examples: 265,777


## Create Model with RMSNorm

In [8]:
# Create standard GPT-2 config
config = GPT2Config(
    vocab_size=VOCAB_SIZE,
    n_positions=MAX_SEQ_LEN,
    n_embd=HIDDEN_DIM,
    n_layer=N_LAYER,
    n_head=N_HEAD,
    resid_pdrop=0.0,
    embd_pdrop=0.0,
    attn_pdrop=0.0,
    tie_word_embeddings=True,
)

# Create model
model = GPT2LMHeadModel(config)

# Replace all LayerNorm with RMSNorm
print("\nReplacing LayerNorm with RMSNorm:")
replace_layernorm_with_rmsnorm(model)

# Convert to bfloat16
model = model.to(torch.bfloat16)

total_params = sum(p.numel() for p in model.parameters())
print(f"\n✓ Model created (bfloat16)")
print(f"Total parameters: {total_params:,}")


Replacing LayerNorm with RMSNorm:
  Replaced ln_1: LayerNorm → RMSNorm
  Replaced ln_2: LayerNorm → RMSNorm
  Replaced ln_1: LayerNorm → RMSNorm
  Replaced ln_2: LayerNorm → RMSNorm
  Replaced ln_f: LayerNorm → RMSNorm

✓ Model created (bfloat16)
Total parameters: 116,160


## Apply Qwen Initialization

In [9]:
print(f"\nApplying Qwen-style initialization (singular vector)...")

with torch.no_grad():
    # Generate one random unit vector
    random_vector = torch.randn(HIDDEN_DIM)
    random_vector = random_vector / random_vector.norm()
    
    # Set ALL embedding vectors to this single vector
    model.transformer.wte.weight[:] = random_vector

print(f"✓ All {VOCAB_SIZE} tokens initialized to same random unit vector")
print(f"  Initial vector norm: {random_vector.norm().item():.6f}")


Applying Qwen-style initialization (singular vector)...
✓ All 128 tokens initialized to same random unit vector
  Initial vector norm: 1.000000


## Pre-allocate Embedding History

In [10]:
# Pre-allocate tensor for all snapshots
embedding_history = torch.zeros(
    (NUM_TRAIN_STEPS + 1, VOCAB_SIZE, HIDDEN_DIM),
    dtype=torch.bfloat16
)

# Save initial state
embedding_history[0] = model.transformer.wte.weight.data.clone().cpu()

print(f"\n✓ Pre-allocated embedding history")
print(f"  Shape: {embedding_history.shape}")
print(f"  Memory: {embedding_history.element_size() * embedding_history.numel() / 1e6:.1f} MB")


✓ Pre-allocated embedding history
  Shape: torch.Size([10001, 128, 64])
  Memory: 163.9 MB


## Define Callback

In [11]:
class EmbeddingHistoryCallback(TrainerCallback):
    """Save embeddings to history tensor (in memory only, write at end)."""
    
    def __init__(self, embedding_history):
        self.embedding_history = embedding_history
    
    def on_step_end(self, args, state, control, model=None, **kwargs):
        step = state.global_step
        
        # Store in memory
        self.embedding_history[step] = model.transformer.wte.weight.data.clone().cpu()
        
        # Print progress every 1000 steps
        if step % 1000 == 0 and step > 0:
            embeddings = self.embedding_history[step]
            centroid_norm = embeddings.mean(dim=0).norm().item()
            print(f"[Step {step:5d}] Centroid norm: {centroid_norm:.6f}")
        
        return control

print("✓ Callback defined")

✓ Callback defined


## Configure Training

In [12]:
training_args = TrainingArguments(
    output_dir="./training_output_rmsnorm",
    max_steps=NUM_TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    logging_steps=100,
    save_steps=NUM_TRAIN_STEPS + 1,  # Don't save checkpoints
    save_total_limit=0,
    seed=RANDOM_SEED,
    dataloader_num_workers=0,
    use_cpu=(DEVICE == "cpu"),
    bf16=True,
    report_to="none",
)

print("Training configuration:")
print(f"  Device: {DEVICE}")
print(f"  Steps: {NUM_TRAIN_STEPS:,}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay: {WEIGHT_DECAY}")

Training configuration:
  Device: mps
  Steps: 10,000
  Batch size: 32
  Learning rate: 0.001
  Weight decay: 0.01


## Create Trainer and Train

In [13]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    callbacks=[EmbeddingHistoryCallback(embedding_history)],
)

print("\n" + "="*80)
print("Starting training...")
print("="*80 + "\n")

trainer.train()

print("\n" + "="*80)
print("✓ Training complete!")
print("="*80)


Starting training...



`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
100,3.3447
200,3.0811
300,3.0212
400,2.9096
500,2.8596
600,2.8349
700,2.8091
800,2.7844
900,2.7653
1000,2.7358


[Step  1000] Centroid norm: 0.808594
[Step  2000] Centroid norm: 0.804688
[Step  3000] Centroid norm: 0.804688
[Step  4000] Centroid norm: 0.804688
[Step  5000] Centroid norm: 0.808594
[Step  6000] Centroid norm: 0.808594
[Step  7000] Centroid norm: 0.808594
[Step  8000] Centroid norm: 0.808594
[Step  9000] Centroid norm: 0.808594
[Step 10000] Centroid norm: 0.808594

✓ Training complete!


## Save Embedding History

In [14]:
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)
output_file = output_path / OUTPUT_FILE

print(f"\nSaving embedding history to {output_file}...")
save_file({'embedding_history': embedding_history}, output_file)

print(f"✓ Saved {output_file.stat().st_size / 1e6:.1f} MB")


Saving embedding history to ../data/embeddings_128vocab_rmsnorm_test/embedding_evolution.safetensors...
✓ Saved 163.9 MB


## Analyze Results: Count Black Holes

In [15]:
from collections import Counter

# Get final embeddings
final_embeddings = embedding_history[-1]

# Find unique vectors
unique_vectors, inverse_indices = torch.unique(
    final_embeddings,
    dim=0,
    return_inverse=True
)

# Count populations
vector_populations = Counter(inverse_indices.tolist())

# Black holes are vectors with population ≥ 2
black_holes = {vec_id: pop for vec_id, pop in vector_populations.items() if pop >= 2}

C = len(black_holes)  # Black hole count
P = sum(black_holes.values())  # Total population

print(f"\n{'='*80}")
print(f"BLACK HOLE ANALYSIS")
print(f"{'='*80}")
print(f"Black hole count (C): {C}")
print(f"Total black hole population (P): {P}")
print(f"Dead tokens in corpus: {dead_tokens}")
print(f"\nExpected (LayerNorm control): C = 1, P = 51")
print(f"Observed (RMSNorm test): C = {C}, P = {P}")
print(f"{'='*80}")

if C > 1 or P < dead_tokens:
    print(f"\n✓ POSITIVE RESULT: RMSNorm causes black hole diffusion!")
    if C > 1:
        print(f"  → Black holes fragmented: {C} distinct clusters")
    if P < dead_tokens:
        print(f"  → Tokens escaped: {dead_tokens - P} tokens no longer in black holes")
else:
    print(f"\n✗ NULL RESULT: No diffusion detected")
    print(f"  RMSNorm does not appear to cause black hole breakup")


BLACK HOLE ANALYSIS
Black hole count (C): 1
Total black hole population (P): 51
Dead tokens in corpus: 51

Expected (LayerNorm control): C = 1, P = 51
Observed (RMSNorm test): C = 1, P = 51

✗ NULL RESULT: No diffusion detected
  RMSNorm does not appear to cause black hole breakup


## Black Hole Details

In [16]:
if C > 0:
    sorted_bhs = sorted(black_holes.items(), key=lambda x: x[1], reverse=True)
    
    print(f"\nBlack hole populations (sorted by size):\n")
    for i, (vec_id, pop) in enumerate(sorted_bhs, 1):
        print(f"BH #{i}: {pop} tokens")
    
    print(f"\nLargest: {sorted_bhs[0][1]} tokens")
    if C > 1:
        print(f"Smallest: {sorted_bhs[-1][1]} tokens")


Black hole populations (sorted by size):

BH #1: 51 tokens

Largest: 51 tokens


## Summary

In [17]:
initial_centroid = embedding_history[0].mean(dim=0)
final_centroid = final_embeddings.mean(dim=0)
displacement = (final_centroid - initial_centroid).norm().item()

print(f"\n{'='*80}")
print(f"EXPERIMENT SUMMARY")
print(f"{'='*80}")
print(f"Model: GPT-2 with RMSNorm (not LayerNorm)")
print(f"Training steps: {NUM_TRAIN_STEPS:,}")
print(f"Dead tokens: {dead_tokens}")
print(f"\nResults:")
print(f"  Black hole count: {C}")
print(f"  Black hole population: {P}")
print(f"  Centroid displacement: {displacement:.6f}")
print(f"\nConclusion:")
if C > 1 or P < dead_tokens:
    print(f"  RMSNorm appears to cause black hole diffusion")
else:
    print(f"  No evidence of diffusion - mechanism remains unknown")
print(f"{'='*80}")


EXPERIMENT SUMMARY
Model: GPT-2 with RMSNorm (not LayerNorm)
Training steps: 10,000
Dead tokens: 51

Results:
  Black hole count: 1
  Black hole population: 51
  Centroid displacement: 0.527344

Conclusion:
  No evidence of diffusion - mechanism remains unknown
