From 653e9d2f6992a059ad7fe45ee5a12f2fcb806683 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Wed, 25 Feb 2026 15:22:05 +0800 Subject: [PATCH 1/2] [Producer] Add Sampler, SamplerWithBuffer, SyncProduceStrategy, AsyncProduceStrategy --- tests/ray/test_producer.py | 103 ++++++++++++++++ xtuner/v1/rl/base/producer.py | 184 +++++++++++++++++++++++++++++ xtuner/v1/rl/base/replay_buffer.py | 18 +++ 3 files changed, 305 insertions(+) create mode 100644 tests/ray/test_producer.py create mode 100644 xtuner/v1/rl/base/producer.py diff --git a/tests/ray/test_producer.py b/tests/ray/test_producer.py new file mode 100644 index 000000000..7c9ecf493 --- /dev/null +++ b/tests/ray/test_producer.py @@ -0,0 +1,103 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch +from xtuner.v1.rl.base.producer import Sampler, SamplerWithReplayBuffer, SyncProduceStrategy, AsyncProduceStrategy +from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StalenessBackend +from xtuner.v1.data_proto.rl_data import RolloutState, Status + +class MockRolloutState: + def __init__(self, id, seq_staleness=1, status=Status.COMPLETED): + self.id = id + self.status = status + self.seq_staleness = seq_staleness + +class TestProducer(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + # 1. 模拟 DataloaderConfig 和 Dataloader + self.mock_dataloader_cfg = MagicMock() + self.mock_dataloader = MagicMock() + # 模拟 next(dataloader_iter) 返回 [RolloutState] + self.mock_dataloader.__iter__.return_value = iter([[MockRolloutState(i)] for i in range(100)]) + self.mock_dataloader_cfg.build.return_value = self.mock_dataloader + + # 2. 模拟 Tokenizer + self.mock_tokenizer = MagicMock() + + # 3. 准备 ReplayBuffer + self.backend = StalenessBackend(limit=10, max_staleness=5) + self.replay_buffer = ReplayBuffer(storage_backend=self.backend) + + async def test_sampler_with_replay_buffer(self): + sampler = SamplerWithReplayBuffer("test_task", self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer) + + # 场景 A: ReplayBuffer 为空,从 Dataloader 拿 + data = await sampler.sample() + self.assertEqual(data.id, 0) + + # 场景 B: ReplayBuffer 有 ABORTED 数据,优先拿 + aborted_item = MockRolloutState(999, status=Status.ABORTED) + await self.replay_buffer.put([aborted_item], "test_task") + + data = await sampler.sample() + self.assertEqual(data[0].id, 999) + + async def test_sync_produce_strategy(self): + # 1. 模拟 AgentLoop + mock_agent_loop = MagicMock() + mock_agent_loop.task_name = "test_task" + # generate_group 返回的是 List[RolloutState] + async def mock_gen(rs, k): + rs.status = Status.COMPLETED + return [rs] + mock_agent_loop.generate_group = mock_gen + + sampler = Sampler("test_task", self.mock_dataloader_cfg, self.mock_tokenizer) + strategy = SyncProduceStrategy(self.replay_buffer) + + # 执行:生产 batch_size 为 2 的数据 + await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, prompt_k=1) + + # 验证:ReplayBuffer 中应该有 2 条 COMPLETED 数据 + final_data = await self.replay_buffer.get(10, "test_task", Status.COMPLETED) + self.assertEqual(len(final_data), 2) + self.assertEqual(final_data[0].id, 0) + self.assertEqual(final_data[1].id, 1) + + async def test_async_produce_strategy(self): + # 这个async_produce_strategy的测试主要验证超发逻辑 + staleness 优先get的逻辑 + # 异步的其他功能如 partial_rollout, tail_batch不在这里进行验证 + mock_agent_loop = MagicMock() + mock_agent_loop.task_name = "test_task" + + call_count = 0 + async def mock_gen(rs, k): + nonlocal call_count + call_count += 1 + if isinstance(rs, list): + for r in rs: + r.seq_staleness = 5 + r.status = Status.COMPLETED + return rs + else: + rs.seq_staleness = call_count + rs.status = Status.COMPLETED + return [rs] + mock_agent_loop.generate_group = mock_gen + + sampler = SamplerWithReplayBuffer("test_task", self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer) + strategy = AsyncProduceStrategy(self.replay_buffer, staleness_threshold = 1) + # 预处理 + aborted_item = MockRolloutState(999, status=Status.ABORTED) + await self.replay_buffer.put([aborted_item], "test_task") + # 执行 + await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, prompt_k=1) + + # 验证:ReplayBuffer 中应该有 4 条 COMPLETED 数据, + # NOTE(@duanyanhui): 目前还没实现暂停功能,所以4条都会推理完成,4条数据按照新鲜度顺序排列,999 是最旧的,0 是最新的 + final_data = await self.replay_buffer.get(10, "test_task", Status.COMPLETED) + self.assertEqual(len(final_data), 4) + self.assertEqual(final_data[0].id, 999) + self.assertEqual(final_data[1].id, 2) + self.assertEqual(final_data[2].id, 1) + self.assertEqual(final_data[3].id, 0) \ No newline at end of file diff --git a/xtuner/v1/rl/base/producer.py b/xtuner/v1/rl/base/producer.py new file mode 100644 index 000000000..72c652419 --- /dev/null +++ b/xtuner/v1/rl/base/producer.py @@ -0,0 +1,184 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import Union + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.datasets.config import DataloaderConfig + +from ..agent_loop import AgentLoop +from .replay_buffer import ReplayBuffer + + +# TODO: 用户把自己的数据集转换成rolloutstate的逻辑放在哪里? +class Sampler: + def __init__( + self, + task_name: str, + dataloader_cfg: DataloaderConfig, + tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast], + ): + self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + if isinstance(tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + else: + self.tokenizer = tokenizer + self.dataloader = dataloader_cfg.build( + tokenizer=self.tokenizer, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1 + ) + self.dataloader_iter = iter(self.dataloader) + self.cur_epoch = 0 + self.task_name = task_name + + async def sample(self) -> RolloutState: + try: + data = next(self.dataloader_iter)[0] + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] + return data + + +class SamplerWithReplayBuffer(Sampler): + def __init__( + self, + task_name: str, + dataloader_cfg: DataloaderConfig, + tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast], + replay_buffer: ReplayBuffer, + ): + super().__init__(task_name, dataloader_cfg, tokenizer) + self.replay_buffer = replay_buffer + + def _sample_from_dataloader(self) -> RolloutState: + try: + data = next(self.dataloader_iter)[0] + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] + return data + + async def sample(self) -> list[RolloutState]: + data = await self.replay_buffer.get(1, task_name=self.task_name, group_status=Status.ABORTED) + if len(data) == 0: + data = self._sample_from_dataloader() + return data + + +class ProduceStrategy(ABC): + # NOTE: dataloader不作为ProduceStrategy的成员变量的原因:produce_strategy不绑定dataloader + def __init__(self, replay_buffer: ReplayBuffer): + self.replay_buffer = replay_buffer + + @abstractmethod + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, prompt_k: int): ... + + +class SyncProduceStrategy(ProduceStrategy): + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, prompt_k: int): + data_concurrency = batch_size + pending_tasks = set() + for _ in range(data_concurrency): + rollout_state = await sampler.sample() + task = asyncio.create_task(agent_loop.generate_group(rollout_state, prompt_k)) + pending_tasks.add(task) + + init_completed_sample_count = await self.replay_buffer.count( + task_name=agent_loop.task_name, group_status=Status.COMPLETED + ) + completed_sample_count = init_completed_sample_count + while completed_sample_count < data_concurrency: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + + # 如果要过滤,在这个地方处理,然后加入到 replay buffer + # 如果被过滤的数据就放到 put_to_filtered pool 中 + for task in done_tasks: + try: + await self.replay_buffer.put(items=task.result(), task_name=agent_loop.task_name) + except Exception as e: + print(f"Error in generating trajectory: {e}") + + if len(pending_tasks) + completed_sample_count < data_concurrency + init_completed_sample_count: + rollout_state = await sampler.sample() + task = asyncio.create_task(agent_loop.generate_group(rollout_state, prompt_k)) + pending_tasks.add(task) + + completed_sample_count = await self.replay_buffer.count( + task_name=agent_loop.task_name, group_status=Status.COMPLETED + ) + + +class AsyncProduceStrategy(ProduceStrategy): + def __init__( + self, + replay_buffer: ReplayBuffer, + staleness_threshold: float = 0.0, + enable_partial_rollout: bool = False, + tail_batch_trigger_size: int = 0, + tail_batch_candidate_step: int = 0, + ): + super().__init__(replay_buffer) + self.staleness_threshold = staleness_threshold + self.enable_partial_rollout = enable_partial_rollout + self.tail_batch_trigger_size = tail_batch_trigger_size + self.tail_batch_candidate_step = tail_batch_candidate_step + + async def produce_batch( + self, agent_loop: AgentLoop, sampler: SamplerWithReplayBuffer, batch_size: int, prompt_k: int + ): + data_concurrency = (1 + self.staleness_threshold) * batch_size + print( + f"AsyncProduceStrategy: data_concurrency={data_concurrency}, staleness_threshold={self.staleness_threshold}" + ) + pending_tasks = set() + for _ in range(data_concurrency): + rollout_state = await sampler.sample() + task = asyncio.create_task(agent_loop.generate_group(rollout_state, prompt_k)) + pending_tasks.add(task) + + init_completed_sample_count = await self.replay_buffer.count( + task_name=agent_loop.task_name, group_status=Status.COMPLETED + ) + completed_sample_count = init_completed_sample_count + while completed_sample_count < data_concurrency: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + + # 如果要过滤,在这个地方处理,然后加入到 replay buffer + # 如果被过滤的数据就放到 put_to_filtered pool 中 + for task in done_tasks: + try: + await self.replay_buffer.put(items=task.result(), task_name=agent_loop.task_name) + except Exception as e: + print(f"Error in generating trajectory: {e}") + + print(f"Completed sample count: {completed_sample_count}, Pending task count: {len(pending_tasks)}") + completed_sample_count = await self.replay_buffer.count( + task_name=agent_loop.task_name, group_status=Status.COMPLETED + ) + if len(pending_tasks) + completed_sample_count < data_concurrency + init_completed_sample_count: + rollout_state = await sampler.sample() + task = asyncio.create_task(agent_loop.generate_group(rollout_state, prompt_k)) + pending_tasks.add(task) + + if len(pending_tasks) > 0: + await agent_loop.pause() + while len(pending_tasks) > 0: + _, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED) + if len(pending_tasks) > 0: + await agent_loop.pause() + await asyncio.sleep(1) + print("All worker tasks have completed after pausing env controller.") diff --git a/xtuner/v1/rl/base/replay_buffer.py b/xtuner/v1/rl/base/replay_buffer.py index 33bd5c7a6..30648bc9e 100644 --- a/xtuner/v1/rl/base/replay_buffer.py +++ b/xtuner/v1/rl/base/replay_buffer.py @@ -20,6 +20,8 @@ async def put(self, items: list[RolloutState], storage_indices: StorageIndices): @abstractmethod async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ... @abstractmethod + async def count(self, storage_indices: StorageIndices) -> int: ... + @abstractmethod def __len__(self): ... @@ -50,6 +52,10 @@ async def get(self, count: int, storage_indices: StorageIndices) -> list[Rollout def __len__(self): return sum(len(v) for v in self._storage.values()) + def count(self, storage_indices: StorageIndices) -> int: + indices = self._hash_storage_indices(storage_indices) + return len(self._storage[indices]) + class PandasStorage(Storage): def __init__(self, limit: int = 0): @@ -204,6 +210,13 @@ async def get(self, count: int, storage_indices: StorageIndices) -> list[Rollout needed -= take return target_items + async def count(self, storage_indices: StorageIndices) -> int: + indices = self._hash_storage_indices(storage_indices) + total_len = 0 + for s in range(self.min_staleness, self.max_staleness + 1): + total_len += len(self._storage[indices][s]) + return total_len + def __len__(self): return sum(count for count in self._bucket_counts.values()) @@ -223,3 +236,8 @@ async def get(self, batch_size: int, task_name: str, group_status: Status, **kwa indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs) async with self._lock: return await self._storage.get(batch_size, indices) + + async def count(self, task_name: str, group_status: Status, **kwargs) -> int: + indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs) + async with self._lock: + return await self._storage.count(indices) From 9fefd4ddbde1afd59fc936486776d49f9cda7524 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 26 Feb 2026 15:28:35 +0800 Subject: [PATCH 2/2] add tqdm in ProduceStrategy and fix comments on sampler --- tests/ray/test_producer.py | 54 ++++---- xtuner/v1/rl/agent_loop/agent_loop.py | 11 +- xtuner/v1/rl/base/producer.py | 180 ++++++++++++++------------ 3 files changed, 128 insertions(+), 117 deletions(-) diff --git a/tests/ray/test_producer.py b/tests/ray/test_producer.py index 7c9ecf493..cd181863d 100644 --- a/tests/ray/test_producer.py +++ b/tests/ray/test_producer.py @@ -29,34 +29,33 @@ def setUp(self): self.replay_buffer = ReplayBuffer(storage_backend=self.backend) async def test_sampler_with_replay_buffer(self): - sampler = SamplerWithReplayBuffer("test_task", self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer) + task_name = "test_task" + sampler = SamplerWithReplayBuffer(self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer) # 场景 A: ReplayBuffer 为空,从 Dataloader 拿 - data = await sampler.sample() - self.assertEqual(data.id, 0) + data = await sampler.sample(task_name) + self.assertEqual(data[0].id, 0) # 场景 B: ReplayBuffer 有 ABORTED 数据,优先拿 aborted_item = MockRolloutState(999, status=Status.ABORTED) - await self.replay_buffer.put([aborted_item], "test_task") + await self.replay_buffer.put([aborted_item], task_name) - data = await sampler.sample() + data = await sampler.sample(task_name) self.assertEqual(data[0].id, 999) async def test_sync_produce_strategy(self): - # 1. 模拟 AgentLoop mock_agent_loop = MagicMock() - mock_agent_loop.task_name = "test_task" - # generate_group 返回的是 List[RolloutState] - async def mock_gen(rs, k): - rs.status = Status.COMPLETED - return [rs] + async def mock_gen(rs): + for r in rs: + r.status = Status.COMPLETED + return rs mock_agent_loop.generate_group = mock_gen - sampler = Sampler("test_task", self.mock_dataloader_cfg, self.mock_tokenizer) + sampler = Sampler(self.mock_dataloader_cfg, self.mock_tokenizer) strategy = SyncProduceStrategy(self.replay_buffer) # 执行:生产 batch_size 为 2 的数据 - await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, prompt_k=1) + await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, task_name="test_task") # 验证:ReplayBuffer 中应该有 2 条 COMPLETED 数据 final_data = await self.replay_buffer.get(10, "test_task", Status.COMPLETED) @@ -68,34 +67,33 @@ async def test_async_produce_strategy(self): # 这个async_produce_strategy的测试主要验证超发逻辑 + staleness 优先get的逻辑 # 异步的其他功能如 partial_rollout, tail_batch不在这里进行验证 mock_agent_loop = MagicMock() - mock_agent_loop.task_name = "test_task" - + task_name = "test_task" call_count = 0 - async def mock_gen(rs, k): + async def mock_gen(rs): nonlocal call_count call_count += 1 - if isinstance(rs, list): - for r in rs: + for r in rs: + if r.id == 999: r.seq_staleness = 5 - r.status = Status.COMPLETED - return rs - else: - rs.seq_staleness = call_count - rs.status = Status.COMPLETED - return [rs] + else: + r.seq_staleness = call_count + r.status = Status.COMPLETED + print(r.id, r.seq_staleness, r.status) + return rs + mock_agent_loop.generate_group = mock_gen - sampler = SamplerWithReplayBuffer("test_task", self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer) + sampler = SamplerWithReplayBuffer(self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer) strategy = AsyncProduceStrategy(self.replay_buffer, staleness_threshold = 1) # 预处理 aborted_item = MockRolloutState(999, status=Status.ABORTED) - await self.replay_buffer.put([aborted_item], "test_task") + await self.replay_buffer.put([aborted_item], task_name) # 执行 - await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, prompt_k=1) + await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, task_name=task_name) # 验证:ReplayBuffer 中应该有 4 条 COMPLETED 数据, # NOTE(@duanyanhui): 目前还没实现暂停功能,所以4条都会推理完成,4条数据按照新鲜度顺序排列,999 是最旧的,0 是最新的 - final_data = await self.replay_buffer.get(10, "test_task", Status.COMPLETED) + final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) self.assertEqual(len(final_data), 4) self.assertEqual(final_data[0].id, 999) self.assertEqual(final_data[1].id, 2) diff --git a/xtuner/v1/rl/agent_loop/agent_loop.py b/xtuner/v1/rl/agent_loop/agent_loop.py index 005f1869e..c36a000d7 100644 --- a/xtuner/v1/rl/agent_loop/agent_loop.py +++ b/xtuner/v1/rl/agent_loop/agent_loop.py @@ -1,8 +1,7 @@ import abc import asyncio from abc import ABC -from copy import deepcopy -from typing import Awaitable, Callable, List +from typing import Awaitable, Callable import ray @@ -29,11 +28,11 @@ def __init__( @abc.abstractmethod async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: ... - async def generate_group(self, rollout_state, prompt_repeat_k) -> List[RolloutState]: + async def generate_group(self, rollout_state: list[RolloutState]) -> list[RolloutState]: pending_tasks = [] - for _ in range(prompt_repeat_k): - rollout_state.sample_params = self.sample_params - task = asyncio.create_task(self.generate_sample(deepcopy(rollout_state))) + for state in rollout_state: + state.sample_params = self.sample_params + task = asyncio.create_task(self.generate_sample(state)) pending_tasks.append(task) generated_samples = asyncio.gather(*pending_tasks) group_samples = await generated_samples diff --git a/xtuner/v1/rl/base/producer.py b/xtuner/v1/rl/base/producer.py index 72c652419..87a42787e 100644 --- a/xtuner/v1/rl/base/producer.py +++ b/xtuner/v1/rl/base/producer.py @@ -1,6 +1,8 @@ import asyncio from abc import ABC, abstractmethod -from typing import Union +from typing import Iterator, Optional, Union + +from tqdm.auto import tqdm from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from xtuner.v1.data_proto.rl_data import RolloutState, Status @@ -14,9 +16,9 @@ class Sampler: def __init__( self, - task_name: str, dataloader_cfg: DataloaderConfig, tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast], + prompt_repeat_k: int = 1, ): self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] if isinstance(tokenizer, str): @@ -26,11 +28,14 @@ def __init__( self.dataloader = dataloader_cfg.build( tokenizer=self.tokenizer, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1 ) - self.dataloader_iter = iter(self.dataloader) + self.dataloader_iter: Optional[Iterator] = None self.cur_epoch = 0 - self.task_name = task_name + self.prompt_repeat_k = prompt_repeat_k - async def sample(self) -> RolloutState: + async def sample(self, task_name: str) -> list[RolloutState]: + if self.dataloader_iter is None: + self.dataloader_iter = iter(self.dataloader) + assert self.dataloader_iter is not None try: data = next(self.dataloader_iter)[0] except StopIteration: @@ -38,21 +43,26 @@ async def sample(self) -> RolloutState: self.dataloader.set_epoch(self.cur_epoch) self.dataloader_iter = iter(self.dataloader) data = next(self.dataloader_iter)[0] - return data + group_data = [data] * self.prompt_repeat_k + return group_data class SamplerWithReplayBuffer(Sampler): def __init__( self, - task_name: str, dataloader_cfg: DataloaderConfig, tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast], replay_buffer: ReplayBuffer, + prompt_repeat_k: int = 1, ): - super().__init__(task_name, dataloader_cfg, tokenizer) + super().__init__(dataloader_cfg, tokenizer, prompt_repeat_k) self.replay_buffer = replay_buffer - def _sample_from_dataloader(self) -> RolloutState: + def _sample_from_dataloader(self) -> list[RolloutState]: + if self.dataloader_iter is None: + self.dataloader_iter = iter(self.dataloader) + + assert self.dataloader_iter is not None try: data = next(self.dataloader_iter)[0] except StopIteration: @@ -60,10 +70,11 @@ def _sample_from_dataloader(self) -> RolloutState: self.dataloader.set_epoch(self.cur_epoch) self.dataloader_iter = iter(self.dataloader) data = next(self.dataloader_iter)[0] - return data + group_data = [data] * self.prompt_repeat_k + return group_data - async def sample(self) -> list[RolloutState]: - data = await self.replay_buffer.get(1, task_name=self.task_name, group_status=Status.ABORTED) + async def sample(self, task_name: str) -> list[RolloutState]: + data = await self.replay_buffer.get(1, task_name=task_name, group_status=Status.ABORTED) if len(data) == 0: data = self._sample_from_dataloader() return data @@ -75,46 +86,48 @@ def __init__(self, replay_buffer: ReplayBuffer): self.replay_buffer = replay_buffer @abstractmethod - async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, prompt_k: int): ... + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str): ... class SyncProduceStrategy(ProduceStrategy): - async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, prompt_k: int): - data_concurrency = batch_size + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str): + pbar_refrash_step = max(1, int(batch_size * 0.1)) pending_tasks = set() - for _ in range(data_concurrency): - rollout_state = await sampler.sample() - task = asyncio.create_task(agent_loop.generate_group(rollout_state, prompt_k)) - pending_tasks.add(task) - - init_completed_sample_count = await self.replay_buffer.count( - task_name=agent_loop.task_name, group_status=Status.COMPLETED - ) - completed_sample_count = init_completed_sample_count - while completed_sample_count < data_concurrency: - if not pending_tasks: - print("All tasks are done but not enough samples collected.") - break - done_tasks, pending_tasks = await asyncio.wait( - pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED - ) - - # 如果要过滤,在这个地方处理,然后加入到 replay buffer - # 如果被过滤的数据就放到 put_to_filtered pool 中 - for task in done_tasks: - try: - await self.replay_buffer.put(items=task.result(), task_name=agent_loop.task_name) - except Exception as e: - print(f"Error in generating trajectory: {e}") - - if len(pending_tasks) + completed_sample_count < data_concurrency + init_completed_sample_count: - rollout_state = await sampler.sample() - task = asyncio.create_task(agent_loop.generate_group(rollout_state, prompt_k)) + completed_sample_count = await self.replay_buffer.count(task_name=task_name, group_status=Status.COMPLETED) + assert completed_sample_count == 0, "SyncProduceStrategy assumes no completed samples at the start." + with tqdm(total=batch_size, desc=f"Sync Producer [{task_name}]", miniters=pbar_refrash_step) as pbar: + last_pbar_n = completed_sample_count + pbar.update(last_pbar_n) + for _ in range(batch_size): + rollout_state = await sampler.sample(task_name=task_name) + task = asyncio.create_task(agent_loop.generate_group(rollout_state)) pending_tasks.add(task) - completed_sample_count = await self.replay_buffer.count( - task_name=agent_loop.task_name, group_status=Status.COMPLETED - ) + while completed_sample_count < batch_size: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + # 如果要过滤,在这个地方处理,然后加入到 replay buffer + # 如果被过滤的数据就放到 put_to_filtered pool 中 + for task in done_tasks: + try: + await self.replay_buffer.put(items=task.result(), task_name=task_name) + except Exception as e: + print(f"Error in generating trajectory: {e}") + + if len(pending_tasks) + completed_sample_count < batch_size: + rollout_state = await sampler.sample(task_name=task_name) + task = asyncio.create_task(agent_loop.generate_group(rollout_state)) + pending_tasks.add(task) + + completed_sample_count = await self.replay_buffer.count( + task_name=task_name, group_status=Status.COMPLETED + ) + pbar.update(completed_sample_count - last_pbar_n) + last_pbar_n = completed_sample_count class AsyncProduceStrategy(ProduceStrategy): @@ -132,48 +145,49 @@ def __init__( self.tail_batch_trigger_size = tail_batch_trigger_size self.tail_batch_candidate_step = tail_batch_candidate_step - async def produce_batch( - self, agent_loop: AgentLoop, sampler: SamplerWithReplayBuffer, batch_size: int, prompt_k: int - ): - data_concurrency = (1 + self.staleness_threshold) * batch_size - print( - f"AsyncProduceStrategy: data_concurrency={data_concurrency}, staleness_threshold={self.staleness_threshold}" - ) + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str): + data_concurrency = int((1 + self.staleness_threshold) * batch_size) + pbar_refrash_step = max(1, int(data_concurrency * 0.1)) pending_tasks = set() - for _ in range(data_concurrency): - rollout_state = await sampler.sample() - task = asyncio.create_task(agent_loop.generate_group(rollout_state, prompt_k)) - pending_tasks.add(task) - init_completed_sample_count = await self.replay_buffer.count( - task_name=agent_loop.task_name, group_status=Status.COMPLETED + task_name=task_name, group_status=Status.COMPLETED ) - completed_sample_count = init_completed_sample_count - while completed_sample_count < data_concurrency: - if not pending_tasks: - print("All tasks are done but not enough samples collected.") - break - done_tasks, pending_tasks = await asyncio.wait( - pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED - ) - - # 如果要过滤,在这个地方处理,然后加入到 replay buffer - # 如果被过滤的数据就放到 put_to_filtered pool 中 - for task in done_tasks: - try: - await self.replay_buffer.put(items=task.result(), task_name=agent_loop.task_name) - except Exception as e: - print(f"Error in generating trajectory: {e}") - - print(f"Completed sample count: {completed_sample_count}, Pending task count: {len(pending_tasks)}") - completed_sample_count = await self.replay_buffer.count( - task_name=agent_loop.task_name, group_status=Status.COMPLETED - ) - if len(pending_tasks) + completed_sample_count < data_concurrency + init_completed_sample_count: - rollout_state = await sampler.sample() - task = asyncio.create_task(agent_loop.generate_group(rollout_state, prompt_k)) + with tqdm(total=batch_size, desc=f"ASync Producer [{task_name}]", miniters=pbar_refrash_step) as pbar: + last_pbar_n = init_completed_sample_count + pbar.update(last_pbar_n) + for _ in range(data_concurrency): + rollout_state = await sampler.sample(task_name=task_name) + task = asyncio.create_task(agent_loop.generate_group(rollout_state)) pending_tasks.add(task) + completed_sample_count = init_completed_sample_count + while completed_sample_count < batch_size: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + + # 如果要过滤,在这个地方处理,然后加入到 replay buffer + # 如果被过滤的数据就放到 put_to_filtered pool 中 + for task in done_tasks: + try: + await self.replay_buffer.put(items=task.result(), task_name=task_name) + except Exception as e: + print(f"Error in generating trajectory: {e}") + + print(f"Completed sample count: {completed_sample_count}, Pending task count: {len(pending_tasks)}") + completed_sample_count = await self.replay_buffer.count( + task_name=task_name, group_status=Status.COMPLETED + ) + pbar.update(completed_sample_count - last_pbar_n) + last_pbar_n = completed_sample_count + if len(pending_tasks) + completed_sample_count < data_concurrency + init_completed_sample_count: + rollout_state = await sampler.sample(task_name=task_name) + task = asyncio.create_task(agent_loop.generate_group(rollout_state)) + pending_tasks.add(task) + if len(pending_tasks) > 0: await agent_loop.pause() while len(pending_tasks) > 0: