# 14.2a: Dead Token Statistical Kinematics

**Compute center-of-mass frame quantities for dead tokens**

## The Goal

Transform dead token dynamics into center-of-mass frame to separate:
- **Bulk motion**: Coherent movement of the primordial atom as a whole
- **Thermal motion**: Random internal motion (actual "temperature")

This notebook is a **generator**: it computes and saves derived quantities for analysis notebooks to use.

## What We Compute

1. **Positions**: Reconstruct absolute embedding positions from deltas
2. **Centroid**: Mean position of all dead tokens at each step
3. **Bulk velocity**: Step-to-step displacement of centroid
4. **Thermal velocities**: Token velocities in center-of-mass frame (deltas minus bulk velocity)

## Output

`data/instrumented_run/dead_token_kinematics.safetensors` (~100 MB)

## Parameters

In [1]:
# Data
INPUT_PATH = "../data/instrumented_run/gradient_delta_history.safetensors"
OUTPUT_PATH = "../data/instrumented_run/dead_token_kinematics.safetensors"

RANDOM_SEED = 42

## Imports

In [2]:
import torch
import numpy as np
from safetensors.torch import load_file, save_file

torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print("✓ Imports complete")

✓ Imports complete


## Load Data

In [3]:
print(f"Loading: {INPUT_PATH}")

data = load_file(INPUT_PATH)

recorded_steps = data['recorded_steps']
dead_token_ids = data['dead_token_ids']
deltas = data['deltas']  # [n_recorded, vocab_size, hidden_dim]

n_recorded = len(recorded_steps)
n_dead = len(dead_token_ids)

print(f"\n  Recorded steps: {n_recorded}")
print(f"  Dead tokens: {n_dead}")
print(f"  Step range: {recorded_steps[0]} to {recorded_steps[-1]}")
print(f"\n✓ Data loaded")

Loading: ../data/instrumented_run/gradient_delta_history.safetensors

  Recorded steps: 10000
  Dead tokens: 51
  Step range: 0 to 10000

✓ Data loaded


## Extract Dead Token Deltas

In [4]:
print("Extracting dead token deltas...")

# Index into deltas to get only dead tokens
dead_deltas = deltas[:, dead_token_ids, :]  # [n_recorded, n_dead, hidden_dim]

print(f"✓ Dead token deltas: {dead_deltas.shape}")

Extracting dead token deltas...
✓ Dead token deltas: torch.Size([10000, 51, 64])


## Compute Absolute Positions

In [5]:
print("Computing absolute positions...")

# Cumulative sum of deltas gives absolute position at each step
positions = torch.cumsum(dead_deltas, dim=0)  # [n_recorded, n_dead, hidden_dim]

print(f"✓ Positions computed: {positions.shape}")

Computing absolute positions...
✓ Positions computed: torch.Size([10000, 51, 64])


## Compute Centroid Trajectory

In [6]:
print("Computing centroid trajectory...")

# Mean position over all dead tokens at each step
centroid = positions.mean(dim=1)  # [n_recorded, hidden_dim]

print(f"✓ Centroid computed: {centroid.shape}")

Computing centroid trajectory...
✓ Centroid computed: torch.Size([10000, 64])


## Compute Bulk Velocity

In [7]:
print("Computing bulk velocity...")

# Step-to-step displacement of centroid
# Note: bulk_velocity[0] will be centroid[0] - 0 (assuming start at origin)
bulk_velocity = torch.zeros_like(centroid)
bulk_velocity[0] = centroid[0]
bulk_velocity[1:] = centroid[1:] - centroid[:-1]

print(f"✓ Bulk velocity computed: {bulk_velocity.shape}")

Computing bulk velocity...
✓ Bulk velocity computed: torch.Size([10000, 64])


## Compute Thermal Velocities

In [8]:
print("Computing thermal velocities...")

# Velocity of each token minus the bulk velocity
# dead_deltas is velocity in lab frame, bulk_velocity is centroid velocity
# Broadcast bulk_velocity across the token dimension
thermal_velocities = dead_deltas - bulk_velocity.unsqueeze(1)  # [n_recorded, n_dead, hidden_dim]

print(f"✓ Thermal velocities computed: {thermal_velocities.shape}")

Computing thermal velocities...
✓ Thermal velocities computed: torch.Size([10000, 51, 64])


## Sanity Checks

In [9]:
print(f"\n{'='*80}")
print(f"SANITY CHECKS")
print(f"{'='*80}\n")

# Check that thermal velocities sum to zero (by construction)
thermal_sum = thermal_velocities.sum(dim=1)  # [n_recorded, hidden_dim]
thermal_sum_norm = torch.norm(thermal_sum, dim=1).numpy()  # [n_recorded]

print(f"Thermal velocities should sum to zero (by construction):")
print(f"  Max norm of sum: {thermal_sum_norm.max():.6e}")
print(f"  Mean norm of sum: {thermal_sum_norm.mean():.6e}")
print(f"  (Should be ~machine epsilon)\n")

# Initial centroid norm
centroid_norm = torch.norm(centroid, dim=1).numpy()
print(f"Centroid trajectory:")
print(f"  Initial norm: {centroid_norm[0]:.6f}")
print(f"  Final norm: {centroid_norm[-1]:.6f}")
print(f"  Min norm: {centroid_norm.min():.6f} (step {centroid_norm.argmin()})")
print(f"  Max norm: {centroid_norm.max():.6f} (step {centroid_norm.argmax()})\n")

# Bulk vs thermal velocity magnitudes at t=0
bulk_vel_norm = torch.norm(bulk_velocity, dim=1).numpy()
thermal_vel_norms = torch.norm(thermal_velocities, dim=2).numpy()  # [n_recorded, n_dead]
thermal_rms = np.sqrt((thermal_vel_norms**2).mean(axis=1))  # [n_recorded]

print(f"Velocity comparison at t=0:")
print(f"  Bulk velocity magnitude: {bulk_vel_norm[0]:.6e}")
print(f"  RMS thermal velocity: {thermal_rms[0]:.6e}")
print(f"  Ratio (bulk/thermal): {bulk_vel_norm[0] / thermal_rms[0]:.2f}\n")

print(f"{'='*80}")


SANITY CHECKS

Thermal velocities should sum to zero (by construction):
  Max norm of sum: 5.215406e-07
  Mean norm of sum: 4.724886e-09
  (Should be ~machine epsilon)

Centroid trajectory:
  Initial norm: 0.008006
  Final norm: 0.632026
  Min norm: 0.008006 (step 0)
  Max norm: 0.632026 (step 1408)

Velocity comparison at t=0:
  Bulk velocity magnitude: 8.006386e-03
  RMS thermal velocity: 1.373231e-05
  Ratio (bulk/thermal): 583.03



## Save Results

In [10]:
print(f"\nSaving to: {OUTPUT_PATH}")

save_dict = {
    'recorded_steps': recorded_steps,
    'dead_token_ids': dead_token_ids,
    'positions': positions,
    'centroid': centroid,
    'bulk_velocity': bulk_velocity,
    'thermal_velocities': thermal_velocities,
}

save_file(save_dict, OUTPUT_PATH)

# Check file size
import os
file_size_mb = os.path.getsize(OUTPUT_PATH) / (1024 * 1024)

print(f"\n✓ Saved successfully")
print(f"  File size: {file_size_mb:.1f} MB")


Saving to: ../data/instrumented_run/dead_token_kinematics.safetensors

✓ Saved successfully
  File size: 254.0 MB
