Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -986,6 +986,8 @@ class CacheReceiver::Impl
}

bool isCancelled = false;
std::unique_ptr<std::promise<void>> queuedPromise;
LlmRequest::RequestIdType cancelledId{0};
auto& asyncResource = mInstanceToAsyncResource.at(processInfo);
{
std::unique_lock<std::mutex> lck(asyncResource->mMtxForQueue);
Expand All @@ -994,6 +996,8 @@ class CacheReceiver::Impl
{ return requestAndPromise.mRequest->mRequestId == llmRequest.mRequestId; });
if (it != asyncResource->mRequestsQueue.end())
{
cancelledId = it->mRequest->mRequestId;
queuedPromise = std::move(it->mPromise);
asyncResource->mRequestsQueue.erase(it);
isCancelled = true;
}
Expand All @@ -1002,6 +1006,31 @@ class CacheReceiver::Impl
TLLM_LOG_WARNING("Cannot cancel request %zu", llmRequest.mRequestId);
}
}
if (queuedPromise)
{
// The future returned by request_and_receive_async() is the only
// signal the disagg gen event loop has that this request is done.
// If we erase the queued (request, promise) pair here without
// first fulfilling the promise, the std::promise destructor sets
// a future_error: Broken promise on the future, which the polling
// loop in CacheTransceiver::checkGenTransferStatus() then surfaces
// as a generic exception with no actionable diagnostic. Mirror
// what the sender-side cancellation path in
// CacheSender::Impl::sendResponse() does: fulfil the promise with
// a structured kNETWORK_ERROR exception so the consumer sees a
// real cancellation instead of a broken promise.
try
{
auto cancelledException
= TLLM_REQUEST_EXCEPTION(cancelledId, tensorrt_llm::common::RequestErrorCode::kNETWORK_ERROR,
"Generation KV cache request cancelled before send for request %zu", cancelledId);
queuedPromise->set_exception(std::make_exception_ptr(cancelledException));
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("Failed to fulfill cancelled gen request promise %zu: %s", cancelledId, e.what());
}
}
return isCancelled;
}

Expand Down
149 changes: 149 additions & 0 deletions tests/unittest/others/test_kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,155 @@ def test_cancel_request_in_transmission(attention_type):
assert gen_request.state == LlmRequestState.DISAGG_TRANS_ERROR


@pytest.mark.timeout(120)
@pytest.mark.parametrize("attention_type",
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
ids=["mha", "mla"])
def test_cancel_queued_gen_request_fulfills_receiver_future(
attention_type, capfd):
"""Reproduce the receiver-side queued-cancel broken-promise.

Mirror of the sender-side reproducer
``test_cancel_request_in_transmission_fulfills_sender_future`` but on
``CacheReceiver::Impl::cancelRequest()``. Pre-fix that function
erases the queued ``(request, promise)`` pair without first
fulfilling the promise, so the ``std::promise`` destructor sets a
``future_error: Broken promise`` on the future returned by
``request_and_receive_async()``. The disagg gen polling loop in
``CacheTransceiver::checkGenTransferStatus()`` then surfaces the
cancellation as a generic exception with no actionable diagnostic
rather than as a structured ``kNETWORK_ERROR``.

To exercise the queued-cancel path the test holds the receiver
worker thread busy with a first generation request that has no
matching context counterpart (so it blocks inside
``sendRequestInfo()`` waiting for a ready signal that will never
arrive), then enqueues a second generation request that sits in the
queue, then cancels it. The post-fix
``CacheReceiver::Impl::cancelRequest()`` extracts the queued promise
and ``set_exception()``-s a structured ``kNETWORK_ERROR`` before
erasing the entry, so the polling loop sees a real cancellation and
no ``Broken promise`` ever appears on stderr.
"""
tensorrt_llm.logger.set_level("info")
mapping = Mapping(world_size=1, rank=0)
dist = Distributed.get(mapping)
kv_cache_manager_ctx = create_kv_cache_manager(mapping, DataType.HALF)
kv_cache_manager_gen = create_kv_cache_manager(mapping, DataType.HALF)

cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT",
max_tokens_in_buffer=512)

kv_cache_transceiver_ctx = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_ctx, attention_type,
cache_transceiver_config)
kv_cache_transceiver_gen = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_gen, attention_type,
cache_transceiver_config)

fill_kv_cache_buffer(kv_cache_manager_ctx)
sampling_params = SamplingParams()

def make_request(request_id, llm_request_type, context_phase_params=None):
kwargs = dict(
request_id=request_id,
max_new_tokens=1,
input_tokens=list(range(256)),
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False,
llm_request_type=llm_request_type,
)
if context_phase_params is not None:
kwargs["context_phase_params"] = context_phase_params
return LlmRequest(**kwargs)

def add_sequence(kv_cache_manager, request):
kv_cache_manager.impl.add_sequence(request.py_request_id,
request.prompt_len, 1, request)

# Drive one full ctx/gen handshake to completion so we can reuse a real
# opaque comm/cache state for the cancellation request below.
template_ctx_request = make_request(
100, LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY)
add_sequence(kv_cache_manager_ctx, template_ctx_request)
kv_cache_transceiver_ctx.respond_and_send_async(template_ctx_request)

template_gen_request = make_request(
100, LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY,
template_ctx_request.context_phase_params)
add_sequence(kv_cache_manager_gen, template_gen_request)
kv_cache_transceiver_gen.request_and_receive_async(template_gen_request)
kv_cache_transceiver_ctx.check_context_transfer_status(1)
kv_cache_transceiver_gen.check_gen_transfer_status(1)

opaque_state = template_ctx_request.context_phase_params.opaque_state
assert opaque_state is not None

kv_cache_manager_ctx.free_resources(template_ctx_request)
kv_cache_manager_gen.free_resources(template_gen_request)

def make_orphan_gen_request(request_id):
ctx_phase_params = trtllm.ContextPhaseParams(
list(template_ctx_request.context_phase_params.first_gen_tokens),
request_id,
bytes(opaque_state),
template_ctx_request.context_phase_params.draft_tokens,
template_ctx_request.context_phase_params.ctx_dp_rank,
template_ctx_request.context_phase_params.disagg_info_endpoint,
)
gen_request = make_request(
request_id, LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY,
ctx_phase_params)
add_sequence(kv_cache_manager_gen, gen_request)
return gen_request

# Submit a first orphan gen request whose context counterpart will never
# respond_and_send_async. The receiver worker dequeues it and parks
# inside requestSync() / sendRequestInfo(), tying up the worker thread.
blocking_gen_request = make_orphan_gen_request(101)
kv_cache_transceiver_gen.request_and_receive_async(blocking_gen_request)

# Submit a second orphan gen request. The first one is still in
# requestSync(), so this one stays in mRequestsQueue and is the actual
# subject of the queued-cancel reproducer.
queued_gen_request = make_orphan_gen_request(102)
kv_cache_transceiver_gen.request_and_receive_async(queued_gen_request)

# Wait briefly so the receiver worker has had time to dequeue the first
# request and block on it, leaving the second one queued.
time.sleep(1)

# Cancel the queued request. Pre-fix this erases the (request, promise)
# pair without fulfilling the promise; post-fix it set_exception()s a
# structured kNETWORK_ERROR before erasing.
is_cancelled = kv_cache_transceiver_gen.cancel_request(queued_gen_request)
assert is_cancelled, (
"queued_gen_request must still be in the receiver queue when we "
"call cancel_request(); if this fails, the receiver worker may "
"have dequeued faster than expected and the test setup needs to "
"be tightened")

# Poll the gen-side polling loop and assert the cancelled request lands
# in DISAGG_TRANS_ERROR within a reasonable window. Pre-fix this returns
# via a Broken-promise exception with no useful diagnostic; post-fix it
# returns via the structured kNETWORK_ERROR set by the fix.
deadline = time.time() + 10
while time.time() < deadline and (queued_gen_request.state
!= LlmRequestState.DISAGG_TRANS_ERROR):
kv_cache_transceiver_gen.check_gen_transfer_status(1)
time.sleep(0.1)

assert queued_gen_request.state == LlmRequestState.DISAGG_TRANS_ERROR

captured = capfd.readouterr()
merged = captured.out + captured.err
assert "Broken promise" not in merged, (
"signature #5 reproduced: cancelling a queued generation request "
"left its std::promise unresolved and the destructor surfaced as "
"Broken promise on the consumer side")


def create_hybrid_cache_manager(mapping,
dtype,
mamba_conv_dtype=torch.float16,
Expand Down
Loading