diff --git a/tests/cache_manager/test_cache_messager.py b/tests/cache_manager/test_cache_messager.py new file mode 100644 index 00000000000..3a75ee71447 --- /dev/null +++ b/tests/cache_manager/test_cache_messager.py @@ -0,0 +1,895 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import importlib.util +import math +import sys +import types +import unittest +from pathlib import Path +from unittest import mock + +import numpy as np + +PROJECT_ROOT = Path(__file__).resolve().parents[2] + + +def _ensure_module(name: str) -> types.ModuleType: + module = sys.modules.get(name) + if module is None: + module = types.ModuleType(name) + sys.modules[name] = module + return module + + +class _FakePlace: + def __init__(self, device: str): + self._device = device + + def __str__(self): # pragma: no cover - representation helper + return f"Place({self._device})" + + +class _FakeTensor: + def __init__(self, array, dtype="float32", device="gpu:0"): + self._array = np.array(array) + self.shape = tuple(self._array.shape) + self.dtype = dtype + self.place = _FakePlace(device) + + def data_ptr(self): + return int(self._array.__array_interface__["data"][0]) + + def numel(self): + return int(self._array.size) + + def numpy(self): + return self._array + + def tolist(self): # pragma: no cover - convenience helper + return self.numpy().tolist() + + def __len__(self): + return len(self._array) + + def __iter__(self): # pragma: no cover - container helper + return iter(self._array) + + def __getitem__(self, idx): + value = self._array[idx] + if isinstance(value, np.ndarray): + return _FakeTensor(value, dtype=self.dtype) + return _FakeScalar(value) + + def __setitem__(self, idx, value): + self._array[idx] = value + + +class _FakeScalar: + def __init__(self, value): + self._value = value.item() if hasattr(value, "item") else value + + def numpy(self): + return np.array(self._value) + + def tolist(self): # pragma: no cover - compatibility helper + return self.numpy().tolist() + + def __int__(self): + return int(self._value) + + def __index__(self): # pragma: no cover - required for range() + return int(self._value) + + def __eq__(self, other): # pragma: no cover - comparison helper + return int(self._value) == other + + +class ParseArgsTest(unittest.TestCase): + def test_parse_args_reads_cli_values(self): + module = _load_cache_messager() + argv = [ + "prog", + "--splitwise_role", + "decode", + "--rank", + "3", + "--device_id", + "5", + "--num_layers", + "4", + "--key_cache_shape", + "2,3,4,5", + "--value_cache_shape", + "2,3,4,5", + "--rdma_port", + "1234", + "--mp_num", + "2", + "--engine_pid", + "abc", + "--protocol", + "ipc,rdma", + "--pod_ip", + "1.2.3.4", + "--cache_queue_port", + "9100", + "--engine_worker_queue_port", + "9101", + "--cache_dtype", + "uint8", + "--speculative_config", + '{"num_extra_cache_layer":1}', + "--local_data_parallel_id", + "7", + ] + with mock.patch.object(sys, "argv", argv): + args = module.parse_args() + + self.assertEqual(args.splitwise_role, "decode") + self.assertEqual(args.rank, 3) + self.assertEqual(args.device_id, 5) + self.assertEqual(args.num_layers, 4) + self.assertEqual(args.protocol, "ipc,rdma") + self.assertEqual(args.cache_dtype, "uint8") + self.assertEqual(args.local_data_parallel_id, 7) + self.assertEqual(args.speculative_config["num_extra_cache_layer"], 1) + + +class _Barrier: + def __init__(self): + self.wait_calls = 0 + + def wait(self): + self.wait_calls += 1 + + +class _IPCCommManager: + def __init__(self, rank, gpu_id, cache_k, cache_v): # pylint: disable=unused-argument + self.rank = rank + self.gpu_id = gpu_id + self.cache_k = cache_k + self.cache_v = cache_v + self.write_calls = [] + self.sync_targets = [] + + def write_cache(self, target_ip, target_id, src_block_ids, dest_block_ids, layer_idx): + self.write_calls.append((target_ip, target_id, tuple(src_block_ids), tuple(dest_block_ids), layer_idx)) + return 0 + + def write_block_by_sync(self, target_id): + self.sync_targets.append(target_id) + + +class _RDMACommManager: + def __init__( + self, + splitwise_role, + rank, + gpu_id, + cache_k_ptr_list, + cache_v_ptr_list, + max_block_num, + block_bytes, + rdma_port, + ): # pylint: disable=unused-argument + self.rank = rank + self.calls = [] + self.connect_results = [] + + def connect(self, target_ip, target_id): + result = True if not self.connect_results else self.connect_results.pop(0) + self.calls.append((target_ip, target_id, result)) + return result + + def write_cache(self, *args, **kwargs): # pragma: no cover - compatibility helper + return 0 + + +class _IPCSignal: + instances: dict[str, "_IPCSignal"] = {} + + def __init__(self, name, array, dtype=None, suffix=None, create=False): # noqa: D401 + # pylint: disable=unused-argument + self.name = name + self.value = np.array(array) + _IPCSignal.instances[name if suffix is None else f"{name}_{suffix}"] = self + + +class _EngineWorkerQueue: + def __init__( + self, + address, + is_server, + num_client, + client_id, + local_data_parallel_id, + ): + self.address = address + self.is_server = is_server + self.num_client = num_client + self.client_id = client_id + self.local_data_parallel_id = local_data_parallel_id + self.cache_info_barrier = _Barrier() + self.finish_send_cache_barrier = _Barrier() + self.finish_add_cache_task_barrier = _Barrier() + self.begin_send_cache_barrier = _Barrier() + self.connect_task_barrier = _Barrier() + self.connect_task_response_barrier = _Barrier() + self.cache_info_sequence = [] + self.cache_info_calls = 0 + self.stop_after_cache_info = False + self.signal_initializer = None + self.connect_tasks = [] + self.connect_task_calls = 0 + self.stop_after_connect_tasks = False + self.finished_requests = [] + self.connect_responses = [] + self.finished_add_cache_task_req = [] + + def get_cache_info(self): + if self.cache_info_calls == 0 and self.signal_initializer: + self.signal_initializer() + if self.cache_info_calls < len(self.cache_info_sequence): + info = self.cache_info_sequence[self.cache_info_calls] + self.cache_info_calls += 1 + return info + if self.stop_after_cache_info: + raise SystemExit("stop cache info") + return [] + + def put_finished_req(self, request_payload): + self.finished_requests.append(request_payload) + + def put_finished_add_cache_task_req(self, req_ids): + self.finished_add_cache_task_req.append(req_ids) + + def get_connect_rdma_task(self): + if self.connect_task_calls < len(self.connect_tasks): + task = self.connect_tasks[self.connect_task_calls] + self.connect_task_calls += 1 + return task, None + if self.stop_after_connect_tasks: + raise SystemExit("stop connect task") + return None, None + + def put_connect_rdma_task_response(self, response): + self.connect_responses.append(response) + + +def _install_paddle_stubs(): + paddle = _ensure_module("paddle") + paddle.Tensor = _FakeTensor + paddle.bfloat16 = "bfloat16" + + def _full(shape, fill_value=0, dtype="float32"): + dtype_str = dtype if isinstance(dtype, str) else str(dtype) + return _FakeTensor(np.full(shape, fill_value), dtype=dtype_str) + + def _to_tensor(data, dtype="float32", place=None): # pylint: disable=unused-argument + dtype_str = dtype if isinstance(dtype, str) else str(dtype) + return _FakeTensor(np.array(data), dtype=dtype_str) + + paddle.full = _full + paddle.to_tensor = _to_tensor + paddle.set_device = lambda _name: None + + device_mod = types.ModuleType("paddle.device") + device_mod.set_device = lambda _name: None + cuda_mod = types.ModuleType("paddle.device.cuda") + cuda_mod.memory_allocated = lambda: 0 + device_mod.cuda = cuda_mod + paddle.device = device_mod + sys.modules["paddle.device"] = device_mod + sys.modules["paddle.device.cuda"] = cuda_mod + + +def _install_fastdeploy_core_stubs(): + fastdeploy_pkg = _ensure_module("fastdeploy") + fastdeploy_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy")] + + utils_module = types.ModuleType("fastdeploy.utils") + envs_module = types.ModuleType("fastdeploy.utils.envs") + envs_module.FD_ENGINE_TASK_QUEUE_WITH_SHM = False + envs_module.ENABLE_V1_KVCACHE_SCHEDULER = False + + class _Logger: + def __init__(self): + self.messages = {"info": [], "debug": [], "error": []} + + def info(self, msg): + self.messages["info"].append(msg) + + def debug(self, msg): + self.messages["debug"].append(msg) + + def error(self, msg): + self.messages["error"].append(msg) + + def _get_logger(_name, _filename=None): # pylint: disable=unused-argument + return _Logger() + + utils_module.envs = envs_module + utils_module.get_logger = _get_logger + sys.modules["fastdeploy.utils"] = utils_module + sys.modules["fastdeploy.utils.envs"] = envs_module + fastdeploy_pkg.utils = utils_module + + +def _install_transfer_factory_stubs(): + transfer_factory = types.ModuleType("fastdeploy.cache_manager.transfer_factory") + transfer_factory.IPCCommManager = _IPCCommManager + transfer_factory.RDMACommManager = _RDMACommManager + sys.modules["fastdeploy.cache_manager.transfer_factory"] = transfer_factory + + +def _install_config_stubs(): + fastdeploy_pkg = _ensure_module("fastdeploy") + config_module = types.ModuleType("fastdeploy.config") + + class _SpeculativeConfig: + def __init__(self, config_dict): + self.num_extra_cache_layer = config_dict.get("num_extra_cache_layer", 0) + self.num_gpu_block_expand_ratio = config_dict.get("num_gpu_block_expand_ratio", 0) + + config_module.SpeculativeConfig = _SpeculativeConfig + sys.modules["fastdeploy.config"] = config_module + fastdeploy_pkg.config = config_module + + +def _install_inter_comm_stubs(): + inter_comm_module = types.ModuleType("fastdeploy.inter_communicator") + inter_comm_module.EngineWorkerQueue = _EngineWorkerQueue + inter_comm_module.IPCSignal = _IPCSignal + inter_comm_module.shared_memory_exists = lambda _name: False + sys.modules["fastdeploy.inter_communicator"] = inter_comm_module + + +def _install_ops_gpu_stubs(): + ops_gpu_module = types.ModuleType("fastdeploy.model_executor.ops.gpu") + + def _get_output_kv_signal(buffer, rank_id, flag): # pylint: disable=unused-argument + sequence = getattr(_get_output_kv_signal, "sequence", None) + if not sequence: + raise SystemExit("kv signal stop") + + step = sequence.pop(0) + if step.get("stop"): + raise SystemExit("kv signal stop") + + data = buffer.numpy() + data.fill(-1) + tasks = step.get("tasks", -1) + data[0] = tasks + if tasks == -1: + return + data[1] = step.get("layer", 0) + data[2] = step.get("engine", 0) + data[3] = step.get("offset", 0) + data[4] = step.get("current", 0) + + ops_gpu_module.get_output_kv_signal = _get_output_kv_signal + ops_gpu_module.set_data_ipc = lambda *args, **kwargs: None + sys.modules["fastdeploy.model_executor.ops.gpu"] = ops_gpu_module + + +def _install_dependency_stubs(): + _install_paddle_stubs() + _install_fastdeploy_core_stubs() + _install_transfer_factory_stubs() + _install_config_stubs() + _install_inter_comm_stubs() + _install_ops_gpu_stubs() + + +def _load_cache_messager(): + module_name = "fastdeploy.cache_manager.cache_messager" + if module_name in sys.modules: + return sys.modules[module_name] + + _install_dependency_stubs() + + spec = importlib.util.spec_from_file_location( + module_name, PROJECT_ROOT / "fastdeploy" / "cache_manager" / "cache_messager.py" + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + sys.modules[module_name] = module + return module + + +def _make_cache_tensors(num_layers, dtype="bfloat16"): + cache = {} + for layer in range(num_layers): + cache[f"key_caches_{layer}_rank0_device0"] = _FakeTensor(np.zeros((2, 3, 4, 5)), dtype=dtype) + cache[f"value_caches_{layer}_rank0_device0"] = _FakeTensor(np.zeros((2, 3, 4, 5)), dtype=dtype) + return cache + + +class CacheMessagerInitTest(unittest.TestCase): + def setUp(self): + self.module = _load_cache_messager() + envs = sys.modules["fastdeploy.utils.envs"] + envs.FD_ENGINE_TASK_QUEUE_WITH_SHM = False + _IPCSignal.instances.clear() + ops_gpu = sys.modules["fastdeploy.model_executor.ops.gpu"] + ops_gpu.get_output_kv_signal.sequence = [] + + def tearDown(self): + _IPCSignal.instances.clear() + + def test_initializes_with_ipc_and_rdma(self): + cache = _make_cache_tensors(num_layers=2) + messager = self.module.CacheMessager( + splitwise_role="mixed", + transfer_protocol="ipc,rdma", + pod_ip="127.0.0.1", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=cache, + rank=0, + nranks=1, + num_layers=2, + gpu_id=0, + rdma_port=55, + ) + + self.assertIsInstance(messager.engine_worker_queue, _EngineWorkerQueue) + self.assertEqual(messager.engine_worker_queue.address, ("127.0.0.1", 9000)) + self.assertIn("ipc", messager.messager) + self.assertIn("rdma", messager.messager) + expected_block_bytes = math.prod(cache["key_caches_0_rank0_device0"].shape[1:]) * 2 + self.assertEqual(messager.block_bytes, expected_block_bytes) + + def test_shm_socket_address_and_uint8_dtype(self): + envs = sys.modules["fastdeploy.utils.envs"] + envs.FD_ENGINE_TASK_QUEUE_WITH_SHM = True + cache = _make_cache_tensors(num_layers=1, dtype="uint8") + messager = self.module.CacheMessager( + splitwise_role="mixed", + transfer_protocol="ipc", + pod_ip="127.0.0.1", + engine_worker_queue_port=9010, + local_data_parallel_id=0, + gpu_cache_kvs=cache, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + ) + + self.assertTrue(str(messager.engine_worker_queue.address).startswith("/dev/shm/fd_task_queue_")) + expected_block_bytes = math.prod(cache["key_caches_0_rank0_device0"].shape[1:]) + self.assertEqual(messager.block_bytes, expected_block_bytes) + + +class PrefillThreadTest(unittest.TestCase): + def setUp(self): + self.module = _load_cache_messager() + envs = sys.modules["fastdeploy.utils.envs"] + envs.FD_ENGINE_TASK_QUEUE_WITH_SHM = False + _IPCSignal.instances.clear() + ops_gpu = sys.modules["fastdeploy.model_executor.ops.gpu"] + ops_gpu.get_output_kv_signal.sequence = [] + + def tearDown(self): + _IPCSignal.instances.clear() + + def test_prefill_thread_transfers_and_marks_finished(self): + cache = _make_cache_tensors(num_layers=1) + messager = self.module.CacheMessager( + splitwise_role="mixed", + transfer_protocol="ipc", + pod_ip="127.0.0.1", + engine_worker_queue_port=9001, + local_data_parallel_id=0, + gpu_cache_kvs=cache, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + ) + + queue = messager.engine_worker_queue + queue.cache_info_sequence = [ + [ + { + "request_id": "req-1", + "transfer_protocol": "ipc", + "src_block_ids": [0, 1], + "dest_block_ids": [2, 3], + "current_id": 0, + "status": "init", + "layer_idx": 0, + "device_ids": {0: 0}, + } + ] + ] + queue.stop_after_cache_info = True + + def _set_signals(instance): + step_key = f"splitwise_complete_prefilled_step_{instance.rank_id}_{instance.gpu_id}" + layer_key = f"splitwise_complete_prefilled_layer_{instance.rank_id}_{instance.gpu_id}" + _IPCSignal.instances[step_key].value[0] = 0 + _IPCSignal.instances[layer_key].value[0] = 0 + + queue.signal_initializer = lambda: _set_signals(messager) + + with self.assertRaises(SystemExit): + messager.prefill_layerwise_send_cache_thread() + + self.assertEqual(queue.finish_send_cache_barrier.wait_calls, 1) + self.assertEqual(queue.finished_requests, [[["req-1", "finished"]]]) + self.assertEqual( + messager.messager["ipc"].write_calls, + [("0.0.0.0", 0, (0, 1), (2, 3), 0)], + ) + + +class HandleConnectTaskTest(unittest.TestCase): + def setUp(self): + self.module = _load_cache_messager() + envs = sys.modules["fastdeploy.utils.envs"] + envs.FD_ENGINE_TASK_QUEUE_WITH_SHM = False + _IPCSignal.instances.clear() + ops_gpu = sys.modules["fastdeploy.model_executor.ops.gpu"] + ops_gpu.get_output_kv_signal.sequence = [] + + def tearDown(self): + _IPCSignal.instances.clear() + + def test_handle_connect_task_success_and_failure(self): + cache = _make_cache_tensors(num_layers=1) + messager = self.module.CacheMessager( + splitwise_role="decode", + transfer_protocol="rdma", + pod_ip="127.0.0.1", + engine_worker_queue_port=9002, + local_data_parallel_id=0, + gpu_cache_kvs=cache, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + rdma_port=88, + ) + + rdma_manager = messager.messager["rdma"] + rdma_manager.connect_results = [True, False] + + queue = messager.engine_worker_queue + queue.connect_tasks = [ + { + "task_id": 1, + "ip": "10.0.0.1", + "rdma_ports": {0: 7}, + }, + { + "task_id": 2, + "ip": "10.0.0.2", + "rdma_ports": {0: 9}, + }, + ] + queue.stop_after_connect_tasks = True + + with self.assertRaises(SystemExit): + messager._handle_connect_task() + + self.assertEqual( + queue.connect_responses, + [ + {"task_id": 1, "success": True}, + {"task_id": 2, "success": False}, + ], + ) + + +class CacheMessagerV1Test(unittest.TestCase): + def setUp(self): + self.module = _load_cache_messager() + envs = sys.modules["fastdeploy.utils.envs"] + envs.FD_ENGINE_TASK_QUEUE_WITH_SHM = False + _IPCSignal.instances.clear() + ops_gpu = sys.modules["fastdeploy.model_executor.ops.gpu"] + ops_gpu.get_output_kv_signal.sequence = [] + + def tearDown(self): + _IPCSignal.instances.clear() + + def test_consume_signals_populates_queue(self): + cache = _make_cache_tensors(num_layers=1) + envs = sys.modules["fastdeploy.utils.envs"] + envs.ENABLE_V1_KVCACHE_SCHEDULER = True + + with mock.patch("threading.Thread") as thread_cls: + + def _fake_thread(*_args, **_kwargs): + return types.SimpleNamespace(start=lambda: None) + + thread_cls.side_effect = _fake_thread + messager = self.module.CacheMessagerV1( + splitwise_role="prefill", + transfer_protocol="ipc", + pod_ip="127.0.0.1", + engine_worker_queue_port=9003, + local_data_parallel_id=0, + gpu_cache_kvs=cache, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + ) + + ops_gpu = sys.modules["fastdeploy.model_executor.ops.gpu"] + ops_gpu.get_output_kv_signal.sequence = [ + {"tasks": -1}, + {"tasks": 1, "layer": 0, "engine": 0, "offset": 0, "current": 4}, + {"stop": True}, + ] + messager.cache_info = {"req": {"status": "init"}} + + with self.assertRaises(SystemExit): + messager.consume_signals() + + queued = messager.cache_prefilled_engine_ids_queue.get_nowait() + self.assertEqual(queued, [(0, 4)]) + + def test_add_cache_task_thread_updates_state(self): + cache = _make_cache_tensors(num_layers=1) + envs = sys.modules["fastdeploy.utils.envs"] + envs.ENABLE_V1_KVCACHE_SCHEDULER = True + + with mock.patch("threading.Thread") as thread_cls: + + def _fake_thread(*_args, **_kwargs): + return types.SimpleNamespace(start=lambda: None) + + thread_cls.side_effect = _fake_thread + messager = self.module.CacheMessagerV1( + splitwise_role="prefill", + transfer_protocol="ipc", + pod_ip="127.0.0.1", + engine_worker_queue_port=9006, + local_data_parallel_id=0, + gpu_cache_kvs=cache, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + ) + + messager.cache_info = { + "req-existing": { + "request_id": "req-existing", + "src_block_ids": [0, 1, 2, 3], + "dest_block_ids": [0, 1], + "current_id": 5, + "transfer_protocol": "ipc", + "status": "pending", + "rdma_ports": {0: 0}, + } + } + + queue = messager.engine_worker_queue + queue.cache_info_sequence = [ + [ + { + "request_id": "req-existing", + "src_block_ids": [0, 1, 2, 3], + "dest_block_ids": [0, 1], + "current_id": 5, + "transfer_protocol": "ipc", + }, + { + "request_id": "req-new", + "src_block_ids": [10, 11], + "dest_block_ids": [12, 13], + "current_id": 7, + "transfer_protocol": "rdma", + "status": "pending", + "ip": "10.0.0.5", + "rdma_ports": {0: 4}, + "device_ids": {0: 1}, + }, + ] + ] + queue.stop_after_cache_info = True + + with self.assertRaises(SystemExit): + messager._add_cache_task_thread() + + self.assertEqual(queue.cache_info_barrier.wait_calls, 1) + self.assertEqual(queue.finish_add_cache_task_barrier.wait_calls, 1) + self.assertEqual(queue.finished_add_cache_task_req, [["req-existing"]]) + updated = messager.cache_info["req-existing"] + self.assertEqual(updated["decode_cached_tokens"], 2 * messager.block_size) + self.assertEqual(updated["sended_block_num"], 2) + self.assertIn(5, messager.idx_cache_task_dict) + self.assertIn("req-new", messager.cache_info) + + def test_prefill_layerwise_send_cache_thread_finishes_request(self): + cache = _make_cache_tensors(num_layers=1) + envs = sys.modules["fastdeploy.utils.envs"] + envs.ENABLE_V1_KVCACHE_SCHEDULER = True + + with mock.patch("threading.Thread") as thread_cls: + + def _fake_thread(*_args, **_kwargs): + return types.SimpleNamespace(start=lambda: None) + + thread_cls.side_effect = _fake_thread + messager = self.module.CacheMessagerV1( + splitwise_role="prefill", + transfer_protocol="ipc", + pod_ip="127.0.0.1", + engine_worker_queue_port=9007, + local_data_parallel_id=0, + gpu_cache_kvs=cache, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + ) + + class _QueueStub: + def __init__(self, payloads): + self._payloads = list(payloads) + + def get(self): + if not self._payloads: + raise SystemExit("stop prefill v1") + return self._payloads.pop(0) + + task = { + "request_id": "req-1", + "transfer_protocol": "ipc", + "device_ids": {0: 0}, + "rdma_ports": {0: 0}, + "src_block_ids": [0, 1], + "dest_block_ids": [2, 3], + "status": "init", + "sended_layer_id": -1, + "sended_block_num": 0, + "current_id": 0, + "need_prefill_tokens": 4, + } + + messager.idx_cache_task_dict = {0: task} + messager.cache_info = {"req-1": task} + messager.engine_cache_tasks[0] = {"prefilled_layer_idx": 0, "prefilled_token_num": 4} + messager.cache_prefilled_engine_ids_queue = _QueueStub([[(0, 4)]]) + + with self.assertRaises(SystemExit): + messager.prefill_layerwise_send_cache_thread() + + queue = messager.engine_worker_queue + self.assertEqual(queue.begin_send_cache_barrier.wait_calls, 1) + self.assertEqual(queue.finish_send_cache_barrier.wait_calls, 1) + self.assertEqual(queue.finished_requests, [[["req-1", "finished"]]]) + self.assertEqual(messager.messager["ipc"].sync_targets, [0]) + self.assertNotIn("req-1", messager.cache_info) + + +class CacheMessagerV1ConnectTest(unittest.TestCase): + def setUp(self): + self.module = _load_cache_messager() + envs = sys.modules["fastdeploy.utils.envs"] + envs.FD_ENGINE_TASK_QUEUE_WITH_SHM = False + _IPCSignal.instances.clear() + ops_gpu = sys.modules["fastdeploy.model_executor.ops.gpu"] + ops_gpu.get_output_kv_signal.sequence = [] + + def tearDown(self): + _IPCSignal.instances.clear() + + def test_handle_connect_task_rdma_paths(self): + cache = _make_cache_tensors(num_layers=1) + with mock.patch("threading.Thread") as thread_cls: + + def _fake_thread(*_args, **_kwargs): + return types.SimpleNamespace(start=lambda: None) + + thread_cls.side_effect = _fake_thread + messager = self.module.CacheMessagerV1( + splitwise_role="decode", + transfer_protocol="ipc,rdma", + pod_ip="127.0.0.1", + engine_worker_queue_port=9008, + local_data_parallel_id=0, + gpu_cache_kvs=cache, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + ) + + rdma_manager = messager.messager["rdma"] + rdma_manager.connect_results = [True, False] + + queue = messager.engine_worker_queue + queue.connect_tasks = [ + { + "task_id": 11, + "ip": "10.0.0.1", + "rdma_ports": {0: 5}, + }, + { + "task_id": 12, + "ip": "10.0.0.2", + "rdma_ports": {0: 6}, + }, + ] + queue.stop_after_connect_tasks = True + + with self.assertRaises(SystemExit): + messager._handle_connect_task() + + self.assertEqual( + queue.connect_responses, + [ + {"task_id": 11, "success": True}, + {"task_id": 12, "success": False}, + ], + ) + + +class MainEntryTest(unittest.TestCase): + def setUp(self): + self.module = _load_cache_messager() + envs = sys.modules["fastdeploy.utils.envs"] + envs.FD_ENGINE_TASK_QUEUE_WITH_SHM = False + envs.ENABLE_V1_KVCACHE_SCHEDULER = False + _IPCSignal.instances.clear() + ops_gpu = sys.modules["fastdeploy.model_executor.ops.gpu"] + ops_gpu.get_output_kv_signal.sequence = [] + + def tearDown(self): + _IPCSignal.instances.clear() + + def test_main_initializes_and_triggers_prefill(self): + args = types.SimpleNamespace( + splitwise_role="prefill", + device_id=0, + rank=0, + num_layers=1, + key_cache_shape="2,3,4,5", + value_cache_shape="2,3,4,5", + rdma_port=None, + mp_num=1, + pod_ip="127.0.0.1", + cache_queue_port=9004, + engine_worker_queue_port=9005, + cache_dtype="bfloat16", + speculative_config={"num_extra_cache_layer": 1, "num_gpu_block_expand_ratio": 0}, + protocol="ipc", + engine_pid="42", + local_data_parallel_id=0, + ) + self.module.args = args + + with mock.patch.object( + self.module.CacheMessager, + "prefill_layerwise_send_cache_thread", + side_effect=SystemExit("stop prefill"), + ) as prefill_mock: + with self.assertRaises(SystemExit): + self.module.main() + + prefill_mock.assert_called_once() + + +if __name__ == "__main__": + unittest.main()