diff --git a/tests/ray/test_producer.py b/tests/ray/test_producer.py new file mode 100644 index 000000000..cd181863d --- /dev/null +++ b/tests/ray/test_producer.py @@ -0,0 +1,101 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch +from xtuner.v1.rl.base.producer import Sampler, SamplerWithReplayBuffer, SyncProduceStrategy, AsyncProduceStrategy +from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StalenessBackend +from xtuner.v1.data_proto.rl_data import RolloutState, Status + +class MockRolloutState: + def __init__(self, id, seq_staleness=1, status=Status.COMPLETED): + self.id = id + self.status = status + self.seq_staleness = seq_staleness + +class TestProducer(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + # 1. 模拟 DataloaderConfig 和 Dataloader + self.mock_dataloader_cfg = MagicMock() + self.mock_dataloader = MagicMock() + # 模拟 next(dataloader_iter) 返回 [RolloutState] + self.mock_dataloader.__iter__.return_value = iter([[MockRolloutState(i)] for i in range(100)]) + self.mock_dataloader_cfg.build.return_value = self.mock_dataloader + + # 2. 模拟 Tokenizer + self.mock_tokenizer = MagicMock() + + # 3. 准备 ReplayBuffer + self.backend = StalenessBackend(limit=10, max_staleness=5) + self.replay_buffer = ReplayBuffer(storage_backend=self.backend) + + async def test_sampler_with_replay_buffer(self): + task_name = "test_task" + sampler = SamplerWithReplayBuffer(self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer) + + # 场景 A: ReplayBuffer 为空,从 Dataloader 拿 + data = await sampler.sample(task_name) + self.assertEqual(data[0].id, 0) + + # 场景 B: ReplayBuffer 有 ABORTED 数据,优先拿 + aborted_item = MockRolloutState(999, status=Status.ABORTED) + await self.replay_buffer.put([aborted_item], task_name) + + data = await sampler.sample(task_name) + self.assertEqual(data[0].id, 999) + + async def test_sync_produce_strategy(self): + mock_agent_loop = MagicMock() + async def mock_gen(rs): + for r in rs: + r.status = Status.COMPLETED + return rs + mock_agent_loop.generate_group = mock_gen + + sampler = Sampler(self.mock_dataloader_cfg, self.mock_tokenizer) + strategy = SyncProduceStrategy(self.replay_buffer) + + # 执行:生产 batch_size 为 2 的数据 + await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, task_name="test_task") + + # 验证:ReplayBuffer 中应该有 2 条 COMPLETED 数据 + final_data = await self.replay_buffer.get(10, "test_task", Status.COMPLETED) + self.assertEqual(len(final_data), 2) + self.assertEqual(final_data[0].id, 0) + self.assertEqual(final_data[1].id, 1) + + async def test_async_produce_strategy(self): + # 这个async_produce_strategy的测试主要验证超发逻辑 + staleness 优先get的逻辑 + # 异步的其他功能如 partial_rollout, tail_batch不在这里进行验证 + mock_agent_loop = MagicMock() + task_name = "test_task" + call_count = 0 + async def mock_gen(rs): + nonlocal call_count + call_count += 1 + for r in rs: + if r.id == 999: + r.seq_staleness = 5 + else: + r.seq_staleness = call_count + r.status = Status.COMPLETED + print(r.id, r.seq_staleness, r.status) + return rs + + mock_agent_loop.generate_group = mock_gen + + sampler = SamplerWithReplayBuffer(self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer) + strategy = AsyncProduceStrategy(self.replay_buffer, staleness_threshold = 1) + # 预处理 + aborted_item = MockRolloutState(999, status=Status.ABORTED) + await self.replay_buffer.put([aborted_item], task_name) + # 执行 + await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, task_name=task_name) + + # 验证:ReplayBuffer 中应该有 4 条 COMPLETED 数据, + # NOTE(@duanyanhui): 目前还没实现暂停功能,所以4条都会推理完成,4条数据按照新鲜度顺序排列,999 是最旧的,0 是最新的 + final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED) + self.assertEqual(len(final_data), 4) + self.assertEqual(final_data[0].id, 999) + self.assertEqual(final_data[1].id, 2) + self.assertEqual(final_data[2].id, 1) + self.assertEqual(final_data[3].id, 0) \ No newline at end of file diff --git a/xtuner/v1/rl/agent_loop/agent_loop.py b/xtuner/v1/rl/agent_loop/agent_loop.py index 005f1869e..c36a000d7 100644 --- a/xtuner/v1/rl/agent_loop/agent_loop.py +++ b/xtuner/v1/rl/agent_loop/agent_loop.py @@ -1,8 +1,7 @@ import abc import asyncio from abc import ABC -from copy import deepcopy -from typing import Awaitable, Callable, List +from typing import Awaitable, Callable import ray @@ -29,11 +28,11 @@ def __init__( @abc.abstractmethod async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: ... - async def generate_group(self, rollout_state, prompt_repeat_k) -> List[RolloutState]: + async def generate_group(self, rollout_state: list[RolloutState]) -> list[RolloutState]: pending_tasks = [] - for _ in range(prompt_repeat_k): - rollout_state.sample_params = self.sample_params - task = asyncio.create_task(self.generate_sample(deepcopy(rollout_state))) + for state in rollout_state: + state.sample_params = self.sample_params + task = asyncio.create_task(self.generate_sample(state)) pending_tasks.append(task) generated_samples = asyncio.gather(*pending_tasks) group_samples = await generated_samples diff --git a/xtuner/v1/rl/base/producer.py b/xtuner/v1/rl/base/producer.py new file mode 100644 index 000000000..87a42787e --- /dev/null +++ b/xtuner/v1/rl/base/producer.py @@ -0,0 +1,198 @@ +import asyncio +from abc import ABC, abstractmethod +from typing import Iterator, Optional, Union + +from tqdm.auto import tqdm + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.datasets.config import DataloaderConfig + +from ..agent_loop import AgentLoop +from .replay_buffer import ReplayBuffer + + +# TODO: 用户把自己的数据集转换成rolloutstate的逻辑放在哪里? +class Sampler: + def __init__( + self, + dataloader_cfg: DataloaderConfig, + tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast], + prompt_repeat_k: int = 1, + ): + self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + if isinstance(tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + else: + self.tokenizer = tokenizer + self.dataloader = dataloader_cfg.build( + tokenizer=self.tokenizer, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1 + ) + self.dataloader_iter: Optional[Iterator] = None + self.cur_epoch = 0 + self.prompt_repeat_k = prompt_repeat_k + + async def sample(self, task_name: str) -> list[RolloutState]: + if self.dataloader_iter is None: + self.dataloader_iter = iter(self.dataloader) + assert self.dataloader_iter is not None + try: + data = next(self.dataloader_iter)[0] + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] + group_data = [data] * self.prompt_repeat_k + return group_data + + +class SamplerWithReplayBuffer(Sampler): + def __init__( + self, + dataloader_cfg: DataloaderConfig, + tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast], + replay_buffer: ReplayBuffer, + prompt_repeat_k: int = 1, + ): + super().__init__(dataloader_cfg, tokenizer, prompt_repeat_k) + self.replay_buffer = replay_buffer + + def _sample_from_dataloader(self) -> list[RolloutState]: + if self.dataloader_iter is None: + self.dataloader_iter = iter(self.dataloader) + + assert self.dataloader_iter is not None + try: + data = next(self.dataloader_iter)[0] + except StopIteration: + self.cur_epoch += 1 + self.dataloader.set_epoch(self.cur_epoch) + self.dataloader_iter = iter(self.dataloader) + data = next(self.dataloader_iter)[0] + group_data = [data] * self.prompt_repeat_k + return group_data + + async def sample(self, task_name: str) -> list[RolloutState]: + data = await self.replay_buffer.get(1, task_name=task_name, group_status=Status.ABORTED) + if len(data) == 0: + data = self._sample_from_dataloader() + return data + + +class ProduceStrategy(ABC): + # NOTE: dataloader不作为ProduceStrategy的成员变量的原因:produce_strategy不绑定dataloader + def __init__(self, replay_buffer: ReplayBuffer): + self.replay_buffer = replay_buffer + + @abstractmethod + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str): ... + + +class SyncProduceStrategy(ProduceStrategy): + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str): + pbar_refrash_step = max(1, int(batch_size * 0.1)) + pending_tasks = set() + completed_sample_count = await self.replay_buffer.count(task_name=task_name, group_status=Status.COMPLETED) + assert completed_sample_count == 0, "SyncProduceStrategy assumes no completed samples at the start." + with tqdm(total=batch_size, desc=f"Sync Producer [{task_name}]", miniters=pbar_refrash_step) as pbar: + last_pbar_n = completed_sample_count + pbar.update(last_pbar_n) + for _ in range(batch_size): + rollout_state = await sampler.sample(task_name=task_name) + task = asyncio.create_task(agent_loop.generate_group(rollout_state)) + pending_tasks.add(task) + + while completed_sample_count < batch_size: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + # 如果要过滤,在这个地方处理,然后加入到 replay buffer + # 如果被过滤的数据就放到 put_to_filtered pool 中 + for task in done_tasks: + try: + await self.replay_buffer.put(items=task.result(), task_name=task_name) + except Exception as e: + print(f"Error in generating trajectory: {e}") + + if len(pending_tasks) + completed_sample_count < batch_size: + rollout_state = await sampler.sample(task_name=task_name) + task = asyncio.create_task(agent_loop.generate_group(rollout_state)) + pending_tasks.add(task) + + completed_sample_count = await self.replay_buffer.count( + task_name=task_name, group_status=Status.COMPLETED + ) + pbar.update(completed_sample_count - last_pbar_n) + last_pbar_n = completed_sample_count + + +class AsyncProduceStrategy(ProduceStrategy): + def __init__( + self, + replay_buffer: ReplayBuffer, + staleness_threshold: float = 0.0, + enable_partial_rollout: bool = False, + tail_batch_trigger_size: int = 0, + tail_batch_candidate_step: int = 0, + ): + super().__init__(replay_buffer) + self.staleness_threshold = staleness_threshold + self.enable_partial_rollout = enable_partial_rollout + self.tail_batch_trigger_size = tail_batch_trigger_size + self.tail_batch_candidate_step = tail_batch_candidate_step + + async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str): + data_concurrency = int((1 + self.staleness_threshold) * batch_size) + pbar_refrash_step = max(1, int(data_concurrency * 0.1)) + pending_tasks = set() + init_completed_sample_count = await self.replay_buffer.count( + task_name=task_name, group_status=Status.COMPLETED + ) + with tqdm(total=batch_size, desc=f"ASync Producer [{task_name}]", miniters=pbar_refrash_step) as pbar: + last_pbar_n = init_completed_sample_count + pbar.update(last_pbar_n) + for _ in range(data_concurrency): + rollout_state = await sampler.sample(task_name=task_name) + task = asyncio.create_task(agent_loop.generate_group(rollout_state)) + pending_tasks.add(task) + + completed_sample_count = init_completed_sample_count + while completed_sample_count < batch_size: + if not pending_tasks: + print("All tasks are done but not enough samples collected.") + break + done_tasks, pending_tasks = await asyncio.wait( + pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED + ) + + # 如果要过滤,在这个地方处理,然后加入到 replay buffer + # 如果被过滤的数据就放到 put_to_filtered pool 中 + for task in done_tasks: + try: + await self.replay_buffer.put(items=task.result(), task_name=task_name) + except Exception as e: + print(f"Error in generating trajectory: {e}") + + print(f"Completed sample count: {completed_sample_count}, Pending task count: {len(pending_tasks)}") + completed_sample_count = await self.replay_buffer.count( + task_name=task_name, group_status=Status.COMPLETED + ) + pbar.update(completed_sample_count - last_pbar_n) + last_pbar_n = completed_sample_count + if len(pending_tasks) + completed_sample_count < data_concurrency + init_completed_sample_count: + rollout_state = await sampler.sample(task_name=task_name) + task = asyncio.create_task(agent_loop.generate_group(rollout_state)) + pending_tasks.add(task) + + if len(pending_tasks) > 0: + await agent_loop.pause() + while len(pending_tasks) > 0: + _, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED) + if len(pending_tasks) > 0: + await agent_loop.pause() + await asyncio.sleep(1) + print("All worker tasks have completed after pausing env controller.") diff --git a/xtuner/v1/rl/base/replay_buffer.py b/xtuner/v1/rl/base/replay_buffer.py index 33bd5c7a6..30648bc9e 100644 --- a/xtuner/v1/rl/base/replay_buffer.py +++ b/xtuner/v1/rl/base/replay_buffer.py @@ -20,6 +20,8 @@ async def put(self, items: list[RolloutState], storage_indices: StorageIndices): @abstractmethod async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ... @abstractmethod + async def count(self, storage_indices: StorageIndices) -> int: ... + @abstractmethod def __len__(self): ... @@ -50,6 +52,10 @@ async def get(self, count: int, storage_indices: StorageIndices) -> list[Rollout def __len__(self): return sum(len(v) for v in self._storage.values()) + def count(self, storage_indices: StorageIndices) -> int: + indices = self._hash_storage_indices(storage_indices) + return len(self._storage[indices]) + class PandasStorage(Storage): def __init__(self, limit: int = 0): @@ -204,6 +210,13 @@ async def get(self, count: int, storage_indices: StorageIndices) -> list[Rollout needed -= take return target_items + async def count(self, storage_indices: StorageIndices) -> int: + indices = self._hash_storage_indices(storage_indices) + total_len = 0 + for s in range(self.min_staleness, self.max_staleness + 1): + total_len += len(self._storage[indices][s]) + return total_len + def __len__(self): return sum(count for count in self._bucket_counts.values()) @@ -223,3 +236,8 @@ async def get(self, batch_size: int, task_name: str, group_status: Status, **kwa indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs) async with self._lock: return await self._storage.get(batch_size, indices) + + async def count(self, task_name: str, group_status: Status, **kwargs) -> int: + indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs) + async with self._lock: + return await self._storage.count(indices)