Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xtuner/v1/rl/rollout/lmdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,12 @@ def _wake_up(self, tags: List[str] | None = None):
assert response.status_code == 200, response.status_code
return response.text

def _decode_routed_experts(self, routed_experts: Any) -> Any:
async def _decode_routed_experts(self, routed_experts: Any) -> Any:
if isinstance(routed_experts, str):
if self.lmdeploy_actor is None:
self.lmdeploy_actor = ray.get_actor(SHARED_STORE, namespace=SHARED_STORE_NAMESPACE)
assert self.lmdeploy_actor is not None, "LMDeploy actor should be available in the shared store."
routed_experts_data = ray.get(self.lmdeploy_actor.get.remote(routed_experts))
routed_experts_data = await self.lmdeploy_actor.get.remote(routed_experts)
return ray.put(routed_experts_data)
return torch.tensor(routed_experts)

Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/rl/rollout/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def reset_prefix_cache(self):
self.flush_cache()
return self._make_request("release_memory_occupation")

def _decode_routed_experts(self, routed_experts: Any):
async def _decode_routed_experts(self, routed_experts: Any):
if isinstance(routed_experts, str):
routed_experts_flat = np.frombuffer(base64.b64decode(routed_experts), dtype=np.int32)
routed_experts_array = routed_experts_flat.reshape(
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/rl/rollout/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def offload(self):
def reset_prefix_cache(self, tags: List[str] | None = None):
raise NotImplementedError("The 'reset_prefix_cache' API is not yet implemented in the vLLM server.")

def _decode_routed_experts(self, routed_experts: Any):
async def _decode_routed_experts(self, routed_experts: Any):
raise NotImplementedError

def _transform_rollout_config_to_server_configs(self) -> Namespace:
Expand Down
31 changes: 25 additions & 6 deletions xtuner/v1/rl/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def check_health(self) -> bool:
self.logger.error(f"Health check failed for server {self.server_url}: {e}")
return False

def _decode_routed_experts(self, routed_experts: Any) -> Any:
async def _decode_routed_experts(self, routed_experts: Any) -> Any:
return routed_experts

async def generate(self, rollout_state: RolloutState) -> RolloutState:
Expand Down Expand Up @@ -835,11 +835,30 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response
logprobs: list[float] = []
routed_experts = None
returned_response = ""
finish_reason = response["meta_info"]["finish_reason"]["type"]
if finish_reason == "abort" and self.receive_abort_request.is_set() is False:
self.receive_abort_request.set()
self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}")
try:
meta_info = response.get("meta_info") or {}
finish_reason_info = meta_info.get("finish_reason") or {}
finish_reason = finish_reason_info.get("type")
if finish_reason is None:
if self.receive_abort_request.is_set():
rollout_state.finish_reason = "abort"
rollout_state.status = Status.ABORTED
self.logger.warning(
f"finish_reason is missing in response meta_info when waiting for aborted message {uid}, defaulting to 'abort'. Response: {response}"
)
else:
rollout_state.finish_reason = "error"
rollout_state.status = Status.FAILED
self.logger.warning(
f"finish_reason is missing in response meta_info for message {uid}, defaulting to 'error'. Response: {response}"
)
rollout_state.error_msg = "Missing finish_reason in response meta_info"
return rollout_state

if finish_reason == "abort" and self.receive_abort_request.is_set() is False:
self.receive_abort_request.set()
self.logger.info(f"Setting receive_abort_request to True for rank {self.rank}")

returned_response = response.get("text", "")
# 获取response_ids && respoonse_ids
if (
Expand All @@ -859,7 +878,7 @@ async def _safe_handle_response(self, rollout_state: RolloutState, http_response
)
routed_experts = response["meta_info"]["routed_experts"] # token[layer[expert]]
if routed_experts is not None:
routed_experts = self._decode_routed_experts(routed_experts)
routed_experts = await self._decode_routed_experts(routed_experts)
if not isinstance(routed_experts, ray.ObjectRef):
routed_experts = ray.put(routed_experts)

Expand Down
Loading