Fix non-canonical cu_seqlens_k from preprocessor#514
Open
jlamypoirier wants to merge 2 commits into
Open
Conversation
The data preprocessor emitted `cu_seqlens_k[0] = first_document_begin` rather than 0, violating the canonical varlen prefix-sum layout required by every public varlen attention API. SDPA's EFFICIENT backward writes corrupted dK/dV rows when fed this layout, propagating wrong gradients through the K/V projection's reduce-scatter under sequence-data-parallel + micro-batch splits. Three changes that compose: - `LengthModelInputPreprocessor` now produces `cu_seqlens_k` starting at 0 and narrows `document_index_k` / `position_index` to the active K extent. The dropped leading-prefix length is exposed as a new `first_document_begin` int kwarg. - Pre-allocate one K/V buffer per attention layer across all micro-sequences of a sequence. Each forward writes the SDP-gather result into the next slice via `gather_op(out=)`; backward accumulates each micro-seq's K/V grad into a shared grad buffer slice. The leading + trailing narrows and the per-step `torch.cat` / `AttachGrad` workaround for the cross-micro-seq splice are all absorbed into the `_query_key_value` custom autograd region. - `_preprocess_for_backup_attention` builds the attention mask against the narrowed K cols so `sdpa_dense` and `backup` consume the same K extent as flash and `sdpa_nested`. Update `tests/data/test_preprocessing.py` to expect the canonical layout.
`_test_first_document_begin` injects a fake past K/V slot with arbitrary leading data, drives attention through a manually-built kwargs with `sequence_k_past` and `first_document_begin` both set to a non-zero `past_length`, and verifies: - forward output matches a per-doc reference computed on the active documents alone (the dropped prefix has no observable effect), - parameter gradients match the reference, - the K/V grad buffer at `[:past_length]` is exactly zero — the specific guarantee of the cu_seqlens_k canonicalization fix. Runs backup + sdpa_dense on fp32, flash + sdpa_nested on bf16 (flash rejects fp32). Plugged into the existing `test_attention` parametrization as a new case with `name="first_document_begin"`, dispatched via name check.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
cu_seqlens_k[0] = first_document_begininstead of0, violating the canonical varlen prefix-sum layout. SDPA EFFICIENT backward writes corrupt dK/dV rows when fed this, propagating wrong K/V projection grads through the reduce-scatter under sequence-data-parallel + micro-batch splits.cu_seqlens_kstarting at 0; narrowsdocument_index_k/position_indexto the active K extent; exposes the dropped leading-prefix length as a newfirst_document_beginint kwarg.gather_op(out=); backward accumulates per-micro-seq K/V grad into a shared grad buffer slice. The leading + trailing narrows and the per-steptorch.cat/AttachGradcross-micro-seq splice are absorbed into_query_key_value's custom autograd region.Test plan
tests/data/test_preprocessing.py(843/843 pass) — updated to expect canonical layouttests/layers/test_attention.py(57/57 pass on CPU and CUDA, including the newfirst_document_beginregression case)test_attention[first_document_begin-[4, 1, 10]]— injects a fake past K/V slot, drives attention withsequence_k_past > 0andfirst_document_begin > 0, and verifies output + parameter grads match a per-doc reference and thatslot.grad_buffer[:past_length]is exactly zero. Passes for backup, sdpa_dense, flash, sdpa_nested.tests/models/test_model.py—gpt_2-ms4,gpt_2-sdp2,gpt_2-sdp2_stp2,gpt_2-sdp2_stp2_bf4,gpt_2-stp2_pp2s1_bf4still fail with ~0.5–1% relative gradient drift vs thesimplebaseline. The new unit test proves the narrow logic is correct in isolation, so the residual is in something the unit test doesn't exercise: multi-micro-sequence buffer chaining (pasts/presentsflowing across 2+ micro-seqs), SDP multi-rank reduce-scatter, or schedule integration. Debugging continues — likely as additional commits on this PR.