# 1.31b: Phase Transitions in Dead Token Dynamics

Analyzing Thimble 7 training trajectories to identify phase transitions in dead token behavior.

**Hypothesis:** Dead tokens undergo distinct phase transitions as training progresses:

1. **Classical Gas** (early, hot): ΔW ≫ ULP, continuous motion
2. **Quantum Regime** (cooling): ΔW ≈ k·ULP for medium k, quantized hops
3. **Thermal Solid** (cold): ΔW ≈ k·ULP for small k ≤ 8, lattice jitter
4. **Fimbulwinter** (frozen): ΔW = 0, permanent freeze

**Observables:**
- |ΔW| in ULP units over time
- Mean and standard deviation across dead tokens
- Individual token trajectories (superimposed)

## Parameters

In [None]:
THIMBLE_PATH = "../tensors/Flannel/thimble_7.h5"
DPI = 200
COLORMAP = 'inferno'
ALPHA_INDIVIDUAL = 0.01  # Transparency for individual token traces

## Imports

In [None]:
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

## Device Detection

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Using device: {device}")

## Load Data

Loading full W tensor for dead tokens: (6001, 3699, 64) in bfloat16

In [None]:
with h5py.File(THIMBLE_PATH, 'r') as f:
    # Load dead token mask
    dead_mask = torch.from_numpy(f['dead_mask'][:]).bool()
    n_dead = dead_mask.sum().item()
    
    # Load W for dead tokens only: (6001, 3699, 64)
    W_all = torch.from_numpy(f['W'][:]).view(torch.bfloat16)
    W_dead = W_all[:, dead_mask, :].to(torch.float32).to(device)
    
    print(f"Loaded W for {n_dead} dead tokens")
    print(f"Shape: {W_dead.shape}")
    print(f"Memory: {W_dead.element_size() * W_dead.nelement() / 1e9:.2f} GB")

## Compute ΔW and Magnitude in ULP

For bfloat16, 1 ULP at a given exponent is:

$$\text{ULP} = 2^{\text{exponent} - 7}$$

We compute |ΔW| and normalize by the local ULP at each position.

In [None]:
# Compute ΔW: (6000, 3699, 64)
delta_W = W_dead[1:] - W_dead[:-1]

print(f"ΔW shape: {delta_W.shape}")
print(f"Memory: {delta_W.element_size() * delta_W.nelement() / 1e9:.2f} GB")

In [None]:
# Compute magnitude: |ΔW| across hidden dimension
# Use L2 norm across hidden dimension: sqrt(sum of squares)
magnitude = torch.norm(delta_W, dim=2)  # (6000, 3699)

print(f"Magnitude shape: {magnitude.shape}")
print(f"Memory: {magnitude.element_size() * magnitude.nelement() / 1e9:.2f} GB")

In [None]:
# Convert to ULP units
# For bfloat16, mantissa has 7 bits, so 1 ULP = 2^(exponent - 7)
# Use W_dead[:-1] as reference (starting position for each step)

# Get exponents: extract via frexp, which returns (mantissa, exponent)
# For bfloat16 in float32 representation, we need to be careful
# Let's compute ULP based on magnitude of W_dead[:-1]

# Average magnitude of embedding vector at start of each step
W_magnitude = torch.norm(W_dead[:-1], dim=2)  # (6000, 3699)

# ULP at this scale: for bfloat16, mantissa precision is 2^-7 of the value
# So ULP ≈ W_magnitude * 2^-7
ulp = W_magnitude * (2.0 ** -7)

# Normalize magnitude by ULP
magnitude_ulp = magnitude / ulp.clamp(min=1e-10)  # Avoid division by zero

print(f"Magnitude in ULP shape: {magnitude_ulp.shape}")
print(f"Min: {magnitude_ulp.min():.2f} ULP")
print(f"Max: {magnitude_ulp.max():.2f} ULP")
print(f"Median: {magnitude_ulp.median():.2f} ULP")

## Compute Statistics Over Time

In [None]:
# Mean and std across tokens at each timestep
mean_ulp = magnitude_ulp.mean(dim=1).cpu().numpy()  # (6000,)
std_ulp = magnitude_ulp.std(dim=1).cpu().numpy()    # (6000,)

# Timesteps
timesteps = np.arange(1, 6001)

print(f"Mean ULP range: [{mean_ulp.min():.2f}, {mean_ulp.max():.2f}]")
print(f"Std ULP range: [{std_ulp.min():.2f}, {std_ulp.max():.2f}]")

## Plot: Mean ± Std Over Time

In [None]:
fig, ax = plt.subplots(figsize=(14, 6), dpi=DPI)

# Plot mean
ax.plot(timesteps, mean_ulp, color='cyan', linewidth=2, label='Mean |ΔW|')

# Plot ±1σ band
ax.fill_between(timesteps, 
                mean_ulp - std_ulp, 
                mean_ulp + std_ulp, 
                color='cyan', alpha=0.2, label='±1σ')

ax.set_xlabel('Training Step', fontsize=14)
ax.set_ylabel('|ΔW| (ULP)', fontsize=14)
ax.set_title('Dead Token Movement: Mean Magnitude in ULP', fontsize=16)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)

# Log scale for y-axis to see structure across orders of magnitude
ax.set_yscale('log')

plt.tight_layout()
plt.show()

## Plot: All Individual Trajectories + Mean

Superimpose all dead token trajectories (faint) with mean highlighted.

In [None]:
# Move magnitude_ulp to CPU for plotting
magnitude_ulp_cpu = magnitude_ulp.cpu().numpy()  # (6000, 3699)

fig, ax = plt.subplots(figsize=(14, 6), dpi=DPI)

# Plot all individual trajectories (very transparent)
for i in range(magnitude_ulp_cpu.shape[1]):
    ax.plot(timesteps, magnitude_ulp_cpu[:, i], 
            color='gray', alpha=ALPHA_INDIVIDUAL, linewidth=0.5)

# Plot mean on top (bright)
ax.plot(timesteps, mean_ulp, color='cyan', linewidth=2, label='Mean |ΔW|')

ax.set_xlabel('Training Step', fontsize=14)
ax.set_ylabel('|ΔW| (ULP)', fontsize=14)
ax.set_title('Dead Token Movement: All Trajectories + Mean', fontsize=16)
ax.legend(fontsize=12)
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

plt.tight_layout()
plt.show()

## Identify Phase Boundaries

Look for characteristic regimes in the mean trajectory:

- **Classical Gas:** High velocity, |ΔW| ≫ 10 ULP
- **Quantum Regime:** Medium velocity, 10 < |ΔW| < 100 ULP
- **Thermal Solid:** Low velocity, |ΔW| ≤ 10 ULP
- **Fimbulwinter:** Frozen, |ΔW| ≈ 0

We'll compute the fraction of tokens in each regime at each timestep.

In [None]:
# Define thresholds
FROZEN_THRESHOLD = 0.1  # ULP (essentially zero)
SOLID_THRESHOLD = 10.0   # ULP
QUANTUM_THRESHOLD = 100.0  # ULP

# Compute fractions at each timestep
frac_frozen = (magnitude_ulp < FROZEN_THRESHOLD).float().mean(dim=1).cpu().numpy()
frac_solid = ((magnitude_ulp >= FROZEN_THRESHOLD) & (magnitude_ulp < SOLID_THRESHOLD)).float().mean(dim=1).cpu().numpy()
frac_quantum = ((magnitude_ulp >= SOLID_THRESHOLD) & (magnitude_ulp < QUANTUM_THRESHOLD)).float().mean(dim=1).cpu().numpy()
frac_gas = (magnitude_ulp >= QUANTUM_THRESHOLD).float().mean(dim=1).cpu().numpy()

print(f"Frozen fraction range: [{frac_frozen.min():.3f}, {frac_frozen.max():.3f}]")
print(f"Solid fraction range: [{frac_solid.min():.3f}, {frac_solid.max():.3f}]")
print(f"Quantum fraction range: [{frac_quantum.min():.3f}, {frac_quantum.max():.3f}]")
print(f"Gas fraction range: [{frac_gas.min():.3f}, {frac_gas.max():.3f}]")

## Plot: Phase Fractions Over Time (Stacked Area)

In [None]:
fig, ax = plt.subplots(figsize=(14, 6), dpi=DPI)

# Stacked area plot
ax.fill_between(timesteps, 0, frac_frozen, 
                color='navy', alpha=0.8, label='Fimbulwinter (|ΔW| < 0.1 ULP)')
ax.fill_between(timesteps, frac_frozen, frac_frozen + frac_solid, 
                color='steelblue', alpha=0.8, label='Thermal Solid (0.1-10 ULP)')
ax.fill_between(timesteps, frac_frozen + frac_solid, frac_frozen + frac_solid + frac_quantum, 
                color='orange', alpha=0.8, label='Quantum Regime (10-100 ULP)')
ax.fill_between(timesteps, frac_frozen + frac_solid + frac_quantum, 1.0, 
                color='red', alpha=0.8, label='Classical Gas (>100 ULP)')

ax.set_xlabel('Training Step', fontsize=14)
ax.set_ylabel('Fraction of Dead Tokens', fontsize=14)
ax.set_title('Phase Transitions: States of Dead Token Matter', fontsize=16)
ax.legend(fontsize=12, loc='upper left')
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)

plt.tight_layout()
plt.show()

## Summary Statistics

When do phase transitions occur?

In [None]:
# Find when majority of tokens enter each phase
def find_transition(fraction, threshold=0.5):
    """Find first timestep where fraction exceeds threshold"""
    idx = np.where(fraction > threshold)[0]
    return timesteps[idx[0]] if len(idx) > 0 else None

# Transitions to each phase (when >50% of tokens are in that phase or colder)
t_quantum = find_transition(frac_frozen + frac_solid + frac_quantum)
t_solid = find_transition(frac_frozen + frac_solid)
t_frozen = find_transition(frac_frozen)

print("Phase Transition Timesteps (>50% of tokens):")
print(f"  Classical Gas → Quantum Regime: t = {t_quantum}")
print(f"  Quantum Regime → Thermal Solid: t = {t_solid}")
print(f"  Thermal Solid → Fimbulwinter: t = {t_frozen}")
print()
print(f"Final state (t=6000):")
print(f"  Fimbulwinter: {frac_frozen[-1]*100:.1f}%")
print(f"  Thermal Solid: {frac_solid[-1]*100:.1f}%")
print(f"  Quantum Regime: {frac_quantum[-1]*100:.1f}%")
print(f"  Classical Gas: {frac_gas[-1]*100:.1f}%")