From 543e367436a3a7be031efb01a7e5dbb55323e614 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 7 May 2026 17:10:29 +0800 Subject: [PATCH 1/4] move partial rollout logic to worker --- .../rl/agent_loop/single_turn_agent_loop.py | 9 -- xtuner/v1/rl/agent_loop/utils.py | 102 ----------------- .../agent_loop_manager/agent_loop_manager.py | 5 +- xtuner/v1/rl/agent_loop_manager/producer.py | 31 ++++- xtuner/v1/rl/rollout/controller.py | 6 + xtuner/v1/rl/rollout/utils.py | 108 +++++++++++++++++- xtuner/v1/rl/rollout/worker.py | 64 +++++++---- xtuner/v1/rl/trainer/controller.py | 2 +- 8 files changed, 189 insertions(+), 138 deletions(-) delete mode 100644 xtuner/v1/rl/agent_loop/utils.py diff --git a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py index 3edac1190..0e99bea20 100644 --- a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py +++ b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py @@ -6,7 +6,6 @@ from xtuner.v1.rl.utils import create_task from .agent_loop import AgentLoop, AgentLoopConfig -from .utils import PartialRolloutHandler class SingleTurnAgentLoopConfig(AgentLoopConfig): @@ -34,8 +33,6 @@ def __init__( enable_batch_judge: bool = False, ): super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger) - self.max_tokens = self.sample_params.max_tokens - self.partial_rollout_handler = PartialRolloutHandler(max_tokens=self.max_tokens) self.enable_batch_judge = enable_batch_judge async def generate_sample( @@ -43,17 +40,11 @@ async def generate_sample( rollout_state: RolloutState, **kwargs, ) -> RolloutState: - enable_partial_rollout = kwargs.get("enable_partial_rollout", False) - - # rollout state 预处理, enable_partial_rollout = True 会在这里拼接 token 和修正 max_token - rollout_state = self.partial_rollout_handler.preprocess(rollout_state, enable_partial_rollout) if not rollout_state.tokens: rollout_state.tokens = rollout_state.prompt_ids # 推理引擎generate, 生成的结果会覆盖到 rollout_state.response_ids 上 rollout_state = await self.rollout_ctl.generate.remote(rollout_state) # type: ignore[attr-defined] - # rollout state 后处理: 合并 partial rollout 的历史上下文 - rollout_state = self.partial_rollout_handler.postprocess(rollout_state) # 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分 if rollout_state.status != Status.COMPLETED: return rollout_state diff --git a/xtuner/v1/rl/agent_loop/utils.py b/xtuner/v1/rl/agent_loop/utils.py deleted file mode 100644 index 6da4c6504..000000000 --- a/xtuner/v1/rl/agent_loop/utils.py +++ /dev/null @@ -1,102 +0,0 @@ -import time - -import ray - -from xtuner.v1.data_proto.rl_data import RolloutState, Status -from xtuner.v1.rl.utils import clear_rollout_response_for_rerun -from xtuner.v1.utils import get_logger - - -def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef) -> list[int]: - if isinstance(routed_experts, ray.ObjectRef): - routed_experts = ray.get(routed_experts) - if hasattr(routed_experts, "tolist"): - routed_experts = routed_experts.tolist() - assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}" - return routed_experts - - -class PartialRolloutHandler: - """Handle preprocessing and postprocessing for partial rollout - continuation.""" - - def __init__(self, max_tokens: int) -> None: - self.logger = get_logger(self.__class__.__name__) - self.max_tokens = max_tokens - - def preprocess(self, rollout_state: RolloutState, enable_partial_rollout: bool = False) -> RolloutState: - if rollout_state.status == Status.EXPIRED or ( - not enable_partial_rollout and rollout_state.status == Status.ABORTED - ): - rollout_state = clear_rollout_response_for_rerun(rollout_state) - rollout_state.sample_params = rollout_state.sample_params.model_copy( - update={"max_tokens": self.max_tokens} - ) - rollout_state.response = "" - rollout_state.status = Status.INIT - - if not rollout_state.response_ids or rollout_state.status == Status.COMPLETED: - return rollout_state - - # Set up token and length variable - response_ids = rollout_state.response_ids - prompt_ids = list(rollout_state.prompt_ids or []) - response_len = len(response_ids) - prompt_len = len(prompt_ids) - - rollout_state.tokens = prompt_ids + response_ids # concatenate for partial rollout continuation - remaining_tokens = self.max_tokens - response_len # compute remaining max_tokens budget - rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": remaining_tokens}) - - self.logger.debug( - f"[PartialRolloutHandler] Sample {rollout_state.uid} continue rollout | Remaining tokens allowed: {remaining_tokens} | Status: {rollout_state.status} | Prompt len: {prompt_len} | Response len: {response_len} | Staleness: {rollout_state.seq_staleness} | Total tokens: {len(rollout_state.tokens)}" - ) - # TODO: handle routed_experts - rollout_state.extra_fields["history_response_dict"] = { - "response_ids": rollout_state.tokens[prompt_len:] if rollout_state.tokens else [], - "response": rollout_state.response or "", - "logprobs": rollout_state.logprobs or [], - "response_mask": rollout_state.response_mask or [], - "routed_experts": rollout_state.routed_experts, - } - return rollout_state - - def postprocess(self, rollout_state: RolloutState) -> RolloutState: - # TODO: if not enable partial rollout, return directly? - - # Concatenate history response fields - history_dict = rollout_state.extra_fields.pop("history_response_dict", None) - if not history_dict: - return rollout_state - - rollout_state.response_ids = history_dict.get("response_ids", []) + (rollout_state.response_ids or []) - rollout_state.response = history_dict.get("response", "") + (rollout_state.response or "") - rollout_state.logprobs = history_dict.get("logprobs", []) + (rollout_state.logprobs or []) - rollout_state.response_mask = history_dict.get("response_mask", []) + (rollout_state.response_mask or []) - history_routed_experts_ref = history_dict.get("routed_experts") - cur_routed_experts_ref = rollout_state.routed_experts - if history_routed_experts_ref is not None and cur_routed_experts_ref is not None: - start_time = time.time() - history_routed_experts = _resolve_routed_experts(history_routed_experts_ref) - cur_routed_experts = _resolve_routed_experts(cur_routed_experts_ref) - cur_routed_experts_len = len(cur_routed_experts) - history_routed_experts_len = len(history_routed_experts) - assert history_routed_experts_len - 1 <= cur_routed_experts_len, ( - f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}" - ) - cur_routed_experts = cur_routed_experts[history_routed_experts_len:] - concat_routed_experts = history_routed_experts + cur_routed_experts - rollout_state.routed_experts = ray.put(concat_routed_experts) - # free_object_refs( - # [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)] - # ) - end_time = time.time() - self.logger.info( - f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" - ) - elif history_routed_experts_ref is None and cur_routed_experts_ref is not None: - rollout_state.routed_experts = cur_routed_experts_ref - elif history_routed_experts_ref is not None and cur_routed_experts_ref is None: - rollout_state.routed_experts = history_routed_experts_ref - - return rollout_state diff --git a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py index 7df634796..3685caaa1 100644 --- a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py +++ b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py @@ -244,7 +244,10 @@ def build( judger=build_judger(task_cfg.judger_config) if task_cfg.judger_config is not None else None, logger=logger, ) - produce_strategy = task_cfg.produce_strategy_config.build(sync_weights_interval=sync_weights_interval) + produce_strategy = task_cfg.produce_strategy_config.build( + sync_weights_interval=sync_weights_interval, + rollout_controller=rollout_controller, + ) sampler = task_cfg.sampler_config.build(tokenizer=tokenizer, replay_buffer=replay_buffer) task_runners.append( _TaskRunner( diff --git a/xtuner/v1/rl/agent_loop_manager/producer.py b/xtuner/v1/rl/agent_loop_manager/producer.py index 71e0ead04..55745930c 100644 --- a/xtuner/v1/rl/agent_loop_manager/producer.py +++ b/xtuner/v1/rl/agent_loop_manager/producer.py @@ -4,7 +4,11 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum, auto -from typing import Any, Awaitable, Callable, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Protocol, runtime_checkable + + +if TYPE_CHECKING: + from xtuner.v1.rl.rollout.controller import RolloutControllerProxy import ray from pydantic import BaseModel, ConfigDict, Field @@ -267,11 +271,21 @@ class ProduceStrategyConfig(ABC, BaseModel): should_continue_fn: ShouldContinueFn = default_should_continue_fn @abstractmethod - def build(self, *, sync_weights_interval: int = 1) -> "ProduceStrategy": ... + def build( + self, + *, + sync_weights_interval: int = 1, + rollout_controller: "Optional[RolloutControllerProxy]" = None, + ) -> "ProduceStrategy": ... class SyncProduceStrategyConfig(ProduceStrategyConfig): - def build(self, *, sync_weights_interval: int = 1) -> "SyncProduceStrategy": + def build( + self, + *, + sync_weights_interval: int = 1, + rollout_controller: "Optional[RolloutControllerProxy]" = None, + ) -> "SyncProduceStrategy": return SyncProduceStrategy( is_valid_sample_fn=self.is_valid_sample_fn, should_continue_fn=self.should_continue_fn ) @@ -283,7 +297,16 @@ class AsyncProduceStrategyConfig(ProduceStrategyConfig): max_staleness: int = Field(default=0, ge=0) tail_batch_trigger_size: int = 0 - def build(self, *, sync_weights_interval: int = 1) -> "AsyncProduceStrategy": + def build( + self, + *, + sync_weights_interval: int = 1, + rollout_controller: "Optional[RolloutControllerProxy]" = None, + ) -> "AsyncProduceStrategy": + if rollout_controller is not None: + import ray + + ray.get(rollout_controller.set_enable_partial_rollout.remote(self.enable_partial_rollout)) return AsyncProduceStrategy( over_sample_threshold=self.over_sample_threshold, enable_partial_rollout=self.enable_partial_rollout, diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index 3bbc3e979..dc6da1b50 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -231,6 +231,12 @@ def _apply_output_parsers(self, rollout_state: RolloutState) -> None: else: rollout_state.extra_fields.pop("reasoning_text", None) + def set_enable_partial_rollout(self, enable: bool) -> None: + """Propagate enable_partial_rollout flag to all active workers.""" + with self.worker_info_lock: + active_actors = [info.actor for info in self.rank2info.values() if info.is_active] + ray.get([actor.set_enable_partial_rollout.remote(enable) for actor in active_actors]) # type: ignore[attr-defined] + def pause_generation(self): self.health_checker.pause() diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py index 2fd75da19..7c45f7a24 100644 --- a/xtuner/v1/rl/rollout/utils.py +++ b/xtuner/v1/rl/rollout/utils.py @@ -8,8 +8,10 @@ import httpx import ray +from ray import ObjectRef as RayObjectRef -from xtuner.v1.rl.utils import asyncio_run +from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.rl.utils import asyncio_run, clear_rollout_response_for_rerun from xtuner.v1.utils import get_logger @@ -277,3 +279,107 @@ async def check_worker_health( f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}" ) return False + + +def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef) -> list[int]: + if isinstance(routed_experts, ray.ObjectRef): + routed_experts = ray.get(routed_experts) + if hasattr(routed_experts, "tolist"): + routed_experts = routed_experts.tolist() + assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}" + return routed_experts + + +class PartialRolloutHandler: + """Handle preprocessing and postprocessing for partial rollout + continuation.""" + + def __init__(self) -> None: + self.logger = get_logger(self.__class__.__name__) + + def preprocess( + self, rollout_state: RolloutState, max_tokens: int, enable_partial_rollout: bool = False + ) -> RolloutState: + if rollout_state.status == Status.EXPIRED or ( + not enable_partial_rollout and rollout_state.status == Status.ABORTED + ): + rollout_state = clear_rollout_response_for_rerun(rollout_state) + rollout_state.sample_params = rollout_state.sample_params.model_copy(update={"max_tokens": max_tokens}) + rollout_state.response = "" + rollout_state.status = Status.INIT + + if not rollout_state.response_ids or rollout_state.status == Status.COMPLETED: + return rollout_state + + # Set up token and length variable + response_ids = rollout_state.response_ids + prompt_ids = list(rollout_state.prompt_ids or []) + response_len = len(response_ids) + prompt_len = len(prompt_ids) + + rollout_state.tokens = prompt_ids + response_ids # concatenate for partial rollout continuation + remaining_tokens = max_tokens - response_len # compute remaining max_tokens budget + rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": remaining_tokens}) + + self.logger.info( + f"[PartialRolloutHandler] Sample {rollout_state.uid} continue rollout | Remaining tokens allowed: {remaining_tokens} | Status: {rollout_state.status} | Prompt len: {prompt_len} | Response len: {response_len} | Staleness: {rollout_state.seq_staleness} | Total tokens: {len(rollout_state.tokens)}" + ) + return rollout_state + + def postprocess( + self, + rollout_state: RolloutState, + *, + response: str, + response_ids: list[int], + logprobs: list[float], + routed_experts: list[int] | RayObjectRef | None, + finish_reason: str, + status: Status, + enable_partial_rollout: bool = False, + ) -> RolloutState: + if not enable_partial_rollout: + rollout_state.response = response + rollout_state.response_ids = response_ids + rollout_state.logprobs = logprobs + rollout_state.routed_experts = routed_experts + rollout_state.status = status + rollout_state.finish_reason = finish_reason + return rollout_state + + else: + rollout_state.finish_reason = finish_reason + rollout_state.status = status + history_response = rollout_state.response or "" + history_response_ids = rollout_state.response_ids or [] + history_logprobs = rollout_state.logprobs or [] + rollout_state.response = history_response + response + rollout_state.response_ids = history_response_ids + response_ids + rollout_state.logprobs = history_logprobs + logprobs + + # 处理routed experts + history_routed_experts = rollout_state.routed_experts or None + if history_routed_experts is not None and routed_experts is not None: + start_time = time.time() + history_routed_experts = _resolve_routed_experts(history_routed_experts) + cur_routed_experts = _resolve_routed_experts(routed_experts) + cur_routed_experts_len = len(cur_routed_experts) + history_routed_experts_len = len(history_routed_experts) + assert history_routed_experts_len - 1 <= cur_routed_experts_len, ( + f"Existing routed_experts len: {history_routed_experts_len}, current routed_experts len: {cur_routed_experts_len}, history_response_ids len: {len(history_response_ids)}, current response_ids len: {len(response_ids)}" + ) + cur_routed_experts = cur_routed_experts[history_routed_experts_len:] + concat_routed_experts = history_routed_experts + cur_routed_experts + rollout_state.routed_experts = ray.put(concat_routed_experts) + # free_object_refs( + # [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)] + # ) + end_time = time.time() + self.logger.info( + f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" + ) + elif history_routed_experts is None and routed_experts is not None: + rollout_state.routed_experts = routed_experts + elif history_routed_experts is not None and routed_experts is None: + rollout_state.routed_experts = history_routed_experts + return rollout_state diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index dc101d71c..71e83d314 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -8,7 +8,7 @@ import traceback from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union, cast import httpx import ray @@ -30,6 +30,8 @@ from xtuner.v1.utils import get_logger from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult +from .utils import PartialRolloutHandler + if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -472,6 +474,11 @@ def __init__( self.abort_timeout = 5.0 self.dist_init_addr: str = "" self.serverl_url: str = "" + self.partial_rollout_handler = PartialRolloutHandler() + self.enable_partial_rollout: bool = False + + def set_enable_partial_rollout(self, enable: bool) -> None: + self.enable_partial_rollout = enable def init(self, dist_init_addr: str) -> tuple[int, str]: """Initialize the worker and launch the server. @@ -583,7 +590,8 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: uid = rollout_state.uid sample_params: SampleParams = rollout_state.sample_params - + max_tokens = sample_params.max_tokens + enable_partial_rollout = self.enable_partial_rollout if sample_params.return_token_ids: endpoint_url = f"{self.server_url}/{self.endpoints['generate']}" else: @@ -594,8 +602,9 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: "Authorization": f"Bearer {self.config.api_key}", } - max_retries = self.config.max_retry_per_sample + rollout_state = self.partial_rollout_handler.preprocess(rollout_state, max_tokens, enable_partial_rollout) payload = self._get_request_payload(rollout_state) + max_retries = self.config.max_retry_per_sample # 早退逻辑 1:检查是否已被标记为完成 if rollout_state.status == Status.COMPLETED: @@ -604,7 +613,7 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: # 早退逻辑 2:检测输入是否还需要 generation (安全获取变量) input_ids = payload.get("input_ids", []) - max_tokens = payload.get("max_tokens") + max_tokens = cast(int, payload.get("max_tokens")) last_id = input_ids[-1] if len(input_ids) > 0 else "None" is_max_tokens_zero = max_tokens is not None and max_tokens <= 0 @@ -830,6 +839,7 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response sample_params = rollout_state.sample_params is_token_out = sample_params.return_token_ids response = http_response.json() + if is_token_out: response_ids: list[int] = [] logprobs: list[float] = [] @@ -914,30 +924,39 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response rollout_state.error_msg = error_msg return rollout_state - rollout_state.response = returned_response - rollout_state.response_ids = response_ids - rollout_state.logprobs = logprobs - rollout_state.routed_experts = routed_experts - rollout_state.finish_reason = finish_reason - rollout_state.status = rollout_status + rollout_state = self.partial_rollout_handler.postprocess( + rollout_state, + response=returned_response, + response_ids=response_ids, + logprobs=logprobs, + routed_experts=routed_experts, + finish_reason=finish_reason, + status=rollout_status, + enable_partial_rollout=self.enable_partial_rollout, + ) return rollout_state except KeyError as e: - error_msg = f"Missing expected key {e} in response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}" raise RuntimeError(error_msg) except IndexError as e: - error_msg = f"Index error {e} while processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Index error {e} while processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except AssertionError as e: - error_msg = f"AssertionError: {e} when processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"AssertionError: {e} when processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except json.JSONDecodeError as e: error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" raise RuntimeError(error_msg) except TypeError as e: - error_msg = f"TypeError: {e} when processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"TypeError: {e} when processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except Exception as e: - error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Unexpected error: {e} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}" raise RuntimeError(error_msg) else: # v1/chat/completions API response @@ -956,22 +975,27 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response rollout_state.status = rollout_status return rollout_state except KeyError as e: - error_msg = f"Missing expected key {e} in response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Missing expected key {e} in response {response_for_log} for {uid}" raise RuntimeError(error_msg) except IndexError as e: - error_msg = f"Index error {e} while processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Index error {e} while processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except AssertionError as e: - error_msg = f"AssertionError: {e} when processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"AssertionError: {e} when processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except json.JSONDecodeError as e: error_msg = f"JSONDecodeError: {e} when processing response {response} for {uid}" raise RuntimeError(error_msg) except TypeError as e: - error_msg = f"TypeError: {e} when processing response {response} for {uid}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"TypeError: {e} when processing response {response_for_log} for {uid}" raise RuntimeError(error_msg) except Exception as e: - error_msg = f"Unexpected error: {e} when processing response {response} for {uid}\nTraceback: {traceback.format_exc()}" + response_for_log = {k: v for k, v in response.items() if k not in ("logprobs", "response_ids")} + error_msg = f"Unexpected error: {e} when processing response {response_for_log} for {uid}\nTraceback: {traceback.format_exc()}" raise RuntimeError(error_msg) def _adapt_input_to_openai_spec(self, prompts, tools, tool_choice): diff --git a/xtuner/v1/rl/trainer/controller.py b/xtuner/v1/rl/trainer/controller.py index 3ce836bef..0c13e8fdb 100644 --- a/xtuner/v1/rl/trainer/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -264,7 +264,7 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: if data["seq_ctx"].pixel_values is not None: free_pixel_value_refs.extend(data["seq_ctx"].pixel_values) # if len(free_pixel_value_refs) > 0: - # free_object_refs(free_pixel_value_refs) + # free_object_refs(free_pixel_value_refs) del packed_data_batches return log_infos From 57a3857b10b018aa80a28c9f83fabc64341744f3 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 12 May 2026 15:36:21 +0800 Subject: [PATCH 2/4] [Feat] support log raw_reward --- tests/rl/test_agent_loop_utils.py | 120 ------------------ .../rl/test_multi_task_agent_loop_manager.py | 26 ++++ tests/rl/test_producer.py | 29 ++++- .../agent_loop_manager/agent_loop_manager.py | 36 ++++-- xtuner/v1/rl/agent_loop_manager/producer.py | 36 ++++++ xtuner/v1/rl/rollout/lmdeploy.py | 3 +- xtuner/v1/train/rl_trainer.py | 26 ++-- 7 files changed, 127 insertions(+), 149 deletions(-) delete mode 100644 tests/rl/test_agent_loop_utils.py diff --git a/tests/rl/test_agent_loop_utils.py b/tests/rl/test_agent_loop_utils.py deleted file mode 100644 index 75ee89c9c..000000000 --- a/tests/rl/test_agent_loop_utils.py +++ /dev/null @@ -1,120 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, MagicMock, patch - -from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status, refresh_seq_staleness -from xtuner.v1.rl.agent_loop.single_turn_agent_loop import SingleTurnAgentLoop -from xtuner.v1.rl.agent_loop.utils import PartialRolloutHandler - - -def _make_rollout_state( - response_ids: list[int], - response_model_steps: list[int] | None = None, - seq_staleness: int = 0, - status: Status = Status.ABORTED, - extra_fields: dict | None = None, -): - return RolloutState( - uid=1, - message=[{"role": "user", "content": "hello"}], - prompt_ids=[101, 102], - response_ids=response_ids, - response="resp", - logprobs=[0.0] * len(response_ids), - response_mask=[1] * len(response_ids), - response_model_steps=response_model_steps, - seq_staleness=seq_staleness, - sample_params=SampleParams(max_tokens=8), - status=status, - extra_fields=extra_fields or {}, - ) - - -class TestAgentLoopUtils(unittest.TestCase): - def test_refresh_seq_staleness_recomputes_from_response_model_steps(self): - group = [_make_rollout_state(response_ids=[1, 2], response_model_steps=[3, 4], seq_staleness=0)] - - refresh_seq_staleness(group, current_train_step=8) - - self.assertEqual(group[0].seq_staleness, 4) - - def test_refresh_seq_staleness_resets_without_response_model_steps(self): - group = [_make_rollout_state(response_ids=[1, 2], response_model_steps=None, seq_staleness=5)] - - refresh_seq_staleness(group, current_train_step=8) - - self.assertEqual(group[0].seq_staleness, 0) - - def test_partial_rollout_postprocess_only_concatenates_history(self): - handler = PartialRolloutHandler(max_tokens=8) - rollout_state = _make_rollout_state( - response_ids=[30, 31], - response_model_steps=[2, 2], - seq_staleness=0, - extra_fields={ - "history_response_dict": { - "response_ids": [10, 11], - "response": "hi", - "logprobs": [0.1, 0.2], - "response_mask": [1, 1], - "routed_experts": None, - } - }, - ) - - result = handler.postprocess(rollout_state) - - self.assertEqual(result.response_ids, [10, 11, 30, 31]) - self.assertEqual(result.response_model_steps, [2, 2]) - self.assertEqual(result.seq_staleness, 0) - - -class TestSingleTurnAgentLoop(unittest.IsolatedAsyncioTestCase): - def _build_agent_loop(self): - rollout_ctl = MagicMock() - rollout_ctl.generate.remote = AsyncMock() - with ( - patch("xtuner.v1.rl.agent_loop.agent_loop.load_tokenizer", return_value=MagicMock()), - patch("xtuner.v1.rl.agent_loop.agent_loop.load_processor", return_value=MagicMock()), - ): - return SingleTurnAgentLoop( - rollout_ctl=rollout_ctl, - sample_params=SampleParams(max_tokens=8), - hf_checkpoint="dummy", - judger=None, - logger=MagicMock(), - ) - - async def test_generate_sample_does_not_update_staleness(self): - agent_loop = self._build_agent_loop() - rollout_state = _make_rollout_state(response_ids=[], status=Status.ABORTED) - generated_state = _make_rollout_state(response_ids=[30, 31], seq_staleness=7, status=Status.ABORTED) - agent_loop.rollout_ctl.generate.remote.return_value = generated_state - - result = await agent_loop.generate_sample( - rollout_state, - ) - - self.assertIsNone(result.response_model_steps) - self.assertEqual(result.seq_staleness, 7) - - async def test_generate_sample_does_not_update_sample_version(self): - agent_loop = self._build_agent_loop() - rollout_state = _make_rollout_state(response_ids=[], status=Status.ABORTED) - generated_state = _make_rollout_state(response_ids=[30, 31], status=Status.ABORTED) - agent_loop.rollout_ctl.generate.remote.return_value = generated_state - - result = await agent_loop.generate_sample(rollout_state) - - self.assertIsNone(result.response_model_steps) - self.assertEqual(result.seq_staleness, 0) - - async def test_generate_sample_does_not_require_model_step(self): - agent_loop = self._build_agent_loop() - rollout_state = _make_rollout_state(response_ids=[], status=Status.ABORTED) - generated_state = _make_rollout_state(response_ids=[30, 31], status=Status.ABORTED) - agent_loop.rollout_ctl.generate.remote.return_value = generated_state - - result = await agent_loop.generate_sample(rollout_state) - - self.assertIsNone(result.response_model_steps) - self.assertEqual(result.seq_staleness, 0) diff --git a/tests/rl/test_multi_task_agent_loop_manager.py b/tests/rl/test_multi_task_agent_loop_manager.py index a87a1ccfa..7db631f66 100644 --- a/tests/rl/test_multi_task_agent_loop_manager.py +++ b/tests/rl/test_multi_task_agent_loop_manager.py @@ -632,6 +632,32 @@ async def test_get_batch_refreshes_staleness_at_entry(self): self.assertEqual(manager._produce_progress.next_consumer_step, 10) self.assertEqual(manager._produce_progress.consumed_samples["task_a"], 1) + async def test_get_batch_returns_raw_reward_stats_from_progress(self): + replay_buffer = _FakeReplayBuffer( + rollout_states_by_task={"task_a": [[_FakeRolloutState("a-0", 0.2)]]}, + leftover_counts={("task_a", Status.COMPLETED): 1}, + ) + manager = AgentLoopManager( + task_runners=[ + _TaskRunner( + task_name="task_a", + agent_loop=_fake_agent_loop(), + produce_strategy=_FakeProduceStrategy(), + sampler=_FakeSampler(), + weight=1.0, + order=0, + ), + ], + replay_buffer=replay_buffer, + ) + manager._produce_progress.add_raw_rewards("task_a", 1.25, 2) + + result = await manager.get_batch(batch_size=1, train_step=9) + + self.assertEqual(result.raw_rewards_sum, 1.25) + self.assertEqual(result.raw_rewards_count, 2) + self.assertEqual(manager._produce_progress.consume_raw_rewards("task_a"), (0.0, 0)) + async def test_get_batch_waits_until_requested_batch_size_is_ready(self): replay_buffer = _SequencedCompletedReplayBuffer( completed_counts=[0, 1, 2], diff --git a/tests/rl/test_producer.py b/tests/rl/test_producer.py index 3fa53dfca..be06e1844 100644 --- a/tests/rl/test_producer.py +++ b/tests/rl/test_producer.py @@ -16,13 +16,14 @@ class MockRolloutState: - def __init__(self, id, seq_staleness=1, status=Status.COMPLETED): + def __init__(self, id, seq_staleness=1, status=Status.COMPLETED, reward_score=None): self.id = id self.uid = id self.status = status self.seq_staleness = seq_staleness self.response_ids = [] self.extra_fields = {} + self.reward = {"score": reward_score} if reward_score is not None else None class TestProducer(unittest.IsolatedAsyncioTestCase): @@ -276,6 +277,32 @@ def is_valid_sample_fn(samples): self.assertEqual(await self.replay_buffer.count(task_name, Status.FILTERED), 1) self.assertEqual(await self.replay_buffer.count(task_name, Status.ABORTED), 1) + async def test_put_generated_group_records_raw_rewards_before_filtering(self): + task_name = "test_raw_reward_before_filter" + + def is_valid_sample_fn(samples): + return False + + strategy = SyncProduceStrategyConfig(is_valid_sample_fn=is_valid_sample_fn).build() + ctx = self._build_context( + strategy, + task_name, + self._build_agent_loop(), + self._build_sampler(), + batch_size=1, + ) + + completed_group = [ + MockRolloutState(1, status=Status.COMPLETED, reward_score=0.25), + MockRolloutState(2, status=Status.COMPLETED, reward_score=0.75), + ] + self.assertFalse(await ctx.put_generated_group(completed_group)) + + self.assertEqual([item.status for item in completed_group], [Status.FILTERED, Status.FILTERED]) + self.assertEqual(ctx.progress.consume_raw_rewards(task_name), (1.0, 2)) + self.assertEqual(ctx.progress.consume_raw_rewards(task_name), (0.0, 0)) + self.assertEqual(await self.replay_buffer.count(task_name, Status.FILTERED), 1) + async def test_sync_produce_strategy(self): task_name = "test_task" mock_agent_loop = self._build_agent_loop({0: 0.0, 1: 0.01}) diff --git a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py index 3685caaa1..776695b48 100644 --- a/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py +++ b/xtuner/v1/rl/agent_loop_manager/agent_loop_manager.py @@ -48,6 +48,8 @@ class ProduceBatchResult: leftover_expired (int): Number of expired groups remaining in the replay buffer. leftover_failed (int): Number of failed groups remaining in the replay buffer. leftover_filtered (int): Number of filtered groups remaining in the replay buffer. + raw_rewards_sum (float): Sum of rewards produced before replay-buffer insertion for the current window. + raw_rewards_count (int): Number of reward-bearing samples included in ``raw_rewards_sum``. """ rollout_states: list[list[RolloutState]] @@ -66,9 +68,9 @@ class ProduceBatchResult: leftover_expired: int = 0 leftover_failed: int = 0 leftover_filtered: int = 0 - # filtered rewards are not included in rollout_states, but we want to record their scores for diagnostics. - filtered_rewards_sum: float = 0.0 - filtered_rewards_count: int = 0 + # rewards produced during the current produce window, including completed and filtered groups. + raw_rewards_sum: float = 0.0 + raw_rewards_count: int = 0 task_batch_sizes: dict[str, int] | None = None task_results: dict[str, "ProduceBatchResult"] | None = None @@ -423,8 +425,8 @@ def _aggregate_task_results( weighted_group_p99_sum = 0.0 weighted_group_ratio_sum = 0.0 total_pause_time_s = 0.0 - filtered_rewards_sum = 0.0 - filtered_rewards_count = 0 + raw_rewards_sum = 0.0 + raw_rewards_count = 0 for task in ordered_tasks: result = task_results[task.task_name] @@ -435,8 +437,8 @@ def _aggregate_task_results( leftover_expired += result.leftover_expired leftover_failed += result.leftover_failed leftover_filtered += result.leftover_filtered - filtered_rewards_sum += result.filtered_rewards_sum - filtered_rewards_count += result.filtered_rewards_count + raw_rewards_sum += result.raw_rewards_sum + raw_rewards_count += result.raw_rewards_count if result.group_gen_count is not None and result.group_gen_mean_s is not None: total_group_count += result.group_gen_count weighted_group_mean_sum += result.group_gen_count * result.group_gen_mean_s @@ -453,8 +455,8 @@ def _aggregate_task_results( leftover_expired=leftover_expired, leftover_failed=leftover_failed, leftover_filtered=leftover_filtered, - filtered_rewards_sum=filtered_rewards_sum, - filtered_rewards_count=filtered_rewards_count, + raw_rewards_sum=raw_rewards_sum, + raw_rewards_count=raw_rewards_count, task_results={task.task_name: task_results[task.task_name] for task in ordered_tasks}, ) if total_group_count > 0: @@ -580,18 +582,29 @@ def _build_result_from_batch( batch_by_task: dict[str, list[list[RolloutState]]], leftover_counts: dict[str, dict[Status, int]], *, + progress: ProduceProgress, pause_time_s: float, ) -> ProduceBatchResult: if len(self.task_runners) == 1: task = self.task_runners[0] - result = ProduceBatchResult(rollout_states=batch_by_task.get(task.task_name, [])) + raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name) + result = ProduceBatchResult( + rollout_states=batch_by_task.get(task.task_name, []), + raw_rewards_sum=raw_rewards_sum, + raw_rewards_count=raw_rewards_count, + ) _fill_leftover_counts(result, leftover_counts.get(task.task_name, {})) _fill_group_timing_stats(result, result.rollout_states, pause_time_s=pause_time_s) return result task_results: dict[str, ProduceBatchResult] = {} for task in self.task_runners: - result = ProduceBatchResult(rollout_states=batch_by_task.get(task.task_name, [])) + raw_rewards_sum, raw_rewards_count = progress.consume_raw_rewards(task.task_name) + result = ProduceBatchResult( + rollout_states=batch_by_task.get(task.task_name, []), + raw_rewards_sum=raw_rewards_sum, + raw_rewards_count=raw_rewards_count, + ) _fill_leftover_counts(result, leftover_counts.get(task.task_name, {})) task_results[task.task_name] = result @@ -620,6 +633,7 @@ async def _get_batch_from_buffer( task_batch_sizes, batch_by_task, leftover_counts, + progress=consume_progress, pause_time_s=pause_time_s, ) diff --git a/xtuner/v1/rl/agent_loop_manager/producer.py b/xtuner/v1/rl/agent_loop_manager/producer.py index 55745930c..2f2ea9c5f 100644 --- a/xtuner/v1/rl/agent_loop_manager/producer.py +++ b/xtuner/v1/rl/agent_loop_manager/producer.py @@ -55,6 +55,8 @@ class ProduceProgress: - consumed_samples:各 task 已被 consumer 从 replay buffer 取走的 group 绝对累计数。 - target_samples:各 task 截至 target_upto_future_step 应生产出的 group 绝对累计目标。 - target_upto_future_step:target_samples 已覆盖到的最大 future step。 + - raw_rewards_sum / raw_rewards_count:各 task 自上次 consumer 取 batch 后,producer 实际生成出的 + completed group reward 统计。filtered group 在过滤前仍按 completed 生成结果计入。 """ next_consumer_step: int = 1 @@ -62,12 +64,16 @@ class ProduceProgress: consumed_samples: dict[str, int] = field(default_factory=dict) target_samples: dict[str, int] = field(default_factory=dict) target_upto_future_step: int = 0 + raw_rewards_sum: dict[str, float] = field(default_factory=dict) + raw_rewards_count: dict[str, int] = field(default_factory=dict) @classmethod def build(cls, task_names: list[str]) -> "ProduceProgress": return cls( consumed_samples={task_name: 0 for task_name in task_names}, target_samples={task_name: 0 for task_name in task_names}, + raw_rewards_sum={task_name: 0.0 for task_name in task_names}, + raw_rewards_count={task_name: 0 for task_name in task_names}, ) @classmethod @@ -84,6 +90,8 @@ def build_local( consumed_samples={task_name: 0 for task_name in task_names}, target_samples=dict(task_batch_sizes), target_upto_future_step=train_step, + raw_rewards_sum={task_name: 0.0 for task_name in task_names}, + raw_rewards_count={task_name: 0 for task_name in task_names}, ) def ensure_target_upto( @@ -112,6 +120,17 @@ def mark_consumed(self, consumed_counts: dict[str, int]) -> None: for task_name, count in consumed_counts.items(): self.consumed_samples[task_name] += count + def add_raw_rewards(self, task_name: str, rewards_sum: float, rewards_count: int) -> None: + self.raw_rewards_sum[task_name] += rewards_sum + self.raw_rewards_count[task_name] += rewards_count + + def consume_raw_rewards(self, task_name: str) -> tuple[float, int]: + rewards_sum = self.raw_rewards_sum[task_name] + rewards_count = self.raw_rewards_count[task_name] + self.raw_rewards_sum[task_name] = 0.0 + self.raw_rewards_count[task_name] = 0 + return rewards_sum, rewards_count + def finish_consume(self, train_step: int) -> None: self.next_consumer_step = train_step + 1 @@ -125,6 +144,8 @@ def state_dict(self) -> dict[str, Any]: "consumed_samples": dict(self.consumed_samples), "target_samples": dict(self.target_samples), "target_upto_future_step": self.target_upto_future_step, + "raw_rewards_sum": dict(self.raw_rewards_sum), + "raw_rewards_count": dict(self.raw_rewards_count), } def load_state_dict(self, state: dict[str, Any]) -> None: @@ -136,6 +157,15 @@ def load_state_dict(self, state: dict[str, Any]) -> None: self.consumed_samples.update(state["consumed_samples"]) self.target_samples.clear() self.target_samples.update(state["target_samples"]) + task_names = set(self.consumed_samples) | set(self.target_samples) + self.raw_rewards_sum.clear() + self.raw_rewards_sum.update( + {task_name: float(state.get("raw_rewards_sum", {}).get(task_name, 0.0)) for task_name in task_names} + ) + self.raw_rewards_count.clear() + self.raw_rewards_count.update( + {task_name: int(state.get("raw_rewards_count", {}).get(task_name, 0)) for task_name in task_names} + ) class ProduceBatchStatus(Enum): @@ -249,6 +279,12 @@ async def put_generated_group(self, group: list[RolloutState]) -> bool: # 只有完整生成的 group 才需要业务有效性过滤;ABORTED / EXPIRED 保留原状态供重试或统计。 is_completed = get_group_status(group) == Status.COMPLETED if is_completed: + rewards_sum = 0.0 + rewards_count = 0 + for item in group: + rewards_sum += float(item.reward["score"]) # type: ignore[index] + rewards_count += 1 + self.progress.add_raw_rewards(self.task_name, rewards_sum, rewards_count) is_valid = self.is_valid_sample_fn(group) if not is_valid: for item in group: diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index 3f2d6f3fe..21584d7b6 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -197,8 +197,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: Returns: Namespace: A namespace object containing the server configuration. """ - from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig - from lmdeploy.messages import SpeculativeConfig + from lmdeploy.messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig accelerator_to_device_type = { "GPU": "cuda", diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 52bc02d1b..abb89c539 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -510,8 +510,8 @@ def _train_one_batch( *, offload_rollout_before_train: bool = False, onload_train_before_train: bool = False, - filtered_rewards_sum: float = 0.0, - filtered_rewards_count: int = 0, + raw_rewards_sum: float = 0.0, + raw_rewards_count: int = 0, ) -> TrainInfo: train_sample_count = sum(len(group) for group in train_batch) self.logger.info(f"generate {train_sample_count} samples for training") @@ -534,8 +534,8 @@ def _train_one_batch( data_batches, data_info = self._prepare_train_data( train_batch, self._train_worker_cfg.pack_max_length, - filtered_rewards_sum=filtered_rewards_sum, - filtered_rewards_count=filtered_rewards_count, + raw_rewards_sum=raw_rewards_sum, + raw_rewards_count=raw_rewards_count, ) self.logger.info(f"Prepared {len(data_batches)} training data batches") @@ -574,8 +574,8 @@ def _prepare_train_data( self, data_groups: list[list[RolloutState]], pack_max_length: int, - filtered_rewards_sum: float = 0.0, - filtered_rewards_count: int = 0, + raw_rewards_sum: float = 0.0, + raw_rewards_count: int = 0, ): rewards_list = [] advantages_list = [] @@ -697,11 +697,7 @@ def _prepare_train_data( prompt_len_t = torch.tensor(prompt_len_list).float() if prompt_len_list else torch.tensor([0.0]).float() response_len_t = torch.tensor(response_len_list).float() if response_len_list else torch.tensor([0.0]).float() - for rewards in rewards_list: - filtered_rewards_sum += rewards - filtered_rewards_count += 1 - - raw_rewards_mean = filtered_rewards_sum / filtered_rewards_count if filtered_rewards_count > 0 else 0.0 + raw_rewards_mean = raw_rewards_sum / raw_rewards_count if raw_rewards_count > 0 else rewards_t.mean().item() info_dict = { "batch_size": len(rewards_list), "rewards/mean": rewards_t.mean().item(), @@ -951,8 +947,8 @@ def fit(self): step_timer_dict, offload_rollout_before_train=True, onload_train_before_train=True, - filtered_rewards_sum=produce_result.filtered_rewards_sum, - filtered_rewards_count=produce_result.filtered_rewards_count, + raw_rewards_sum=produce_result.raw_rewards_sum, + raw_rewards_count=produce_result.raw_rewards_count, ) weights_synced = self._sync_weights_and_save(train_step, step_timer_dict) @@ -1116,8 +1112,8 @@ async def _fit(self): train_batch, train_step, step_timer_dict, - filtered_rewards_sum=produce_result.filtered_rewards_sum, - filtered_rewards_count=produce_result.filtered_rewards_count, + raw_rewards_sum=produce_result.raw_rewards_sum, + raw_rewards_count=produce_result.raw_rewards_count, ) else: self.logger.info( From 3de263a6b0e3dd3ca6f6316c778ee2bb079890d8 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 12 May 2026 19:27:43 +0800 Subject: [PATCH 3/4] fix ut --- tests/rl/test_producer.py | 25 ++++++++++++--------- xtuner/v1/rl/agent_loop_manager/producer.py | 5 +++++ xtuner/v1/rl/rollout/utils.py | 4 ++-- xtuner/v1/rl/rollout/worker.py | 21 +++++++++++------ 4 files changed, 36 insertions(+), 19 deletions(-) diff --git a/tests/rl/test_producer.py b/tests/rl/test_producer.py index be06e1844..b6ca89ae2 100644 --- a/tests/rl/test_producer.py +++ b/tests/rl/test_producer.py @@ -52,14 +52,18 @@ def _build_progress( target: int, train_step: int = 0, consumed: int = 0, + producer_future_step: int | None = None, + target_upto_future_step: int | None = None, ) -> ProduceProgress: - return ProduceProgress( - next_consumer_step=train_step, - producer_future_step=train_step, - consumed_samples={task_name: consumed}, - target_samples={task_name: target}, - target_upto_future_step=train_step, + progress = ProduceProgress.build([task_name]) + progress.next_consumer_step = train_step + progress.producer_future_step = producer_future_step if producer_future_step is not None else train_step + progress.consumed_samples[task_name] = consumed + progress.target_samples[task_name] = target + progress.target_upto_future_step = ( + target_upto_future_step if target_upto_future_step is not None else train_step ) + return progress def _build_agent_loop(self, sleep_by_id: dict[int, float] | None = None): mock_agent_loop = MagicMock() @@ -416,11 +420,12 @@ async def mock_gen(rs, **kwargs): sampler = self._build_sampler() # 该用例验证版本记录顺序,放宽 stale 策略避免在生产入口提前返回。 strategy = AsyncProduceStrategyConfig(over_sample_threshold=0.0, max_staleness=3).build() - progress = ProduceProgress( - next_consumer_step=1, + progress = self._build_progress( + task_name, + target=2, + train_step=1, + consumed=1, producer_future_step=2, - consumed_samples={task_name: 1}, - target_samples={task_name: 2}, target_upto_future_step=2, ) diff --git a/xtuner/v1/rl/agent_loop_manager/producer.py b/xtuner/v1/rl/agent_loop_manager/producer.py index 2f2ea9c5f..88500ae2c 100644 --- a/xtuner/v1/rl/agent_loop_manager/producer.py +++ b/xtuner/v1/rl/agent_loop_manager/producer.py @@ -282,6 +282,11 @@ async def put_generated_group(self, group: list[RolloutState]) -> bool: rewards_sum = 0.0 rewards_count = 0 for item in group: + if item.reward is None or "score" not in item.reward: + logger.warning( + f"Missing reward score in item (uid: {item.uid}) of completed group for task {self.task_name}. This item will be skipped in reward statistics." + ) + continue rewards_sum += float(item.reward["score"]) # type: ignore[index] rewards_count += 1 self.progress.add_raw_rewards(self.task_name, rewards_sum, rewards_count) diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py index 7c45f7a24..b8b74ca6b 100644 --- a/xtuner/v1/rl/rollout/utils.py +++ b/xtuner/v1/rl/rollout/utils.py @@ -321,7 +321,7 @@ def preprocess( remaining_tokens = max_tokens - response_len # compute remaining max_tokens budget rollout_state.sample_params = rollout_state.sample_params.copy(update={"max_tokens": remaining_tokens}) - self.logger.info( + self.logger.debug( f"[PartialRolloutHandler] Sample {rollout_state.uid} continue rollout | Remaining tokens allowed: {remaining_tokens} | Status: {rollout_state.status} | Prompt len: {prompt_len} | Response len: {response_len} | Staleness: {rollout_state.seq_staleness} | Total tokens: {len(rollout_state.tokens)}" ) return rollout_state @@ -375,7 +375,7 @@ def postprocess( # [ref for ref in (history_routed_experts_ref, cur_routed_experts_ref) if isinstance(ref, ray.ObjectRef)] # ) end_time = time.time() - self.logger.info( + self.logger.debug( f"[PartialRolloutHandler] Postprocess routed_experts concatenation time: {end_time - start_time:.4f} seconds" ) elif history_routed_experts is None and routed_experts is not None: diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index 71e83d314..bf5a23c95 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -623,13 +623,20 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: self.logger.debug( f"No generation needed for request {uid}: max_tokens={max_tokens} or last input_id={last_id} is in eos_token." ) - rollout_state.status = Status.COMPLETED - rollout_state.response_ids = [] - rollout_state.response = "" - rollout_state.logprobs = [] - rollout_state.response_mask = [] - rollout_state.response_model_steps = [] - rollout_state.finish_reason = "stop" if is_eos_reached else "length" + finish_reason = "stop" if is_eos_reached else "length" + rollout_state = self.partial_rollout_handler.postprocess( + rollout_state, + response="", + response_ids=[], + logprobs=[], + routed_experts=None, + finish_reason=finish_reason, + status=Status.COMPLETED, + enable_partial_rollout=enable_partial_rollout, + ) + if not enable_partial_rollout: + rollout_state.response_mask = [] + rollout_state.response_model_steps = [] return rollout_state for attempt in range(max_retries + 1): From deaf090713c3f0be5dbca7931fa0eb96f680722a Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 12 May 2026 19:52:57 +0800 Subject: [PATCH 4/4] replace ray.get with await in partial_rollout_handler --- xtuner/v1/rl/agent_loop_manager/producer.py | 1 + xtuner/v1/rl/rollout/utils.py | 16 ++++++++++------ xtuner/v1/rl/rollout/worker.py | 4 ++-- xtuner/v1/train/rl_trainer.py | 2 +- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/xtuner/v1/rl/agent_loop_manager/producer.py b/xtuner/v1/rl/agent_loop_manager/producer.py index 2ae3b6d02..081ce1bcb 100644 --- a/xtuner/v1/rl/agent_loop_manager/producer.py +++ b/xtuner/v1/rl/agent_loop_manager/producer.py @@ -355,6 +355,7 @@ class SyncProduceStrategyConfig(ProduceStrategyConfig): config = SyncProduceStrategyConfig() """ + def build( self, *, diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py index b8b74ca6b..87b6467d7 100644 --- a/xtuner/v1/rl/rollout/utils.py +++ b/xtuner/v1/rl/rollout/utils.py @@ -281,9 +281,9 @@ async def check_worker_health( return False -def _resolve_routed_experts(routed_experts: list[int] | ray.ObjectRef) -> list[int]: - if isinstance(routed_experts, ray.ObjectRef): - routed_experts = ray.get(routed_experts) +async def _resolve_routed_experts(routed_experts: list[int] | RayObjectRef) -> list[int]: + if isinstance(routed_experts, RayObjectRef): + routed_experts = await routed_experts if hasattr(routed_experts, "tolist"): routed_experts = routed_experts.tolist() assert isinstance(routed_experts, list), f"Unexpected routed_experts type: {type(routed_experts)}" @@ -326,7 +326,7 @@ def preprocess( ) return rollout_state - def postprocess( + async def postprocess( self, rollout_state: RolloutState, *, @@ -361,8 +361,12 @@ def postprocess( history_routed_experts = rollout_state.routed_experts or None if history_routed_experts is not None and routed_experts is not None: start_time = time.time() - history_routed_experts = _resolve_routed_experts(history_routed_experts) - cur_routed_experts = _resolve_routed_experts(routed_experts) + history_routed_experts, cur_routed_experts = await asyncio.gather( + _resolve_routed_experts(history_routed_experts), + _resolve_routed_experts(routed_experts), + ) + assert history_routed_experts, "History routed_experts should not be empty after resolution" + assert cur_routed_experts, "Current routed_experts should not be empty after resolution" cur_routed_experts_len = len(cur_routed_experts) history_routed_experts_len = len(history_routed_experts) assert history_routed_experts_len - 1 <= cur_routed_experts_len, ( diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index bf5a23c95..ce6b3b3ce 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -624,7 +624,7 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState: f"No generation needed for request {uid}: max_tokens={max_tokens} or last input_id={last_id} is in eos_token." ) finish_reason = "stop" if is_eos_reached else "length" - rollout_state = self.partial_rollout_handler.postprocess( + rollout_state = await self.partial_rollout_handler.postprocess( rollout_state, response="", response_ids=[], @@ -931,7 +931,7 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response rollout_state.error_msg = error_msg return rollout_state - rollout_state = self.partial_rollout_handler.postprocess( + rollout_state = await self.partial_rollout_handler.postprocess( rollout_state, response=returned_response, response_ids=response_ids, diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index d3d78d5c8..a4f48e53f 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -227,7 +227,7 @@ class BaseRLTrainerConfig(BaseModel): hf_max_keep: int | None = -1 checkpoint_no_save_optimizer: bool = False log_dir: Path | str | None = None - seed: int = 66 + seed: int = 42 debug_rollout: bool = False debug_rollout_dir: Path | str | None = None debug_train: bool = False