feat(pi07): add ring-attention sequence parallelism#298
Conversation
Adds a paper-style ring-attention path (arxiv.org/abs/2310.01889) for pi07's prefix forward as a new `attention_implementation="ring"` choice. Each rank holds 1/W of the sequence; K/V rotate around the ring while an online softmax accumulates the output per rank. The forward and backward are written as a custom autograd Function with NCCL P2P (isend/irecv) for the rotation; only Q, K, V, the final O, and the per-row log-sum-exp survive the forward -> backward boundary, so the legacy per-layer torch.utils.checkpoint is no longer needed under the ring path (skipped automatically) and the blockwise rematerialisation described in the paper's Section 3.2.2 replaces it. Suffix (action-expert) forwards transparently route through SDPA via a small dispatch flag, since they are short and cross-attend to the cached prefix K/V. KV-cache writes gather across ranks before storing so the suffix sees the full prefix cache exactly as before. Verified on 2x RTX 3090: - src/opentau/scripts/ringattn_experiments/unit_test.py — kernel matches eager attention to 1.2e-7 (fp32 chain). - correctness.py — full-model output diff (ring vs sdpa) is within 1.5x of the existing sdpa-vs-eager bf16 reassociation noise. - memory_scaling.py — at hidden=256, 1 layer, ws=2: 33% peak memory reduction at seq=8192; 22% at seq=32768; 12% at seq=65536. - longest_context.py — ring extends the largest feasible context past SDPA's ceiling (>=81920 vs ~75264 in the harness).
|
GPU pytests on the 5090 box (worktree at 10 skips are pre-existing fixture-conditional cases (LIBERO env config etc.) — no regressions from this branch. |
|
A100 benchmark results — ring vs PR #273 SDPA baseline. Setup
Matrix
FSDP sweep stopped after FS-1 — at 8.2 sps it is 2.5× behind DS-ring's cam=4 peak (20.6 sps), so FSDP cannot beat DS at this scale and the per-task early-stop rule fired. FS-2..FS-6 were not run. Headline (peak per backend × cam)
VerdictRing attention is a clear win on the production backend (DeepSpeed ZeRO-2) at both camera counts. Apples-to-apples at PR #273's batch sizes, ring buys 1.4-1.6× more samples/s. Past that, ring opens batch headroom: cam=4 fits bs=32 under ring (PR #273 OOM'd at bs=24); cam=1 fits bs=48 (PR #273 OOM at bs=44). Why ring helps more than the seq-length argument predicts. The +56% throughput jump on cam=1 (whose prefix is relatively short) under-explains by the 2.6%-32.8% attention-memory reduction in this PR's 3090 scaling table alone. The dominant factor is almost certainly that ring subsumes per-layer Production recommendation update vs PR #273. PR #273 settled on bs=20 cam=4 → 13.7 sps. With ring the new sweet spot is bs=32 cam=4 → 20.6 sps at 99% HBM (1.50× PR #273), or bs=28 cam=4 → 19.7 sps at 91% if 99% feels too tight. For cam=1 (some pretrain phases): bs=40 → 41.6 sps at 84% HBM. Notes — two unrelated bugs surfaced en route
Repro for the new cam=4 production settingPYTORCH_CUDA_ALLOC_CONF=expandable_segments:True PROFILE_STEPS=10 \
accelerate launch \
--config_file configs/examples/accelerate_deepspeed_config.yaml \
--num_processes=8 \
src/opentau/scripts/profile_step.py \
--config_path=configs/examples/pi07_low_level_libero.json \
--batch_size=32 --dataloader_batch_size=32 \
--gradient_accumulation_steps=1 --num_cams=4 \
--policy.freeze_vision_encoder=false --policy.train_expert_only=false \
--policy.gradient_checkpointing=true \
--policy.attention_implementation=ring \
--wandb.enable=false(needs note 1 fixed first, or a hand-stripped copy of the JSON). |
|
See #300 for a discussion on NaNs caused by numerical instability with BF16 and FP32 |
Adds `cfg.ring_group_size` (TrainPipelineConfig) so users can carve WORLD into ring sub-groups of size R along the intra-node axis with DP-replicates across rings. - `build_ring_and_dp_groups(R)` in ring_attention.py creates contiguous ring sub-groups and orthogonal DP sub-groups, pinning the ring group via `set_ring_group`. Called from train.py right after Accelerator init. - HierarchicalSampler now accepts a seed; the trainer keys it on `dp_rank` so ranks inside the same ring sub-group draw identical sample streams, and different DP groups draw different ones. - `_broadcast_batch_in_ring` makes that equality robust to dataloader- worker stochasticity (augmentations etc.) — broadcasts tensors in-place from the sub-group leader. - `loss *= ring_group_size` before backward rescales the world-MEAN reduction ZeRO/DDP performs into the correct (ring-SUM, DP-MEAN) gradient — keeps ZeRO over WORLD so no DeepSpeed plugin surgery is needed. `ring_group_size == world_size` (or unset) preserves the previous single-ring behaviour from this PR; `ring_group_size < world_size` activates the 2D path.
The defensive broadcast inside _broadcast_batch_in_ring was issuing one extra NCCL collective per training step on every ring sub-group, and could diverge across sub-group members when a leaf tensor's .is_cuda evaluated differently per rank (the prepared dataloader was the suspect). With the HierarchicalSampler seeded per dp_rank, ranks in the same ring sub-group already draw identical sample streams from a deterministic dataset path — the broadcast was belt-and-braces, not load-bearing. Keeps _broadcast_batch_in_ring as a public helper for future callers whose dataloader has unkeyed stochasticity.
Multi-ring (ring_group_size < world_size) hung at the first training step under DeepSpeed ZeRO-2. The hang correlates with the orthogonal DP sub-groups built by build_ring_and_dp_groups — even though no code path ever called a collective on them. NCCL is documented as unsafe under concurrent multi-communicator use on shared CUDA streams, and every extra sub-group communicator competes with ZeRO's world-group all-gather / reduce-scatter on the same streams. The fix: - get_dp_rank now computes the DP index arithmetically (world_rank // ring_group_size). It's only used to seed the sampler; no collective is ever issued on a DP communicator, so the communicator itself was unnecessary. - When ring_group_size == world_size, the ring group is just WORLD — no duplicate communicator is created. - Only when ring_group_size < world_size do we create explicit ring sub-groups, and only those — not the orthogonal DP groups. The (ring_group, dp_group) return signature is kept so future call sites that genuinely need an explicit DP communicator (e.g. for gradient all-reduce on a non-world DP axis) can be added without breaking the API. dp_group is always None today.
When ring_group_size == world_size the ring sub-group is WORLD, so
within-sub-group rank == global rank and the bug was masked. With a
proper sub-group (e.g. {4,5,6,7}), passing within-sub-group rank 1 to
P2POp triggers 'Global rank 1 is not part of group' at the first
batch_isend_irecv inside _RingAttention.forward, killing the first
training step. Translate the local rotation step to a global peer rank
via dist.get_global_rank.
|
Pushed 2D parallelism wiring on top of the original PR:
Bugs found and fixed during testing on a100-training / gpu-387 (bot_387 reservation, ZeRO-2)
Verification
Known issue (separate from this PR's topology work)Both `rs=8` and `rs=4` show `total_loss = nan` from step 2 onward in real pi07 training. The same NaN appears when `ring_group_size` is unset (legacy single-ring across WORLD from this PR's original commit, no loss scaling), so it is not introduced by the 2D wiring. It is a pre-existing bug in the ring attention kernel's real-data bf16 backward — likely in how the online-softmax accumulator interacts with ZeRO-2's bf16 master grads. Needs a separate investigation; tracking as a follow-up. |
Real-data pi07 training under ZeRO-2 produces NaN at step 2 even with ring_group_size unset (single ring across WORLD). The current ring kernel uses _NEG_INF = -2.38e38 (mirrors the eager constant) and does several fp32 subtractions on it inside the online softmax / LSE chain; values that close to fp32's representable limit can cancel into NaN under those compositions. -1e9 is still well below -log(fp32_eps) so masked positions still underflow to 0 in fp32 wherever any unmasked score is present, and fully-masked rows still self-cancel to exp(0)=1. Adds RING_DEBUG_PROBES=1 — guarded torch.isfinite() assertions in the forward and backward to localise which tensor first goes non-finite under a real run. Off by default; cheap enough to leave on for a smoke run.
Forward.O / LSE were finite per the previous probe firing on backward.dq_local; the NaN is being introduced somewhere in the per-tile loop. Add probes on saved tensors, S, P, dP, dS, and D_tile so the next run pin-points the exact operation that first emits a non-finite value.
Local diagnostic test scripts: - backward_test.py: gradient-equivalence vs eager autograd (within bf16 noise). - nan_repro.py: kernel stress with pi07-style block-causal mask + PAD rows. - deepspeed_repro.py: tiny Gemma3 with expert under accelerate's DeepSpeed ZeRO-2 bf16 plugin, 5 training steps with a pi07-shaped CE loss path. All three pass on 2x RTX 3090; no NaN. Defensive _RingAttention.backward: nan_to_num the returned dq/dk/dv before casting to bf16. The kernel is finite-input-finite-output for every case we could construct locally, but the real-data pi07 + ZeRO-2 run still hit a step-2 NaN whose intermediate source we could not localise under RING_DEBUG_PROBES (the production dataset cache kept failing the fine-grained probe job before training started). nan_to_num gives the optimizer the same skip-this-element semantics DeepSpeed's overflow check would, without changing the math on the finite-output path. The strict _assert_finite probes still fire before sanitisation, so RING_DEBUG_PROBES=1 remains useful for future investigation.
What this does
Adds a paper-style ring attention implementation
(Liu, Zaharia & Abbeel, 2023) to the
pi07 prefix forward as a new
attention_implementation="ring"choicealongside
eager/sdpa.Each rank holds 1/W of the sequence along the seq axis; K/V rotate
around the ring via NCCL P2P (
isend/irecv) while an online softmaxaccumulates the per-rank output. Forward and backward are a custom
torch.autograd.Functionso only Q, K, V, the final O, and the per-rowlog-sum-exp survive the forward → backward boundary — full
(S, S)attention-score matrices are never stored on any rank. The per-block
score / probability tensors are also Q-tiled to keep peak memory bounded
by a small constant (
_RING_Q_TILE, default 256) rather than `Sq_localPer the paper's Section 3.2.2, that blockwise rematerialisation
replaces the legacy per-layer
torch.utils.checkpointwrapping: underring,
gradient_checkpointingis ignored (and reported in theattention_implementationdocstring as such).Suffix (action-expert) forwards transparently route through SDPA via a
small
_ring_activedispatch flag on the model — they are short andcross-attend to the already-cached prefix K/V, so ring's fixed costs
outweigh the benefit. KV-cache writes
gather_seqacross ranks beforestoring so the suffix forward sees the full prefix cache.
Scope is intentionally narrow: pi07 low-level training, prefix forward
only. The high-level planner is unchanged.
🗃️ Feature
How it was tested
Hardware
2x NVIDIA RTX 3090 (24 GiB each) on a single node. The implementation
itself is hardware-agnostic — it uses standard NCCL collectives over
the default process group, so it transparently inherits the production
A100 + NVLink + InfiniBand path.
Numerical correctness
Two scripts under
src/opentau/scripts/ringattn_experiments/:unit_test.py— kernel-level: builds random Q/K/V, runs the eagerreference and
_RingAttention.applyon the same data, checks theper-rank slice of the output matches.
correctness.py— full-model: a tiny Gemma 3 with expert, prefixforward run under
eager,sdpa, andring. Ring should be nofurther from SDPA than SDPA is from eager (bf16 reassociation noise
the model already absorbs today).
Memory scaling (
memory_scaling.py)hidden=256, 1 backbone layer, ws=2, forward + backward:
Longest feasible context (
longest_context.py)Same harness, binary-search the largest seq_len that completes a
forward+backward without OOM:
So ring extends the feasible context by at least ~9% on this
harness, with headroom for a more refined search.
Existing tests
pytest tests/policies/ -m "not gpu" -k "pi07 or pi06"— 222 passed,98 deselected (the GPU subset is gated for the nightly runner). No
existing tests regress.
How to checkout & try? (for the reviewer)
Checklist
GPU pytest / nightly regression runs are pending — they need an A100
box; the 3090 harness here is for the experiment numbers above.