diff --git a/tests/unittest/others/test_kv_cache_transceiver.py b/tests/unittest/others/test_kv_cache_transceiver.py index cf9d544a9ce0..fb09f46636d2 100644 --- a/tests/unittest/others/test_kv_cache_transceiver.py +++ b/tests/unittest/others/test_kv_cache_transceiver.py @@ -1,3 +1,4 @@ +import threading import time import uuid @@ -229,6 +230,176 @@ def test_cancel_request_in_transmission(attention_type): assert gen_request.state == LlmRequestState.DISAGG_TRANS_ERROR +@pytest.mark.timeout(120) +@pytest.mark.parametrize("attention_type", + [AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA], + ids=["mha", "mla"]) +def test_check_gen_transfer_status_at_least_one_does_not_block_on_unready_future( + attention_type): + """Reproduce the gen-side blocking hang in checkGenTransferStatus(1). + + On stock ``rc11`` the polling path called from the PyExecutor disagg + loop unconditionally ``future.get()``s the first selected requester + future, even when its ``wait_for(0)`` is still ``timeout``. A single + in-flight generation request whose context-side ready signal has not + yet arrived therefore blocks the entire decoder event loop, which is + indistinguishable from a wedge. + + The test exercises the same shape as the wedge: drive one full + ctx/gen handshake to completion to capture an opaque comm/cache + state, then enqueue a generation request whose context counterpart + has not yet been ``respond_and_send_async()``-ed and call + ``check_gen_transfer_status(1)`` from a separate thread. The call + must return promptly (we use a 1-second probe timeout) instead of + blocking on the unresolved future. + """ + tensorrt_llm.logger.set_level("info") + mapping = Mapping(world_size=1, rank=0) + dist = Distributed.get(mapping) + kv_cache_manager_ctx = create_kv_cache_manager(mapping, DataType.HALF) + kv_cache_manager_gen = create_kv_cache_manager(mapping, DataType.HALF) + + cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT", + max_tokens_in_buffer=512) + + kv_cache_transceiver_ctx = create_kv_cache_transceiver( + mapping, dist, kv_cache_manager_ctx, attention_type, + cache_transceiver_config) + kv_cache_transceiver_gen = create_kv_cache_transceiver( + mapping, dist, kv_cache_manager_gen, attention_type, + cache_transceiver_config) + + fill_kv_cache_buffer(kv_cache_manager_ctx) + sampling_params = SamplingParams() + + def make_request(request_id, llm_request_type, context_phase_params=None): + kwargs = dict( + request_id=request_id, + max_new_tokens=1, + input_tokens=list(range(256)), + sampling_config=tensorrt_llm.bindings.SamplingConfig( + sampling_params._get_sampling_config()), + is_streaming=False, + llm_request_type=llm_request_type, + ) + if context_phase_params is not None: + kwargs["context_phase_params"] = context_phase_params + return LlmRequest(**kwargs) + + def add_sequence(kv_cache_manager, request): + kv_cache_manager.impl.add_sequence(request.py_request_id, + request.prompt_len, 1, request) + + # Complete one normal transfer first so we can reuse its opaque comm/cache + # state for the second (intentionally unresolved) generation request. + template_ctx_request = make_request( + 100, LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY) + add_sequence(kv_cache_manager_ctx, template_ctx_request) + kv_cache_transceiver_ctx.respond_and_send_async(template_ctx_request) + + template_gen_request = make_request( + 100, LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY, + template_ctx_request.context_phase_params) + add_sequence(kv_cache_manager_gen, template_gen_request) + kv_cache_transceiver_gen.request_and_receive_async(template_gen_request) + kv_cache_transceiver_ctx.check_context_transfer_status(1) + kv_cache_transceiver_gen.check_gen_transfer_status(1) + + opaque_state = template_ctx_request.context_phase_params.opaque_state + assert opaque_state is not None + + kv_cache_manager_ctx.free_resources(template_ctx_request) + kv_cache_manager_gen.free_resources(template_gen_request) + + # Build a generation request for a different ctx_request_id before the + # sender has any matching ready response. This leaves a real unresolved + # future in the C++ transceiver and reproduces the blocking pattern + # behind signature #4. + blocked_request_id = 101 + blocked_ctx_request = make_request( + blocked_request_id, LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY) + add_sequence(kv_cache_manager_ctx, blocked_ctx_request) + + blocked_context_phase_params = trtllm.ContextPhaseParams( + list(template_ctx_request.context_phase_params.first_gen_tokens), + blocked_request_id, + bytes(opaque_state), + template_ctx_request.context_phase_params.draft_tokens, + template_ctx_request.context_phase_params.ctx_dp_rank, + template_ctx_request.context_phase_params.disagg_info_endpoint, + ) + blocked_gen_request = make_request( + blocked_request_id, LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY, + blocked_context_phase_params) + add_sequence(kv_cache_manager_gen, blocked_gen_request) + kv_cache_transceiver_gen.request_and_receive_async(blocked_gen_request) + + # Sanity: at_least_request_num=0 must not block under any circumstance. + start = time.time() + kv_cache_transceiver_gen.check_gen_transfer_status(0) + assert time.time() - start < 1.0, ( + "check_gen_transfer_status(0) must be non-blocking even when futures " + "are unresolved") + + # Real reproducer: at_least_request_num=1 must NOT hang on an unresolved + # future. Run it on a worker thread so that pre-fix this thread stays + # blocked past the 1-second probe timeout (the failure signature), and + # post-fix it returns immediately because the unready future is skipped. + check_result = {"returned": False, "error": None} + probe_timeout_s = 1.0 + + def call_blocking_check(): + try: + kv_cache_transceiver_gen.check_gen_transfer_status(1) + except BaseException as exc: # noqa: BLE001 + check_result["error"] = exc + finally: + check_result["returned"] = True + + blocked_check = threading.Thread(target=call_blocking_check, daemon=True) + blocked_check.start() + blocked_check.join(timeout=probe_timeout_s) + blocked_during_probe = blocked_check.is_alive() + + # Allow the wedged context request to complete cleanly so the worker + # thread can finish (in either pre- or post-fix behaviour) and we can + # tear down the test without leaking threads. + kv_cache_transceiver_ctx.respond_and_send_async(blocked_ctx_request) + + deadline = time.time() + 10 + completed_ids, error_ids = [], [] + while time.time( + ) < deadline and blocked_ctx_request.py_request_id not in completed_ids: + completed_ids, error_ids = ( + kv_cache_transceiver_ctx.check_context_transfer_status(1)) + assert blocked_ctx_request.py_request_id not in error_ids + if blocked_ctx_request.py_request_id in completed_ids: + break + time.sleep(0.1) + assert blocked_ctx_request.py_request_id in completed_ids + + if blocked_during_probe: + blocked_check.join(timeout=10) + assert not blocked_check.is_alive() + else: + deadline = time.time() + 10 + while time.time() < deadline and ( + not kv_cache_transceiver_gen.check_gen_transfer_complete()): + kv_cache_transceiver_gen.check_gen_transfer_status(1) + time.sleep(0.1) + if check_result["error"] is not None: + raise check_result["error"] + assert check_result["returned"] + assert kv_cache_transceiver_gen.check_gen_transfer_complete() + assert not blocked_during_probe, ( + "signature #4 reproduced: check_gen_transfer_status(1) blocked on an " + "unresolved generation future") + + assert torch.equal( + kv_cache_manager_gen.get_buffers(0), + kv_cache_manager_ctx.get_buffers(0)), "different kv-cache values" + + def create_hybrid_cache_manager(mapping, dtype, mamba_conv_dtype=torch.float16,