feat: support staleness-window in ReplayBufferNew#2458
Conversation
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
85e179f to
19334ad
Compare
Signed-off-by: Yuki Huang <yukih@nvidia.com>
| # limitations under the License. | ||
|
|
||
| import threading as _threading | ||
| from collections import Counter |
| sampled_weights | ||
| ) | ||
| sampled_items = [self.trajectories[i] for i in selected] | ||
| for idx in sorted(selected, reverse=True): |
There was a problem hiding this comment.
Could we refactored into another function:
def _remove_indices(self, indices: Iterable[int]) -> None:
for idx in sorted(indices, reverse=True):
self.trajectory_versions.pop(idx)
self.target_weight_versions.pop(idx)
self.trajectories.pop(idx)
can then use it in _evict and sample and provide different Iterables?
| """ | ||
| min_valid = current_weight_version - self.max_staleness | ||
| stale = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid] | ||
| for idx in sorted(stale, reverse=True): |
There was a problem hiding this comment.
See comment below on adding it in a function.
| stale = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid] | ||
| for idx in sorted(stale, reverse=True): | ||
| self.trajectory_versions.pop(idx) | ||
| self.trajectories.pop(idx) |
There was a problem hiding this comment.
I know we want to eventually get rid of target_weight_versions but since we inherited from ReplayBufferImpl that list will be created. So we either keep that state aligned or remove it?
| @ray.remote # pragma: no cover | ||
| class ReplayBufferNew(ReplayBufferImpl): | ||
| pass | ||
| """Staleness-window replay buffer. |
There was a problem hiding this comment.
I think we need a follow-up task here before wiring this in:
ReplayBufferNew removes exact target matching in sample() here, but the collector still enforces
target-version reservation and generation-limit pauses through last_target_weight_already_generated.
For end-to-end staleness-window sampling, the collector needs a mode that generates based on
current generation_weight_version and buffer/backpressure capacity, and not future target_weight_version slots.
We'll control generation using SingleController by:
- Buffer Capacity
- Inflight Semaphore
- Refit pause
- Any manual Pause
- Dataloader availability
Part of RL-727. Stacks on #2448.
Implements
ReplayBufferNew, a temporary replacement forReplayBufferuntilTQReplayBufferis ready.Motivation:
ReplayBuffer.sample()requirestarget_weight_version == current_weight_version, which stalls training when the exact-match trajectories haven't arrived yet (buffer starvation).ReplayBufferNewfixes this by allowing slightly older trajectories to be used, with an importance-sampling correction.Changes:
max_stalenessconfig: trajectories withtrainer_version - weight_version > max_stalenessare evicted at the start of eachsample()call.sample()selects from the staleness window[trainer_version - max_staleness, trainer_version], removing the stricttarget_weight_version == current_weight_versiongate.sample_freshest_firstflag (defaultTrue): whenTrue, selects the highest-version trajectories first; whenFalse, uses FIFO (insertion order).target_weight_versionsis intentionally unused inReplayBufferNew— it gates generation on specific trainer steps, causing generation pauses. Will be removed when cleaning up afterTQReplayBufferlands.