From eab4c2d9b61628299b5b55b7ac33eb0aa1d49f7a Mon Sep 17 00:00:00 2001 From: xunyoyo <33387866+xunyoyo@users.noreply.github.com> Date: Sat, 15 Nov 2025 18:27:26 +0800 Subject: [PATCH 1/3] Add unit tests for SplitWiseScheduler module --- tests/scheduler/test_splitwise_scheduler.py | 976 ++++++++++++++++++++ 1 file changed, 976 insertions(+) create mode 100644 tests/scheduler/test_splitwise_scheduler.py diff --git a/tests/scheduler/test_splitwise_scheduler.py b/tests/scheduler/test_splitwise_scheduler.py new file mode 100644 index 00000000000..d04d7c9304b --- /dev/null +++ b/tests/scheduler/test_splitwise_scheduler.py @@ -0,0 +1,976 @@ +"""Unit tests for :mod:`fastdeploy.scheduler.splitwise_scheduler`. + +To generate a focused coverage report for this module, run:: + + python -m coverage run -m unittest tests.scheduler.test_splitwise_scheduler + python -m coverage report -m --include='fastdeploy/scheduler/splitwise_scheduler.py' +""" + +from __future__ import annotations + +import argparse +import importlib +import json +import sys +import time +import types +import unittest +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +PROJECT_ROOT = Path(__file__).resolve().parents[2] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +_MODULE_CACHE = {} + + +def _install_stub_modules() -> None: + """Install lightweight stand-ins for the external dependencies.""" + + if getattr(_install_stub_modules, "_installed", False): + return + + # ------------------------------------------------------------------ orjson + orjson_mod = types.ModuleType("orjson") + + def _dumps(obj: Any) -> bytes: + return json.dumps(obj).encode("utf-8") + + def _loads(data: Any) -> Any: + if isinstance(data, (bytes, bytearray)): + data = data.decode("utf-8") + return json.loads(data) + + orjson_mod.dumps = _dumps # type: ignore[attr-defined] + orjson_mod.loads = _loads # type: ignore[attr-defined] + sys.modules.setdefault("orjson", orjson_mod) + + # ----------------------------------------------------- scheduler logger stub + logger_mod = types.ModuleType("fastdeploy.utils.scheduler_logger") + + def _log(*_args: Any, **_kwargs: Any) -> None: + return None + + logger_mod.info = _log # type: ignore[attr-defined] + logger_mod.error = _log # type: ignore[attr-defined] + logger_mod.debug = _log # type: ignore[attr-defined] + logger_mod.warning = _log # type: ignore[attr-defined] + sys.modules["fastdeploy.utils.scheduler_logger"] = logger_mod + + utils_mod = types.ModuleType("fastdeploy.utils") + utils_mod.scheduler_logger = logger_mod # type: ignore[attr-defined] + sys.modules["fastdeploy.utils"] = utils_mod + + # --------------------------------------------------------------- Redis stubs + class _FakePipeline: + def __init__(self, client: "_FakeRedis") -> None: + self._client = client + self._commands: list[tuple[str, tuple[Any, ...]]] = [] + + def __enter__(self) -> "_FakePipeline": + return self + + def __exit__(self, exc_type, exc, tb) -> None: # type: ignore[override] + return None + + def multi(self) -> "_FakePipeline": + return self + + def lpush(self, key: str, *values: Any) -> "_FakePipeline": + self._commands.append(("lpush", (key, values))) + return self + + def expire(self, key: str, ttl: int) -> "_FakePipeline": + self._commands.append(("expire", (key, ttl))) + return self + + def execute(self) -> None: + for name, params in self._commands: + if name == "lpush": + key, values = params + self._client.lpush(key, *values) + elif name == "expire": + key, ttl = params + self._client.expire(key, ttl) + self._commands.clear() + + class _FakeRedis: + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.storage: dict[str, list[Any]] = {} + self.hashes: dict[str, dict[Any, Any]] = {} + self.expirations: dict[str, int] = {} + + # ------------------------------- list operations used by the scheduler + def lpush(self, key: str, *values: Any) -> None: + items = list(values) + if not items: + return + bucket = self.storage.setdefault(key, []) + for value in items: + bucket.insert(0, value) + + def rpop(self, key: str, count: Optional[int] = None) -> Optional[list[Any]]: + bucket = self.storage.get(key) + if not bucket: + return None + if count is None: + return [bucket.pop()] + count = min(count, len(bucket)) + values = [bucket.pop() for _ in range(count)] + return values + + def brpop(self, keys: Iterable[str], timeout: int = 0): # type: ignore[override] + for key in keys: + bucket = self.storage.get(key) + if bucket: + return (key, bucket.pop()) + return None + + # ------------------------------------------ hash operations for cluster + def hset(self, key: str, field: str, value: Any) -> None: + self.hashes.setdefault(key, {})[field] = value + + def hgetall(self, key: str) -> dict[Any, Any]: + return {k: v for k, v in self.hashes.get(key, {}).items()} + + def hdel(self, key: str, field: str) -> None: + if key in self.hashes: + self.hashes[key].pop(field, None) + + # -------------------------------------------------------------- misc ops + def expire(self, key: str, ttl: int) -> None: + self.expirations[key] = ttl + + def pipeline(self) -> _FakePipeline: + return _FakePipeline(self) + + redis_mod = types.ModuleType("redis") + redis_mod.Redis = _FakeRedis # type: ignore[attr-defined] + sys.modules.setdefault("redis", redis_mod) + + # ------------------------------------------- fastdeploy.engine.request stub + request_mod = types.ModuleType("fastdeploy.engine.request") + + @dataclass + class CompletionOutput: + index: int + send_idx: int + token_ids: List[int] + finished: bool = False + + def to_dict(self) -> Dict[str, Any]: + return { + "index": self.index, + "send_idx": self.send_idx, + "token_ids": list(self.token_ids), + "finished": self.finished, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CompletionOutput": + return cls( + index=data.get("index", 0), + send_idx=data.get("send_idx", 0), + token_ids=list(data.get("token_ids", [])), + finished=data.get("finished", False), + ) + + @dataclass + class RequestMetrics: + arrival_time: float + inference_start_time: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "arrival_time": self.arrival_time, + "inference_start_time": self.inference_start_time, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RequestMetrics": + return cls( + arrival_time=data.get("arrival_time", time.time()), + inference_start_time=data.get("inference_start_time"), + ) + + class Request: + def __init__( + self, + request_id: str, + prompt: Optional[str] = None, + prompt_token_ids: Optional[List[int]] = None, + prompt_token_ids_len: int = 0, + arrival_time: Optional[float] = None, + disaggregate_info: Optional[Dict[str, Any]] = None, + ) -> None: + self.request_id = request_id + self.prompt = prompt or "" + self.prompt_token_ids = prompt_token_ids or [] + self.prompt_token_ids_len = prompt_token_ids_len + self.arrival_time = arrival_time if arrival_time is not None else time.time() + self.disaggregate_info = disaggregate_info + + def to_dict(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "prompt": self.prompt, + "prompt_token_ids": list(self.prompt_token_ids), + "prompt_token_ids_len": self.prompt_token_ids_len, + "arrival_time": self.arrival_time, + "disaggregate_info": self.disaggregate_info, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Request": + return cls( + request_id=data["request_id"], + prompt=data.get("prompt"), + prompt_token_ids=data.get("prompt_token_ids"), + prompt_token_ids_len=data.get("prompt_token_ids_len", 0), + arrival_time=data.get("arrival_time", time.time()), + disaggregate_info=data.get("disaggregate_info"), + ) + + class RequestOutput: + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + outputs: CompletionOutput, + metrics: RequestMetrics, + finished: bool = False, + error_code: int = 200, + error_msg: Optional[str] = None, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.outputs = outputs + self.metrics = metrics + self.finished = finished + self.error_code = error_code + self.error_msg = error_msg + + def to_dict(self) -> Dict[str, Any]: + return { + "request_id": self.request_id, + "prompt": self.prompt, + "prompt_token_ids": list(self.prompt_token_ids), + "outputs": self.outputs.to_dict(), + "metrics": self.metrics.to_dict(), + "finished": self.finished, + "error_code": self.error_code, + "error_msg": self.error_msg, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RequestOutput": + return cls( + request_id=data["request_id"], + prompt=data.get("prompt", ""), + prompt_token_ids=list(data.get("prompt_token_ids", [])), + outputs=CompletionOutput.from_dict(data.get("outputs", {})), + metrics=RequestMetrics.from_dict(data.get("metrics", {})), + finished=data.get("finished", False), + error_code=data.get("error_code", 200), + error_msg=data.get("error_msg"), + ) + + request_mod.CompletionOutput = CompletionOutput # type: ignore[attr-defined] + request_mod.RequestMetrics = RequestMetrics # type: ignore[attr-defined] + request_mod.Request = Request # type: ignore[attr-defined] + request_mod.RequestOutput = RequestOutput # type: ignore[attr-defined] + sys.modules["fastdeploy.engine.request"] = request_mod + + # --------------------------------------------------------------- package stubs + fd_pkg = types.ModuleType("fastdeploy") + fd_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy")] + sys.modules.setdefault("fastdeploy", fd_pkg) + + scheduler_pkg = types.ModuleType("fastdeploy.scheduler") + scheduler_pkg.__path__ = [str(PROJECT_ROOT / "fastdeploy" / "scheduler")] + sys.modules.setdefault("fastdeploy.scheduler", scheduler_pkg) + + _install_stub_modules._installed = True + + +def _import_splitwise_scheduler(): + """Import the scheduler module with the stub environment.""" + + if "module" in _MODULE_CACHE: + return _MODULE_CACHE["module"] + + _install_stub_modules() + module = importlib.import_module("fastdeploy.scheduler.splitwise_scheduler") + _MODULE_CACHE["module"] = module + return module + + +class _PatchedThread: + def __init__(self, *args: Any, target=None, **kwargs: Any) -> None: # type: ignore[override] + self._target = target + self.started = False + + def start(self) -> None: + self.started = True + + +class SplitWiseSchedulerTestCase(unittest.TestCase): + def setUp(self) -> None: + self.module = _import_splitwise_scheduler() + self._orig_thread = self.module.threading.Thread + self.module.threading.Thread = _PatchedThread # type: ignore[assignment] + + def tearDown(self) -> None: + self.module.threading.Thread = self._orig_thread # type: ignore[assignment] + + +class SplitWiseSchedulerConfigTest(SplitWiseSchedulerTestCase): + def test_threshold_defaults_to_model_ratio(self) -> None: + config = self.module.SplitWiseSchedulerConfig( + enable_chunked_prefill=True, + max_num_partial_prefills=5, + max_long_partial_prefills=3, + max_model_len=1000, + ) + self.assertEqual(config.long_prefill_token_threshold, 40) + self.assertEqual(config.expire_period, 3.0) + + def test_check_and_print_cover_logging(self) -> None: + config = self.module.SplitWiseSchedulerConfig( + enable_chunked_prefill=True, + max_num_partial_prefills=1, + max_long_partial_prefills=1, + max_model_len=50, + ) + config.check() + config.print() + + +class NodeInfoTest(SplitWiseSchedulerTestCase): + def test_serialization_and_expiration(self) -> None: + node = self.module.NodeInfo( + nodeid="node-1", + role="prefill", + host="localhost", + disaggregated={"transfer_protocol": ["ipc", "rdma"]}, + load=2, + ) + + payload = node.serialize() + loaded = self.module.NodeInfo.load_from("node-1", payload) + self.assertFalse(loaded.expired(10)) + + loaded.ts -= 20 + self.assertTrue(loaded.expired(1)) + + loaded.add_req("req-1", 4) + self.assertIn("req-1", loaded.reqs) + + loaded.update_req_timestamp(["req-1"]) + before = loaded.reqs["req-1"][1] + loaded.reqs["req-1"][1] -= 1000 + loaded.expire_reqs(ttl=1) + self.assertNotIn("req-1", loaded.reqs) + + loaded.add_req("req-2", 2) + loaded.finish_req("req-2") + self.assertNotIn("req-2", loaded.reqs) + self.assertNotEqual(before, loaded.ts) + + def test_comparisons(self) -> None: + low = self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1) + high = self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5) + self.assertTrue(low < high) + self.assertIn("a(1)", repr(low)) + + +class ResultReaderTest(SplitWiseSchedulerTestCase): + def test_read_groups_partial_outputs(self) -> None: + client = sys.modules["redis"].Redis() + reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="group-a") + + req = self.module.Request("req-A", prompt_token_ids_len=3) + reader.add_req(req) + + metrics = self.module.RequestMetrics(arrival_time=time.time()) + first = self.module.RequestOutput( + request_id="req-A", + prompt="", + prompt_token_ids=[], + outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1, 2]), + metrics=metrics, + finished=False, + ) + follow = self.module.RequestOutput( + request_id="req-A", + prompt="", + prompt_token_ids=[], + outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[3]), + metrics=metrics, + finished=True, + ) + + reader.data.appendleft(follow) + reader.data.appendleft(first) + + outputs = reader.read() + self.assertIn("req-A", outputs) + self.assertEqual(len(outputs["req-A"]), 2) + + def test_sync_results_converts_payloads(self) -> None: + client = sys.modules["redis"].Redis() + reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="") + + metrics = self.module.RequestMetrics(arrival_time=time.time()) + ro = self.module.RequestOutput( + request_id="req-B", + prompt="p", + prompt_token_ids=[1], + outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[4]), + metrics=metrics, + finished=True, + ) + + payload = self.module.orjson.dumps(ro.to_dict()) + client.storage.setdefault("req-key", []).append(payload) + + total = reader.sync_results(["req-key"]) + self.assertEqual(total, 1) + self.assertTrue(reader.data) + + def test_read_uses_out_buffer(self) -> None: + client = sys.modules["redis"].Redis() + reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp") + + req = self.module.Request("req-out", prompt_token_ids_len=2) + reader.add_req(req) + + metrics = self.module.RequestMetrics(arrival_time=time.time()) + head = self.module.RequestOutput( + request_id="req-out", + prompt="", + prompt_token_ids=[], + outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]), + metrics=metrics, + finished=False, + ) + tail = self.module.RequestOutput( + request_id="req-out", + prompt="", + prompt_token_ids=[], + outputs=self.module.CompletionOutput(index=0, send_idx=2, token_ids=[2, 3]), + metrics=metrics, + finished=True, + ) + + with reader.lock: + reader.out_buffer[req.request_id] = [tail] + reader.data.appendleft(head) + + outputs = reader.read() + self.assertEqual(len(outputs["req-out"]), 2) + + def test_sync_results_with_group_override(self) -> None: + client = sys.modules["redis"].Redis() + reader = self.module.ResultReader(client, idx=0, batch=10, ttl=30, group="grp") + + metrics = self.module.RequestMetrics(arrival_time=time.time()) + ro = self.module.RequestOutput( + request_id="req-group", + prompt="", + prompt_token_ids=[], + outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[7]), + metrics=metrics, + finished=True, + ) + payload = self.module.orjson.dumps(ro.to_dict()) + client.storage.setdefault("grp", []).append(payload) + + total = reader.sync_results(["unused"]) + self.assertEqual(total, 1) + self.assertEqual(reader.data[-1].request_id, "req-group") + + def test_run_emits_expired_placeholder(self) -> None: + client = sys.modules["redis"].Redis() + reader = self.module.ResultReader(client, idx=0, batch=10, ttl=1, group="") + reader.reqs["old"] = {"arrival_time": time.time() - 5} + original_sleep = self.module.time.sleep + self.module.time.sleep = lambda *_args, **_kwargs: (_ for _ in ()).throw(SystemExit()) + try: + with self.assertRaises(SystemExit): + reader.run() + finally: + self.module.time.sleep = original_sleep + self.assertNotIn("old", reader.reqs) + self.assertTrue(reader.data) + + +class APISchedulerTest(SplitWiseSchedulerTestCase): + def _make_config(self) -> Any: + return self.module.SplitWiseSchedulerConfig( + enable_chunked_prefill=True, + max_num_partial_prefills=5, + max_long_partial_prefills=3, + max_model_len=200, + ) + + def test_schedule_mixed_node_uses_single_queue(self) -> None: + config = self._make_config() + scheduler = self.module.APIScheduler(config) + + req = self.module.Request("req-1", prompt_token_ids_len=10) + mixed = self.module.NodeInfo("mixed", "mixed", "host-a", {"transfer_protocol": ["ipc"]}, load=1) + scheduler.select_pd = lambda *args, **kwargs: mixed # type: ignore[assignment] + + scheduler.schedule(req, [mixed], [], [], group="g0") + key = f"ReqQ_{mixed.nodeid}" + self.assertIn(key, scheduler.client.storage) + stored = scheduler.client.storage[key][0] + decoded = self.module.orjson.loads(stored) + self.assertEqual(decoded["group"], "g0") + self.assertIsNone(decoded["disaggregate_info"]) + + def test_schedule_disaggregated_updates_protocol(self) -> None: + config = self._make_config() + scheduler = self.module.APIScheduler(config) + + req = self.module.Request("req-2", prompt_token_ids_len=10) + prefill = self.module.NodeInfo("prefill", "prefill", "host-a", {"transfer_protocol": ["ipc"]}, load=1) + decode = self.module.NodeInfo( + "decode", + "decode", + "host-b", + {"transfer_protocol": ["ipc", "rdma"]}, + load=1, + ) + + def _select(req_obj, nodes, role): + return nodes[0] + + scheduler.select_pd = _select # type: ignore[assignment] + + scheduler.schedule(req, [prefill], [decode], [], group="") + self.assertIn("ReqQ_prefill", scheduler.client.storage) + self.assertIn("ReqQ_decode", scheduler.client.storage) + + decoded = self.module.orjson.loads(scheduler.client.storage["ReqQ_prefill"][0]) + self.assertEqual(decoded["disaggregate_info"]["transfer_protocol"], "rdma") + + def test_sync_cluster_filters_expired_nodes(self) -> None: + config = self._make_config() + scheduler = self.module.APIScheduler(config) + + fresh = self.module.NodeInfo("n1", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=1) + scheduler.client.hset(scheduler.cluster_key, fresh.nodeid.encode(), fresh.serialize()) + + stale_payload = self.module.orjson.dumps( + { + "ts": time.time() - (config.expire_period + 1), + "role": "prefill", + "load": 1, + "host": "h", + "disaggregated": {"transfer_protocol": ["ipc"]}, + } + ) + scheduler.client.hset(scheduler.cluster_key, b"n2", stale_payload) + + pnodes, _, _ = scheduler.sync_cluster() + self.assertEqual([node.nodeid for node in pnodes], ["n1"]) + + def test_start_put_and_get_results(self) -> None: + config = self._make_config() + scheduler = self.module.APIScheduler(config) + scheduler.start() + + reqs = [self.module.Request(f"req-{i}", prompt_token_ids_len=1) for i in range(2)] + result = scheduler.put_requests(reqs) + self.assertEqual(len(result), 2) + + fake_output = {"a": ["value"]} + scheduler.readers = [types.SimpleNamespace(read=lambda: fake_output)] + outputs = scheduler.get_results() + self.assertEqual(outputs, fake_output) + + def test_select_pd_prefill_and_decode(self) -> None: + config = self._make_config() + scheduler = self.module.APIScheduler(config) + + req = self.module.Request("req-select", prompt_token_ids_len=50) + prefill_nodes = [ + self.module.NodeInfo("a", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=5), + self.module.NodeInfo("b", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=20), + ] + decode_nodes = [ + self.module.NodeInfo("c", "decode", "h", {"transfer_protocol": ["ipc"]}, load=1), + self.module.NodeInfo("d", "decode", "h", {"transfer_protocol": ["ipc"]}, load=2), + ] + + original_choice = self.module.random.choice + self.module.random.choice = lambda seq: seq[-1] # type: ignore[assignment] + try: + picked_prefill = scheduler.select_pd(req, prefill_nodes, "prefill") + picked_decode = scheduler.select_pd(req, decode_nodes, "decode") + finally: + self.module.random.choice = original_choice + + self.assertEqual(picked_prefill.nodeid, "b") + self.assertEqual(picked_decode.nodeid, "d") + + with self.assertRaises(Exception): + scheduler.select_pd(req, prefill_nodes, "unknown") + + +class InferSchedulerTest(SplitWiseSchedulerTestCase): + def _make_config(self, **overrides: Any) -> Any: + base = dict( + enable_chunked_prefill=True, + max_num_partial_prefills=3, + max_long_partial_prefills=1, + max_model_len=200, + ) + base.update(overrides) + return self.module.SplitWiseSchedulerConfig(**base) + + def test_get_requests_limits_partial_prefills(self) -> None: + config = self._make_config(long_prefill_token_threshold=5) + infer = self.module.InferScheduler(config) + infer.role = "prefill" + infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0) + + long = self.module.Request("req-long", prompt_token_ids_len=10) + longer = self.module.Request("req-longer", prompt_token_ids_len=12) + infer.reqs_queue.extend([longer, long]) + + picked = infer.get_requests( + available_blocks=100, + block_size=4, + reserved_output_blocks=1, + max_num_batched_tokens=100, + batch=5, + ) + self.assertEqual([req.request_id for req in picked], ["req-longer"]) + self.assertEqual([req.request_id for req in infer.reqs_queue], ["req-long"]) + + def test_get_requests_non_chunked_uses_token_cap(self) -> None: + config = self._make_config(enable_chunked_prefill=False) + infer = self.module.InferScheduler(config) + infer.role = "prefill" + infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0) + + infer.reqs_queue.extend( + [ + self.module.Request("req-1", prompt_token_ids_len=10), + self.module.Request("req-2", prompt_token_ids_len=20), + ] + ) + + picked = infer.get_requests( + available_blocks=100, + block_size=4, + reserved_output_blocks=1, + max_num_batched_tokens=15, + batch=5, + ) + self.assertEqual([req.request_id for req in picked], ["req-1"]) + self.assertEqual(len(infer.reqs_queue), 1) + + def test_put_results_groups_by_writer_index(self) -> None: + config = self._make_config() + infer = self.module.InferScheduler(config) + infer.role = "prefill" + infer.node = self.module.NodeInfo("n", "prefill", "h", {"transfer_protocol": ["ipc"]}, load=0) + + class _Writer: + def __init__(self) -> None: + self.items: list[tuple[str, list[bytes]]] = [] + + def put(self, key: str, items: list[bytes]) -> None: + self.items.append((key, items)) + + infer.writers = [_Writer(), _Writer()] + infer.node.add_req("req#0#g", 1) + + metrics = self.module.RequestMetrics(arrival_time=time.time()) + result = self.module.RequestOutput( + request_id="req#0#g", + prompt="", + prompt_token_ids=[], + outputs=self.module.CompletionOutput(index=0, send_idx=0, token_ids=[1]), + metrics=metrics, + finished=True, + ) + + infer.put_results([result]) + self.assertEqual(len(infer.writers[0].items), 1) + key, payloads = infer.writers[0].items[0] + self.assertEqual(key, "g") + decoded = self.module.orjson.loads(payloads[0]) + self.assertFalse(decoded["finished"]) + + def test_put_results_handles_errors(self) -> None: + config = self._make_config() + infer = self.module.InferScheduler(config) + infer.role = "decode" + infer.node = self.module.NodeInfo("n", "decode", "h", {"transfer_protocol": ["ipc"]}, load=0) + + class _Writer: + def __init__(self) -> None: + self.items = [] + + def put(self, key: str, items: list[bytes]) -> None: + self.items.append((key, items)) + + infer.writers = [_Writer()] + infer.node.add_req("bad#0#", 1) + + metrics = self.module.RequestMetrics(arrival_time=time.time()) + result = self.module.RequestOutput( + request_id="bad#0#", + prompt="", + prompt_token_ids=[], + outputs=self.module.CompletionOutput(index=0, send_idx=1, token_ids=[1]), + metrics=metrics, + finished=True, + error_code=500, + ) + + infer.put_results([result]) + self.assertFalse(infer.node.reqs) + + def test_start_initializes_writers(self) -> None: + config = self._make_config() + infer = self.module.InferScheduler(config) + infer.start("prefill", "host", {"transfer_protocol": ["ipc"]}) + self.assertEqual(len(infer.writers), config.writer_parallel) + + +class SplitWiseSchedulerFacadeTest(SplitWiseSchedulerTestCase): + def test_facade_delegates_to_components(self) -> None: + module = self.module + + class _FakeAPI: + def __init__(self, _config: Any) -> None: + self.started = False + self.reqs: List[Any] = [] + + def start(self) -> None: + self.started = True + + def put_requests(self, reqs: List[Any]): + self.reqs.extend(reqs) + return [(req.request_id, None) for req in reqs] + + def get_results(self): + return {"x": 1} + + class _FakeInfer: + def __init__(self, _config: Any) -> None: + self.started = False + self.nodeid = None + + def start(self, role, host, disaggregated): + self.started = True + + def get_requests(self, *args, **kwargs): + return ["scheduled"] + + def put_results(self, results): + return list(results) + + original_api = module.APIScheduler + original_infer = module.InferScheduler + module.APIScheduler = _FakeAPI # type: ignore[assignment] + module.InferScheduler = _FakeInfer # type: ignore[assignment] + + try: + config = module.SplitWiseSchedulerConfig( + enable_chunked_prefill=True, + max_num_partial_prefills=1, + max_long_partial_prefills=1, + max_model_len=10, + ) + facade = module.SplitWiseScheduler(config) + + facade.start("prefill", "host", {"tp": "ipc"}) + self.assertTrue(facade.scheduler.started) + self.assertTrue(facade.infer.started) + + reqs = [module.Request("req", prompt_token_ids_len=1)] + result = facade.put_requests(reqs) + self.assertEqual(result[0][0], "req") + self.assertEqual(facade.get_results(), {"x": 1}) + + scheduled = facade.get_requests(10, 1, 1, 10, batch=1) + self.assertEqual(scheduled, ["scheduled"]) + + outputs = facade.put_results([1, 2]) + self.assertEqual(outputs, [1, 2]) + finally: + module.APIScheduler = original_api # type: ignore[assignment] + module.InferScheduler = original_infer # type: ignore[assignment] + + def test_get_requests_with_insufficient_resources(self) -> None: + module = self.module + config = module.SplitWiseSchedulerConfig( + enable_chunked_prefill=True, + max_num_partial_prefills=1, + max_long_partial_prefills=1, + max_model_len=10, + ) + facade = module.SplitWiseScheduler(config) + facade.infer = types.SimpleNamespace(get_requests=lambda *args, **kwargs: ["should not reach"]) + facade.scheduler = types.SimpleNamespace() + + result = facade.get_requests( + available_blocks=1, block_size=1, reserved_output_blocks=2, max_num_batched_tokens=10 + ) + self.assertEqual(result, []) + + result = facade.get_requests( + available_blocks=10, block_size=1, reserved_output_blocks=2, max_num_batched_tokens=10, batch=0 + ) + self.assertEqual(result, []) + + def test_start_uses_real_components(self) -> None: + module = self.module + config = module.SplitWiseSchedulerConfig( + enable_chunked_prefill=True, + max_num_partial_prefills=1, + max_long_partial_prefills=1, + max_model_len=10, + ) + facade = module.SplitWiseScheduler(config) + + infer_flags = {} + scheduler_flags = {} + + facade.infer = types.SimpleNamespace( + start=lambda role, host, disagg: infer_flags.setdefault("called", (role, host, disagg)), + ) + facade.scheduler = types.SimpleNamespace(start=lambda: scheduler_flags.setdefault("called", True)) + + facade.start("prefill", "host", {"mode": "ipc"}) + self.assertEqual(infer_flags["called"], ("prefill", "host", {"mode": "ipc"})) + self.assertTrue(scheduler_flags["called"]) + facade.reset_nodeid("new-id") + self.assertEqual(facade.scheduler.nodeid, "new-id") + + +class BackgroundWorkerTest(SplitWiseSchedulerTestCase): + def test_result_writer_run_single_iteration(self) -> None: + client = sys.modules["redis"].Redis() + writer = self.module.ResultWriter(client, idx=0, batch=5, ttl=10) + with writer.cond: + writer.data.appendleft(("key", b"payload")) + + class _Pipeline: + def __init__(self, parent): + self.parent = parent + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return None + + def multi(self): + return self + + def lpush(self, key, *items): + self.parent.lpush(key, *items) + return self + + def expire(self, key, ttl): + raise SystemExit() + + def execute(self): + return None + + client.pipeline = lambda: _Pipeline(client) # type: ignore[assignment] + + with self.assertRaises(SystemExit): + writer.run() + + def test_infer_scheduler_routine_report(self) -> None: + config = self.module.SplitWiseSchedulerConfig( + enable_chunked_prefill=True, + max_num_partial_prefills=1, + max_long_partial_prefills=1, + max_model_len=10, + ) + infer = self.module.InferScheduler(config) + infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0) + + def _fake_hset(*_args, **_kwargs): + raise SystemExit() + + infer.client.hset = _fake_hset # type: ignore[assignment] + + with self.assertRaises(SystemExit): + infer.routine_report() + + def test_infer_scheduler_loop_expire_reqs(self) -> None: + config = self.module.SplitWiseSchedulerConfig( + enable_chunked_prefill=True, + max_num_partial_prefills=1, + max_long_partial_prefills=1, + max_model_len=10, + ) + infer = self.module.InferScheduler(config) + infer.node = self.module.NodeInfo("nid", "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0) + + def _raise_exit(ttl): + raise SystemExit() + + infer.node.expire_reqs = _raise_exit # type: ignore[assignment] + + with self.assertRaises(SystemExit): + infer.loop_expire_reqs() + + def test_infer_scheduler_loop_get_reqs(self) -> None: + config = self.module.SplitWiseSchedulerConfig( + enable_chunked_prefill=True, + max_num_partial_prefills=1, + max_long_partial_prefills=1, + max_model_len=10, + ) + infer = self.module.InferScheduler(config) + infer.role = "prefill" + infer.node = self.module.NodeInfo(infer.nodeid, "prefill", "host", {"transfer_protocol": ["ipc"]}, load=0) + infer.writers = [types.SimpleNamespace(put=lambda key, items: None)] + + req = self.module.Request("rq", prompt_token_ids_len=3) + payload = self.module.orjson.dumps(dict(req.to_dict(), group="")) + key = f"ReqQ_{infer.nodeid}" + infer.client.storage[key] = [payload] + + state = {"called": False} + + def _fake_rpop(k, batch): + if not state["called"]: + state["called"] = True + return infer.client.storage[k][:] + raise SystemExit() + + infer.client.rpop = _fake_rpop # type: ignore[assignment] + infer.client.brpop = lambda *_args, **_kwargs: None # type: ignore[assignment] + + with self.assertRaises(SystemExit): + infer.loop_get_reqs() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--print-coverage-command", action="store_true") + known_args, remaining = parser.parse_known_args() + + if known_args.print_coverage_command: + print("python -m coverage run -m unittest tests.scheduler.test_splitwise_scheduler") + print("python -m coverage report -m --include='fastdeploy/scheduler/splitwise_scheduler.py'") + + unittest.main(argv=[sys.argv[0]] + remaining) From 559e4919f0f287d8478c4e49303a89ae333451c1 Mon Sep 17 00:00:00 2001 From: xunyoyo <33387866+xunyoyo@users.noreply.github.com> Date: Mon, 17 Nov 2025 19:40:02 +0800 Subject: [PATCH 2/3] Add info and ping to fake redis client for tests --- tests/scheduler/test_splitwise_scheduler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/scheduler/test_splitwise_scheduler.py b/tests/scheduler/test_splitwise_scheduler.py index d04d7c9304b..d8909de9d03 100644 --- a/tests/scheduler/test_splitwise_scheduler.py +++ b/tests/scheduler/test_splitwise_scheduler.py @@ -147,6 +147,12 @@ def expire(self, key: str, ttl: int) -> None: def pipeline(self) -> _FakePipeline: return _FakePipeline(self) + def info(self) -> dict[str, str]: + return {"redis_version": "6.2.0"} + + def ping(self) -> bool: + return True + redis_mod = types.ModuleType("redis") redis_mod.Redis = _FakeRedis # type: ignore[attr-defined] sys.modules.setdefault("redis", redis_mod) From 30ba10104afe06fa754e0139228b0bf7ab19f763 Mon Sep 17 00:00:00 2001 From: xunyoyo <33387866+xunyoyo@users.noreply.github.com> Date: Mon, 17 Nov 2025 19:47:11 +0800 Subject: [PATCH 3/3] Document fake redis metadata methods in tests --- tests/scheduler/test_splitwise_scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/scheduler/test_splitwise_scheduler.py b/tests/scheduler/test_splitwise_scheduler.py index d8909de9d03..71ba808631a 100644 --- a/tests/scheduler/test_splitwise_scheduler.py +++ b/tests/scheduler/test_splitwise_scheduler.py @@ -147,9 +147,11 @@ def expire(self, key: str, ttl: int) -> None: def pipeline(self) -> _FakePipeline: return _FakePipeline(self) + # Metadata required by InferScheduler.check_redis_version def info(self) -> dict[str, str]: return {"redis_version": "6.2.0"} + # Health check used by InferScheduler.start def ping(self) -> bool: return True