Skip to content

[skyrl] Pass through profiler_config to vLLM engine#1622

Open
pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
pcmoritz:torch-profile
Open

[skyrl] Pass through profiler_config to vLLM engine#1622
pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
pcmoritz:torch-profile

Conversation

@pcmoritz
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz commented May 5, 2026

Can be used via

uv run --extra tinker --extra megatron -m skyrl.tinker.api --base-model "Qwen/Qwen3.5-4B" --backend megatron --backend-config '{"trainer": {"train_batch_size":128,"policy_mini_batch_size":128,"micro_train_batch_size_per_gpu":1,"micro_forward_batch_size_per_gpu":1,"use_sample_packing":false,"policy":{"megatron_config":{"optimizer_config_kwargs":{"use_precision_aware_optimizer":true},"transformer_config_kwargs":{"gradient_accumulation_fusion":false}}},"placement":{"colocate_all":false,"policy_num_gpus_per_node":4,"ref_num_gpus_per_node":4}},"generator":{"inference_engine":{"tensor_parallel_size":4,"engine_init_kwargs":{"profiler_config":{"profiler":"torch","torch_profiler_dir":"/mnt/shared_storage/vllm_profile"}}}}}' > out.log

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements profiling support for vLLM inference engines and remote clients, adding start_profile and stop_profile functionality across the stack and improving error reporting in the remote client by capturing response bodies on failure. Feedback indicates that AsyncVLLMInferenceEngine requires a lock to serialize profiling during concurrent generate calls, as vLLM's profiler is global and does not support overlapping spans.

self._weight_loader = VLLMWeightLoader(self.llm, is_async=True)
# vLLM raises if profile() is called without profiler_config; gate on it.
self._profile_enabled = self.llm.vllm_config.profiler_config.profiler is not None
self._profile_counter = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since AsyncVLLMInferenceEngine is an asynchronous engine, multiple generate calls can run concurrently. Because vLLM's profiler is global, concurrent profiling attempts will conflict. We should initialize a lock here to serialize profiling in the generate method.

Suggested change
self._profile_counter = 0
self._profile_counter = 0
self._profile_lock = asyncio.Lock()

Comment on lines +495 to +509
if self._profile_enabled:
await self.llm.start_profile(profile_prefix=f"sample_{self._profile_counter}")
self._profile_counter += 1
try:
tasks = []
for prompt in prompt_token_ids:
# Schedule the collection of outputs for each prompt.
# Avoid duplicate request_ids
request_id = str(uuid4().hex)
task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params))
tasks.append(task)
outputs = await asyncio.gather(*tasks)
finally:
if self._profile_enabled:
await self.llm.stop_profile()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

As mentioned in the __init__ comment, vLLM's profiler is global and does not support concurrent profiling spans. If multiple generate calls occur simultaneously while profiling is enabled, they will interfere with each other. This block should be protected by a lock to ensure that only one profiling session is active at a time.

Suggested change
if self._profile_enabled:
await self.llm.start_profile(profile_prefix=f"sample_{self._profile_counter}")
self._profile_counter += 1
try:
tasks = []
for prompt in prompt_token_ids:
# Schedule the collection of outputs for each prompt.
# Avoid duplicate request_ids
request_id = str(uuid4().hex)
task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params))
tasks.append(task)
outputs = await asyncio.gather(*tasks)
finally:
if self._profile_enabled:
await self.llm.stop_profile()
async def _do_generate():
tasks = [
asyncio.create_task(self._collect_outputs(p, str(uuid4().hex), sampling_params))
for p in prompt_token_ids
]
return await asyncio.gather(*tasks)
if self._profile_enabled:
async with self._profile_lock:
await self.llm.start_profile(profile_prefix=f"sample_{self._profile_counter}")
self._profile_counter += 1
try:
outputs = await _do_generate()
finally:
await self.llm.stop_profile()
else:
outputs = await _do_generate()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant