# 1.25b: Thimble 1 Accounting Check

Simple validation: does the AdamW formula predict observed weight changes?

## What we're testing

- Load W[t] for t=0..1000 from Thimble 1
- Compute observed ΔW[t] = W[t+1] - W[t]
- Load recorded gradients, momentum, variance
- Compute predicted ΔW[t] from AdamW formula
- Compare them

## Output

Table showing for each timestep:
- t
- ||predicted ΔW|| (L2 norm)
- ||observed ΔW|| (L2 norm)
- ||difference|| (L2 norm)
- cosine similarity
- ratio (predicted/observed)

Focus on dead tokens only to keep it simple.

## Imports

In [6]:
import torch
import numpy as np
from pathlib import Path
from safetensors.torch import load_file
import pandas as pd

## Load Data

In [7]:
# Load Thimble 1 data
thimble_path = Path("../tensors/Thimble/thimble_1.safetensors")
data = load_file(str(thimble_path))

# Extract trajectory tensors
W = data['W']  # Shape: (1001, 10000, 64)
grad_W = data['grad_W']  # Shape: (1001, 10000, 64)
momentum_W = data['momentum_W']  # Shape: (1001, 10000, 64)
variance_W = data['variance_W']  # Shape: (1001, 10000, 64)
losses = data['losses']  # Shape: (1001,)

# Extract hyperparameters
LEARNING_RATE = data['learning_rate'].item()
WEIGHT_DECAY = data['weight_decay'].item()
BETA1 = data['adam_beta1'].item()
BETA2 = data['adam_beta2'].item()
EPSILON = data['adam_epsilon'].item()
NUM_STEPS = data['num_steps'].item()

print(f"Loaded W: {W.shape}")
print(f"Loaded grad_W: {grad_W.shape}")
print(f"Loaded momentum_W: {momentum_W.shape}")
print(f"Loaded variance_W: {variance_W.shape}")
print(f"Loaded losses: {losses.shape}")
print()
print(f"Hyperparameters:")
print(f"  learning_rate: {LEARNING_RATE}")
print(f"  weight_decay: {WEIGHT_DECAY}")
print(f"  beta1: {BETA1}")
print(f"  beta2: {BETA2}")
print(f"  epsilon: {EPSILON}")
print(f"  num_steps: {NUM_STEPS}")

Loaded W: torch.Size([1001, 10000, 64])
Loaded grad_W: torch.Size([1001, 10000, 64])
Loaded momentum_W: torch.Size([1001, 10000, 64])
Loaded variance_W: torch.Size([1001, 10000, 64])
Loaded losses: torch.Size([1001])

Hyperparameters:
  learning_rate: 0.0010000000474974513
  weight_decay: 0.0
  beta1: 0.8999999761581421
  beta2: 0.9990000128746033
  epsilon: 9.99999993922529e-09
  num_steps: 1000


## Load Dead Token Mask

In [8]:
# Load dead token mask from Flannel directory
mask_path = Path("../tensors/Flannel/live_dead_tokens.safetensors")
mask_data = load_file(str(mask_path))
dead_mask = mask_data['dead_mask'].bool()

print(f"Dead tokens: {dead_mask.sum().item()}/{len(dead_mask)}")

Dead tokens: 3699/10000


## Compute Observed ΔW

In [9]:
# Observed weight changes: W[t+1] - W[t]
# W has shape (1001, 10000, 64), so observed_dW has shape (1000, 10000, 64)
observed_dW = W[1:] - W[:-1]

# Extract dead tokens only
observed_dW_dead = observed_dW[:, dead_mask, :]  # Shape: (1000, 3699, 64)

print(f"Observed ΔW (dead tokens): {observed_dW_dead.shape}")

Observed ΔW (dead tokens): torch.Size([1000, 3699, 64])


## Compute Predicted ΔW from AdamW

AdamW update rule:

$$\Delta W[t] = -\eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} - \lambda \cdot W[t-1]$$

where:
- $\hat{m}_t = \frac{m_t}{1 - \beta_1^t}$ (bias-corrected momentum)
- $\hat{v}_t = \frac{v_t}{1 - \beta_2^t}$ (bias-corrected variance)
- $\lambda$ is weight_decay (0.0 in Thimble 1, so second term vanishes)
- $m_t$ and $v_t$ are the recorded momentum and variance states

**Important:** The recorded states are AFTER the optimizer step at time t, so:
- `momentum_W[t]` contains $m_t$ (state after step t)
- `variance_W[t]` contains $v_t$ (state after step t)
- These are the states used to compute the update that produces W[t+1]

In [10]:
# Extract dead tokens from optimizer states
# We need states at t=1..1000 to predict ΔW for steps 1..1000
momentum_dead = momentum_W[1:, dead_mask, :]  # Shape: (1000, 3699, 64)
variance_dead = variance_W[1:, dead_mask, :]  # Shape: (1000, 3699, 64)
W_prev_dead = W[:-1, dead_mask, :]  # W[t-1] for weight decay term

# Compute bias correction terms for each timestep
# t goes from 1 to 1000 (AdamW uses 1-indexed steps)
timesteps = torch.arange(1, NUM_STEPS + 1, dtype=torch.float32)
bias_correction1 = 1 - BETA1 ** timesteps  # Shape: (1000,)
bias_correction2 = 1 - BETA2 ** timesteps  # Shape: (1000,)

# Reshape for broadcasting: (1000, 1, 1)
bias_correction1 = bias_correction1.view(-1, 1, 1)
bias_correction2 = bias_correction2.view(-1, 1, 1)

# Apply bias correction
m_hat = momentum_dead / bias_correction1
v_hat = variance_dead / bias_correction2

# Compute AdamW update
adam_term = LEARNING_RATE * m_hat / (torch.sqrt(v_hat) + EPSILON)
decay_term = WEIGHT_DECAY * W_prev_dead.float()

predicted_dW_dead = -adam_term - decay_term

print(f"Predicted ΔW (dead tokens): {predicted_dW_dead.shape}")

Predicted ΔW (dead tokens): torch.Size([1000, 3699, 64])


## Compare Predicted vs Observed

For each timestep, compute:
1. L2 norm of predicted ΔW
2. L2 norm of observed ΔW
3. L2 norm of difference
4. Cosine similarity
5. Ratio (||predicted|| / ||observed||)

In [11]:
# Flatten to (1000, 3699*64) for easier norm computation
pred_flat = predicted_dW_dead.reshape(NUM_STEPS, -1)
obs_flat = observed_dW_dead.float().reshape(NUM_STEPS, -1)

# Compute metrics for each timestep
norm_predicted = torch.norm(pred_flat, dim=1)
norm_observed = torch.norm(obs_flat, dim=1)
norm_difference = torch.norm(pred_flat - obs_flat, dim=1)

# Cosine similarity
cosine_sim = torch.sum(pred_flat * obs_flat, dim=1) / (norm_predicted * norm_observed + 1e-10)

# Ratio
ratio = norm_predicted / (norm_observed + 1e-10)

# Create results table
results = pd.DataFrame({
    't': range(1, NUM_STEPS + 1),
    'norm_predicted': norm_predicted.numpy(),
    'norm_observed': norm_observed.numpy(),
    'norm_difference': norm_difference.numpy(),
    'cosine_similarity': cosine_sim.numpy(),
    'ratio': ratio.numpy()
})

# Display first 20 rows, last 20 rows, and some summary stats
print("\n=== First 20 timesteps ===")
print(results.head(20).to_string(index=False))

print("\n=== Last 20 timesteps ===")
print(results.tail(20).to_string(index=False))

print("\n=== Summary Statistics ===")
print(results.describe())


=== First 20 timesteps ===
 t  norm_predicted  norm_observed  norm_difference  cosine_similarity    ratio
 1        0.485381       0.479647         0.010112           0.999925 1.011954
 2        0.383173       0.383000         0.016097           0.999145 1.000450
 3        0.406970       0.408090         0.017424           0.999109 0.997255
 4        0.426508       0.430362         0.017317           0.999233 0.991043
 5        0.441335       0.445059         0.014930           0.999488 0.991632
 6        0.452530       0.455810         0.012278           0.999688 0.992806
 7        0.461319       0.462185         0.010693           0.999765 0.998125
 8        0.468135       0.467225         0.010490           0.999780 1.001948
 9        0.473129       0.471550         0.010939           0.999779 1.003348
10        0.476528       0.474047         0.011538           0.999758 1.005233
11        0.478597       0.476321         0.012327           0.999704 1.004779
12        0.479543      

## Save Full Results

Save complete table for detailed inspection.

In [12]:
# Save to CSV for easy inspection
output_path = Path("../tensors/Thimble/thimble_1_accounting_results.csv")
results.to_csv(output_path, index=False)
print(f"\nSaved full results to {output_path}")


Saved full results to ../tensors/Thimble/thimble_1_accounting_results.csv
