[TRTLLM-8922][feat] py cache transceiver for gen-first workflow#11941
[TRTLLM-8922][feat] py cache transceiver for gen-first workflow#11941reasonsolo merged 3 commits intoNVIDIA:mainfrom
Conversation
68db95a to
8216528
Compare
6e65ff3 to
de925e2
Compare
c5f103d to
d43fcc0
Compare
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThis pull request refactors the disaggregation transfer system by replacing SessionArgsBase with a new SessionBase abstraction hierarchy (SessionBase, TxSessionBase, RxSessionBase), adding auxiliary buffer slot management with AuxSlot return types, updating transfer implementations to support generation-first scheduling workflows, and extending test coverage to validate multiple transfer order strategies. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Transceiver as PyNativeCacheTransceiver
participant Scheduler as Scheduler
participant Transfer as Transfer System<br/>(TxSession/RxSession)
participant Buffer as Aux Buffer
Client->>Transceiver: Submit context request
Transceiver->>Transfer: prepare_context_requests()
Transfer->>Transfer: Create TxSession<br/>allocate AuxSlot
Transfer->>Buffer: alloc_slot() → AuxSlot
Buffer-->>Transfer: AuxSlot{id, buffer}
Scheduler->>Transceiver: check_context_transfer_status()
Transceiver->>Transfer: Poll session readiness
alt Session READY
Transceiver->>Transfer: dispatch_all_tasks()
Transfer->>Transfer: Start KV transfer
end
Client->>Transceiver: Submit generation request
Transceiver->>Transfer: respond_and_send_async()
Transfer->>Transfer: Send KV slice
alt Generation-First Mode
Transceiver->>Transfer: _need_aux_transfer()
Transfer->>Transfer: send_aux() if needed
end
Scheduler->>Transceiver: check_gen_transfer_status()
Transceiver->>Transfer: wait_complete(task_id,<br/>wait_aux=true)
Transfer->>Buffer: Verify transfer complete
Transfer-->>Transceiver: Completion status
sequenceDiagram
participant App as Application
participant TxSession
participant RxSession
participant Messenger as Messenger<br/>(IPC)
participant Buffer as Aux Buffer
App->>TxSession: dispatch_all_tasks(req_info)
TxSession->>TxSession: Create KVSendTask<br/>Create AuxSendTask
TxSession->>Messenger: Send KV metadata<br/>Send aux metadata
Messenger-->>RxSession: Receive metadata
RxSession->>RxSession: Parse metadata via<br/>create_req_info()
RxSession->>Buffer: Allocate slot via<br/>alloc_slot()
Buffer-->>RxSession: AuxSlot{id, buffer}
App->>TxSession: wait_complete(task_id,<br/>wait_aux=true)
TxSession->>Messenger: Poll transfer status
Messenger-->>TxSession: Status update
App->>RxSession: wait_complete(task_id,<br/>wait_aux=true)
RxSession->>RxSession: unpack_aux()<br/>materialize data
RxSession-->>App: Ready
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (4)
tests/unittest/disaggregated/test_py_cache_transceiver_mp.py (1)
941-942: Hardcodedtime.sleep(3)may cause flaky tests.A 3-second sleep is fragile—it may be too short under load or unnecessarily slow otherwise. Consider polling with a timeout loop similar to the pattern used in
_run_gen_first1_transfer(lines 906-915).♻️ Proposed polling approach
dist.barrier() - time.sleep(3) # wait for the receive requests to be submitted + # Poll until context side detects peer info (or timeout) + # This replaces the fragile fixed sleep.However, given the test structure where gen submits first and ctx prepares after, the barrier should suffice for synchronization. If additional delay is truly needed, document the reason or use a shorter sleep with retry logic.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/disaggregated/test_py_cache_transceiver_mp.py` around lines 941 - 942, Replace the fragile fixed sleep after dist.barrier() with a bounded polling loop that checks the condition used elsewhere (reuse the pattern from _run_gen_first1_transfer) to wait until receive requests are submitted or a timeout elapses; specifically, remove time.sleep(3) and implement a retry/polling loop that queries the same status/flag that indicates receives are ready, with a short interval and an overall timeout, logging or failing the test if the timeout is hit (alternatively, if barrier truly guarantees sync, replace the 3s sleep with a short documented sleep + retry to reduce flakiness).tensorrt_llm/_torch/disaggregation/base/transfer.py (1)
109-113: Consider fallback whenget_unique_rid(request)returnsNone.If
requestis provided butrequest.py_disaggregated_paramsisNoneordisagg_request_idisNone,_unique_ridwill beNoneeven ifunique_ridparameter was explicitly passed. Consider using the parameter as fallback:♻️ Proposed fix
def __init__(self, request: Optional[LlmRequest], unique_rid: Optional[int] = None): self._request = request - self._unique_rid: int = get_unique_rid(request) if request else unique_rid + self._unique_rid: int = get_unique_rid(request) if request else None + if self._unique_rid is None: + self._unique_rid = unique_rid self._state = SessionState(status=SessionStatus.INIT, finished_tasks=[]) self._exception: Optional[Exception] = None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/disaggregation/base/transfer.py` around lines 109 - 113, In __init__, get_unique_rid(request) can return None even when the caller passed unique_rid, so change the _unique_rid assignment to use the unique_rid parameter as a fallback: call get_unique_rid(request) only when request is present, check its return value for None, and if it is None (or request is None) set self._unique_rid to the unique_rid argument; update references to _unique_rid accordingly so callers receive the explicit unique_rid when provided.tensorrt_llm/_torch/disaggregation/native/transfer.py (1)
931-949: Broad exception handling inwait_complete.The
except Exceptioncatches all exceptions which is flagged by static analysis (BLE001). While this defensive approach prevents crashes, consider narrowing to specific expected exceptions if possible.♻️ Consider narrowing exception scope
except concurrent.futures.TimeoutError: logger.warning( f"TxSession {self.unique_rid} timed out waiting for completion after {timeout_ms} milliseconds." ) return False - except Exception as e: + except (RuntimeError, concurrent.futures.CancelledError) as e: logger.error(f"Exception in TxSession.wait_complete: {e}") return FalseIf other exception types are expected, add them explicitly to the tuple.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/disaggregation/native/transfer.py` around lines 931 - 949, The broad except Exception in wait_complete should be narrowed to only the expected error types: replace the generic except Exception as e with explicit handlers for the likely cases (e.g., KeyError for missing task_id in _kv_tasks, concurrent.futures.CancelledError for cancelled futures, and RuntimeError for other runtime issues) and handle/log them similarly; keep a fallback that re-raises truly unexpected exceptions (or let them propagate) instead of swallowing them silently; reference function wait_complete, attributes _kv_tasks and _aux_task, and use unique_rid in the log messages when keeping the specific exception handlers.tensorrt_llm/_torch/disaggregation/native/py_cache_transceiver.py (1)
241-253: Consider simplifying the status checking logic.The
SessionStatus.ERRORcheck is duplicated in both branches. This can be consolidated:♻️ Proposed simplification
for request_id, session in self.recv_sessions.items(): req = self.recv_req_id_to_request[request_id] need_aux_transfer = self._need_aux_transfer(req) session_status = session.state.status - if need_aux_transfer: - if session_status == SessionStatus.AUX_TRANSFERRED: - local_completed_request_ids.append(request_id) - elif session_status == SessionStatus.ERROR: - local_failed_request_ids.append(request_id) - elif session_status == SessionStatus.TRANSFERRED: + + if session_status == SessionStatus.ERROR: + local_failed_request_ids.append(request_id) + elif need_aux_transfer and session_status == SessionStatus.AUX_TRANSFERRED: + local_completed_request_ids.append(request_id) + elif not need_aux_transfer and session_status == SessionStatus.TRANSFERRED: local_completed_request_ids.append(request_id) - elif session_status == SessionStatus.ERROR: - local_failed_request_ids.append(request_id)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/disaggregation/native/py_cache_transceiver.py` around lines 241 - 253, The loop duplicates the SessionStatus.ERROR check in both the aux-transfer and non-aux branches; simplify by evaluating need_aux_transfer = self._need_aux_transfer(req) and session_status = session.state.status, then handle the success/completion cases (e.g., if need_aux_transfer and session_status == SessionStatus.AUX_TRANSFERRED -> append to local_completed_request_ids; elif not need_aux_transfer and session_status == SessionStatus.TRANSFERRED -> append to local_completed_request_ids), and after those checks do a single if session_status == SessionStatus.ERROR -> append to local_failed_request_ids; use the existing names self.recv_sessions, self.recv_req_id_to_request, _need_aux_transfer, SessionStatus, local_completed_request_ids, and local_failed_request_ids to locate and update the code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/disaggregation/base/transfer.py`:
- Around line 128-136: The setter decorated with `@request.setter` is incorrectly
named set_request so it doesn't register as the property setter; rename the
method from set_request to request(self, request: LlmRequest) (or remove the
decorator and use a plain method) so the `@request.setter` pairs with the property
getter, keep the existing logic using get_unique_rid(request) to assert equality
with self.unique_rid and assign self._request = request to set the session
request.
- Around line 188-194: The abstract method send currently lacks a method body
placeholder; add a placeholder implementation (e.g., a single-line ellipsis or
raise NotImplementedError) inside def send(self, slice: KVSlice) -> TaskIdType
so the abstract method is syntactically valid and consistent with other abstract
methods in the class; ensure the signature (send, KVSlice, TaskIdType) is
unchanged.
In `@tensorrt_llm/_torch/disaggregation/native/transfer.py`:
- Around line 878-885: The constructor parameter unique_rid is implicitly typed
as int but defaults to None; update its annotation to Optional[int] to comply
with PEP 484 (e.g., change "unique_rid: int = None" to "unique_rid:
Optional[int]") in the TxSession.__init__ signature and adjust any related type
hints or docstrings referencing unique_rid in the TxSession class if present.
- Around line 1339-1356: The function unpack_aux currently returns request but
has no return type annotation; update the signature of unpack_aux to include the
correct return type (the request object type used in this module) so it reads
like def unpack_aux(self) -> <RequestType>: and keep the final return request,
or if callers never use the returned value, remove the trailing return request
and leave the function as returning None; locate unpack_aux and adjust either
its signature or remove the return, and ensure callers of unpack_aux (who expect
the request) are updated to match the chosen change (symbols to check:
unpack_aux, self.request, request, ContextPhaseParams, request.py_draft_tokens).
- Around line 951-952: The TxSession class is missing initialization of
self._expected_transfers referenced by is_peer_info_ready; update
TxSession.__init__ to set self._expected_transfers to an appropriate default
(e.g., 0 or an empty structure consistent with _sender._peer_reqs.is_ready
expectations) and ensure any code path that should set a real expected transfer
count updates this attribute (search for TxSession.__init__, is_peer_info_ready,
unique_rid, and _sender to place the initialization and any subsequent
assignments).
---
Nitpick comments:
In `@tensorrt_llm/_torch/disaggregation/base/transfer.py`:
- Around line 109-113: In __init__, get_unique_rid(request) can return None even
when the caller passed unique_rid, so change the _unique_rid assignment to use
the unique_rid parameter as a fallback: call get_unique_rid(request) only when
request is present, check its return value for None, and if it is None (or
request is None) set self._unique_rid to the unique_rid argument; update
references to _unique_rid accordingly so callers receive the explicit unique_rid
when provided.
In `@tensorrt_llm/_torch/disaggregation/native/py_cache_transceiver.py`:
- Around line 241-253: The loop duplicates the SessionStatus.ERROR check in both
the aux-transfer and non-aux branches; simplify by evaluating need_aux_transfer
= self._need_aux_transfer(req) and session_status = session.state.status, then
handle the success/completion cases (e.g., if need_aux_transfer and
session_status == SessionStatus.AUX_TRANSFERRED -> append to
local_completed_request_ids; elif not need_aux_transfer and session_status ==
SessionStatus.TRANSFERRED -> append to local_completed_request_ids), and after
those checks do a single if session_status == SessionStatus.ERROR -> append to
local_failed_request_ids; use the existing names self.recv_sessions,
self.recv_req_id_to_request, _need_aux_transfer, SessionStatus,
local_completed_request_ids, and local_failed_request_ids to locate and update
the code.
In `@tensorrt_llm/_torch/disaggregation/native/transfer.py`:
- Around line 931-949: The broad except Exception in wait_complete should be
narrowed to only the expected error types: replace the generic except Exception
as e with explicit handlers for the likely cases (e.g., KeyError for missing
task_id in _kv_tasks, concurrent.futures.CancelledError for cancelled futures,
and RuntimeError for other runtime issues) and handle/log them similarly; keep a
fallback that re-raises truly unexpected exceptions (or let them propagate)
instead of swallowing them silently; reference function wait_complete,
attributes _kv_tasks and _aux_task, and use unique_rid in the log messages when
keeping the specific exception handlers.
In `@tests/unittest/disaggregated/test_py_cache_transceiver_mp.py`:
- Around line 941-942: Replace the fragile fixed sleep after dist.barrier() with
a bounded polling loop that checks the condition used elsewhere (reuse the
pattern from _run_gen_first1_transfer) to wait until receive requests are
submitted or a timeout elapses; specifically, remove time.sleep(3) and implement
a retry/polling loop that queries the same status/flag that indicates receives
are ready, with a short interval and an overall timeout, logging or failing the
test if the timeout is hit (alternatively, if barrier truly guarantees sync,
replace the 3s sleep with a short documented sleep + retry to reduce flakiness).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: fa92f678-2e6a-4240-ae16-2b5275fb6b66
📒 Files selected for processing (6)
tensorrt_llm/_torch/disaggregation/base/transfer.pytensorrt_llm/_torch/disaggregation/native/messenger.pytensorrt_llm/_torch/disaggregation/native/py_cache_transceiver.pytensorrt_llm/_torch/disaggregation/native/region/aux_.pytensorrt_llm/_torch/disaggregation/native/transfer.pytests/unittest/disaggregated/test_py_cache_transceiver_mp.py
|
PR_Github #38200 [ run ] triggered by Bot. Commit: |
Thanks for reminder, I've run |
|
PR_Github #38673 [ run ] completed with state
|
01a6e13 to
1d3fc9b
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #38803 [ run ] triggered by Bot. Commit: |
|
PR_Github #38803 [ run ] completed with state
|
1d3fc9b to
a836107
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #38856 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
2d8e92e to
5877efd
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #38874 [ run ] triggered by Bot. Commit: |
|
PR_Github #38875 [ run ] triggered by Bot. Commit: |
|
PR_Github #38875 [ run ] completed with state
|
…eq for tp ranks, fix hang and lock usage Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
5877efd to
d51b93f
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #38959 [ run ] triggered by Bot. Commit: |
|
PR_Github #38959 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #38976 [ run ] triggered by Bot. Commit: |
|
PR_Github #38976 [ run ] completed with state |
Summary by CodeRabbit
New Features
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.