diff --git a/examples/v1/config/rl_dapo_math.py b/examples/v1/config/rl_dapo_math.py index 5a61e6f2cb..e81666bb80 100644 --- a/examples/v1/config/rl_dapo_math.py +++ b/examples/v1/config/rl_dapo_math.py @@ -143,6 +143,7 @@ tasks=TaskSpecConfig( task_name="train_task", agent_loop_config=agent_loop_config, + judger_config=judger_config, produce_strategy_config=produce_strategy_config, sampler_config=sampler_config, ), @@ -178,6 +179,7 @@ tasks=TaskSpecConfig( task_name="eval_task", agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, sampler_config=eval_sampler_config, ), ) @@ -191,7 +193,6 @@ def dapo_compute_metric(samples): resources=resources, train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config rollout_config=rollout_config, - judger_config=judger_config, tokenizer_path=model_path, replay_buffer_config=SyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, diff --git a/examples/v1/config/rl_dapo_math_async.py b/examples/v1/config/rl_dapo_math_async.py index 98fd3fccdf..5cc7e07ee4 100644 --- a/examples/v1/config/rl_dapo_math_async.py +++ b/examples/v1/config/rl_dapo_math_async.py @@ -146,6 +146,7 @@ tasks=TaskSpecConfig( task_name="train_task", agent_loop_config=agent_loop_config, + judger_config=judger_config, produce_strategy_config=produce_strategy_config, sampler_config=sampler_config, ), @@ -181,6 +182,7 @@ tasks=TaskSpecConfig( task_name="eval_task", agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, sampler_config=eval_sampler_config, ), ) @@ -194,7 +196,6 @@ def dapo_compute_metric(samples): resources=resources, train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config rollout_config=rollout_config, - judger_config=judger_config, tokenizer_path=model_path, replay_buffer_config=AsyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, diff --git a/examples/v1/config/rl_dapo_math_async_filter.py b/examples/v1/config/rl_dapo_math_async_filter.py index c572668a1b..dcd6c2d553 100644 --- a/examples/v1/config/rl_dapo_math_async_filter.py +++ b/examples/v1/config/rl_dapo_math_async_filter.py @@ -161,6 +161,7 @@ def group_samples_filter_func(rollout_states): tasks=TaskSpecConfig( task_name="train_task", agent_loop_config=agent_loop_config, + judger_config=judger_config, produce_strategy_config=produce_strategy_config, sampler_config=sampler_config, ), @@ -196,6 +197,7 @@ def group_samples_filter_func(rollout_states): tasks=TaskSpecConfig( task_name="eval_task", agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, sampler_config=eval_sampler_config, ), ) @@ -209,7 +211,6 @@ def dapo_compute_metric(samples): resources=resources, train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config rollout_config=rollout_config, - judger_config=judger_config, tokenizer_path=model_path, replay_buffer_config=AsyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, diff --git a/examples/v1/config/rl_grpo_geo3k_judge.py b/examples/v1/config/rl_grpo_geo3k_judge.py index bb04598aff..caf228322a 100644 --- a/examples/v1/config/rl_grpo_geo3k_judge.py +++ b/examples/v1/config/rl_grpo_geo3k_judge.py @@ -147,6 +147,7 @@ tasks=TaskSpecConfig( task_name="train_task", agent_loop_config=agent_loop_config, + judger_config=judger_config, produce_strategy_config=produce_strategy_config, sampler_config=sampler_config, ), @@ -192,6 +193,7 @@ tasks=TaskSpecConfig( task_name="eval_task", agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, sampler_config=eval_sampler_config, ), ) @@ -204,7 +206,6 @@ resources=resources, train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config rollout_config=rollout_config, - judger_config=judger_config, tokenizer_path=model_path, replay_buffer_config=SyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, diff --git a/examples/v1/config/rl_grpo_gsm8k_async.py b/examples/v1/config/rl_grpo_gsm8k_async.py index 9008ea530b..2efcbd94c0 100644 --- a/examples/v1/config/rl_grpo_gsm8k_async.py +++ b/examples/v1/config/rl_grpo_gsm8k_async.py @@ -139,6 +139,7 @@ tasks=TaskSpecConfig( task_name="train_task", agent_loop_config=agent_loop_config, + judger_config=judger_config, produce_strategy_config=produce_strategy_config, sampler_config=sampler_config, ), @@ -174,6 +175,7 @@ tasks=TaskSpecConfig( task_name="eval_task", agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, sampler_config=eval_sampler_config, ), ) @@ -186,7 +188,6 @@ resources=resources, train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config rollout_config=rollout_config, - judger_config=judger_config, tokenizer_path=model_path, replay_buffer_config=AsyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, diff --git a/examples/v1/config/rl_grpo_gsm8k_judge.py b/examples/v1/config/rl_grpo_gsm8k_judge.py index a034371cf1..582a261abd 100644 --- a/examples/v1/config/rl_grpo_gsm8k_judge.py +++ b/examples/v1/config/rl_grpo_gsm8k_judge.py @@ -134,6 +134,7 @@ tasks=TaskSpecConfig( task_name="train_task", agent_loop_config=agent_loop_config, + judger_config=judger_config, produce_strategy_config=produce_strategy_config, sampler_config=sampler_config, ), @@ -169,6 +170,7 @@ tasks=TaskSpecConfig( task_name="eval_task", agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, sampler_config=eval_sampler_config, ), ) @@ -181,7 +183,6 @@ resources=resources, train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config rollout_config=rollout_config, - judger_config=judger_config, tokenizer_path=model_path, replay_buffer_config=SyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, diff --git a/examples/v1/config/rl_grpo_gsm8k_with_tool.py b/examples/v1/config/rl_grpo_gsm8k_with_tool.py index 41ea77bbc9..8ce95d404c 100644 --- a/examples/v1/config/rl_grpo_gsm8k_with_tool.py +++ b/examples/v1/config/rl_grpo_gsm8k_with_tool.py @@ -155,6 +155,7 @@ tasks=TaskSpecConfig( task_name="train_task", agent_loop_config=agent_loop_config, + judger_config=judger_config, produce_strategy_config=produce_strategy_config, sampler_config=sampler_config, ), @@ -191,6 +192,7 @@ tasks=TaskSpecConfig( task_name="eval_task", agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, sampler_config=eval_sampler_config, ), ) @@ -203,7 +205,6 @@ resources=resources, train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config rollout_config=rollout_config, - judger_config=judger_config, tokenizer_path=model_path, replay_buffer_config=SyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, diff --git a/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py b/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py index 83e257ade9..0373187c68 100644 --- a/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py +++ b/examples/v1/config/rl_multi_task_gsm8k_dapo_math.py @@ -200,6 +200,7 @@ task_name="train_task:dapo_math", weight=dapo_task_weight, agent_loop_config=dapo_train_agent_loop_config, + judger_config=judger_config, produce_strategy_config=SyncProduceStrategyConfig(), sampler_config=dapo_train_sampler_config, ), @@ -269,6 +270,7 @@ task_name="eval_task:dapo_math", weight=dapo_task_weight, agent_loop_config=dapo_eval_agent_loop_config, + judger_config=judger_config, sampler_config=dapo_eval_sampler_config, ), TaskSpecConfig( @@ -291,7 +293,6 @@ def compute_metric(samples): resources=resources, train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, - judger_config=judger_config, tokenizer_path=model_path, replay_buffer_config=SyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, diff --git a/tests/rl/test_agent_loop.py b/tests/rl/test_agent_loop.py index c9ce01370a..e33126379d 100644 --- a/tests/rl/test_agent_loop.py +++ b/tests/rl/test_agent_loop.py @@ -1,14 +1,20 @@ import os import unittest -import asyncio +import copy import ray import tempfile import torch from transformers import AutoTokenizer from xtuner.v1.rl.rollout.worker import RolloutConfig from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.rl.agent_loop import SingleTurnAgentLoopConfig, AgentLoopManagerConfig, TaskSpecConfig, SyncProduceStrategyConfig, SamplerConfig -from xtuner.v1.data_proto import RolloutState, Status, SampleParams +from xtuner.v1.rl.agent_loop import ( + SingleTurnAgentLoopConfig, + AgentLoopManagerConfig, + TaskSpecConfig, + SyncProduceStrategyConfig, + SamplerConfig, +) +from xtuner.v1.data_proto import RolloutState, Status, SampleParams from xtuner.v1.rl.rollout import RolloutController from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig from xtuner.v1.rl.replay_buffer import SyncReplayBufferConfig @@ -79,14 +85,16 @@ async def test_gsm8k_agent_loop(self): judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router") agent_loop_cfg = SingleTurnAgentLoopConfig( hf_checkpoint=self.model_path, - sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0) + sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0), ) - # 2. 创建 rollout_controller, judger + # 2. 创建 rollout_controller pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) - gsm8k_judger = judger_config.build() # 3. 创建 AgentLoop - agent_loop = agent_loop_cfg.build(rollout_controller=rollout_controller, judger=gsm8k_judger) + agent_loop = agent_loop_cfg.build( + rollout_controller=rollout_controller, + judger=judger_config.build(), + ) # 4. 构造输入数据 prompt_repeat_k = 4 rollout_state = FAKE_INPUT_ITEM @@ -104,6 +112,51 @@ async def test_gsm8k_agent_loop(self): self.assertGreater(len(single_rollout_state.response_ids), 0) self.assertEqual(single_rollout_state.reward["score"], 1) + async def test_gsm8k_agent_loop_with_ray_actor_judger(self): + self.init_config() + rollout_config = RolloutConfig( + env="test_agent_loop_ray_actor", + model_path=self.model_path, + model_name=os.path.basename(self.model_path).lower(), + tokenizer_path=self.model_path, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + ) + judger_config = GSM8KJudgerConfig( + judger_name="openai/gsm8k", + judger_type="ray.actor", + num_cpus_per_actor=1, + ) + agent_loop_cfg = SingleTurnAgentLoopConfig( + hf_checkpoint=self.model_path, + sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0), + num_ray_actors=1, + num_cpus=1, + ) + + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + agent_loop = agent_loop_cfg.build( + rollout_controller=rollout_controller, + judger=judger_config.build(), + ) + + prompt_repeat_k = 2 + rollout_state = copy.deepcopy(FAKE_INPUT_ITEM) + group_in_rollout_state = [copy.deepcopy(FAKE_INPUT_ITEM) for _ in range(prompt_repeat_k)] + + group_rollout_state = await agent_loop.generate_group.remote(group_in_rollout_state) + single_rollout_state = await agent_loop.generate_sample.remote(rollout_state) + + self.assertEqual(len(group_rollout_state), prompt_repeat_k) + for state in group_rollout_state: + self.assertEqual(state.status, Status.COMPLETED) + self.assertGreater(len(state.response_ids), 0) + self.assertEqual(state.reward["score"], 1) + self.assertEqual(single_rollout_state.status, Status.COMPLETED) + self.assertGreater(len(single_rollout_state.response_ids), 0) + self.assertEqual(single_rollout_state.reward["score"], 1) + async def test_gsm8k_agent_loop_manager(self): # 1. 初始化 config self.init_config() @@ -118,7 +171,7 @@ async def test_gsm8k_agent_loop_manager(self): judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router") agent_loop_cfg = SingleTurnAgentLoopConfig( hf_checkpoint=self.model_path, - sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0) + sample_params=SampleParams(max_tokens=self.max_response_length, temperature=0.0), ) sampler_config = SamplerConfig( dataloader_cfg=DataloaderConfig( @@ -141,21 +194,20 @@ async def test_gsm8k_agent_loop_manager(self): TaskSpecConfig( task_name="test_gsm8k", agent_loop_config=agent_loop_cfg, + judger_config=judger_config, produce_strategy_config=SyncProduceStrategyConfig(), sampler_config=sampler_config, ) ], ) - # 2. 创建 rollout_controller, judger + # 2. 创建 rollout_controller pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) - gsm8k_judger = judger_config.build() # 3. 创建 AgentLoopManager replay_buffer_cfg = SyncReplayBufferConfig() replay_buffer = replay_buffer_cfg.build() agent_loop_manager = agent_loop_manager_cfg.build( rollout_controller=rollout_controller, - judger=gsm8k_judger, tokenizer=self.tokenizer, replay_buffer=replay_buffer, ) diff --git a/tests/rl/test_async_rollout.py b/tests/rl/test_async_rollout.py index 6e80150cf2..32cd372aea 100644 --- a/tests/rl/test_async_rollout.py +++ b/tests/rl/test_async_rollout.py @@ -118,7 +118,6 @@ def _build_agent_loop_manager( ) manager = manager_cfg.build( rollout_controller=rollout_ctl, - judger=None, tokenizer=tokenizer, replay_buffer=replay_buffer, logger=None, diff --git a/tests/rl/test_rl_colocate_trainer_integration.py b/tests/rl/test_rl_colocate_trainer_integration.py index 86b866e95d..a2133235c2 100644 --- a/tests/rl/test_rl_colocate_trainer_integration.py +++ b/tests/rl/test_rl_colocate_trainer_integration.py @@ -166,6 +166,7 @@ def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxke TaskSpecConfig( task_name="train_task", agent_loop_config=agent_loop_config, + judger_config=judger_config, produce_strategy_config=produce_strategy_config, sampler_config=sampler_config, ) @@ -186,6 +187,7 @@ def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxke TaskSpecConfig( task_name="eval_task", agent_loop_config=eval_agent_loop_config, + judger_config=judger_config, sampler_config=eval_sampler_config, ) ], @@ -198,7 +200,6 @@ def build_trainer_config(self, work_dir, checkpoint_interval=1, checkpoint_maxke resources=resources, train_worker_cfg=train_worker_cfg, rollout_config=rollout_config, - judger_config=judger_config, tokenizer_path=model_path, replay_buffer_config=SyncReplayBufferConfig(), agent_loop_manager_cfg=agent_loop_manager_cfg, diff --git a/xtuner/v1/rl/agent_loop/__init__.py b/xtuner/v1/rl/agent_loop/__init__.py index ba0d351b6c..1931f99302 100644 --- a/xtuner/v1/rl/agent_loop/__init__.py +++ b/xtuner/v1/rl/agent_loop/__init__.py @@ -1,4 +1,14 @@ -from .agent_loop import AgentLoop, AgentLoopConfig +from xtuner.v1.rl.judger import JudgerConfigSpec, JudgerLike, JudgerSpec, JudgerSpecConfig + +from .agent_loop import ( + AgentLoop, + AgentLoopActor, + AgentLoopConfig, + AgentLoopSpec, + RayAgentLoop, + RayAgentLoopProxy, + RouterAgentLoop, +) from .agent_loop_manager import ( AgentLoopManager, AgentLoopManagerConfig, @@ -21,7 +31,16 @@ "AgentLoopConfig", "SingleTurnAgentLoopConfig", "AgentLoop", + "AgentLoopSpec", + "AgentLoopActor", + "RouterAgentLoop", + "RayAgentLoop", + "RayAgentLoopProxy", "SingleTurnAgentLoop", + "JudgerLike", + "JudgerSpec", + "JudgerConfigSpec", + "JudgerSpecConfig", "AgentLoopManagerConfig", "AgentLoopManager", "TaskSpecConfig", diff --git a/xtuner/v1/rl/agent_loop/agent_loop.py b/xtuner/v1/rl/agent_loop/agent_loop.py index dcd87eb714..b0bf468440 100644 --- a/xtuner/v1/rl/agent_loop/agent_loop.py +++ b/xtuner/v1/rl/agent_loop/agent_loop.py @@ -1,14 +1,18 @@ +from __future__ import annotations + import asyncio from abc import ABC, abstractmethod -from typing import Callable +from typing import TypeAlias, cast -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, model_validator +from ray.actor import ActorClass, ActorProxy +from ray.util.placement_group import PlacementGroup from xtuner.v1.data_proto import RolloutState, SampleParams -from xtuner.v1.rl.judger import NativeJudger, RouterJudger +from xtuner.v1.rl.judger import JudgerSpec from xtuner.v1.rl.rollout import RolloutController -from xtuner.v1.rl.utils import create_task -from xtuner.v1.utils import get_logger +from xtuner.v1.rl.utils import CPUActorLauncher, create_task +from xtuner.v1.utils import get_logger, ray_method from xtuner.v1.utils.processing_utils import load_processor, load_tokenizer @@ -16,9 +20,116 @@ class AgentLoopConfig(ABC, BaseModel): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) hf_checkpoint: str sample_params: SampleParams + num_ray_actors: int = Field( + default=0, + ge=0, + description="Number of AgentLoop Ray actor instances. 0 means local mode.", + ) + num_cpus: float = Field(default=1, gt=0, description="CPU cores required by the AgentLoop actor itself.") + cpu_memory: int = Field(default=1024**3, gt=0, description="CPU memory in bytes required by AgentLoop.") + + @model_validator(mode="after") + def _validate_ray_actor_config(self) -> AgentLoopConfig: + if self.num_ray_actors == 0 and (self.num_cpus != 1 or self.cpu_memory != 1024**3): + logger = get_logger() + logger.warning("num_cpus and cpu_memory are ignored when AgentLoop runs in local mode.") + return self + + def build(self, rollout_controller, judger: JudgerSpec = None, logger=None) -> AgentLoopSpec: + if self.num_ray_actors == 0: + return self.build_local( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + if self.num_ray_actors > 1: + return self._build_router( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + return self._build_ray_actor( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) @abstractmethod - def build(self, rollout_controller, judger=None, logger=None) -> "AgentLoop": ... + def build_local( + self, + rollout_controller, + judger: JudgerSpec = None, + logger=None, + ) -> AgentLoop: ... + + def _build_ray_actor( + self, + rollout_controller: RolloutController, + pg: PlacementGroup | None = None, + judger: JudgerSpec = None, + logger=None, + ) -> RayAgentLoopProxy: + return cast( + "RayAgentLoopProxy", + CPUActorLauncher.build_actor( + AgentLoopActor, + self, + rollout_controller, + judger, + logger, + pg=pg, + bundle_idx=0, + actor_num_cpus=self.num_cpus, + actor_memory=self.cpu_memory, + capture_child_tasks=True, + ), + ) + + def _build_ray_actors( + self, + rollout_controller: RolloutController, + num_actors: int, + pg: PlacementGroup | None = None, + judger: JudgerSpec = None, + logger=None, + start_bundle_idx: int = 0, + ) -> list[RayAgentLoopProxy]: + return cast( + list["RayAgentLoopProxy"], + CPUActorLauncher.build_actors( + AgentLoopActor, + self, + rollout_controller, + judger, + logger, + pg=pg, + start_bundle_idx=start_bundle_idx, + num_workers=num_actors, + actor_num_cpus_per_worker=self.num_cpus, + actor_memory_per_worker=self.cpu_memory, + capture_child_tasks=True, + ), + ) + + def _build_router( + self, + rollout_controller: RolloutController, + pg: PlacementGroup | None = None, + judger: JudgerSpec = None, + logger=None, + start_bundle_idx: int = 0, + ) -> RouterAgentLoop: + return RouterAgentLoop( + workers=self._build_ray_actors( + rollout_controller=rollout_controller, + num_actors=self.num_ray_actors, + pg=pg, + judger=judger, + logger=logger, + start_bundle_idx=start_bundle_idx, + ), + rollout_ctl=rollout_controller, + ) class AgentLoop(ABC): @@ -27,7 +138,7 @@ def __init__( rollout_ctl: RolloutController, sample_params: SampleParams, hf_checkpoint: str, - judger: Callable | NativeJudger | RouterJudger | None = None, + judger: JudgerSpec = None, logger=None, ) -> None: self.rollout_ctl = rollout_ctl @@ -54,13 +165,84 @@ async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> l group_samples = await generated_samples return group_samples - async def judge_sample(self, rollout_state: RolloutState) -> RolloutState: - if self.judger is None: - return rollout_state - if callable(self.judger): - rollout_state = await self.judger(rollout_state) - elif isinstance(self.judger, RouterJudger) or isinstance(self.judger, NativeJudger): - rollout_state = await self.judger.judge(rollout_state) # type: ignore[operator] - else: - raise ValueError(f"Invalid judger type: {type(self.judger)}") - return rollout_state + +class RouterAgentLoop: + def __init__(self, workers: list[RayAgentLoopProxy], rollout_ctl: RolloutController): + self.workers = workers + self.rollout_ctl = rollout_ctl + self._worker_loads = dict.fromkeys(workers, 0) + self._rr_index = 0 + self._lock = asyncio.Lock() + + async def _pick_worker(self) -> RayAgentLoopProxy: + async with self._lock: + min_load = min(self._worker_loads.values()) + candidates = [worker for worker in self.workers if self._worker_loads[worker] == min_load] + worker = candidates[self._rr_index % len(candidates)] + self._rr_index = (self._rr_index + 1) % len(self.workers) + self._worker_loads[worker] += 1 + return worker + + async def _release_worker(self, worker: RayAgentLoopProxy) -> None: + async with self._lock: + self._worker_loads[worker] -= 1 + + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + worker = await self._pick_worker() + try: + return await worker.generate_sample.remote(rollout_state, **kwargs) + finally: + await self._release_worker(worker) + + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + worker = await self._pick_worker() + try: + return await worker.generate_group.remote(rollout_state, **kwargs) + finally: + await self._release_worker(worker) + + def get_worker_status(self) -> dict[str, int]: + return {str(worker): load for worker, load in self._worker_loads.items()} + + +async def get_agent_loop_rollout_ctl(agent_loop: AgentLoopSpec) -> RolloutController: + rollout_ctl = getattr(agent_loop, "rollout_ctl", None) + if rollout_ctl is not None: + return rollout_ctl + + get_rollout_ctl = getattr(agent_loop, "get_rollout_ctl", None) + if get_rollout_ctl is None or not hasattr(get_rollout_ctl, "remote"): + raise AttributeError(f"Agent loop {type(agent_loop)} does not expose rollout_ctl or get_rollout_ctl().") + return await get_rollout_ctl.remote() + + +class AgentLoopActor: + def __init__( + self, + agent_loop_config: AgentLoopConfig, + rollout_controller: RolloutController, + judger: JudgerSpec = None, + logger=None, + ): + self.agent_loop = agent_loop_config.build_local( + rollout_controller=rollout_controller, + judger=judger, + logger=logger, + ) + + @ray_method + async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> RolloutState: + return await self.agent_loop.generate_sample(rollout_state, **kwargs) + + @ray_method + async def generate_group(self, rollout_state: list[RolloutState], **kwargs) -> list[RolloutState]: + return await self.agent_loop.generate_group(rollout_state, **kwargs) + + @ray_method + async def get_rollout_ctl(self): + return self.agent_loop.rollout_ctl + + +RayAgentLoop = cast(ActorClass[AgentLoopActor], CPUActorLauncher.to_actor_class(AgentLoopActor)) +RayAgentLoopProxy: TypeAlias = ActorProxy[AgentLoopActor] +AgentLoopSpec: TypeAlias = AgentLoop | RayAgentLoopProxy | RouterAgentLoop diff --git a/xtuner/v1/rl/agent_loop/agent_loop_manager.py b/xtuner/v1/rl/agent_loop/agent_loop_manager.py index b14fd65692..fb27bd9d82 100644 --- a/xtuner/v1/rl/agent_loop/agent_loop_manager.py +++ b/xtuner/v1/rl/agent_loop/agent_loop_manager.py @@ -4,17 +4,17 @@ from dataclasses import dataclass from pathlib import Path -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from xtuner.v1.data_proto import RolloutState, Status -from xtuner.v1.rl.judger import Judger +from xtuner.v1.rl.judger import JudgerCallable, JudgerConfig, JudgerConfigLike, JudgerSpecConfig from xtuner.v1.rl.replay_buffer import ReplayBuffer from xtuner.v1.rl.rollout import RolloutController, continue_generation, pause_generation from xtuner.v1.rl.utils import asyncio_run from xtuner.v1.utils import get_logger -from .agent_loop import AgentLoop, AgentLoopConfig +from .agent_loop import AgentLoopConfig, AgentLoopSpec, get_agent_loop_rollout_ctl from .producer import ProducerTimings, ProduceStrategy, ProduceStrategyConfig, SyncProduceStrategyConfig from .sampler import Sampler, SamplerConfig @@ -55,7 +55,7 @@ class ProduceBatchResult: @dataclass(frozen=True) class _TaskRunner: task_name: str - agent_loop: AgentLoop + agent_loop: AgentLoopSpec produce_strategy: ProduceStrategy sampler: Sampler weight: float = 1.0 @@ -137,9 +137,17 @@ class TaskSpecConfig(BaseModel): task_name: str weight: float = Field(default=1.0, ge=0.0) agent_loop_config: AgentLoopConfig + judger_config: JudgerConfig | dict[str, JudgerConfigLike] | JudgerCallable | JudgerSpecConfig | None = None produce_strategy_config: ProduceStrategyConfig = SyncProduceStrategyConfig() sampler_config: SamplerConfig + @field_validator("judger_config", mode="after") + @classmethod + def _normalize_judger_config(cls, value): + if value is None or isinstance(value, JudgerSpecConfig): + return value + return JudgerSpecConfig.from_value(value) + class AgentLoopManagerConfig(BaseModel): model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) @@ -149,7 +157,6 @@ class AgentLoopManagerConfig(BaseModel): def build( self, rollout_controller: RolloutController, - judger: Judger, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, replay_buffer: ReplayBuffer, logger=None, @@ -167,7 +174,7 @@ def build( agent_loop = task_cfg.agent_loop_config.build( rollout_controller=rollout_controller, - judger=judger, + judger=task_cfg.judger_config.build() if task_cfg.judger_config is not None else None, logger=logger, ) produce_strategy = task_cfg.produce_strategy_config.build() @@ -328,7 +335,7 @@ async def produce_batch(self, batch_size: int, rollout_step: int = 0) -> Produce if len(self.task_runners) == 1: task = self.task_runners[0] - rollout_ctl = task.agent_loop.rollout_ctl + rollout_ctl = await get_agent_loop_rollout_ctl(task.agent_loop) await continue_generation(rollout_ctl) try: return await _produce_single_task_batch( @@ -348,7 +355,7 @@ async def produce_batch(self, batch_size: int, rollout_step: int = 0) -> Produce results: list[ProduceBatchResult] = [] if active_tasks: - rollout_ctl = active_tasks[0].agent_loop.rollout_ctl + rollout_ctl = await get_agent_loop_rollout_ctl(active_tasks[0].agent_loop) await continue_generation(rollout_ctl) try: results = await asyncio.gather( diff --git a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py index 39a37cb820..2000cd85cd 100644 --- a/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py +++ b/xtuner/v1/rl/agent_loop/gsm8k_with_tool.py @@ -6,7 +6,8 @@ from pydantic import BaseModel, ConfigDict from xtuner.v1.data_proto import RolloutState, SampleParams -from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig +from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig, JudgerSpec +from xtuner.v1.rl.judger import judge_sample from xtuner.v1.rl.rollout import RolloutController from xtuner.v1.utils import get_logger @@ -17,7 +18,7 @@ class GSM8KToolAgentLoopConfig(AgentLoopConfig): max_turns: int - def build(self, rollout_controller, judger=None, logger=None) -> "GSM8KToolAgentLoop": + def build_local(self, rollout_controller, judger: JudgerSpec = None, logger=None) -> "GSM8KToolAgentLoop": return GSM8KToolAgentLoop( max_turns=self.max_turns, rollout_ctl=rollout_controller, @@ -41,7 +42,7 @@ def __init__( rollout_ctl: RolloutController, hf_checkpoint: str, sample_params: SampleParams, - judger=None, + judger: JudgerSpec = None, ): super().__init__( rollout_ctl=rollout_ctl, hf_checkpoint=hf_checkpoint, sample_params=sample_params, judger=judger @@ -151,5 +152,5 @@ async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> Rollou assert len(rollout_state.response_ids) == len(rollout_state.response_mask) == len(rollout_state.logprobs), ( f"{len(rollout_state.response_ids)} vs {len(rollout_state.response_mask)} vs {len(rollout_state.logprobs)}" ) - rollout_state = await self.judge_sample(rollout_state) + rollout_state = await judge_sample(self.judger, rollout_state) return rollout_state diff --git a/xtuner/v1/rl/agent_loop/producer.py b/xtuner/v1/rl/agent_loop/producer.py index 4cd0f2efd3..ddf3e8e78f 100644 --- a/xtuner/v1/rl/agent_loop/producer.py +++ b/xtuner/v1/rl/agent_loop/producer.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from typing import Protocol, runtime_checkable +import ray from pydantic import BaseModel, ConfigDict from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_expired_status @@ -12,7 +13,7 @@ from xtuner.v1.rl.utils import create_task from xtuner.v1.utils import get_logger -from .agent_loop import AgentLoop +from .agent_loop import AgentLoopSpec, get_agent_loop_rollout_ctl from .sampler import Sampler @@ -33,10 +34,13 @@ class ProducerTimings: async def _timed_generate_group( - agent_loop: AgentLoop, rollout_state: list[RolloutState], **kwargs + agent_loop: AgentLoopSpec, rollout_state: list[RolloutState], **kwargs ) -> tuple[list[RolloutState], float]: start = time.perf_counter() - result = await agent_loop.generate_group(rollout_state, **kwargs) + if isinstance(agent_loop, ray.actor.ActorHandle): + result = await agent_loop.generate_group.remote(rollout_state, **kwargs) + else: + result = await agent_loop.generate_group(rollout_state, **kwargs) return result, time.perf_counter() - start @@ -103,7 +107,7 @@ def __init__( @abstractmethod async def produce_batch( self, - agent_loop: AgentLoop, + agent_loop: AgentLoopSpec, sampler: Sampler, replay_buffer: ReplayBuffer, batch_size: int, @@ -115,7 +119,7 @@ async def produce_batch( class SyncProduceStrategy(ProduceStrategy): async def produce_batch( self, - agent_loop: AgentLoop, + agent_loop: AgentLoopSpec, sampler: Sampler, replay_buffer: ReplayBuffer, batch_size: int, @@ -194,10 +198,10 @@ async def _process_leftover_samples(self, replay_buffer: ReplayBuffer, task_name await replay_buffer.put(group, task_name) async def _cleanup_pending_tasks( - self, pending_tasks: set, agent_loop: AgentLoop, replay_buffer: ReplayBuffer, task_name: str + self, pending_tasks: set, agent_loop: AgentLoopSpec, replay_buffer: ReplayBuffer, task_name: str ) -> float: pause_start = time.perf_counter() - rollout_ctl = agent_loop.rollout_ctl + rollout_ctl = await get_agent_loop_rollout_ctl(agent_loop) await pause_generation(rollout_ctl) while len(pending_tasks) > 0: done_task, pending_tasks = await asyncio.wait( @@ -221,7 +225,7 @@ async def _cleanup_pending_tasks( async def produce_batch( self, - agent_loop: AgentLoop, + agent_loop: AgentLoopSpec, sampler: Sampler, replay_buffer: ReplayBuffer, batch_size: int, diff --git a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py index d82494ba8a..04c9451985 100644 --- a/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py +++ b/xtuner/v1/rl/agent_loop/single_turn_agent_loop.py @@ -1,12 +1,13 @@ from xtuner.v1.data_proto import RolloutState, SampleParams, Status +from xtuner.v1.rl.judger import judge_sample from xtuner.v1.rl.rollout import RolloutController -from .agent_loop import AgentLoop, AgentLoopConfig +from .agent_loop import AgentLoop, AgentLoopConfig, JudgerSpec from .utils import PartialRolloutHandler class SingleTurnAgentLoopConfig(AgentLoopConfig): - def build(self, rollout_controller, judger=None, logger=None) -> "SingleTurnAgentLoop": + def build_local(self, rollout_controller, judger: JudgerSpec = None, logger=None) -> "SingleTurnAgentLoop": return SingleTurnAgentLoop( rollout_ctl=rollout_controller, sample_params=self.sample_params, @@ -18,7 +19,12 @@ def build(self, rollout_controller, judger=None, logger=None) -> "SingleTurnAgen class SingleTurnAgentLoop(AgentLoop): def __init__( - self, rollout_ctl: RolloutController, sample_params: SampleParams, hf_checkpoint: str, judger=None, logger=None + self, + rollout_ctl: RolloutController, + sample_params: SampleParams, + hf_checkpoint: str, + judger: JudgerSpec = None, + logger=None, ): super().__init__(rollout_ctl, sample_params, hf_checkpoint, judger, logger) self.max_tokens = self.sample_params.max_tokens @@ -40,5 +46,5 @@ async def generate_sample(self, rollout_state: RolloutState, **kwargs) -> Rollou # 非 COMPLETED 状态(如被截断、放弃等)直接早退,不触发打分 if rollout_state.status != Status.COMPLETED: return rollout_state - rollout_state = await self.judge_sample(rollout_state) + rollout_state = await judge_sample(self.judger, rollout_state) return rollout_state diff --git a/xtuner/v1/rl/judger/__init__.py b/xtuner/v1/rl/judger/__init__.py index 76c291d99b..73030fe35a 100644 --- a/xtuner/v1/rl/judger/__init__.py +++ b/xtuner/v1/rl/judger/__init__.py @@ -1,4 +1,13 @@ from .dapo_math import DapoMathJudgerConfig +from .factory import ( + JudgerCallable, + JudgerConfigLike, + JudgerConfigSpec, + JudgerLike, + JudgerSpec, + JudgerSpecConfig, + judge_sample, +) from .geo3k import GEO3KJudgerConfig from .gsm8k import GSM8KJudgerConfig from .native import ( diff --git a/xtuner/v1/rl/judger/factory.py b/xtuner/v1/rl/judger/factory.py new file mode 100644 index 0000000000..204e84e1d1 --- /dev/null +++ b/xtuner/v1/rl/judger/factory.py @@ -0,0 +1,119 @@ +import inspect +from typing import Awaitable, Callable, TypeAlias + +import ray + +from xtuner.v1.data_proto import RolloutState + +from .native import Judger, JudgerConfig, RayJudgerProxy + + +JudgerCallable: TypeAlias = Callable[[RolloutState], RolloutState | Awaitable[RolloutState]] +JudgerLike: TypeAlias = Judger | RayJudgerProxy | JudgerCallable +JudgerSpec: TypeAlias = JudgerLike | dict[str, JudgerLike] | None +JudgerConfigLike: TypeAlias = JudgerConfig | JudgerCallable +JudgerConfigSpec: TypeAlias = JudgerConfigLike | dict[str, JudgerConfigLike] | None + + +class JudgerSpecConfig: + def __init__(self, judger_config: JudgerConfigSpec): + self.judger_config = judger_config + + @classmethod + def from_judger_config(cls, judger_config: JudgerConfig) -> "JudgerSpecConfig": + return cls(judger_config) + + @classmethod + def from_judger_config_dict(cls, judger_config: dict[str, JudgerConfigLike]) -> "JudgerSpecConfig": + return cls(judger_config) + + @classmethod + def from_judger_callable(cls, judger_callable: JudgerCallable) -> "JudgerSpecConfig": + return cls(judger_callable) + + @classmethod + def from_value(cls, judger_config: JudgerConfigSpec) -> "JudgerSpecConfig": + return cls(judger_config) + + def build(self) -> JudgerSpec: + judger_config = self.judger_config + if judger_config is None: + return None + + if isinstance(judger_config, dict): + judger_dict = {} + for key, config in judger_config.items(): + if isinstance(config, JudgerConfig): + judger_dict[key] = config.build() + elif callable(config): + judger_dict[key] = config + else: + raise ValueError(f"Invalid judger config type: {type(config)} for key {key}") + return judger_dict + + if isinstance(judger_config, JudgerConfig): + return judger_config.build() + + if callable(judger_config): + return judger_config + + raise ValueError(f"Invalid judger config type: {type(judger_config)}") + + +def _resolve_judger_from_dict(judger_dict: dict[str, JudgerLike], rollout_state: RolloutState) -> JudgerLike: + if not judger_dict: + raise ValueError("judger dict must not be empty.") + + candidate_keys: list[str] = [] + if rollout_state.task_name: + candidate_keys.append(rollout_state.task_name) + + data_source = rollout_state.data_source + if isinstance(data_source, str): + candidate_keys.append(data_source) + elif isinstance(data_source, dict): + for field in ("name", "id", "type", "data_source"): + value = data_source.get(field) + if isinstance(value, str): + candidate_keys.append(value) + + for key in candidate_keys: + if key in judger_dict: + return judger_dict[key] + + if "default" in judger_dict: + return judger_dict["default"] + + if len(judger_dict) == 1: + return next(iter(judger_dict.values())) + + raise KeyError( + "Unable to resolve judger from dict with " + f"task_name={rollout_state.task_name!r}, data_source={rollout_state.data_source!r}, " + f"available_keys={sorted(judger_dict)}" + ) + + +async def judge_sample(judger: JudgerSpec, rollout_state: RolloutState) -> RolloutState: + if judger is None: + return rollout_state + + if isinstance(judger, dict): + judger = _resolve_judger_from_dict(judger, rollout_state) + + if isinstance(judger, Judger): + rollout_state = await judger.judge(rollout_state) + elif isinstance(judger, ray.actor.ActorHandle): + rollout_state = await judger.judge.remote(rollout_state) + elif callable(judger): + judger_result = judger(rollout_state) + if inspect.isawaitable(judger_result): + rollout_state = await judger_result + else: + rollout_state = judger_result + else: + raise ValueError(f"Invalid judger type: {type(judger)}") + + if not isinstance(rollout_state, RolloutState): + raise TypeError(f"Judger must return RolloutState, but got {type(rollout_state)}") + return rollout_state diff --git a/xtuner/v1/rl/judger/native.py b/xtuner/v1/rl/judger/native.py index f6c0af2644..956ad70982 100644 --- a/xtuner/v1/rl/judger/native.py +++ b/xtuner/v1/rl/judger/native.py @@ -1,15 +1,17 @@ +from __future__ import annotations + import asyncio import inspect from abc import ABC, abstractmethod -from typing import Callable, List, Literal, TypeAlias, cast +from typing import Callable, Literal, TypeAlias, cast import httpx -import ray from pydantic import BaseModel, ConfigDict, Field, model_validator from ray.actor import ActorClass, ActorProxy from ray.util.placement_group import PlacementGroup from xtuner.v1.data_proto.rl_data import RolloutState +from xtuner.v1.rl.utils import CPUActorLauncher from xtuner.v1.utils.logger import get_logger from xtuner.v1.utils.type_helper import ray_method @@ -72,7 +74,9 @@ async def judge(self, rollout_state: RolloutState) -> RolloutState: # type: ign else: judger_response = self.reward_handler(**input_kwargs) assert judger_response is not None, "Reward handler did not return a response." - # native postprocess + assert isinstance(judger_response, dict), ( + f"Reward handler must return a dict, but got {type(judger_response)}." + ) rollout_state.reward = judger_response return rollout_state @@ -85,13 +89,6 @@ def get_judger_name(self) -> str: return self._judger_name -# For type hint and IDE support. For more info, please refer to: -# 1. https://docs.ray.io/en/latest/ray-core/actors.html#type-hints-and-static-typing-for-actors -# 2. https://github.com/InternLM/xtuner/pull/1349 -RayJudger = cast(ActorClass[NativeJudger], ray.remote(NativeJudger)) -RayJudgerProxy: TypeAlias = ActorProxy[NativeJudger] - - class RouterJudger(Judger): """NativeJudger 路由管理器。 @@ -100,14 +97,15 @@ class RouterJudger(Judger): 2. 当负载相同时,通过轮询(Round-robin)分配任务。 """ - def __init__(self, workers: List[RayJudgerProxy], judger_name: str): + def __init__(self, workers: list[RayJudgerProxy], judger_name: str): self.workers = workers self._worker_loads = dict.fromkeys(workers, 0) self._rr_index = 0 self._lock = asyncio.Lock() self._judger_name = judger_name - async def judge(self, rollout_state: RolloutState) -> RolloutState: + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: # type: ignore[override] async with self._lock: min_load = min(self._worker_loads.values()) candidates: list[RayJudgerProxy] = [w for w in self.workers if self._worker_loads[w] == min_load] @@ -147,10 +145,7 @@ class JudgerConfig(BaseModel): ) @model_validator(mode="after") - def _validate_ray_actor_config(self) -> "JudgerConfig": - if self.judger_type == "ray.actor" and self.num_ray_actors > 1: - logger.warning("num_ray_actors will be set to 1 when judger_type is 'ray.actor'.") - self.num_ray_actors = 1 + def _validate_ray_actor_config(self) -> JudgerConfig: if self.judger_type == "native": if self.num_ray_actors > 1 or self.num_cpus_per_actor > 1 or self.cpu_memory_per_actor != 1024**3: logger.warning( @@ -158,89 +153,89 @@ def _validate_ray_actor_config(self) -> "JudgerConfig": ) return self - def _build_worker(self, pg: PlacementGroup | None = None, bundle_idx: int = 0) -> RayJudgerProxy: - pg_options = {"num_cpus": self.num_cpus_per_actor, "memory": self.cpu_memory_per_actor} - if pg is None: - # NOTE: 保持与 router 构建逻辑一致,默认创建 PlacementGroup。 - from xtuner.v1.rl.utils.ray_worker import CPUResourcesConfig - - cpu_resource_cfg = CPUResourcesConfig( - num_workers=self.num_ray_actors, - num_cpus_per_worker=self.num_cpus_per_actor, - cpu_memory_per_worker=self.cpu_memory_per_actor, - ) - pg = cpu_resource_cfg.build_placement_group() - ray.get(pg.ready()) - bundle_idx = 0 - - assert len(pg.bundle_specs) > bundle_idx, "Placement group does not have enough bundles for ray actor." - assert pg.bundle_specs[bundle_idx].get("CPU", 1) >= self.num_cpus_per_actor, ( - f"Placement group bundle {bundle_idx} does not have enough CPU resources." - ) - assert pg.bundle_specs[bundle_idx].get("memory", 0) >= self.cpu_memory_per_actor, ( - f"Placement group bundle {bundle_idx} does not have enough memory resources." - ) - return RayJudger.options( - placement_group=pg, - placement_group_bundle_index=bundle_idx, - **pg_options, - ).remote( + def get_num_placement_group_bundles(self) -> int: + if self.judger_type == "native": + return 0 + return self.num_ray_actors + + def get_cpu_bundles(self) -> list[dict[str, float | int]]: + return [ + { + "CPU": self.num_cpus_per_actor, + "memory": self.cpu_memory_per_actor, + } + for _ in range(self.get_num_placement_group_bundles()) + ] + + def build_local(self) -> Judger: + return NativeJudger( judger_name=self.judger_name, reward_handler=self.reward_handler, request_timeout=self.request_timeout, extra_info=self.extra_info, ) - def _build_workers(self, pg: PlacementGroup | None = None, start_bundle_idx: int = 0) -> list[RayJudgerProxy]: - """Create and launch Ray actor instances for router workers. + def _build_ray_actor(self, pg: PlacementGroup | None = None, bundle_idx: int = 0) -> RayJudgerProxy: + return CPUActorLauncher.build_actor( + JudgerActor, + self, + pg=pg, + bundle_idx=bundle_idx, + actor_num_cpus=self.num_cpus_per_actor, + actor_memory=self.cpu_memory_per_actor, + ) - This method instantiates multiple NativeJudger Ray actors according to `num_ray_actors`, - assigning each to a specific bundle in the provided placement group for resource isolation. - Each actor is initialized with the judger's configuration and reward function. + def _build_ray_actor_list( + self, + pg: PlacementGroup | None = None, + start_bundle_idx: int = 0, + ) -> list[RayJudgerProxy]: + return CPUActorLauncher.build_actors( + JudgerActor, + self, + pg=pg, + start_bundle_idx=start_bundle_idx, + num_workers=self.num_ray_actors, + actor_num_cpus_per_worker=self.num_cpus_per_actor, + actor_memory_per_worker=self.cpu_memory_per_actor, + ) - Args: - pg: The Ray PlacementGroup used to allocate resources for the actors. - start_bundle_idx: The starting bundle index in the placement group for actor placement. + def _build_router_workers( + self, pg: PlacementGroup | None = None, start_bundle_idx: int = 0 + ) -> list[RayJudgerProxy]: + return self._build_ray_actor_list(pg=pg, start_bundle_idx=start_bundle_idx) - Returns: - List[ActorClass]: A list of Ray actor handles representing the launched judger workers. - """ - if pg is None: - # NOTE: 这里直接在build_workers里创建PlacementGroup是为了简化用户使用,用户不需要关心PlacementGroup的细节。 - from xtuner.v1.rl.utils.ray_worker import CPUResourcesConfig - - cpu_resource_cfg = CPUResourcesConfig( - num_workers=self.num_ray_actors, - num_cpus_per_worker=self.num_cpus_per_actor, - cpu_memory_per_worker=self.cpu_memory_per_actor, - ) - pg = cpu_resource_cfg.build_placement_group() - ray.get(pg.ready()) - start_bundle_idx = 0 - - workers_list = [] - assert len(pg.bundle_specs) >= self.num_ray_actors, ( - "Placement group does not have enough bundles for the number of ray actors." - ) - for idx in range(self.num_ray_actors): - workers_list.append(self._build_worker(pg=pg, bundle_idx=start_bundle_idx + idx)) - return workers_list + def _build_router(self, pg: PlacementGroup | None = None, start_bundle_idx: int = 0) -> RouterJudger: + workers_list = self._build_router_workers(pg=pg, start_bundle_idx=start_bundle_idx) + return RouterJudger(workers=workers_list, judger_name=self.judger_name) def build( self, pg: PlacementGroup | None = None, start_bundle_idx: int = 0, - ) -> NativeJudger | RayJudgerProxy | RouterJudger: + ) -> Judger | RayJudgerProxy: if self.judger_type == "native": - return NativeJudger( - judger_name=self.judger_name, - reward_handler=self.reward_handler, - request_timeout=self.request_timeout, - extra_info=self.extra_info, - ) + return self.build_local() if self.judger_type == "ray.actor": - return self._build_worker(pg=pg, bundle_idx=start_bundle_idx) + if self.num_ray_actors > 1: + return self._build_router(pg=pg, start_bundle_idx=start_bundle_idx) + return self._build_ray_actor(pg=pg, bundle_idx=start_bundle_idx) - workers_list = self._build_workers(pg=pg, start_bundle_idx=start_bundle_idx) - return RouterJudger(workers=workers_list, judger_name=self.judger_name) + return self._build_router(pg=pg, start_bundle_idx=start_bundle_idx) + + +class JudgerActor: + def __init__(self, judger_config: JudgerConfig): + self.judger = judger_config.build_local() + + @ray_method + async def judge(self, rollout_state: RolloutState) -> RolloutState: + return await self.judger.judge(rollout_state) + + +# For type hint and IDE support. For more info, please refer to: +# 1. https://docs.ray.io/en/latest/ray-core/actors.html#type-hints-and-static-typing-for-actors +# 2. https://github.com/InternLM/xtuner/pull/1349 +RayJudger = cast(ActorClass[JudgerActor], CPUActorLauncher.to_actor_class(JudgerActor)) +RayJudgerProxy: TypeAlias = ActorProxy[JudgerActor] diff --git a/xtuner/v1/rl/utils/__init__.py b/xtuner/v1/rl/utils/__init__.py index 8a0a44799d..d87ef41407 100644 --- a/xtuner/v1/rl/utils/__init__.py +++ b/xtuner/v1/rl/utils/__init__.py @@ -29,6 +29,7 @@ AutoAcceleratorWorkers, AutoCPUWorkers, BaseCPUWorker, + CPUActorLauncher, CPUResourcesConfig, SingleAcceleratorWorker, ) @@ -39,6 +40,7 @@ "SingleAcceleratorWorker", "AutoAcceleratorWorkers", "CPUResourcesConfig", + "CPUActorLauncher", "BaseCPUWorker", "AutoCPUWorkers", "get_ray_accelerator", diff --git a/xtuner/v1/rl/utils/ray_worker.py b/xtuner/v1/rl/utils/ray_worker.py index 5a416c8f10..859d6cc649 100644 --- a/xtuner/v1/rl/utils/ray_worker.py +++ b/xtuner/v1/rl/utils/ray_worker.py @@ -1,4 +1,5 @@ import os +import threading from typing import Any, Dict, List, Literal, Tuple, TypeVar import ray @@ -13,6 +14,7 @@ placement_group, placement_group_table, ) +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from typing_extensions import Annotated from .ray_utils import find_master_addr_and_port, get_accelerator_ids @@ -98,7 +100,7 @@ def build_placement_group(self) -> PlacementGroup: Returns: PlacementGroup: The created Ray PlacementGroup. """ - return AutoCPUWorkers.build_placement_group(self) + return CPUActorLauncher.build_placement_group(self) class AcceleratorResourcesConfig(BaseModel): @@ -562,9 +564,16 @@ def __init__(self, config, num_cpus: float | int = 1): self.num_cpus = num_cpus -class AutoCPUWorkers: - """A utility class for automatically creating and managing cpu actors - within a Ray PlacementGroup.""" +class CPUActorLauncher: + """Infrastructure for launching CPU Ray actors from plain Python classes. + + This class owns the generic actorization flow for CPU-only components: + building homogeneous CPU placement groups, converting plain classes into + Ray actor classes, validating bundle resources, and launching one or more + actors on specific bundles. + """ + + _ACTOR_CLASS_CACHE: dict[type, ActorClass] = {} @staticmethod def build_placement_group(resources_config: CPUResourcesConfig): @@ -608,6 +617,164 @@ def get_pg_options(pg: PlacementGroup, num_cpus: int | float = -1) -> Dict: default_cpu = pg.bundle_specs[0].get("CPU", 1) return {"num_cpus": num_cpus if num_cpus >= 0 else default_cpu} + @classmethod + def to_actor_class(cls, worker_cls): + """Convert a plain Python class into a Ray actor class. + + If ``worker_cls`` is already a Ray actor class, it is returned as-is. + """ + if hasattr(worker_cls, "remote") and hasattr(worker_cls, "options"): + return worker_cls + + if worker_cls not in cls._ACTOR_CLASS_CACHE: + cls._ACTOR_CLASS_CACHE[worker_cls] = ray.remote(worker_cls) + return cls._ACTOR_CLASS_CACHE[worker_cls] + + @staticmethod + def _get_bundle_resources(pg: PlacementGroup, bundle_idx: int) -> dict[str, float | int]: + assert len(pg.bundle_specs) > bundle_idx, f"Placement group does not have bundle index {bundle_idx}." + return pg.bundle_specs[bundle_idx] + + @classmethod + def _resolve_actor_resources( + cls, + pg: PlacementGroup, + bundle_idx: int, + actor_num_cpus: int | float | None = None, + actor_memory: int | None = None, + ) -> tuple[float | int, int]: + bundle = cls._get_bundle_resources(pg, bundle_idx) + resolved_num_cpus = actor_num_cpus if actor_num_cpus is not None else bundle.get("CPU", 1) + resolved_memory = actor_memory if actor_memory is not None else int(bundle.get("memory", 0)) + assert bundle.get("CPU", 1) >= resolved_num_cpus, ( + f"Placement group bundle {bundle_idx} does not have enough CPU resources." + ) + assert bundle.get("memory", 0) >= resolved_memory, ( + f"Placement group bundle {bundle_idx} does not have enough memory resources." + ) + return resolved_num_cpus, resolved_memory + + @classmethod + def build_actor( + cls, + worker_cls, + *init_args, + pg: PlacementGroup | None = None, + bundle_idx: int = 0, + actor_num_cpus: int | float | None = None, + actor_memory: int | None = None, + capture_child_tasks: bool = False, + **init_kwargs, + ): + """Build a single CPU actor from a plain class or Ray actor class.""" + resolved_num_cpus = 1 if actor_num_cpus is None else actor_num_cpus + resolved_memory = actor_memory + + actor_cls = cls.to_actor_class(worker_cls) + actor_options = { + "num_cpus": resolved_num_cpus, + } + if resolved_memory is not None and resolved_memory > 0: + actor_options["memory"] = resolved_memory + + if pg is None: + return actor_cls.options(**actor_options).remote(*init_args, **init_kwargs) + + resolved_num_cpus, resolved_memory = cls._resolve_actor_resources( + pg=pg, + bundle_idx=bundle_idx, + actor_num_cpus=actor_num_cpus, + actor_memory=actor_memory, + ) + actor_options["num_cpus"] = resolved_num_cpus + if resolved_memory > 0: + actor_options["memory"] = resolved_memory + actor_options["scheduling_strategy"] = PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=bundle_idx, + placement_group_capture_child_tasks=capture_child_tasks, + ) + return actor_cls.options(**actor_options).remote(*init_args, **init_kwargs) + + @classmethod + def build_actors( + cls, + worker_cls, + *init_args, + pg: PlacementGroup | None = None, + start_bundle_idx: int = 0, + num_workers: int = 1, + actor_num_cpus_per_worker: int | float | None = None, + actor_memory_per_worker: int | None = None, + capture_child_tasks: bool = False, + **init_kwargs, + ): + """Build multiple homogeneous CPU actors from a plain class or Ray + actor class.""" + workers_list = [] + for idx in range(num_workers): + workers_list.append( + cls.build_actor( + worker_cls, + *init_args, + pg=pg, + bundle_idx=start_bundle_idx + idx, + actor_num_cpus=actor_num_cpus_per_worker, + actor_memory=actor_memory_per_worker, + capture_child_tasks=capture_child_tasks, + **init_kwargs, + ) + ) + return workers_list + + +class AutoCPUWorkers(CPUActorLauncher): + """Convenience wrapper for BaseCPUWorker-style homogeneous worker pools. + + `CPUActorLauncher` is the generic actorization layer. `AutoCPUWorkers` + keeps the legacy worker-centric API that instantiates one worker per bundle + using the conventional `(worker_config, num_cpus=...)` constructor shape. + """ + + _PG_NEXT_BUNDLE_INDEX: dict[str, int] = {} + _PG_NEXT_BUNDLE_INDEX_LOCK = threading.Lock() + + @staticmethod + def _get_pg_key(pg: PlacementGroup) -> str: + """Build a stable placement-group identifier for local bundle + tracking.""" + return str(pg.id) + + @classmethod + def _reserve_bundle_range( + cls, + pg: PlacementGroup, + num_workers: int, + start_bundle_idx: int | None, + ) -> tuple[int, int]: + """Reserve a contiguous bundle range for worker creation. + + When ``start_bundle_idx`` is omitted, the next unconsumed bundle range + in this process is used. Explicit bundle reservations still advance the + local cursor so later auto-allocation does not reuse the same bundles. + """ + pg_key = cls._get_pg_key(pg) + + with cls._PG_NEXT_BUNDLE_INDEX_LOCK: + current_cursor = cls._PG_NEXT_BUNDLE_INDEX.get(pg_key, 0) + resolved_start_bundle_idx = current_cursor if start_bundle_idx is None else start_bundle_idx + resolved_num_workers = num_workers if num_workers > 0 else pg.bundle_count - resolved_start_bundle_idx + + assert resolved_num_workers > 0, "At least one worker must be created from the placement group." + assert resolved_start_bundle_idx >= 0, "start_bundle_idx must be non-negative." + assert resolved_start_bundle_idx + resolved_num_workers <= pg.bundle_count, ( + "Placement group does not have enough remaining bundles for the requested CPU workers." + ) + + cls._PG_NEXT_BUNDLE_INDEX[pg_key] = max(current_cursor, resolved_start_bundle_idx + resolved_num_workers) + + return resolved_start_bundle_idx, resolved_num_workers + @classmethod def from_config(cls, worker_cls, worker_config, cpu_config: CPUResourcesConfig): """Create workers and a placement group from configuration objects. @@ -621,13 +788,20 @@ def from_config(cls, worker_cls, worker_config, cpu_config: CPUResourcesConfig): Returns: List[T]: List of created worker instances. """ - pg = AutoCPUWorkers.build_placement_group(cpu_config) + pg = cls.build_placement_group(cpu_config) workers_list = cls.from_placement_group(worker_cls, worker_config, pg) return workers_list, pg @classmethod - def from_placement_group(cls, worker_cls, worker_config, pg: PlacementGroup, num_workers: int = -1): + def from_placement_group( + cls, + worker_cls, + worker_config, + pg: PlacementGroup, + num_workers: int = -1, + start_bundle_idx: int | None = None, + ): """Create workers from an existing placement group. Args: @@ -635,19 +809,25 @@ def from_placement_group(cls, worker_cls, worker_config, pg: PlacementGroup, num worker_config: The configuration for each worker instance. pg (PlacementGroup): The existing placement group to use. num_workers (int): The number of workers to create. Defaults to -1, - the number of bundles in the placement group will be used. + the remaining bundles in the placement group will be used. + start_bundle_idx (int | None): Bundle index to start from. If + omitted, the next unconsumed local bundle range for this + placement group will be used. Returns: List[T]: List of created worker instances. """ - pg_options = cls.get_pg_options(pg) - - num_workers = num_workers if num_workers > 0 else pg.bundle_count - workers_list = [] - for _ in range(num_workers): - worker = worker_cls.options(placement_group=pg, **pg_options).remote( - worker_config, num_cpus=pg_options.get("num_cpus", 1) - ) # type: ignore[attr-defined] - workers_list.append(worker) - - return workers_list + start_bundle_idx, num_workers = cls._reserve_bundle_range( + pg=pg, num_workers=num_workers, start_bundle_idx=start_bundle_idx + ) + default_cpu = cls._get_bundle_resources(pg, start_bundle_idx).get("CPU", 1) + return cls.build_actors( + worker_cls, + worker_config, + num_cpus=default_cpu, + pg=pg, + start_bundle_idx=start_bundle_idx, + num_workers=num_workers, + actor_num_cpus_per_worker=default_cpu, + actor_memory_per_worker=None, + ) diff --git a/xtuner/v1/train/rl_colocate_trainer.py b/xtuner/v1/train/rl_colocate_trainer.py index df56c56462..a4e142a4c3 100644 --- a/xtuner/v1/train/rl_colocate_trainer.py +++ b/xtuner/v1/train/rl_colocate_trainer.py @@ -19,7 +19,6 @@ from xtuner.v1.patch import patch_default_save_plan from xtuner.v1.rl.agent_loop import AgentLoopManagerConfig, ProduceBatchResult from xtuner.v1.rl.evaluator import EvaluatorConfig -from xtuner.v1.rl.judger import JudgerConfig from xtuner.v1.rl.replay_buffer import AsyncReplayBufferConfig, SyncReplayBufferConfig from xtuner.v1.rl.rollout.controller import RolloutControllerProxy from xtuner.v1.rl.rollout.worker import RolloutConfig @@ -156,7 +155,6 @@ class RLColocateTrainerConfig(BaseModel): resources: AcceleratorResourcesConfig train_worker_cfg: WorkerConfig rollout_config: RolloutConfig - judger_config: JudgerConfig tokenizer_path: Union[str, Path] replay_buffer_config: SyncReplayBufferConfig | AsyncReplayBufferConfig = SyncReplayBufferConfig() agent_loop_manager_cfg: AgentLoopManagerConfig @@ -188,7 +186,6 @@ def build(self) -> "RLColocateTrainer": resources=self.resources, train_worker_cfg=self.train_worker_cfg, rollout_config=self.rollout_config, - judger_config=self.judger_config, tokenizer_path=self.tokenizer_path, replay_buffer_config=self.replay_buffer_config, agent_loop_manager_cfg=self.agent_loop_manager_cfg, @@ -231,7 +228,6 @@ def __init__( resources: AcceleratorResourcesConfig, train_worker_cfg: WorkerConfig, rollout_config: RolloutConfig, - judger_config: JudgerConfig, # Sampler config # sampler_config: SamplerConfig, tokenizer_path: str | Path, @@ -345,15 +341,11 @@ def __init__( self.rollout_controller = rollout_config.build(self._pg) - # build judger - judger = judger_config.build() - replay_buffer = replay_buffer_config.build() self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) # build agnet_loop_manager self.agent_loop_manager = agent_loop_manager_cfg.build( rollout_controller=self.rollout_controller, - judger=judger, tokenizer=self.tokenizer, replay_buffer=replay_buffer, logger=self.logger, @@ -362,7 +354,6 @@ def __init__( # build eval agent loop manager self.eval_agent_loop_manager = eval_agent_loop_manager_cfg.build( rollout_controller=self.rollout_controller, - judger=judger, tokenizer=self.tokenizer, replay_buffer=replay_buffer, logger=self.logger,