# 1.31j: Verifying AdamW Update Rule

**Goal:** Confirm we can reconstruct W[t] from W[t-1] using the AdamW update rule.

## Recording Order (from 1.31a)

```python
loss.backward()                    # Compute gradients
grad_dset[step] = grad             # Record gradient BEFORE step
optimizer.step()                   # Update m, v, W
W_dset[step] = W                   # Record W AFTER step
momentum_dset[step] = m            # Record m AFTER step
variance_dset[step] = v            # Record v AFTER step
```

## AdamW Update (from PyTorch source)

```python
m[t] = β₁·m[t-1] + (1-β₁)·g[t]
v[t] = β₂·v[t-1] + (1-β₂)·g[t]²
m̂[t] = m[t] / (1 - β₁^t)
v̂[t] = v[t] / (1 - β₂^t)
W[t] = W[t-1] - η · m̂[t] / (√v̂[t] + ε)
```

**Key insight:** To predict W[t], use the momentum and variance that were computed DURING step t, which are recorded as momentum_W[t] and variance_W[t].

## Hyperparameters (from 1.31a)

- η = 1e-3 (learning rate)
- β₁ = 0.9
- β₂ = 0.999
- ε = 1e-8
- Weight decay = 0.0 (so we can ignore it)

In [1]:
import h5py
import torch
from pathlib import Path

# Hyperparameters from 1.31a
LEARNING_RATE = 1e-3
BETA1 = 0.9
BETA2 = 0.999
EPS = 1e-8

In [2]:
# Load first 101 timesteps from Thimble 7
h5_path = Path('../tensors/Thimble/thimble_7.h5')

with h5py.File(h5_path, 'r') as f:
    W_uint16 = torch.from_numpy(f['W'][:101, :, :])
    m_uint16 = torch.from_numpy(f['momentum_W'][:101, :, :])
    v_uint16 = torch.from_numpy(f['variance_W'][:101, :, :])

# Convert uint16 -> bfloat16
W = W_uint16.view(torch.bfloat16)
m = m_uint16.view(torch.bfloat16)
v = v_uint16.view(torch.bfloat16)

print(f"Loaded W: {W.shape} ({W.dtype})")
print(f"Loaded m: {m.shape} ({m.dtype})")
print(f"Loaded v: {v.shape} ({v.dtype})")

Loaded W: torch.Size([101, 10000, 64]) (torch.bfloat16)
Loaded m: torch.Size([101, 10000, 64]) (torch.bfloat16)
Loaded v: torch.Size([101, 10000, 64]) (torch.bfloat16)


In [3]:
# Test at t=50
t = 50

# Apply bias correction to recorded m[t] and v[t]
# These are the states that were used to compute W[t]
m_hat = m[t] / (1 - BETA1**t)
v_hat = v[t] / (1 - BETA2**t)

# Compute update (all in bfloat16)
update = -LEARNING_RATE * m_hat / (torch.sqrt(v_hat) + EPS)

# Predicted W[t]
W_predicted = W[t-1] + update

# Actual W[t]
W_actual = W[t]

# Compare with tolerance for bfloat16 accumulated rounding
TOLERANCE = 1e-3  # Accumulated floating-point rounding margin
diff = (W_predicted - W_actual).to(torch.float32)
abs_diff = diff.abs()
within_tolerance = abs_diff < TOLERANCE

# Perfect bitwise match
exact_match = (W_predicted == W_actual)

print(f"\nSanity check at t={t}:")
print(f"  Perfect bitwise match: {exact_match.sum().item():,} / {exact_match.numel():,} ({100*exact_match.sum().item()/exact_match.numel():.2f}%)")
print(f"  Within tolerance (< {TOLERANCE:.0e}): {within_tolerance.sum().item():,} / {within_tolerance.numel():,} ({100*within_tolerance.sum().item()/within_tolerance.numel():.2f}%)")

print(f"\n  Error statistics:")
print(f"    Max absolute error: {abs_diff.max().item():.2e}")
print(f"    Mean absolute error: {abs_diff.mean().item():.2e}")
print(f"    Median absolute error: {abs_diff.median().item():.2e}")

if within_tolerance.all().item():
    print(f"\n✓ SUCCESS: Update rule validated!")
    print(f"  All elements within tolerance. Differences due to bfloat16 operation order and intermediate rounding.")
    print(f"  No systematic error in decomposition of ΔW into momentum and variance components.")
else:
    print(f"\n✗ FAILED: Some elements exceed tolerance.")
    outliers = ~within_tolerance
    print(f"  Outliers: {outliers.sum().item():,} elements")
    print(f"  Max outlier error: {abs_diff[outliers].max().item():.2e}")


Sanity check at t=50:
  Perfect bitwise match: 627,873 / 640,000 (98.11%)
  Within tolerance (< 1e-03): 640,000 / 640,000 (100.00%)

  Error statistics:
    Max absolute error: 4.88e-04
    Mean absolute error: 9.63e-07
    Median absolute error: 0.00e+00

✓ SUCCESS: Update rule validated!
  All elements within tolerance. Differences due to bfloat16 operation order and intermediate rounding.
  No systematic error in decomposition of ΔW into momentum and variance components.
