Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 0 additions & 120 deletions tests/rl/test_agent_loop_utils.py

This file was deleted.

26 changes: 26 additions & 0 deletions tests/rl/test_multi_task_agent_loop_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,32 @@ async def test_get_batch_refreshes_staleness_at_entry(self):
self.assertEqual(manager._produce_progress.next_consumer_step, 10)
self.assertEqual(manager._produce_progress.consumed_samples["task_a"], 1)

async def test_get_batch_returns_raw_reward_stats_from_progress(self):
replay_buffer = _FakeReplayBuffer(
rollout_states_by_task={"task_a": [[_FakeRolloutState("a-0", 0.2)]]},
leftover_counts={("task_a", Status.COMPLETED): 1},
)
manager = AgentLoopManager(
task_runners=[
_TaskRunner(
task_name="task_a",
agent_loop=_fake_agent_loop(),
produce_strategy=_FakeProduceStrategy(),
sampler=_FakeSampler(),
weight=1.0,
order=0,
),
],
replay_buffer=replay_buffer,
)
manager._produce_progress.add_raw_rewards("task_a", 1.25, 2)

result = await manager.get_batch(batch_size=1, train_step=9)

self.assertEqual(result.raw_rewards_sum, 1.25)
self.assertEqual(result.raw_rewards_count, 2)
self.assertEqual(manager._produce_progress.consume_raw_rewards("task_a"), (0.0, 0))

async def test_get_batch_waits_until_requested_batch_size_is_ready(self):
replay_buffer = _SequencedCompletedReplayBuffer(
completed_counts=[0, 1, 2],
Expand Down
54 changes: 43 additions & 11 deletions tests/rl/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@


class MockRolloutState:
def __init__(self, id, seq_staleness=1, status=Status.COMPLETED):
def __init__(self, id, seq_staleness=1, status=Status.COMPLETED, reward_score=None):
self.id = id
self.uid = id
self.status = status
self.seq_staleness = seq_staleness
self.response_ids = []
self.extra_fields = {}
self.reward = {"score": reward_score} if reward_score is not None else None


class TestProducer(unittest.IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -51,14 +52,18 @@ def _build_progress(
target: int,
train_step: int = 0,
consumed: int = 0,
producer_future_step: int | None = None,
target_upto_future_step: int | None = None,
) -> ProduceProgress:
return ProduceProgress(
next_consumer_step=train_step,
producer_future_step=train_step,
consumed_samples={task_name: consumed},
target_samples={task_name: target},
target_upto_future_step=train_step,
progress = ProduceProgress.build([task_name])
progress.next_consumer_step = train_step
progress.producer_future_step = producer_future_step if producer_future_step is not None else train_step
progress.consumed_samples[task_name] = consumed
progress.target_samples[task_name] = target
progress.target_upto_future_step = (
target_upto_future_step if target_upto_future_step is not None else train_step
)
return progress

def _build_agent_loop(self, sleep_by_id: dict[int, float] | None = None):
mock_agent_loop = MagicMock()
Expand Down Expand Up @@ -276,6 +281,32 @@ def is_valid_sample_fn(samples):
self.assertEqual(await self.replay_buffer.count(task_name, Status.FILTERED), 1)
self.assertEqual(await self.replay_buffer.count(task_name, Status.ABORTED), 1)

async def test_put_generated_group_records_raw_rewards_before_filtering(self):
task_name = "test_raw_reward_before_filter"

def is_valid_sample_fn(samples):
return False

strategy = SyncProduceStrategyConfig(is_valid_sample_fn=is_valid_sample_fn).build()
ctx = self._build_context(
strategy,
task_name,
self._build_agent_loop(),
self._build_sampler(),
batch_size=1,
)

completed_group = [
MockRolloutState(1, status=Status.COMPLETED, reward_score=0.25),
MockRolloutState(2, status=Status.COMPLETED, reward_score=0.75),
]
self.assertFalse(await ctx.put_generated_group(completed_group))

self.assertEqual([item.status for item in completed_group], [Status.FILTERED, Status.FILTERED])
self.assertEqual(ctx.progress.consume_raw_rewards(task_name), (1.0, 2))
self.assertEqual(ctx.progress.consume_raw_rewards(task_name), (0.0, 0))
self.assertEqual(await self.replay_buffer.count(task_name, Status.FILTERED), 1)

async def test_sync_produce_strategy(self):
task_name = "test_task"
mock_agent_loop = self._build_agent_loop({0: 0.0, 1: 0.01})
Expand Down Expand Up @@ -389,11 +420,12 @@ async def mock_gen(rs, **kwargs):
sampler = self._build_sampler()
# 该用例验证版本记录顺序,放宽 stale 策略避免在生产入口提前返回。
strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0, max_staleness=3).build()
progress = ProduceProgress(
next_consumer_step=1,
progress = self._build_progress(
task_name,
target=2,
train_step=1,
consumed=1,
producer_future_step=2,
consumed_samples={task_name: 1},
target_samples={task_name: 2},
target_upto_future_step=2,
)

Expand Down
9 changes: 0 additions & 9 deletions xtuner/v1/rl/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from xtuner.v1.rl.utils import create_task

from .agent_loop import AgentLoop, AgentLoopConfig
from .utils import PartialRolloutHandler


class SingleTurnAgentLoopConfig(AgentLoopConfig):
Expand Down Expand Up @@ -65,26 +64,18 @@ def __init__(
enable_batch_judge: bool = False,
):
super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger)
self.max_tokens = self.sample_params.max_tokens
self.partial_rollout_handler = PartialRolloutHandler(max_tokens=self.max_tokens)
self.enable_batch_judge = enable_batch_judge

async def generate_sample(
self,
rollout_state: RolloutState,
**kwargs,
) -> RolloutState:
enable_partial_rollout = kwargs.get("enable_partial_rollout", False)

# rollout state 预处理, enable_partial_rollout = True 会在这里拼接 token 和修正 max_token
rollout_state = self.partial_rollout_handler.preprocess(rollout_state, enable_partial_rollout)
if not rollout_state.tokens:
rollout_state.tokens = rollout_state.prompt_ids

# 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上
rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined]
# rollout state 后处理: 合并 partial rollout 的历史上下文
rollout_state = self.partial_rollout_handler.postprocess(rollout_state)
# 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分
if rollout_state.status != Status.COMPLETED:
return rollout_state
Expand Down
102 changes: 0 additions & 102 deletions xtuner/v1/rl/agent_loop/utils.py

This file was deleted.

Loading
Loading