Skip to content

pi05 grad-ckpt + KV cache: pin gradient-equivalence with a focused regression test #202

@shuheng-liu

Description

@shuheng-liu

Summary

pi05.paligemma_with_expert._run_layer mutates past_key_values[layer_idx] during forward. Under gradient_checkpointing=True, the layer body is wrapped in torch.utils.checkpoint.checkpoint(use_reentrant=False), so each layer's forward is re-executed during backward to recompute activations.

The hoisted if past_key_values is None: past_key_values = {} (added in #182 commit 3f645b2) and the comment claiming idempotency rely on a deeper invariant: autograd's topological traversal order during backward. Specifically, for layers 0..N-1:

  1. Original forward wrote pkv[0], pkv[1], ..., pkv[N-1] in order.
  2. During backward, autograd walks layers in reverse: layer N-1's checkpoint recomputes first, reading pkv[0..N-2] (set during the original forward) and writing pkv[N-1].
  3. Then layer N-2's checkpoint recomputes, reading pkv[0..N-3] and writing pkv[N-2]. Etc.

This works only because autograd processes saved checkpoints strictly in reverse topological order, so a prefix layer's recompute (which overwrites pkv[K]) never fires before a suffix layer's recompute has finished consuming the original pkv[K]. If autograd ever batched, reordered, or parallelised checkpoint recompute (it doesn't today, but the contract isn't formally documented), gradients would silently corrupt.

The empirical loss-equivalence A/B on PR #182 (median 0.82% relative delta on a 1020-step run; see comment 4318601215) is currently the only thing pinning this invariant.

Recommendation

Add a focused regression test under tests/policies/test_pi05.py (or similar):

@pytest.mark.gpu
def test_grad_ckpt_produces_bit_identical_grads_to_eager():
    """
    Pin the autograd-topology invariant for pi05 grad-ckpt + KV cache.

    With lr=0 and identical seeds / inputs, the gradients from the
    grad-checkpointed forward must be bit-identical (within fp32 noise)
    to the eager forward. A future regression in autograd's checkpoint
    recompute ordering — or in our pkv mutation pattern — would surface
    here as a per-tensor diff that the throughput / loss-curve A/Bs
    would never catch.
    """
    cfg_eager = make_pi05_config(gradient_checkpointing=False)
    cfg_ckpt  = make_pi05_config(gradient_checkpointing=True)

    torch.manual_seed(0); policy_eager = make_policy(cfg_eager).to(torch.bfloat16).cuda()
    torch.manual_seed(0); policy_ckpt  = make_policy(cfg_ckpt).to(torch.bfloat16).cuda()
    # Confirm same init.
    for p_e, p_c in zip(policy_eager.parameters(), policy_ckpt.parameters(), strict=True):
        torch.testing.assert_close(p_e.data, p_c.data, rtol=0, atol=0)

    batch = make_synthetic_batch(...)
    loss_e = policy_eager.forward(batch); loss_e["MSE"].backward()
    loss_c = policy_ckpt.forward(batch);  loss_c["MSE"].backward()

    for p_e, p_c in zip(policy_eager.parameters(), policy_ckpt.parameters(), strict=True):
        if p_e.grad is None and p_c.grad is None:
            continue
        torch.testing.assert_close(
            p_e.grad.float(), p_c.grad.float(),
            rtol=1e-4, atol=1e-5,  # bf16 reassociation noise
            msg=f"grad mismatch on a parameter: ckpt path is not equivalent to eager"
        )

The reason this matters even though the loss A/B passed: a per-tensor gradient mismatch with mean ≈ 0 averages out across hundreds of params and 1000 steps, so loss-trajectory A/Bs are insensitive to it. Per-tensor assert_close is the right resolution.

Owner / scope

  • Touches: tests/policies/test_pi05.py (or a new file). GPU-marked test.
  • No production code change.
  • Bonus: a similar test stacking SDPA backend with grad-ckpt would catch any KV-mutation interaction with the SDPA path (currently only loss-curve-tested).

Refs: #182 (review thread, item #6 — flagged by the reviewer as one of the two items most worth hardening in a follow-up).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions