Skip to content

[Rollout] Part 2: adapt RolloutController and Worker to RolloutState and adjust interface#1488

Merged
YanhuiDua merged 6 commits intoInternLM:rl_designfrom
YanhuiDua:dyh/rl_design
Feb 10, 2026
Merged

[Rollout] Part 2: adapt RolloutController and Worker to RolloutState and adjust interface#1488
YanhuiDua merged 6 commits intoInternLM:rl_designfrom
YanhuiDua:dyh/rl_design

Conversation

@YanhuiDua
Copy link
Copy Markdown
Collaborator

@YanhuiDua YanhuiDua commented Feb 10, 2026

RolloutController接口定义

class RolloutWorkerMetadata(TypedDict):
    # 主要作用为更新权重、对外暴露worker url,可能还会增加其他信息,所以暂时定义为TypedDict
    engine_rank_mesh_array: List[List[int]]
    server_url_dict: Dict[str, List[str]] # rank -> url
    rollout_config: RolloutConfig
    worker_server_urls_status: Dict[str, bool]

class RolloutController:
    def __init__(self, infer_config: RolloutConfig, placement_group: PlacementGroup):...
    # 健康检查,目前还未实现,预计为后台进程,每隔一段时间执行
    def check_health(self) -> None:...
    # 停止rollout,包括abort_request或者设置暂停的状态位等
    def pause_genetation(self) -> None:...   
    # 重新开始rollout,恢复某些状态位 
    def continue_genetation(self) -> None:...    
    # offload
    def offload(self) -> None:...
    # onload weigths + cachekv,考虑到大部分情况都是一起onload,所以暴露一个接口
    def onload(self) -> None: ...
    def onload_weights(self) -> None: ...
    def onload_kvcache(self) -> None: ...
    # 主要的生成函数
    def generate(self, state: RolloutState) -> RolloutState:...
    def get_rollout_metadata(self) -> RolloutWorkerMetadata: 

RolloutWorker接口定义

class RolloutWorker:
    def __init__(self, config, rank, master_addr, master_port, world_size, accelerator: str = "GPU"):...
    
   # Common API for RolloutController
    def init(self, dist_init_addr: str) -> tuple[int, str]: ...
    def init_dist_port(self) -> str: ...
    def shutdown(self) -> None: ...
    def pause_genetation(self) -> None: ...
    def continue_genetation(self) -> None: ...
    def check_health(self) -> bool: ...
    async def generate(self, rollout_state: RolloutState) -> RolloutState: ..
    
    # Abstart API should be implemented in lmdeploy/vllm/sglang
    def _transform_sample_params(self, sample_params: Dict) -> dict: ...
    def _transform_rollout_config_to_server_configs(self) -> Namespace: ...
    def offload(self): ...
    def onload_weights(self): ...
    def onload_kvcache(self): ...
    async def _get_request_payload(self, url: str, rollout_state: RolloutState) -> dict: ...
    
    # Important and Common Internal API in init and generate logic
    def _launch_server(self): ...
    async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: ...
    async def _safe_handle_response(self, rollout_state: RolloutState, http_response: httpx.Response) -> RolloutState:
   

TODO:

  • @YanhuiDua :现在的worker_info的管理写的不好,由于打印和计数都不太合理,实际用处不大,仅在debug时才会打开,先去掉这部分逻辑,后面加上
  • @YanhuiDua 目前先去掉根据失败的样本的数量来判断worker是否可用的逻辑,实际上这部分逻辑也用不到,需要更好的check_health的机制来判断worker是否可用
  • @YanhuiDua support claude format input
  • @YanhuiDua 需要看下新的输入输出(RolloutState)怎么适配PartialRollout的逻辑,目前先跑起来
  • @YanhuiDua 对于流式返回的response先删掉,目前还用不上,等需要的时候再加上
  • @YanhuiDua 合并test_rollout_v2至test_rollout, test_mock_rollout已经修改完成
  • @hhaAndroid 等mm_tokenize_fn支持后,_get_request_payload函数中增加处理image_data的逻辑
  • Rollout所有功能支持完后,梳理下内部接口,不需要的内部接口进行合并或者删除

Comment thread xtuner/v1/data_proto/rl_data.py
Comment thread xtuner/v1/ray/rollout/controller.py Outdated
Comment thread xtuner/v1/ray/rollout/controller.py Outdated
…st interface to pause_generation and offload;
Comment thread xtuner/v1/ray/rollout/controller.py Outdated
@YanhuiDua YanhuiDua merged commit dd00ae7 into InternLM:rl_design Feb 10, 2026
0 of 3 checks passed
response_ids: list[int] | None = None
logprobs: list[float] | None = None
routed_experts: list[int] | RayObjectRef | None = None # type: ignore[valid-type]
finish_reason: str | None = None
Copy link
Copy Markdown
Collaborator Author

@YanhuiDua YanhuiDua Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑多轮:假如一个prompt通过不同的工具调用产生的不同的轨迹怎么记录,是记录新的rolloutstate,还是包含多轮的信息比较好呢?

error_msg: str | None = None
seq_staleness: int = 0 # 整条序列的staleness,一般为最大的token_staleness
token_staleness: list[int] | None = None # 每一个token的staleness,长度和tokens保持一致
loss_mask: list[int] | None = None # tokens + response_ids的长度
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在多轮的情况下,比如说第二轮对话要删掉第一轮对话thinking的内容,loss_mask也要同时修改

YanhuiDua added a commit that referenced this pull request Apr 27, 2026
…and adjust interface (#1488)

* [Rollout] Part 1.1: add return_routed_experts in sample_params and add update_status_from_finish_reason

* [Rollout] Part 2: refactor RolloutController interface and use RolloutState

* [Rollout] Part 2.1: adapt RolloutWorker to RolloutController

* [Rollout] Part 2.2: add rollout ut

* [Rollout] fix comments: 1. support error_msg in RolloutState; 2. adjust interface to pause_generation and offload;

* [Rollout] fix comments: delete useless return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants