Skip to content

feat: support staleness-window in ReplayBufferNew#2458

Open
yuki-97 wants to merge 3 commits into
yukih/refactor-async-utilsfrom
yukih/staleness-sample
Open

feat: support staleness-window in ReplayBufferNew#2458
yuki-97 wants to merge 3 commits into
yukih/refactor-async-utilsfrom
yukih/staleness-sample

Conversation

@yuki-97
Copy link
Copy Markdown
Contributor

@yuki-97 yuki-97 commented May 11, 2026

Part of RL-727. Stacks on #2448.

Implements ReplayBufferNew, a temporary replacement for ReplayBuffer until TQReplayBuffer is ready.

Motivation: ReplayBuffer.sample() requires target_weight_version == current_weight_version, which stalls training when the exact-match trajectories haven't arrived yet (buffer starvation). ReplayBufferNew fixes this by allowing slightly older trajectories to be used, with an importance-sampling correction.

Changes:

  • Add max_staleness config: trajectories with trainer_version - weight_version > max_staleness are evicted at the start of each sample() call.
  • sample() selects from the staleness window [trainer_version - max_staleness, trainer_version], removing the strict target_weight_version == current_weight_version gate.
  • Add sample_freshest_first flag (default True): when True, selects the highest-version trajectories first; when False, uses FIFO (insertion order).
  • target_weight_versions is intentionally unused in ReplayBufferNew — it gates generation on specific trainer steps, causing generation pauses. Will be removed when cleaning up after TQReplayBuffer lands.
  • Unit tests covering eviction, staleness-window sampling, freshest-first ordering, and FIFO ordering.

Signed-off-by: Yuki Huang <yukih@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yuki-97 yuki-97 marked this pull request as ready for review May 11, 2026 05:06
@yuki-97 yuki-97 requested review from a team as code owners May 11, 2026 05:06
@yuki-97 yuki-97 requested review from mehraakash and terrykong May 11, 2026 05:06
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97 yuki-97 force-pushed the yukih/staleness-sample branch from 85e179f to 19334ad Compare May 11, 2026 10:04
Signed-off-by: Yuki Huang <yukih@nvidia.com>
# limitations under the License.

import threading as _threading
from collections import Counter
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

sampled_weights
)
sampled_items = [self.trajectories[i] for i in selected]
for idx in sorted(selected, reverse=True):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we refactored into another function:

  def _remove_indices(self, indices: Iterable[int]) -> None:
      for idx in sorted(indices, reverse=True):
          self.trajectory_versions.pop(idx)
          self.target_weight_versions.pop(idx)
          self.trajectories.pop(idx)

can then use it in _evict and sample and provide different Iterables?

"""
min_valid = current_weight_version - self.max_staleness
stale = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid]
for idx in sorted(stale, reverse=True):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below on adding it in a function.

stale = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid]
for idx in sorted(stale, reverse=True):
self.trajectory_versions.pop(idx)
self.trajectories.pop(idx)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we want to eventually get rid of target_weight_versions but since we inherited from ReplayBufferImpl that list will be created. So we either keep that state aligned or remove it?

@ray.remote # pragma: no cover
class ReplayBufferNew(ReplayBufferImpl):
pass
"""Staleness-window replay buffer.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a follow-up task here before wiring this in:

ReplayBufferNew removes exact target matching in sample() here, but the collector still enforces
target-version reservation and generation-limit pauses through last_target_weight_already_generated.

For end-to-end staleness-window sampling, the collector needs a mode that generates based on
current generation_weight_version and buffer/backpressure capacity, and not future target_weight_version slots.

We'll control generation using SingleController by:

  • Buffer Capacity
  • Inflight Semaphore
  • Refit pause
  • Any manual Pause
  • Dataloader availability

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants