Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,10 @@ class CacheSender::Impl
// not be removed from mCancelledRequests. This should be handled by timeout.
auto it = mReadyResponses.find(mCurrentRequest.value());
TLLM_CHECK(it != mReadyResponses.end());
auto cancelledException
= TLLM_REQUEST_EXCEPTION(reqId, tensorrt_llm::common::RequestErrorCode::kNETWORK_ERROR,
"Context KV cache transfer cancelled after ready-signal for request %zu", reqId);
it->second.mPromise.set_exception(std::make_exception_ptr(cancelledException));
{
std::scoped_lock lkResp(mSenderMutex);
mReadyResponses.erase(it);
Expand Down
83 changes: 83 additions & 0 deletions tests/unittest/others/test_kv_cache_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,89 @@ def test_cancel_request_in_transmission(attention_type):
assert gen_request.state == LlmRequestState.DISAGG_TRANS_ERROR


@pytest.mark.timeout(120)
@pytest.mark.parametrize("attention_type",
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
ids=["mha", "mla"])
def test_cancel_request_in_transmission_does_not_break_sender_future(
attention_type, capfd):
tensorrt_llm.logger.set_level("info")
mapping = Mapping(world_size=1, rank=0)
dist = Distributed.get(mapping)
ctx_kv_cache_dtype, gen_kv_cache_dtype = DataType.HALF, DataType.HALF
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)

cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT",
max_tokens_in_buffer=512)

kv_cache_transceiver_ctx = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_ctx, attention_type,
cache_transceiver_config)

kv_cache_transceiver_gen = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_gen, attention_type,
cache_transceiver_config)

fill_kv_cache_buffer(kv_cache_manager_ctx)

sampling_params = SamplingParams()
ctx_request = LlmRequest(
request_id=0,
max_new_tokens=1,
input_tokens=list(range(256)),
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False,
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY)

kv_cache_manager_ctx.impl.add_sequence_batch(
[(ctx_request.py_request_id, ctx_request.prompt_len, 1)], [ctx_request])
kv_cache_transceiver_ctx.respond_and_send_async(ctx_request)

time.sleep(2)
is_cancelled = kv_cache_transceiver_ctx.cancel_request(ctx_request)
assert is_cancelled

gen_request = LlmRequest(
request_id=0,
max_new_tokens=1,
input_tokens=list(range(256)),
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False,
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY,
context_phase_params=ctx_request.context_phase_params)

kv_cache_manager_gen.impl.add_sequence_batch(
[(gen_request.py_request_id, gen_request.prompt_len, 1)], [gen_request])
kv_cache_transceiver_gen.request_and_receive_async(gen_request)

completed_ids, error_ids = [], []
deadline = time.time() + 10
while time.time() < deadline and not error_ids:
completed_ids, error_ids = kv_cache_transceiver_ctx.check_context_transfer_status(
1)
if error_ids:
break
time.sleep(0.1)

assert ctx_request.py_request_id not in completed_ids
assert ctx_request.py_request_id in error_ids

deadline = time.time() + 10
while time.time(
) < deadline and gen_request.state != LlmRequestState.DISAGG_TRANS_ERROR:
kv_cache_transceiver_gen.check_gen_transfer_status(1)
time.sleep(0.1)

assert gen_request.state == LlmRequestState.DISAGG_TRANS_ERROR

captured = capfd.readouterr()
merged = captured.out + captured.err
assert "Broken promise" not in merged


def create_hybrid_cache_manager(mapping,
dtype,
mamba_conv_dtype=torch.float16,
Expand Down
Loading