Skip to content

[Draft] Add NIXL transfer release cancellation hook#13495

Draft
yifjiang wants to merge 4 commits intoNVIDIA:mainfrom
yifjiang:codex/nixl-transfer-cancel-release
Draft

[Draft] Add NIXL transfer release cancellation hook#13495
yifjiang wants to merge 4 commits intoNVIDIA:mainfrom
yifjiang:codex/nixl-transfer-cancel-release

Conversation

@yifjiang
Copy link
Copy Markdown
Contributor

@yifjiang yifjiang commented Apr 27, 2026

Summary

This draft uses the same intended base/merge point as #13439:
4e69c14f732a6e6afce4f71616db5b5cd2b10530.

It keeps the conservative request-lifetime hardening from #13439, then adds a
NIXL transfer-release hook so TRT-LLM can release backend transfer handles when
cancellation is observed.

The key safety boundary is intentional: release() means the backend accepted
release of the transfer handle. It is not treated as proof that remote KV memory
is quiesced and immediately safe to recycle, especially for UCX-backed
one-sided RMA paths.

What Changed

Inherited from #13439:

  • C++ generation receive tracking keeps std::shared_ptr<LlmRequest> while the
    async receive future is outstanding.
  • Receiver worker queues keep std::shared_ptr<LlmRequest> while queued or
    executing.
  • Python global error handling fails closed when generation KV receive is still
    in flight instead of freeing request resources under raw C++ users.
  • Python termination avoids freeing resources while context transfer is still in
    progress.
  • Buffer-index slots use deterministic RAII cleanup.
  • Diagnostics around futures, buffer pools, request IDs, and transfer workers
    were expanded.

Added in this PR:

  • Added TransferStatus::release() to the C++ transfer-status interface.
  • Implemented NixlTransferStatus::release() with
    nixlAgent::releaseXferReq().
  • Made NixlTransferStatus release outstanding handles in its destructor as a
    final cleanup guard.
  • Changed sender-side transfer wait to poll in bounded intervals, observe
    getTransferTerminate(), and call status->release() before surfacing
    cancellation.
  • Changed sync and ready notification waits to return whether the expected
    notification actually arrived.
  • Made receive paths fail when termination stops a notification wait instead of
    treating that as a successful receive.
  • Exposed release() through nanobind and Python transfer-status wrappers.
  • Made sender cancellation complete the moved-out promise with an explicit
    request-specific cancellation exception after sender bookkeeping has been
    erased. This keeps the working v1 timing shape without relying on
    std::future_error: Broken promise.

Request Lifetime Before This Branch

Before the request-lifetime hardening, generation receive was vulnerable because
C++ tracked raw request pointers while Python could free resources through broad
error cleanup.

sequenceDiagram
    participant Py as Python executor
    participant CT as CacheTransceiver
    participant CR as CacheReceiver worker
    participant Req as LlmRequest
    participant RM as ResourceManager

    Py->>CT: start generation receive
    CT->>CR: queue raw LlmRequest pointer
    CT->>CT: store raw pointer plus future
    Py->>Py: broad error path handles active requests
    Py->>RM: free resources for request
    RM-->>Req: KV blocks and request resources may be released
    CR->>Req: later worker access through stale raw pointer
    CT->>Req: later status check through stale raw pointer
Loading

Impact: if the request or its KV resources were freed while the transfer worker
or status checker still had only raw references, a later access could crash or
corrupt memory.

Request Lifetime After This Branch

The current branch pins the request object in C++ until transfer completion or
error, and Python avoids freeing resources on broad generation-transfer errors.

sequenceDiagram
    participant Py as Python executor
    participant CT as CacheTransceiver
    participant CR as CacheReceiver worker
    participant Req as shared LlmRequest
    participant RM as ResourceManager

    Py->>CT: start generation receive
    CT->>CR: queue shared_ptr LlmRequest
    CT->>CT: store shared_ptr plus future
    Py->>Py: broad error while generation receive is in flight
    Py->>Py: fail closed without freeing active request resources
    CR-->>CT: future resolves or reports error
    CT->>Req: set transfer complete or transfer error
    Py->>RM: free resources only after C++ tracking has drained
Loading

Impact: the LlmRequest object remains valid while C++ workers and future
tracking can still touch it. Ambiguous generation receive failures remain
fail-closed because receiver-side cancellation cannot prove that a remote sender
is no longer writing into the target KV blocks.

Context Send Cancellation In Current v3

The first version of this PR recovered by abandoning a std::promise, which made
the waiting future ready with std::future_error: Broken promise. A later patch
made the future ready too early and caused a regression. The current v3 keeps the
working ordering but uses an explicit cancellation exception.

sequenceDiagram
    participant Py as Python timeout path
    participant CS as CacheSender
    participant Resp as queued Response
    participant Fut as sender future
    participant CT as CacheTransceiver status

    Py->>CS: cancel context request
    CS->>CS: send not-ready signal
    CS->>Resp: move response out of ready map
    CS->>CS: erase ready response and cancel bookkeeping
    CS->>CS: clear current request and ready state
    CS->>Fut: set explicit cancellation exception
    CT->>Fut: future.get observes cancellation
    CT->>CT: mark request DISAGG_TRANS_ERROR and erase future
Loading

This preserves the observed working cleanup path without depending on abandoned
promise semantics or racing the status checker against sender-side bookkeeping.

Generation Timeout In Current v3

The generation side still does not have a clean in-progress cancel path. When the
worker already owns the receive request, CacheReceiver::cancelRequest() can
return false and log Cannot cancel request. In v3, the loop is bounded because
the worker or future status path eventually resolves or errors and Python removes
the request from the active transfer path.

sequenceDiagram
    participant Py as Python timeout path
    participant CT as CacheTransceiver
    participant CR as CacheReceiver
    participant Fut as receiver future

    Py->>Py: KV transfer timeout flag is set
    Py->>CT: cancel generation request
    CT->>CR: cancelRequest
    CR-->>CT: false if worker already owns request
    CT-->>Py: cancellation still pending
    Py->>Py: later iterations retry while request remains active
    CR-->>Fut: worker completes or reports error
    CT->>Fut: future.get
    CT->>CT: set complete or DISAGG_TRANS_ERROR
    Py->>Py: active request cleanup stops retry loop
Loading

This is functionally acceptable in the current e2e burst harness, but it is not
as clean as the PR13301 deadline-driven path because it can still emit
Cannot cancel request log noise.

E2E Observations

Tested with pr13495 + pr13359 v3 on the 1P1D burst harness at concurrency
16, 48, and 128:

  • All three concurrencies recovered.
  • The v2 regression, where decode pods could emit millions of
    Cannot cancel request warnings and fail to recover, is fixed.
  • v3 recovery is a hybrid path: bounded Cannot cancel request markers, sparse
    Broken promise markers, and decode-side KV timeout markers.
  • The clean PR13301 reference path still has cleaner logs because its deadline
    cleanup avoids the Cannot cancel request loop.

Safety Notes

  • releaseXferReq() is treated as backend handle release or cancellation
    request, not as a proof that remote KV memory can be immediately recycled.
  • Receiver-side cancellation still cannot directly abort an already-issued
    remote sender RMA.
  • Ambiguous in-flight generation receive failures remain fail-closed to avoid
    freeing and reusing KV memory while a sender may still be writing.
  • [https://nvbugs/6104831][fix] Detach pruned trie children #13572 is a separate KV-cache trie/block-reuse correctness fix. It addresses a
    cascade-prune assertion in block reuse and eviction, not transport
    cancellation. It should land independently or be included in test stacks that
    exercise the NVBugs 6104831 burst workload.

Remaining Limitations

  • In-progress generation receive cancellation is still not a single clean
    per-request cancel path.
  • Cannot cancel request may still appear when the receiver worker has already
    popped the request from its queue.
  • Full scale or shadow-traffic testing has not been run from this PR branch.

Validation

  • git diff --check
  • PYTHONPYCACHEPREFIX=/tmp/trtllm-cancel-pycache python3 -m py_compile tensorrt_llm/_torch/disaggregation/base/agent.py tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py
  • User-provided e2e burst testing for pr13495 + pr13359 v3 at concurrency
    16, 48, and 128 recovered successfully.

Not yet run from this branch: full C++ build or TRT-LLM runtime test suite.

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Apr 27, 2026
@yifjiang yifjiang force-pushed the codex/nixl-transfer-cancel-release branch from 037258e to 1ce7e8a Compare April 27, 2026 19:40
Copy link
Copy Markdown
Contributor Author

Follow-up from E2E testing:

The previous version recovered from context-transfer cancellation because the sender erased a queued Response whose std::promise had not been fulfilled. That made the stored future become ready with std::future_error: Broken promise, and checkContextTransferStatus() then caught it and marked the request DISAGG_TRANS_ERROR.

This branch now makes that unwind explicit: when the sender takes the cancelled isReady=false path, it sets a request-specific cancellation exception on the promise before erasing the response. The intended per-request cleanup path is unchanged, but the recovery no longer depends on abandoned-promise semantics or logs misleading Broken promise errors.

@yifjiang yifjiang force-pushed the codex/nixl-transfer-cancel-release branch from 1ce7e8a to a0bbb57 Compare April 27, 2026 23:56
Copy link
Copy Markdown
Contributor Author

Correction to the previous follow-up:

The first explicit-exception patch changed the ordering: it made the sender future ready before the cancelled response was removed from mReadyResponses and before sender bookkeeping was cleaned. That can race the status checker/error cleanup against sender-side cancellation cleanup.

The branch has been amended again so cancellation now moves the Response out, removes mReadyResponses / mCancelledRequests / mRemainSendCount, updates sender state, and only then sets the explicit cancellation exception on the moved-out promise. This keeps the old broken-promise timing shape while making the future error intentional.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants