Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/rl/test_rl_disaggregated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
- checkpoint 保存发生在 fit 完成的 model_step 上,且 manager.save 为 async 调用。
- eval 在 producer 恢复前运行;update_weights 本身不直接 pause/continue rollout controller。
- sync/checkpoint/eval interval 必须是 sync_weights_interval 的整数倍,资源布局必须 fail fast。
- 前台训练 batch 阻塞时,后台 producer 仍能在事件循环中继续推进。
"""

import asyncio
import json
import tempfile
import threading
import unittest
from pathlib import Path
from types import SimpleNamespace
Expand Down Expand Up @@ -65,6 +67,22 @@ def shutdown(self):
self._finish_event.set()


class _TickingManager(_FakeManager):
def __init__(self, get_batch_results, training_started: threading.Event, producer_ticked: threading.Event):
super().__init__(get_batch_results)
self._training_started = training_started
self._producer_ticked = producer_ticked

async def produce_loop(self, batch_size: int):
self.calls.append(("produce_loop_start", batch_size))
while not self._finish_event.is_set():
if self._training_started.is_set():
self.calls.append("produce_loop_tick_during_training")
self._producer_ticked.set()
await asyncio.sleep(0)
self.calls.append("produce_loop_exit")


class TestRLDisaggregatedTrainer(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -219,6 +237,33 @@ def test_fit_trains_non_empty_expired_batch_then_syncs_current_step(self):
self.assertIn(("continue_produce", 1), manager.calls)
self.assertEqual(trainer._cur_step, 1)

def test_fit_keeps_background_producer_running_while_training_blocks(self):
# 验证非共卡训练阻塞在同步训练 batch 时,后台 producer 仍能继续调度。
train_sample = SimpleNamespace(message_uid=1, uid=1)
training_started = threading.Event()
producer_ticked = threading.Event()
manager = _TickingManager(
[ProduceBatchResult(rollout_states=[[train_sample]], status=ProduceBatchStatus.NORMAL)],
training_started,
producer_ticked,
)
trainer = self._make_trainer(manager)
trainer._sync_weights_and_save = AsyncMock()

def blocking_train_one_batch(*args, **kwargs):
training_started.set()
if not producer_ticked.wait(timeout=1.0):
raise AssertionError("background producer did not run while training was blocked")
return self._minimal_train_info(training_samples=1, training_tokens=4)

trainer._train_one_batch = MagicMock(side_effect=blocking_train_one_batch)

self._run_fit(trainer)

trainer._train_one_batch.assert_called_once()
self.assertIn("produce_loop_tick_during_training", manager.calls)
self.assertEqual(trainer._cur_step, 1)

def test_fit_runs_eval_before_reset_and_stops_producer(self):
# 验证 eval 在 producer 恢复前执行,避免生产侧提前抢占 rollout 资源。
# 确定性排序依赖 RolloutState 的 message_uid 和 uid,测试用轻量对象模拟即可。
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/train/rl_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import math
import os
Expand Down Expand Up @@ -1623,7 +1624,10 @@ async def _fit(self):
"RLDisaggregatedTrainer expects get_batch() to return non-empty rollout_states "
"unless status is empty EXPIRED_BATCH."
)
train_log_info = self._train_one_batch(
# 非共卡训练要求后台 producer 在训练当前 batch 时继续推进;
# 同步训练路径放到线程里执行,避免 ray.get / 文件写入阻塞事件循环。
train_log_info = await asyncio.to_thread(
self._train_one_batch,
train_batch,
train_step,
step_timer_dict,
Expand Down
Loading