Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions tests/unittest/others/test_kv_cache_transceiver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
import time
import uuid

Expand Down Expand Up @@ -229,6 +230,176 @@ def test_cancel_request_in_transmission(attention_type):
assert gen_request.state == LlmRequestState.DISAGG_TRANS_ERROR


@pytest.mark.timeout(120)
@pytest.mark.parametrize("attention_type",
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
ids=["mha", "mla"])
def test_check_gen_transfer_status_at_least_one_does_not_block_on_unready_future(
attention_type):
"""Reproduce the gen-side blocking hang in checkGenTransferStatus(1).

On stock ``rc11`` the polling path called from the PyExecutor disagg
loop unconditionally ``future.get()``s the first selected requester
future, even when its ``wait_for(0)`` is still ``timeout``. A single
in-flight generation request whose context-side ready signal has not
yet arrived therefore blocks the entire decoder event loop, which is
indistinguishable from a wedge.

The test exercises the same shape as the wedge: drive one full
ctx/gen handshake to completion to capture an opaque comm/cache
state, then enqueue a generation request whose context counterpart
has not yet been ``respond_and_send_async()``-ed and call
``check_gen_transfer_status(1)`` from a separate thread. The call
must return promptly (we use a 1-second probe timeout) instead of
blocking on the unresolved future.
"""
tensorrt_llm.logger.set_level("info")
mapping = Mapping(world_size=1, rank=0)
dist = Distributed.get(mapping)
kv_cache_manager_ctx = create_kv_cache_manager(mapping, DataType.HALF)
kv_cache_manager_gen = create_kv_cache_manager(mapping, DataType.HALF)

cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT",
max_tokens_in_buffer=512)

kv_cache_transceiver_ctx = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_ctx, attention_type,
cache_transceiver_config)
kv_cache_transceiver_gen = create_kv_cache_transceiver(
mapping, dist, kv_cache_manager_gen, attention_type,
cache_transceiver_config)

fill_kv_cache_buffer(kv_cache_manager_ctx)
sampling_params = SamplingParams()

def make_request(request_id, llm_request_type, context_phase_params=None):
kwargs = dict(
request_id=request_id,
max_new_tokens=1,
input_tokens=list(range(256)),
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False,
llm_request_type=llm_request_type,
)
if context_phase_params is not None:
kwargs["context_phase_params"] = context_phase_params
return LlmRequest(**kwargs)

def add_sequence(kv_cache_manager, request):
kv_cache_manager.impl.add_sequence(request.py_request_id,
request.prompt_len, 1, request)

# Complete one normal transfer first so we can reuse its opaque comm/cache
# state for the second (intentionally unresolved) generation request.
template_ctx_request = make_request(
100, LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY)
add_sequence(kv_cache_manager_ctx, template_ctx_request)
kv_cache_transceiver_ctx.respond_and_send_async(template_ctx_request)

template_gen_request = make_request(
100, LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY,
template_ctx_request.context_phase_params)
add_sequence(kv_cache_manager_gen, template_gen_request)
kv_cache_transceiver_gen.request_and_receive_async(template_gen_request)
kv_cache_transceiver_ctx.check_context_transfer_status(1)
kv_cache_transceiver_gen.check_gen_transfer_status(1)

opaque_state = template_ctx_request.context_phase_params.opaque_state
assert opaque_state is not None

kv_cache_manager_ctx.free_resources(template_ctx_request)
kv_cache_manager_gen.free_resources(template_gen_request)

# Build a generation request for a different ctx_request_id before the
# sender has any matching ready response. This leaves a real unresolved
# future in the C++ transceiver and reproduces the blocking pattern
# behind signature #4.
blocked_request_id = 101
blocked_ctx_request = make_request(
blocked_request_id, LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY)
add_sequence(kv_cache_manager_ctx, blocked_ctx_request)

blocked_context_phase_params = trtllm.ContextPhaseParams(
list(template_ctx_request.context_phase_params.first_gen_tokens),
blocked_request_id,
bytes(opaque_state),
template_ctx_request.context_phase_params.draft_tokens,
template_ctx_request.context_phase_params.ctx_dp_rank,
template_ctx_request.context_phase_params.disagg_info_endpoint,
)
blocked_gen_request = make_request(
blocked_request_id, LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY,
blocked_context_phase_params)
add_sequence(kv_cache_manager_gen, blocked_gen_request)
kv_cache_transceiver_gen.request_and_receive_async(blocked_gen_request)

# Sanity: at_least_request_num=0 must not block under any circumstance.
start = time.time()
kv_cache_transceiver_gen.check_gen_transfer_status(0)
assert time.time() - start < 1.0, (
"check_gen_transfer_status(0) must be non-blocking even when futures "
"are unresolved")

# Real reproducer: at_least_request_num=1 must NOT hang on an unresolved
# future. Run it on a worker thread so that pre-fix this thread stays
# blocked past the 1-second probe timeout (the failure signature), and
# post-fix it returns immediately because the unready future is skipped.
check_result = {"returned": False, "error": None}
probe_timeout_s = 1.0

def call_blocking_check():
try:
kv_cache_transceiver_gen.check_gen_transfer_status(1)
except BaseException as exc: # noqa: BLE001
check_result["error"] = exc
finally:
check_result["returned"] = True

blocked_check = threading.Thread(target=call_blocking_check, daemon=True)
blocked_check.start()
blocked_check.join(timeout=probe_timeout_s)
blocked_during_probe = blocked_check.is_alive()

# Allow the wedged context request to complete cleanly so the worker
# thread can finish (in either pre- or post-fix behaviour) and we can
# tear down the test without leaking threads.
kv_cache_transceiver_ctx.respond_and_send_async(blocked_ctx_request)

deadline = time.time() + 10
completed_ids, error_ids = [], []
while time.time(
) < deadline and blocked_ctx_request.py_request_id not in completed_ids:
completed_ids, error_ids = (
kv_cache_transceiver_ctx.check_context_transfer_status(1))
assert blocked_ctx_request.py_request_id not in error_ids
if blocked_ctx_request.py_request_id in completed_ids:
break
time.sleep(0.1)
assert blocked_ctx_request.py_request_id in completed_ids

if blocked_during_probe:
blocked_check.join(timeout=10)
assert not blocked_check.is_alive()
else:
deadline = time.time() + 10
while time.time() < deadline and (
not kv_cache_transceiver_gen.check_gen_transfer_complete()):
kv_cache_transceiver_gen.check_gen_transfer_status(1)
time.sleep(0.1)
if check_result["error"] is not None:
raise check_result["error"]
assert check_result["returned"]
assert kv_cache_transceiver_gen.check_gen_transfer_complete()
assert not blocked_during_probe, (
"signature #4 reproduced: check_gen_transfer_status(1) blocked on an "
"unresolved generation future")

assert torch.equal(
kv_cache_manager_gen.get_buffers(0),
kv_cache_manager_ctx.get_buffers(0)), "different kv-cache values"


def create_hybrid_cache_manager(mapping,
dtype,
mamba_conv_dtype=torch.float16,
Expand Down
Loading