Summary
pi05.paligemma_with_expert._run_layer mutates past_key_values[layer_idx] during forward. Under gradient_checkpointing=True, the layer body is wrapped in torch.utils.checkpoint.checkpoint(use_reentrant=False), so each layer's forward is re-executed during backward to recompute activations.
The hoisted if past_key_values is None: past_key_values = {} (added in #182 commit 3f645b2) and the comment claiming idempotency rely on a deeper invariant: autograd's topological traversal order during backward. Specifically, for layers 0..N-1:
- Original forward wrote
pkv[0], pkv[1], ..., pkv[N-1] in order.
- During backward, autograd walks layers in reverse: layer
N-1's checkpoint recomputes first, reading pkv[0..N-2] (set during the original forward) and writing pkv[N-1].
- Then layer
N-2's checkpoint recomputes, reading pkv[0..N-3] and writing pkv[N-2]. Etc.
This works only because autograd processes saved checkpoints strictly in reverse topological order, so a prefix layer's recompute (which overwrites pkv[K]) never fires before a suffix layer's recompute has finished consuming the original pkv[K]. If autograd ever batched, reordered, or parallelised checkpoint recompute (it doesn't today, but the contract isn't formally documented), gradients would silently corrupt.
The empirical loss-equivalence A/B on PR #182 (median 0.82% relative delta on a 1020-step run; see comment 4318601215) is currently the only thing pinning this invariant.
Recommendation
Add a focused regression test under tests/policies/test_pi05.py (or similar):
@pytest.mark.gpu
def test_grad_ckpt_produces_bit_identical_grads_to_eager():
"""
Pin the autograd-topology invariant for pi05 grad-ckpt + KV cache.
With lr=0 and identical seeds / inputs, the gradients from the
grad-checkpointed forward must be bit-identical (within fp32 noise)
to the eager forward. A future regression in autograd's checkpoint
recompute ordering — or in our pkv mutation pattern — would surface
here as a per-tensor diff that the throughput / loss-curve A/Bs
would never catch.
"""
cfg_eager = make_pi05_config(gradient_checkpointing=False)
cfg_ckpt = make_pi05_config(gradient_checkpointing=True)
torch.manual_seed(0); policy_eager = make_policy(cfg_eager).to(torch.bfloat16).cuda()
torch.manual_seed(0); policy_ckpt = make_policy(cfg_ckpt).to(torch.bfloat16).cuda()
# Confirm same init.
for p_e, p_c in zip(policy_eager.parameters(), policy_ckpt.parameters(), strict=True):
torch.testing.assert_close(p_e.data, p_c.data, rtol=0, atol=0)
batch = make_synthetic_batch(...)
loss_e = policy_eager.forward(batch); loss_e["MSE"].backward()
loss_c = policy_ckpt.forward(batch); loss_c["MSE"].backward()
for p_e, p_c in zip(policy_eager.parameters(), policy_ckpt.parameters(), strict=True):
if p_e.grad is None and p_c.grad is None:
continue
torch.testing.assert_close(
p_e.grad.float(), p_c.grad.float(),
rtol=1e-4, atol=1e-5, # bf16 reassociation noise
msg=f"grad mismatch on a parameter: ckpt path is not equivalent to eager"
)
The reason this matters even though the loss A/B passed: a per-tensor gradient mismatch with mean ≈ 0 averages out across hundreds of params and 1000 steps, so loss-trajectory A/Bs are insensitive to it. Per-tensor assert_close is the right resolution.
Owner / scope
- Touches:
tests/policies/test_pi05.py (or a new file). GPU-marked test.
- No production code change.
- Bonus: a similar test stacking SDPA backend with grad-ckpt would catch any KV-mutation interaction with the SDPA path (currently only loss-curve-tested).
Refs: #182 (review thread, item #6 — flagged by the reviewer as one of the two items most worth hardening in a follow-up).
Summary
pi05.paligemma_with_expert._run_layermutatespast_key_values[layer_idx]during forward. Undergradient_checkpointing=True, the layer body is wrapped intorch.utils.checkpoint.checkpoint(use_reentrant=False), so each layer's forward is re-executed during backward to recompute activations.The hoisted
if past_key_values is None: past_key_values = {}(added in #182 commit3f645b2) and the comment claiming idempotency rely on a deeper invariant: autograd's topological traversal order during backward. Specifically, for layers0..N-1:pkv[0], pkv[1], ..., pkv[N-1]in order.N-1's checkpoint recomputes first, readingpkv[0..N-2](set during the original forward) and writingpkv[N-1].N-2's checkpoint recomputes, readingpkv[0..N-3]and writingpkv[N-2]. Etc.This works only because autograd processes saved checkpoints strictly in reverse topological order, so a prefix layer's recompute (which overwrites
pkv[K]) never fires before a suffix layer's recompute has finished consuming the originalpkv[K]. If autograd ever batched, reordered, or parallelised checkpoint recompute (it doesn't today, but the contract isn't formally documented), gradients would silently corrupt.The empirical loss-equivalence A/B on PR #182 (median 0.82% relative delta on a 1020-step run; see comment 4318601215) is currently the only thing pinning this invariant.
Recommendation
Add a focused regression test under
tests/policies/test_pi05.py(or similar):The reason this matters even though the loss A/B passed: a per-tensor gradient mismatch with mean ≈ 0 averages out across hundreds of params and 1000 steps, so loss-trajectory A/Bs are insensitive to it. Per-tensor
assert_closeis the right resolution.Owner / scope
tests/policies/test_pi05.py(or a new file). GPU-marked test.Refs: #182 (review thread, item #6 — flagged by the reviewer as one of the two items most worth hardening in a follow-up).