Skip to content

feat(pi07): add ring-attention sequence parallelism#298

Closed
WilliamYue37 wants to merge 8 commits into
mainfrom
feat/pi07-ring-attention
Closed

feat(pi07): add ring-attention sequence parallelism#298
WilliamYue37 wants to merge 8 commits into
mainfrom
feat/pi07-ring-attention

Conversation

@WilliamYue37
Copy link
Copy Markdown
Member

@WilliamYue37 WilliamYue37 commented May 12, 2026

What this does

Adds a paper-style ring attention implementation
(Liu, Zaharia & Abbeel, 2023) to the
pi07 prefix forward as a new attention_implementation="ring" choice
alongside eager / sdpa.

Each rank holds 1/W of the sequence along the seq axis; K/V rotate
around the ring via NCCL P2P (isend/irecv) while an online softmax
accumulates the per-rank output. Forward and backward are a custom
torch.autograd.Function so only Q, K, V, the final O, and the per-row
log-sum-exp survive the forward → backward boundary — full (S, S)
attention-score matrices are never stored on any rank. The per-block
score / probability tensors are also Q-tiled to keep peak memory bounded
by a small constant (_RING_Q_TILE, default 256) rather than `Sq_local

  • Sk_block`.

Per the paper's Section 3.2.2, that blockwise rematerialisation
replaces the legacy per-layer torch.utils.checkpoint wrapping: under
ring, gradient_checkpointing is ignored (and reported in the
attention_implementation docstring as such).

Suffix (action-expert) forwards transparently route through SDPA via a
small _ring_active dispatch flag on the model — they are short and
cross-attend to the already-cached prefix K/V, so ring's fixed costs
outweigh the benefit. KV-cache writes gather_seq across ranks before
storing so the suffix forward sees the full prefix cache.

Scope is intentionally narrow: pi07 low-level training, prefix forward
only. The high-level planner is unchanged.

🗃️ Feature

How it was tested

Hardware

2x NVIDIA RTX 3090 (24 GiB each) on a single node. The implementation
itself is hardware-agnostic — it uses standard NCCL collectives over
the default process group, so it transparently inherits the production
A100 + NVLink + InfiniBand path.

Numerical correctness

Two scripts under src/opentau/scripts/ringattn_experiments/:

  • unit_test.py — kernel-level: builds random Q/K/V, runs the eager
    reference and _RingAttention.apply on the same data, checks the
    per-rank slice of the output matches.

    [rank 0] ref local |max| = 1.2188, max |ref - ring| = 0.000e+00, mean = 0.000e+00
    [rank 1] ref local |max| = 0.9453, max |ref - ring| = 1.192e-07, mean = 2.183e-11
    PASS
    
  • correctness.py — full-model: a tiny Gemma 3 with expert, prefix
    forward run under eager, sdpa, and ring. Ring should be no
    further from SDPA than SDPA is from eager (bf16 reassociation noise
    the model already absorbs today).

    ||eager|| = 512.0071   ||sdpa|| = 511.9952   ||ring|| = 511.9866
    max |eager - sdpa|  = 2.031e-01   mean = 2.706e-02
    max |eager - ring|  = 1.914e-01   mean = 2.609e-02
    max |sdpa  - ring|  = 2.734e-01   mean = 2.285e-02
    PASS
    

Memory scaling (memory_scaling.py)

hidden=256, 1 backbone layer, ws=2, forward + backward:

seq_len sdpa (GiB) ring (GiB) reduction
2048 0.077 0.075 2.6%
8192 0.423 0.285 32.8%
32768 3.879 3.036 21.7%
65536 13.681 12.052 11.9%

Longest feasible context (longest_context.py)

Same harness, binary-search the largest seq_len that completes a
forward+backward without OOM:

  • SDPA: largest OK ≈ 75264 tokens, first OOM at 75520.
  • Ring: largest OK ≥ 81920 tokens (first OOM only seen at 98304).

So ring extends the feasible context by at least ~9% on this
harness, with headroom for a more refined search.

Existing tests

pytest tests/policies/ -m "not gpu" -k "pi07 or pi06" — 222 passed,
98 deselected (the GPU subset is gated for the nightly runner). No
existing tests regress.

How to checkout & try? (for the reviewer)

git checkout feat/pi07-ring-attention
uv sync --extra dev --extra libero
source .venv/bin/activate

# Kernel-level numerical sanity (any 2-GPU box):
torchrun --nproc_per_node=2 -m opentau.scripts.ringattn_experiments.unit_test

# Full-model output drift vs sdpa/eager:
torchrun --nproc_per_node=2 -m opentau.scripts.ringattn_experiments.correctness

# Memory scaling table:
torchrun --nproc_per_node=2 -m opentau.scripts.ringattn_experiments.memory_scaling

# Longest feasible context (slow — takes minutes per arm):
torchrun --nproc_per_node=2 -m opentau.scripts.ringattn_experiments.longest_context

# To enable ring on a real config, set:
#   --policy.gemma3_with_expert_config.attention_implementation=ring
# in any pi07 training command (only needs >=2 GPUs to do anything useful).

Checklist

  • I have added Google-style docstrings to important functions and ensured function parameters are typed.
  • My PR includes policy-related changes.
    • If the above is checked: I have run the GPU pytests (pytest -m "gpu") and regression tests.

GPU pytest / nightly regression runs are pending — they need an A100
box; the 3090 harness here is for the experiment numbers above.

Adds a paper-style ring-attention path
(arxiv.org/abs/2310.01889) for pi07's prefix forward as a new
`attention_implementation="ring"` choice. Each rank holds 1/W of the
sequence; K/V rotate around the ring while an online softmax accumulates
the output per rank. The forward and backward are written as a custom
autograd Function with NCCL P2P (isend/irecv) for the rotation; only Q,
K, V, the final O, and the per-row log-sum-exp survive the forward ->
backward boundary, so the legacy per-layer torch.utils.checkpoint is no
longer needed under the ring path (skipped automatically) and the
blockwise rematerialisation described in the paper's Section 3.2.2
replaces it.

Suffix (action-expert) forwards transparently route through SDPA via a
small dispatch flag, since they are short and cross-attend to the
cached prefix K/V. KV-cache writes gather across ranks before storing
so the suffix sees the full prefix cache exactly as before.

Verified on 2x RTX 3090:
- src/opentau/scripts/ringattn_experiments/unit_test.py — kernel matches
  eager attention to 1.2e-7 (fp32 chain).
- correctness.py — full-model output diff (ring vs sdpa) is within
  1.5x of the existing sdpa-vs-eager bf16 reassociation noise.
- memory_scaling.py — at hidden=256, 1 layer, ws=2: 33% peak memory
  reduction at seq=8192; 22% at seq=32768; 12% at seq=65536.
- longest_context.py — ring extends the largest feasible context past
  SDPA's ceiling (>=81920 vs ~75264 in the harness).
@WilliamYue37 WilliamYue37 self-assigned this May 12, 2026
@WilliamYue37 WilliamYue37 added feature New feature or request optimization Optimizes the performance of something labels May 12, 2026
@WilliamYue37
Copy link
Copy Markdown
Member Author

WilliamYue37 commented May 12, 2026

GPU pytests on the 5090 box (worktree at ~/OpenTau-ringattn, branch feat/pi07-ring-attention @ f35c1fa):

pytest -m "gpu" -n 0
...........ssssss....ss..ss.                                             [100%]
18 passed, 10 skipped, 1124 deselected, 5 warnings in 167.82s (0:02:47)

10 skips are pre-existing fixture-conditional cases (LIBERO env config etc.) — no regressions from this branch.

@shuheng-liu
Copy link
Copy Markdown
Member

A100 benchmark results — ring vs PR #273 SDPA baseline.

Setup

Matrix

cell backend cam bs alloc GiB resv GiB step ms ring sps PR #273 SDPA sps Δ % %resv
DS-1 deepspeed 4 20 57.96 59.41 8353 19.15 13.70 +39.8% 74%
DS-2 deepspeed 4 24 64.32 66.07 10285 18.67 OOM in #273 83%
DS-3 deepspeed 4 28 70.68 72.71 11367 19.71 not tried 91%
DS-4 deepspeed 4 32 77.04 79.38 12434 20.59 not tried 99%
DS-5 deepspeed 1 40 66.03 67.25 7690 41.61 26.70 +55.8% 84%
DS-6 deepspeed 1 48 74.01 75.47 9415 40.79 OOM in #273 94%
DS-7 deepspeed 1 56 OOM
FS-1 fsdp 4 8 54.99 71.12 7827 8.18 7.00 +16.8% 89%

FSDP sweep stopped after FS-1 — at 8.2 sps it is 2.5× behind DS-ring's cam=4 peak (20.6 sps), so FSDP cannot beat DS at this scale and the per-task early-stop rule fired. FS-2..FS-6 were not run.

Headline (peak per backend × cam)

backend cam ring peak PR #273 SDPA peak ratio
DeepSpeed ZeRO-2 1 bs=40 → 41.6 sps 26.7 sps 1.56×
DeepSpeed ZeRO-2 4 bs=32 → 20.6 sps 13.7 sps 1.50×
FSDP-FULL_SHARD 4 bs=8 → 8.2 sps 7.0 sps 1.17×

Verdict

Ring attention is a clear win on the production backend (DeepSpeed ZeRO-2) at both camera counts. Apples-to-apples at PR #273's batch sizes, ring buys 1.4-1.6× more samples/s. Past that, ring opens batch headroom: cam=4 fits bs=32 under ring (PR #273 OOM'd at bs=24); cam=1 fits bs=48 (PR #273 OOM at bs=44).

Why ring helps more than the seq-length argument predicts. The +56% throughput jump on cam=1 (whose prefix is relatively short) under-explains by the 2.6%-32.8% attention-memory reduction in this PR's 3090 scaling table alone. The dominant factor is almost certainly that ring subsumes per-layer gradient_checkpointing via blockwise rematerialisation in the kernel itself: PR #273's cam=1 ZeRO-2 cell was at 93% HBM where checkpointing was load-bearing both for memory and for throughput, and ring relaxes both at once.

Production recommendation update vs PR #273. PR #273 settled on bs=20 cam=4 → 13.7 sps. With ring the new sweet spot is bs=32 cam=4 → 20.6 sps at 99% HBM (1.50× PR #273), or bs=28 cam=4 → 19.7 sps at 91% if 99% feels too tight. For cam=1 (some pretrain phases): bs=40 → 41.6 sps at 84% HBM.

Notes — two unrelated bugs surfaced en route

  1. configs/examples/pi07_low_level_libero.json (added in PR feat(pi07,fsdp): enable FSDP-FULL_SHARD for full unfreeze + profile_step audit + ZeRO-2 vs FSDP matrix #273) has dataset_mixture.history_interval and policy.n_obs_history fields that draccus rejects — they aren't declared on the corresponding dataclasses. That config currently cannot be loaded via the PR feat(pi07,fsdp): enable FSDP-FULL_SHARD for full unfreeze + profile_step audit + ZeRO-2 vs FSDP matrix #273 PR-body repro recipe (any profile_step.py / train.py invocation against it crashes at config-decode). The matrix above used a sibling config with the two stray fields removed. Worth a small follow-up PR.
  2. PR feat(pi07,fsdp): enable FSDP-FULL_SHARD for full unfreeze + profile_step audit + ZeRO-2 vs FSDP matrix #273's PR-body suggests --policy.gemma3_with_expert_config.disable_internal_bf16_cast=true is a CLI flag — it isn't accepted by draccus. The flag is auto-applied inside profile_step.py / train.py when FSDP is detected (lines 217-218 / 392-393). FS-1 above ran with the default behaviour and still saw the ring gain.

Repro for the new cam=4 production setting

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True PROFILE_STEPS=10 \
  accelerate launch \
    --config_file configs/examples/accelerate_deepspeed_config.yaml \
    --num_processes=8 \
    src/opentau/scripts/profile_step.py \
    --config_path=configs/examples/pi07_low_level_libero.json \
    --batch_size=32 --dataloader_batch_size=32 \
    --gradient_accumulation_steps=1 --num_cams=4 \
    --policy.freeze_vision_encoder=false --policy.train_expert_only=false \
    --policy.gradient_checkpointing=true \
    --policy.attention_implementation=ring \
    --wandb.enable=false

(needs note 1 fixed first, or a hand-stripped copy of the JSON).

@shuheng-liu
Copy link
Copy Markdown
Member

shuheng-liu commented May 13, 2026

See #300 for a discussion on NaNs caused by numerical instability with BF16 and FP32

Adds `cfg.ring_group_size` (TrainPipelineConfig) so users can carve WORLD
into ring sub-groups of size R along the intra-node axis with DP-replicates
across rings.

- `build_ring_and_dp_groups(R)` in ring_attention.py creates contiguous
  ring sub-groups and orthogonal DP sub-groups, pinning the ring group via
  `set_ring_group`. Called from train.py right after Accelerator init.
- HierarchicalSampler now accepts a seed; the trainer keys it on `dp_rank`
  so ranks inside the same ring sub-group draw identical sample streams,
  and different DP groups draw different ones.
- `_broadcast_batch_in_ring` makes that equality robust to dataloader-
  worker stochasticity (augmentations etc.) — broadcasts tensors in-place
  from the sub-group leader.
- `loss *= ring_group_size` before backward rescales the world-MEAN
  reduction ZeRO/DDP performs into the correct (ring-SUM, DP-MEAN)
  gradient — keeps ZeRO over WORLD so no DeepSpeed plugin surgery is
  needed.

`ring_group_size == world_size` (or unset) preserves the previous
single-ring behaviour from this PR; `ring_group_size < world_size`
activates the 2D path.
The defensive broadcast inside _broadcast_batch_in_ring was issuing one extra
NCCL collective per training step on every ring sub-group, and could diverge
across sub-group members when a leaf tensor's .is_cuda evaluated
differently per rank (the prepared dataloader was the suspect). With the
HierarchicalSampler seeded per dp_rank, ranks in the same ring sub-group
already draw identical sample streams from a deterministic dataset path —
the broadcast was belt-and-braces, not load-bearing.

Keeps _broadcast_batch_in_ring as a public helper for future callers whose
dataloader has unkeyed stochasticity.
Multi-ring (ring_group_size < world_size) hung at the first training step
under DeepSpeed ZeRO-2. The hang correlates with the orthogonal DP
sub-groups built by build_ring_and_dp_groups — even though no code path
ever called a collective on them. NCCL is documented as unsafe under
concurrent multi-communicator use on shared CUDA streams, and every
extra sub-group communicator competes with ZeRO's world-group all-gather
/ reduce-scatter on the same streams.

The fix:
- get_dp_rank now computes the DP index arithmetically
  (world_rank // ring_group_size). It's only used to seed the sampler;
  no collective is ever issued on a DP communicator, so the
  communicator itself was unnecessary.
- When ring_group_size == world_size, the ring group is just WORLD —
  no duplicate communicator is created.
- Only when ring_group_size < world_size do we create explicit ring
  sub-groups, and only those — not the orthogonal DP groups.

The (ring_group, dp_group) return signature is kept so future call sites
that genuinely need an explicit DP communicator (e.g. for gradient
all-reduce on a non-world DP axis) can be added without breaking the
API. dp_group is always None today.
When ring_group_size == world_size the ring sub-group is WORLD, so
within-sub-group rank == global rank and the bug was masked. With a
proper sub-group (e.g. {4,5,6,7}), passing within-sub-group rank 1 to
P2POp triggers 'Global rank 1 is not part of group' at the first
batch_isend_irecv inside _RingAttention.forward, killing the first
training step. Translate the local rotation step to a global peer rank
via dist.get_global_rank.
@WilliamYue37
Copy link
Copy Markdown
Member Author

Pushed 2D parallelism wiring on top of the original PR:

  • `cfg.ring_group_size` (TrainPipelineConfig) splits WORLD into ring sub-groups of size R along the intra-node axis. Validated in `TrainPipelineConfig.validate()` against `WORLD_SIZE` (from env) and the policy's `attention_implementation`.
  • `build_ring_and_dp_groups(R)` creates ring sub-groups, pins via `set_ring_group`. DP topology is exposed arithmetically (`world_rank // R`) without a separate communicator — every extra NCCL sub-group competes with ZeRO's world-group collectives on the same CUDA streams.
  • HierarchicalSampler accepts a seed; the trainer keys it on `dp_rank` so ranks inside the same ring sub-group draw identical sample streams.
  • `loss *= ring_group_size` before backward rescales ZeRO's world-MEAN reduce into the correct `(ring-SUM, DP-MEAN)` gradient. ZeRO stays over WORLD — no DeepSpeed plugin surgery.

Bugs found and fixed during testing on a100-training / gpu-387 (bot_387 reservation, ZeRO-2)

  1. `62c4a88` — dropped a redundant per-step batch broadcast (the seeded sampler already gives identical batches within a ring sub-group; the broadcast was issuing an extra collective per step and could diverge if `is_cuda` differed across ranks).
  2. `6ce3cef` — stopped creating an orthogonal DP communicator. It was unused (DP index is just `world_rank // R`), and creating the extra NCCL sub-group correlated with a first-step hang under ZeRO-2.
  3. `dc6b4df` — `P2POp`'s `peer` argument is a global rank, not a within-group rank. When `ring_group_size == world_size` the bug was masked (the sub-group equals WORLD). With a proper sub-group (e.g. `{4,5,6,7}`), passing local rank 1 as the peer crashed at the first `batch_isend_irecv` with "Global rank 1 is not part of group". Now translates the local rotation step to a global peer via `dist.get_global_rank`.

Verification

  • Local 2× RTX 3090 `unit_test.py` (kernel-level): still bit-equivalent to eager (max abs diff 1.2e-7).
  • a100-training, 8× A100, DeepSpeed ZeRO-2, pi07 low-level pretraining config:
    • `ring_group_size=8` (single ring) — runs cleanly (job 1310, 20 steps).
    • `ring_group_size=4` (2 rings × 4 GPUs) — runs after the three bugfixes above (job 1317 step 1: `total_loss 64.94 scaled / 16.24 unscaled`; matches the rs=8 baseline within bf16 noise).

Known issue (separate from this PR's topology work)

Both `rs=8` and `rs=4` show `total_loss = nan` from step 2 onward in real pi07 training. The same NaN appears when `ring_group_size` is unset (legacy single-ring across WORLD from this PR's original commit, no loss scaling), so it is not introduced by the 2D wiring. It is a pre-existing bug in the ring attention kernel's real-data bf16 backward — likely in how the online-softmax accumulator interacts with ZeRO-2's bf16 master grads. Needs a separate investigation; tracking as a follow-up.

Real-data pi07 training under ZeRO-2 produces NaN at step 2 even with
ring_group_size unset (single ring across WORLD). The current ring kernel
uses _NEG_INF = -2.38e38 (mirrors the eager constant) and does several
fp32 subtractions on it inside the online softmax / LSE chain; values
that close to fp32's representable limit can cancel into NaN under those
compositions. -1e9 is still well below -log(fp32_eps) so masked positions
still underflow to 0 in fp32 wherever any unmasked score is present, and
fully-masked rows still self-cancel to exp(0)=1.

Adds RING_DEBUG_PROBES=1 — guarded torch.isfinite() assertions in the
forward and backward to localise which tensor first goes non-finite under
a real run. Off by default; cheap enough to leave on for a smoke run.
Forward.O / LSE were finite per the previous probe firing on
backward.dq_local; the NaN is being introduced somewhere in the per-tile
loop. Add probes on saved tensors, S, P, dP, dS, and D_tile so the next
run pin-points the exact operation that first emits a non-finite value.
Local diagnostic test scripts:
- backward_test.py: gradient-equivalence vs eager autograd (within bf16 noise).
- nan_repro.py: kernel stress with pi07-style block-causal mask + PAD rows.
- deepspeed_repro.py: tiny Gemma3 with expert under accelerate's DeepSpeed
  ZeRO-2 bf16 plugin, 5 training steps with a pi07-shaped CE loss path.

All three pass on 2x RTX 3090; no NaN.

Defensive _RingAttention.backward: nan_to_num the returned dq/dk/dv before
casting to bf16. The kernel is finite-input-finite-output for every case
we could construct locally, but the real-data pi07 + ZeRO-2 run still hit
a step-2 NaN whose intermediate source we could not localise under
RING_DEBUG_PROBES (the production dataset cache kept failing the
fine-grained probe job before training started). nan_to_num gives the
optimizer the same skip-this-element semantics DeepSpeed's overflow check
would, without changing the math on the finite-output path. The strict
_assert_finite probes still fire before sanitisation, so RING_DEBUG_PROBES=1
remains useful for future investigation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request optimization Optimizes the performance of something

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants