[tinker] Abort/retry recovery for forwarded multi-tenant LoRA samples#1669
[tinker] Abort/retry recovery for forwarded multi-tenant LoRA samples#1669erictang000 wants to merge 2 commits into
Conversation
Adds three control-plane endpoints on the vLLM server actor and a submission-gate middleware so out-of-process callers (e.g. the Tinker SkyRLTrainInferenceForwardingClient) can observe the same pause window as in-process RemoteInferenceClient.sample_with_retry callers. - POST /skyrl/v1/abort_lora_requests: now also clears a per-LoRA asyncio.Event in app.state.paused_loras before the engine.abort() fan-out. - POST /skyrl/v1/resume_lora_requests: sets the event. - POST /skyrl/v1/wait_lora_unpaused: long-polls until the event is set or the caller-supplied timeout elapses. - Pure-ASGI middleware blocks fresh /v1/completions and /v1/chat/completions for a paused LoRA until resume, closing the race between abort_lora_requests and load_lora_adapter where a new request could observe torn adapter weights. RemoteInferenceClient.resume_generation now also POSTs the new /skyrl/v1/resume_lora_requests so the server-side flag matches the in-process gate. Existing test_pause_lora.py callers don't observe a return-value contract change. All endpoints carry the same TRANSIENT marker as the existing abort endpoint: delete the stack when vLLM ships native per-LoRA pause. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Builds on the server-side pause-state endpoints in the parent PR to make
``SkyRLTrainInferenceForwardingClient`` participate in per-LoRA pause
cycles. Previously, an in-flight Tinker asample for a LoRA being
weight-synced by the trainer would:
1. Receive ``finish_reason="abort"`` from vLLM and silently map it to
``stop_reason="length"`` in the forwarding client (truncating result),
or
2. Race ``load_lora_adapter`` if submitted during the pause window.
Changes:
* ``EngineStateDB.inference_server_urls``: new JSON column carrying the
vLLM workers' direct URLs. vllm-router is data-plane-only (verified at
remote_inference_client.py:12-30) and does not forward control-plane
endpoints, so the forwarding client needs at least one worker URL to
reach ``/skyrl/v1/wait_lora_unpaused``.
* ``SkyRLTrainBackend._publish_inference_state`` and
``TinkerEngine._write_inference_state_to_db`` now take
``(proxy_url, server_urls)`` and persist both. ``Callable`` signature
on the publisher updated accordingly.
* ``SkyRLTrainInferenceForwardingClient.call_and_store_result`` now
mirrors ``RemoteInferenceClient.sample_with_retry`` for the per-LoRA
``num_samples=1`` path: dispatch ``/v1/completions``, and on
``finish_reason="abort"`` accumulate partial tokens, long-poll
``/skyrl/v1/wait_lora_unpaused``, then resubmit with
``prompt + accumulated`` and remaining ``max_tokens``. Loops until a
non-abort finish. The base-model and ``num_samples > 1`` paths stay
single-shot — an abort there now surfaces as ``FutureDB.status=FAILED``
rather than silently corrupting the result.
Tests: ``tests/tinker/skyrl_train/test_pause_async_sample.py`` mirrors
the four tests in ``test_pause_lora.py`` but drives the forwarding client
directly:
1. Pausing LoRA A doesn't affect forwarding samples for LoRA B.
2. In-flight forwarding samples for paused LoRA A recover from abort
and contain the right LoRA content after retry.
3. A single forwarding sample spans a real Meow→Woof weight swap;
merged tokens contain both adapters' signatures in order.
4. A FRESH sample submitted while a LoRA is paused (no in-flight to
abort) blocks at the server-side submission gate until resume,
closing the torn-weights race that the gate is there to prevent.
All new code carries the same TRANSIENT marker as ``sample_with_retry``:
delete the abort/retry plumbing when vLLM ships native per-LoRA pause.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request implements a multi-tenant LoRA-aware pause and resume mechanism across the inference server and the Tinker forwarding client. Key changes include the addition of a submission gate middleware to block requests for paused LoRAs, new control-plane endpoints for resuming and waiting on LoRA status, and an abort-aware retry loop in the SkyRLTrainInferenceForwardingClient that utilizes long-polling. Feedback highlights the need for safer indexing when accessing server URLs and completion choices, as well as validation for max_tokens to avoid potential type errors during the retry process.
| if row.inference_server_urls: | ||
| server_url = row.inference_server_urls[0] |
There was a problem hiding this comment.
Accessing the first element of inference_server_urls without checking if the list is non-empty will raise an IndexError if the backend published an empty list of worker URLs. While unlikely in normal operation, it's safer to handle this case.
| if row.inference_server_urls: | |
| server_url = row.inference_server_urls[0] | |
| if row.inference_server_urls and len(row.inference_server_urls) > 0: | |
| server_url = row.inference_server_urls[0] |
| wait_url = f"{server_url}/skyrl/v1/wait_lora_unpaused" if server_url else None | ||
| retry_eligible = base_model is None and num_samples == 1 and wait_url is not None | ||
|
|
||
| original_max_tokens = sp.max_tokens |
There was a problem hiding this comment.
The retry loop logic relies on original_max_tokens being a positive integer to bound the generation and decrement the budget correctly. If sp.max_tokens is None or non-integer, line 186 will raise a TypeError. Consider adding validation similar to RemoteInferenceClient.sample_with_retry to ensure max_tokens is a positive integer when retry_eligible is true.
| return self._build_sample_output(choices, abort_recovery=False) | ||
|
|
||
| # num_samples == 1, LoRA — abort-aware retry path. | ||
| choice = choices[0] |
There was a problem hiding this comment.
Summary
Builds on #1666 to make
SkyRLTrainInferenceForwardingClientrecover from per-LoRA aborts the same way the in-processRemoteInferenceClient.sample_with_retry(#1657) does. Closes the fully-async multi-tenant Tinker path: a Tinker SDKasamplefor a LoRA adapter whose weights are being synced now transparently waits for the sync to finish instead of failing or silently returning truncated tokens.This PR is stacked on #1666 — please review #1666 first.
What was broken
For a LoRA adapter being weight-synced via
RemoteInferenceClient.pause_generation(lora_name=X), in-flight Tinker asample requests for X (routed throughSkyRLTrainInferenceForwardingClientin non-colocated Megatron/FSDP) were not handled correctly:finish_reason="abort"in the/v1/completionsresponse. The forwarding client'sfinish_reason → stop_reasonmap (skyrl_train_inference_forwarding.py:169before this PR) lumps any non-stop/stop_tokenreason into"length". Result: SDK consumer sees a truncated, half-finished completion with no indication that it was aborted.load_lora_adapterand observing torn adapter weights.How
EngineStateDB schema: new
inference_server_urls(JSON list) column. The vllm-router forwards data-plane endpoints only (remote_inference_client.py:12-30is explicit about this), so the forwarding client needs direct worker URLs to call/skyrl/v1/wait_lora_unpaused.Publisher signature:
set_inference_state_publisherand the engine-side writer now take(proxy_url, server_urls);SkyRLTrainBackend._publish_inference_stateis updated to passserver_setup.server_urls.Forwarding client retry loop:
call_and_store_resultfor the per-LoRA,num_samples=1path mirrorsRemoteInferenceClient.sample_with_retry:/v1/completions. Onfinish_reason="abort", append partialtoken_idsandlogprobsto local accumulators./skyrl/v1/wait_lora_unpaused(new endpoint in [multi-tenant-lora] Expose per-LoRA pause state over HTTP #1666) on a worker URL; loop on{paused: true}returns.prompt + accumulatedandmax_tokens = original - len(accumulated). Loops until non-abort or budget exhausted.prompt_logprobsback tolen(original_prompt)to match the in-process implementation.Base-model sampling and
num_samples > 1keep the single-shot path. An abort there raises so the consumer seesFutureDB.status=FAILEDinstead of silent corruption.Test plan
New file
tests/tinker/skyrl_train/test_pause_async_sample.pymirrors the four tests intests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_pause_lora.py(#1657) but drives the forwarding client directly rather thanRemoteInferenceClient:test_pause_lora_does_not_affect_other_lora_via_forwarding— while LoRA A is paused, forwarding-client samples for LoRA B complete promptly.test_forwarding_client_recovers_from_abort— 4 concurrent meow + 4 concurrent woof forwarding samples; pause meow mid-flight, await woof first (cross-LoRA isolation), then resume and verify meow tasks complete with"meow"content.test_forwarding_client_observes_mid_sample_weight_swap— one forwarding sample forlora-targetspans a Meow→Woof weight swap; merged output contains both adapter signatures with "meow" before "woof".test_fresh_request_during_pause_blocks_until_resume— verifies the submission-gate middleware from [multi-tenant-lora] Expose per-LoRA pause state over HTTP #1666: a sample submitted after pause (with no in-flight to abort) blocks server-side until resume, then completes normally.ruff+blackclean.Existing non-GPU Tinker suites (
test_db,test_api_validation,test_api) pass with the publisher signature change.The four GPU integration tests require 1 H100/A100/B200 and the Qwen3-0.6B Meow/Woof LoRA fixtures (auto-downloaded via
snapshot_download). Same shape astest_pause_lora.py.🤖 Generated with Claude Code