[tx] Lazy inference engine initialization#1069
Conversation
2260838 to
f0a7666
Compare
f0a7666 to
fb60f53
Compare
fb60f53 to
e2260e3
Compare
There was a problem hiding this comment.
Code Review
This pull request refactors the initialization of inference engines to be lazy, deferring their creation until they are first needed for sampling. This is a good optimization that avoids allocating resources for inference engines in training-only workflows and introduces a mechanism to sleep inference engines during training passes to free up GPU memory. A medium-severity Denial of Service vulnerability was identified, as the sample method does not validate user-provided sampling parameters, which could be abused to cause excessive resource consumption. Additionally, a critical issue exists where the new lazy initialization logic can lead to a crash if num_inference_engines is configured to 0, requiring graceful handling in sample() and save_sampler_checkpoint().
| return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices} | ||
| # 1. Ensure inference engines are initialized | ||
| self._ensure_inference_engines() | ||
|
|
There was a problem hiding this comment.
While lazy initialization is a great improvement, this change introduces a regression where sample() will crash if num_inference_engines is 0. Previously, there was a check (though likely buggy) to handle this. With lazy init, _ensure_inference_engines() will create an InferenceEngineClient with 0 engines, which then causes a crash inside _inference_engine_client.sample().
Please add a check to ensure there are engines available before proceeding with sampling.
if not self._inference_engine_client or not self._inference_engine_client.engines:
error = types.ErrorResponse(
error="Sampling not enabled. Inference engines were not initialized (num_inference_engines=0 in SkyRL config).",
status="error",
)
return {req_id: error for req_id, _, _, _, _ in prepared_batch.request_batch_slices}| asyncio.run(self._dispatch.save_weights_for_sampler()) | ||
| logger.info(f"Synced weights for {model_id} to inference engines via NCCL") |
There was a problem hiding this comment.
Similar to the sample method, this will crash if num_inference_engines is 0 because _dispatch.save_weights_for_sampler() will fail when there are no engines in the InferenceEngineClient. The weight sync should be skipped if no inference engines are configured.
if self._inference_engine_client and self._inference_engine_client.engines:
asyncio.run(self._dispatch.save_weights_for_sampler())
logger.info(f"Synced weights for {model_id} to inference engines via NCCL")
else:
logger.info("Skipping sampler weight sync: no inference engines configured.")c410869 to
e8137bb
Compare
Defer vLLM inference engine creation from create_model() to first sampling-related call (save_sampler_checkpoint or sample). SFT scripts that never sample no longer pay the inference engine memory cost. Sleep inference engines before forward/forward_backward when colocate_all=True so the training model can load without OOM. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Map lora_config.rank and alpha to the SkyRL-Train LoRA config so that LoRA requests actually create LoRA adapters instead of silently doing full fine-tuning. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
e8137bb to
b2b74fc
Compare
Summary
create_model()to first sampling-related call (save_sampler_checkpointorsample)save_weights_and_get_sampling_client()call