[TRTLLM-11657][feat] Conversation affinity disagg router#12526
[TRTLLM-11657][feat] Conversation affinity disagg router#12526longlee0622 merged 9 commits intoNVIDIA:mainfrom
Conversation
77c3665 to
feb8e28
Compare
Extract common load-balancing logic shared between LoadBalancingRouter and KvCacheAwareRouter into a reusable LoadBalancingMixin class, reducing duplication of server state management and request tracking. Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Signed-off-by: Lizhi Zhou <lizhiz@nvidia.com> Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Signed-off-by: Lizhi Zhou <lizhiz@nvidia.com> Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
📝 WalkthroughWalkthroughThese changes introduce conversation-aware routing to the TensorRT-LLM serving infrastructure. Key additions include: (1) extracting conversation identity from request headers, (2) propagating conversation IDs through disaggregated payloads, (3) refactoring router logic into reusable mixins for load-balancing and KV-cache handling, (4) implementing a new Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant OpenAIServer as OpenAI Disagg<br/>Server
participant OpenAIService as OpenAI Disagg<br/>Service
participant ConvRouter as Conversation<br/>Router
participant BackendServers as Backend<br/>Servers
Client->>OpenAIServer: HTTP Request<br/>(x-correlation-id header)
OpenAIServer->>OpenAIServer: _extract_conversation_id()
Note over OpenAIServer: Extract conversation_id<br/>from header, populate<br/>disaggregated_params
OpenAIServer->>OpenAIService: Route request<br/>(with conversation_id)
alt Session Affinity Match
OpenAIService->>ConvRouter: Route with explicit<br/>conversation_id
ConvRouter->>ConvRouter: Check session_table<br/>for conversation_id
ConvRouter->>BackendServers: Send to affinity<br/>mapped server
else Implicit Continuation Match
OpenAIService->>ConvRouter: Route with prompt tokens
ConvRouter->>ConvRouter: Compute block hash,<br/>search trie for<br/>longest prefix match
ConvRouter->>BackendServers: Send to best<br/>matching server
else No Match
OpenAIService->>ConvRouter: Route new conversation
ConvRouter->>ConvRouter: Select least-loaded<br/>server
ConvRouter->>BackendServers: Send to selected<br/>server
end
BackendServers->>ConvRouter: Response + usage
ConvRouter->>OpenAIService: Return response
OpenAIService->>OpenAIServer: Return response
OpenAIServer->>Client: HTTP Response
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/serve/router.py (1)
739-769:⚠️ Potential issue | 🟠 MajorAlign
workloadswith the filtered candidate list.
serversremovesexclude_server, butworkloadsstill iterates over every_server_state. If the excluded server is not the last entry, the load from server N gets applied to server N+1, so the KV/load score is computed against the wrong host exactly on the path that usesexclude_server.🛠️ Suggested fix
async with self._lock: servers = [ server for server in self._server_state.keys() if server != exclude_server ] + if not servers: + raise ValueError( + f"No available servers after excluding {exclude_server}") token_lists = self._tokenize(request) block_hashes = self._compute_block_hashes(token_lists) @@ - workloads = [ - state.num_active_requests() - for state in self._server_state.values() - ] + workloads = [ + self._server_state[server].num_active_requests() + for server in servers + ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/serve/router.py` around lines 739 - 769, The workloads list is misaligned because it iterates over self._server_state.values() while servers was filtered to exclude exclude_server; change workloads to compute per-candidate by iterating over the filtered servers list (e.g., workloads = [self._server_state[server].num_active_requests() for server in servers]) so each workload entry corresponds to the same index in servers used later for matches/score; ensure lengths of workloads, matches, and scores all match servers before selecting winner (symbols: servers, workloads, self._server_state, matched_tokens, _tokens_per_block, _max_batch_size, _rr_counter).tensorrt_llm/serve/openai_protocol.py (1)
118-130:⚠️ Potential issue | 🟡 MinorDocument
conversation_idon the public schema.
DisaggregatedParamsis part of the request/response surface, so the new field should useField(..., description=...)instead of a bare attribute to keep generated API docs complete.📝 Suggested change
- conversation_id: Optional[str] = None + conversation_id: Optional[str] = Field( + default=None, + description="Conversation/session identifier used for affinity-aware disaggregated routing.", + )As per coding guidelines, "Add descriptions to all user-facing Pydantic fields via
Field(description="...")rather than using comments".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/serve/openai_protocol.py` around lines 118 - 130, The public schema field conversation_id on the DisaggregatedParams model is missing a Field description; update the declaration in class DisaggregatedParams to use pydantic.Field(..., description="...") (e.g., Field(None, description="Human-readable conversation identifier used to correlate disaggregated requests/responses")) instead of a bare Optional[str] so generated API docs include the purpose of conversation_id; keep the type and optionality the same and follow the same description style used by other fields in OpenAIBaseModel-derived classes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/serve/openai_disagg_server.py`:
- Around line 181-188: Do not inject a synthetic DisaggregatedParams into req;
instead keep the header-derived conversation_id out-of-band and only attach real
per-phase disaggregated_params inside _get_ctx_request() and _get_gen_request(),
or if you must set req.disaggregated_params for other callers, ensure you remove
or clear req.disaggregated_params before forwarding to the generation server in
the ctx-first fast path (the branch controlled by need_ctx=False in the
service). Locate uses of req.disaggregated_params and the DisaggregatedParams
constructor in openai_disagg_server.py and change the logic so that
conversation_id is stored separately (e.g., a local variable or a
request-metadata field) and only build/assign DisaggregatedParams inside
_get_ctx_request/_get_gen_request(), or explicitly strip
req.disaggregated_params right before the direct-gen send to avoid mislabeling
the worker execution path.
In `@tensorrt_llm/serve/router.py`:
- Around line 893-898: The constructor currently sets max_sessions=10000 which
contradicts the feature contract expecting 100000; update the default value for
the max_sessions parameter in the constructor signature (the router/__init__ or
relevant constructor where max_sessions is declared) to 100000, and adjust any
related docstring/comments/tests that assert the old default so the affinity
table defaults to 100k sessions as specified.
- Around line 901-902: ConversationRouter records per-server content weight in
_server_content_load but fallback placement still uses _select_least_loaded
which reads _num_active_requests because _init_load_balancing was called with
use_tokens=False; update the fallback selection to consider content weight (or
reinitialize load balancing with use_tokens=True) so routing uses the new
metric. Locate _init_load_balancing and either set use_tokens=True when
tokens_per_block/content-weighting is enabled, or modify _select_least_loaded
(and any callers at lines around 923–930 and 1215–1220) to combine
_server_content_load with _num_active_requests (e.g., compute a weighted score)
so servers with higher estimated token load are deprioritized. Ensure references
to _server_content_load and _select_least_loaded are updated consistently.
In `@tests/unittest/disaggregated/test_router.py`:
- Around line 747-773: The test test_conversation_router_hash_skip_count assumes
the shared system prefix will push ConversationRouter's implicit-match logic
over match_threshold, but with defaults (match_threshold=0.75,
tokens_per_block=128) it does not; update the test to align with the
implementation by changing the router instantiation (ConversationRouter in this
test) or inputs so the shared prefix actually meets the threshold — e.g., set a
lower match_threshold (e.g., 0.2) or reduce tokens_per_block, or increase
sys_prompt length so the number of shared blocks reaches ~75%; adjust the r1/r2
constructors (where servers, sys_prompt and hash_skip_count are used)
accordingly so the assertions reflect the router's real behavior.
---
Outside diff comments:
In `@tensorrt_llm/serve/openai_protocol.py`:
- Around line 118-130: The public schema field conversation_id on the
DisaggregatedParams model is missing a Field description; update the declaration
in class DisaggregatedParams to use pydantic.Field(..., description="...")
(e.g., Field(None, description="Human-readable conversation identifier used to
correlate disaggregated requests/responses")) instead of a bare Optional[str] so
generated API docs include the purpose of conversation_id; keep the type and
optionality the same and follow the same description style used by other fields
in OpenAIBaseModel-derived classes.
In `@tensorrt_llm/serve/router.py`:
- Around line 739-769: The workloads list is misaligned because it iterates over
self._server_state.values() while servers was filtered to exclude
exclude_server; change workloads to compute per-candidate by iterating over the
filtered servers list (e.g., workloads =
[self._server_state[server].num_active_requests() for server in servers]) so
each workload entry corresponds to the same index in servers used later for
matches/score; ensure lengths of workloads, matches, and scores all match
servers before selecting winner (symbols: servers, workloads,
self._server_state, matched_tokens, _tokens_per_block, _max_batch_size,
_rr_counter).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a9fbc542-4c3b-4135-a8e2-9922e1962418
📒 Files selected for processing (5)
tensorrt_llm/serve/openai_disagg_server.pytensorrt_llm/serve/openai_disagg_service.pytensorrt_llm/serve/openai_protocol.pytensorrt_llm/serve/router.pytests/unittest/disaggregated/test_router.py
- Remove synthetic DisaggregatedParams creation in _extract_conversation_id to prevent context_only label leaking to workers on the need_ctx=False path - Fix max_sessions default from 10000 to 100000 as documented - Override _get_server_load in ConversationRouter to use content weight so _select_least_loaded balances by estimated tokens, not request count - Fix test_conversation_router_hash_skip_count: increase shared prefix to 2000 chars so match ratio (0.79) exceeds the 0.75 threshold Signed-off-by: Lizhi Zhou <lizhiz@nvidia.com> Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
…rect The service layer always rebuilds disaggregated_params in _get_ctx_request / _get_gen_request before forwarding to workers, so the synthetic params never leak downstream. Restored original logic that creates a minimal DisaggregatedParams to carry the header-derived conversation_id. Signed-off-by: Lizhi Zhou <lizhiz@nvidia.com> Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
When _extract_conversation_id injects a synthetic DisaggregatedParams(request_type="context_only") from the X-Correlation-ID header, and the conditional disagg path decides need_ctx=False, strip the synthetic params before forwarding to the gen worker. This prevents a stale context_only label from reaching the worker on the full-generation path. Signed-off-by: Lizhi Zhou <lizhiz@nvidia.com> Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #42086 [ run ] triggered by Bot. Commit: |
|
PR_Github #42086 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42256 [ run ] triggered by Bot. Commit: |
|
PR_Github #42256 [ run ] completed with state
|
The LoadBalancingMixin's _select_least_loaded uses round-robin tie-breaking instead of the old heap's alphabetical ordering. Update test_request_balancing_router, test_tokens_balancing_router, and test_kv_cache_aware_router to assert on load-balancing invariants (uniquely-least-loaded selection, set membership, cache-hit routing) rather than hardcoded server names that depend on tie-breaking order. Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Use dict(zip(all_servers, matches)) to look up match counts by server name directly, eliminating the positional idx_of indirection. Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #42458 [ run ] triggered by Bot. Commit: |
|
PR_Github #42458 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #42613 [ run ] triggered by Bot. Commit: |
|
PR_Github #42613 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42908 [ run ] triggered by Bot. Commit: |
|
PR_Github #42908 [ run ] completed with state |
Summary
Add a ConversationRouter for multi-turn session affinity in disaggregated serving, and refactor shared logic out of existing router classes, similar to vllm's ConsistentHashPolicy.
Refactoring
New: ConversationRouter
Routes requests with the same conversation_id to the same server for KV cache reuse across multi-turn conversations.
Protocol
Test plan
Summary by CodeRabbit
Release Notes
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.