Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions tests/ray/test_producer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import unittest
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch
from xtuner.v1.rl.base.producer import Sampler, SamplerWithReplayBuffer, SyncProduceStrategy, AsyncProduceStrategy
from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StalenessBackend
from xtuner.v1.data_proto.rl_data import RolloutState, Status

class MockRolloutState:
def __init__(self, id, seq_staleness=1, status=Status.COMPLETED):
self.id = id
self.status = status
self.seq_staleness = seq_staleness

class TestProducer(unittest.IsolatedAsyncioTestCase):

def setUp(self):
# 1. 模拟 DataloaderConfig 和 Dataloader
self.mock_dataloader_cfg = MagicMock()
self.mock_dataloader = MagicMock()
# 模拟 next(dataloader_iter) 返回 [RolloutState]
self.mock_dataloader.__iter__.return_value = iter([[MockRolloutState(i)] for i in range(100)])
self.mock_dataloader_cfg.build.return_value = self.mock_dataloader

# 2. 模拟 Tokenizer
self.mock_tokenizer = MagicMock()

# 3. 准备 ReplayBuffer
self.backend = StalenessBackend(limit=10, max_staleness=5)
self.replay_buffer = ReplayBuffer(storage_backend=self.backend)

async def test_sampler_with_replay_buffer(self):
task_name = "test_task"
sampler = SamplerWithReplayBuffer(self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer)

# 场景 A: ReplayBuffer 为空,从 Dataloader 拿
data = await sampler.sample(task_name)
self.assertEqual(data[0].id, 0)

# 场景 B: ReplayBuffer 有 ABORTED 数据,优先拿
aborted_item = MockRolloutState(999, status=Status.ABORTED)
await self.replay_buffer.put([aborted_item], task_name)

data = await sampler.sample(task_name)
self.assertEqual(data[0].id, 999)

async def test_sync_produce_strategy(self):
mock_agent_loop = MagicMock()
async def mock_gen(rs):
for r in rs:
r.status = Status.COMPLETED
return rs
mock_agent_loop.generate_group = mock_gen

sampler = Sampler(self.mock_dataloader_cfg, self.mock_tokenizer)
strategy = SyncProduceStrategy(self.replay_buffer)

# 执行:生产 batch_size 为 2 的数据
await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, task_name="test_task")

# 验证:ReplayBuffer 中应该有 2 条 COMPLETED 数据
final_data = await self.replay_buffer.get(10, "test_task", Status.COMPLETED)
self.assertEqual(len(final_data), 2)
self.assertEqual(final_data[0].id, 0)
self.assertEqual(final_data[1].id, 1)

async def test_async_produce_strategy(self):
# 这个async_produce_strategy的测试主要验证超发逻辑 + staleness 优先get的逻辑
# 异步的其他功能如 partial_rollout, tail_batch不在这里进行验证
mock_agent_loop = MagicMock()
task_name = "test_task"
call_count = 0
async def mock_gen(rs):
nonlocal call_count
call_count += 1
for r in rs:
if r.id == 999:
r.seq_staleness = 5
else:
r.seq_staleness = call_count
r.status = Status.COMPLETED
print(r.id, r.seq_staleness, r.status)
return rs

mock_agent_loop.generate_group = mock_gen

sampler = SamplerWithReplayBuffer(self.mock_dataloader_cfg, self.mock_tokenizer, self.replay_buffer)
strategy = AsyncProduceStrategy(self.replay_buffer, staleness_threshold = 1)
# 预处理
aborted_item = MockRolloutState(999, status=Status.ABORTED)
await self.replay_buffer.put([aborted_item], task_name)
# 执行
await strategy.produce_batch(mock_agent_loop, sampler, batch_size=2, task_name=task_name)

# 验证:ReplayBuffer 中应该有 4 条 COMPLETED 数据,
# NOTE(@duanyanhui): 目前还没实现暂停功能,所以4条都会推理完成,4条数据按照新鲜度顺序排列,999 是最旧的,0 是最新的
final_data = await self.replay_buffer.get(10, task_name, Status.COMPLETED)
self.assertEqual(len(final_data), 4)
self.assertEqual(final_data[0].id, 999)
self.assertEqual(final_data[1].id, 2)
self.assertEqual(final_data[2].id, 1)
self.assertEqual(final_data[3].id, 0)
11 changes: 5 additions & 6 deletions xtuner/v1/rl/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import abc
import asyncio
from abc import ABC
from copy import deepcopy
from typing import Awaitable, Callable, List
from typing import Awaitable, Callable

import ray

Expand All @@ -29,11 +28,11 @@ def __init__(
@abc.abstractmethod
async def generate_sample(self, rollout_state: RolloutState) -> RolloutState: ...

async def generate_group(self, rollout_state, prompt_repeat_k) -> List[RolloutState]:
async def generate_group(self, rollout_state: list[RolloutState]) -> list[RolloutState]:
pending_tasks = []
for _ in range(prompt_repeat_k):
rollout_state.sample_params = self.sample_params
task = asyncio.create_task(self.generate_sample(deepcopy(rollout_state)))
for state in rollout_state:
state.sample_params = self.sample_params
task = asyncio.create_task(self.generate_sample(state))
pending_tasks.append(task)
generated_samples = asyncio.gather(*pending_tasks)
group_samples = await generated_samples
Expand Down
198 changes: 198 additions & 0 deletions xtuner/v1/rl/base/producer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Iterator, Optional, Union

from tqdm.auto import tqdm

from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.datasets.config import DataloaderConfig

from ..agent_loop import AgentLoop
from .replay_buffer import ReplayBuffer


# TODO: 用户把自己的数据集转换成rolloutstate的逻辑放在哪里?
class Sampler:
def __init__(
self,
dataloader_cfg: DataloaderConfig,
tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt_repeat_k: int = 1,
):
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
if isinstance(tokenizer, str):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
else:
self.tokenizer = tokenizer
self.dataloader = dataloader_cfg.build(
tokenizer=self.tokenizer, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1
)
self.dataloader_iter: Optional[Iterator] = None
self.cur_epoch = 0
self.prompt_repeat_k = prompt_repeat_k

async def sample(self, task_name: str) -> list[RolloutState]:
if self.dataloader_iter is None:
self.dataloader_iter = iter(self.dataloader)
assert self.dataloader_iter is not None
try:
data = next(self.dataloader_iter)[0]
except StopIteration:
self.cur_epoch += 1
self.dataloader.set_epoch(self.cur_epoch)
self.dataloader_iter = iter(self.dataloader)
data = next(self.dataloader_iter)[0]
group_data = [data] * self.prompt_repeat_k
return group_data


class SamplerWithReplayBuffer(Sampler):
def __init__(
self,
dataloader_cfg: DataloaderConfig,
tokenizer: Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast],
replay_buffer: ReplayBuffer,
prompt_repeat_k: int = 1,
):
super().__init__(dataloader_cfg, tokenizer, prompt_repeat_k)
self.replay_buffer = replay_buffer

def _sample_from_dataloader(self) -> list[RolloutState]:
if self.dataloader_iter is None:
self.dataloader_iter = iter(self.dataloader)

assert self.dataloader_iter is not None
try:
data = next(self.dataloader_iter)[0]
except StopIteration:
self.cur_epoch += 1
self.dataloader.set_epoch(self.cur_epoch)
self.dataloader_iter = iter(self.dataloader)
data = next(self.dataloader_iter)[0]
group_data = [data] * self.prompt_repeat_k
return group_data

async def sample(self, task_name: str) -> list[RolloutState]:
data = await self.replay_buffer.get(1, task_name=task_name, group_status=Status.ABORTED)
if len(data) == 0:
data = self._sample_from_dataloader()
return data


class ProduceStrategy(ABC):
# NOTE: dataloader不作为ProduceStrategy的成员变量的原因:produce_strategy不绑定dataloader
def __init__(self, replay_buffer: ReplayBuffer):
self.replay_buffer = replay_buffer

@abstractmethod
async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str): ...


class SyncProduceStrategy(ProduceStrategy):
async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str):
pbar_refrash_step = max(1, int(batch_size * 0.1))
pending_tasks = set()
completed_sample_count = await self.replay_buffer.count(task_name=task_name, group_status=Status.COMPLETED)
assert completed_sample_count == 0, "SyncProduceStrategy assumes no completed samples at the start."
with tqdm(total=batch_size, desc=f"Sync Producer [{task_name}]", miniters=pbar_refrash_step) as pbar:
last_pbar_n = completed_sample_count
pbar.update(last_pbar_n)
for _ in range(batch_size):
rollout_state = await sampler.sample(task_name=task_name)
task = asyncio.create_task(agent_loop.generate_group(rollout_state))
pending_tasks.add(task)

while completed_sample_count < batch_size:
if not pending_tasks:
print("All tasks are done but not enough samples collected.")
break
done_tasks, pending_tasks = await asyncio.wait(
pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED
)
# 如果要过滤,在这个地方处理,然后加入到 replay buffer
# 如果被过滤的数据就放到 put_to_filtered pool 中
for task in done_tasks:
try:
await self.replay_buffer.put(items=task.result(), task_name=task_name)
except Exception as e:
print(f"Error in generating trajectory: {e}")

if len(pending_tasks) + completed_sample_count < batch_size:
rollout_state = await sampler.sample(task_name=task_name)
task = asyncio.create_task(agent_loop.generate_group(rollout_state))
pending_tasks.add(task)

completed_sample_count = await self.replay_buffer.count(
task_name=task_name, group_status=Status.COMPLETED
)
pbar.update(completed_sample_count - last_pbar_n)
last_pbar_n = completed_sample_count


class AsyncProduceStrategy(ProduceStrategy):
def __init__(
self,
replay_buffer: ReplayBuffer,
staleness_threshold: float = 0.0,
enable_partial_rollout: bool = False,
tail_batch_trigger_size: int = 0,
tail_batch_candidate_step: int = 0,
):
super().__init__(replay_buffer)
self.staleness_threshold = staleness_threshold
self.enable_partial_rollout = enable_partial_rollout
self.tail_batch_trigger_size = tail_batch_trigger_size
self.tail_batch_candidate_step = tail_batch_candidate_step

async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str):
data_concurrency = int((1 + self.staleness_threshold) * batch_size)
pbar_refrash_step = max(1, int(data_concurrency * 0.1))
pending_tasks = set()
init_completed_sample_count = await self.replay_buffer.count(
task_name=task_name, group_status=Status.COMPLETED
)
with tqdm(total=batch_size, desc=f"ASync Producer [{task_name}]", miniters=pbar_refrash_step) as pbar:
last_pbar_n = init_completed_sample_count
pbar.update(last_pbar_n)
for _ in range(data_concurrency):
rollout_state = await sampler.sample(task_name=task_name)
task = asyncio.create_task(agent_loop.generate_group(rollout_state))
pending_tasks.add(task)

completed_sample_count = init_completed_sample_count
while completed_sample_count < batch_size:
if not pending_tasks:
print("All tasks are done but not enough samples collected.")
break
done_tasks, pending_tasks = await asyncio.wait(
pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED
)

# 如果要过滤,在这个地方处理,然后加入到 replay buffer
# 如果被过滤的数据就放到 put_to_filtered pool 中
for task in done_tasks:
try:
await self.replay_buffer.put(items=task.result(), task_name=task_name)
except Exception as e:
print(f"Error in generating trajectory: {e}")

print(f"Completed sample count: {completed_sample_count}, Pending task count: {len(pending_tasks)}")
completed_sample_count = await self.replay_buffer.count(
task_name=task_name, group_status=Status.COMPLETED
)
pbar.update(completed_sample_count - last_pbar_n)
last_pbar_n = completed_sample_count
if len(pending_tasks) + completed_sample_count < data_concurrency + init_completed_sample_count:
rollout_state = await sampler.sample(task_name=task_name)
task = asyncio.create_task(agent_loop.generate_group(rollout_state))
pending_tasks.add(task)

if len(pending_tasks) > 0:
await agent_loop.pause()
while len(pending_tasks) > 0:
_, pending_tasks = await asyncio.wait(pending_tasks, timeout=1, return_when=asyncio.FIRST_COMPLETED)
if len(pending_tasks) > 0:
await agent_loop.pause()
await asyncio.sleep(1)
print("All worker tasks have completed after pausing env controller.")
18 changes: 18 additions & 0 deletions xtuner/v1/rl/base/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ async def put(self, items: list[RolloutState], storage_indices: StorageIndices):
@abstractmethod
async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ...
@abstractmethod
async def count(self, storage_indices: StorageIndices) -> int: ...
@abstractmethod
def __len__(self): ...


Expand Down Expand Up @@ -50,6 +52,10 @@ async def get(self, count: int, storage_indices: StorageIndices) -> list[Rollout
def __len__(self):
return sum(len(v) for v in self._storage.values())

def count(self, storage_indices: StorageIndices) -> int:
indices = self._hash_storage_indices(storage_indices)
return len(self._storage[indices])


class PandasStorage(Storage):
def __init__(self, limit: int = 0):
Expand Down Expand Up @@ -204,6 +210,13 @@ async def get(self, count: int, storage_indices: StorageIndices) -> list[Rollout
needed -= take
return target_items

async def count(self, storage_indices: StorageIndices) -> int:
indices = self._hash_storage_indices(storage_indices)
total_len = 0
for s in range(self.min_staleness, self.max_staleness + 1):
total_len += len(self._storage[indices][s])
return total_len

def __len__(self):
return sum(count for count in self._bucket_counts.values())

Expand All @@ -223,3 +236,8 @@ async def get(self, batch_size: int, task_name: str, group_status: Status, **kwa
indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs)
async with self._lock:
return await self._storage.get(batch_size, indices)

async def count(self, task_name: str, group_status: Status, **kwargs) -> int:
indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs)
async with self._lock:
return await self._storage.count(indices)
Loading