Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/backend/server/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ async def _forward_request(self, request_data: Dict, request_id: str, received_t
except Exception as e:
logger.exception(f"get_routing_table error: {e}")
return JSONResponse(
content={"error": "Routing table not found"},
content={"error": "Get routing table error"},
status_code=500,
)

# None -> scheduler has not set yet; treat as hard error (no waiting here)
if routing_table is None:
return JSONResponse(
content={"error": "Routing not ready"},
content={"error": "Routing pipelines not ready"},
status_code=503,
)

Expand All @@ -89,7 +89,7 @@ async def _forward_request(self, request_data: Dict, request_id: str, received_t
# If still empty after retries, return 429 Too Many Requests
if routing_table is not None and len(routing_table) == 0:
return JSONResponse(
content={"error": "All pipelines are busy. Please retry later."},
content={"error": "All pipelines are busy or not ready. Please retry later."},
status_code=429,
)

Expand Down
21 changes: 10 additions & 11 deletions src/scheduling/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
from dataclasses import dataclass, field
from math import floor
from typing import Callable, Dict, List, Optional
from typing import Dict, List, Optional

from parallax_utils.logging_config import get_logger
from parallax_utils.utils import bytes_per_element, compute_max_batch_size
Expand Down Expand Up @@ -189,7 +189,6 @@ class Node:
load_compensator: float = 0.05

rtt_to_nodes: Optional[Dict[str, float]] = None
rtt_getter: Optional[Callable[["Node", "Node"], float]] = None

_force_max_concurrent_requests: bool = False

Expand Down Expand Up @@ -375,19 +374,19 @@ def update_rtt(self, target_node_id: str, rtt_ms: float):
self.rtt_to_nodes[target_node_id] = rtt_ms

def get_rtt_to(self, other: "Node") -> float:
"""Get RTT to another node, measuring via `rtt_getter` if needed.
"""Get RTT to another node from cached RTTs.

Falls back to 0.0 if no getter is provided and no cached RTT exists.
Returns:
RTT in milliseconds, or float("inf") if no cached RTT exists.
"""
if self == other:
return 0.0
if other.node_id in self.rtt_to_nodes:
return self.rtt_to_nodes[other.node_id]
if self.rtt_getter is None:
return 0.0
rtt_ms = float(self.rtt_getter(self, other))
self.update_rtt(other.node_id, rtt_ms)
return rtt_ms
if self.rtt_to_nodes is None:
return float("inf")
if other.node_id not in self.rtt_to_nodes:
logger.warning("Cannot find RTT from node %s to node %s", self.node_id, other.node_id)
return float("inf")
return self.rtt_to_nodes[other.node_id]

def hosts_layer(self, layer_id: int) -> bool:
"""Return True if this node hosts the given layer id.
Expand Down
5 changes: 3 additions & 2 deletions src/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def update_node_info(
if layer_latency_ms is not None:
node.set_layer_latency_ms(layer_latency_ms)
if new_rtt_to_nodes is not None:
node.rtt_to_nodes.update(new_rtt_to_nodes)
node.rtt_to_nodes = new_rtt_to_nodes
if is_active is not None:
node.is_active = is_active
node.last_heartbeat = time.time()
Expand Down Expand Up @@ -422,7 +422,8 @@ def _dispatch_loop(self, poll_interval: float) -> None:
req = self._request_queue.get(timeout=poll_interval)
if req is None:
continue
path, _ = self.request_router.find_optimal_path(self.nodes, self.num_layers)
path, path_rtt = self.request_router.find_optimal_path(self.nodes, self.num_layers)
logger.debug(f"Path RTT: {path_rtt}")
req.routing_table = path
for node_id in path:
n = self.node_id_to_node[node_id]
Expand Down
22 changes: 22 additions & 0 deletions tests/scheduler_tests/test_request_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ def test_optimal_path_single_node():
assert latency == pytest.approx(float(n.layer_latency_ms), rel=1e-6)


def test_optimal_path_missing_rtt():
"""If RTT is missing between two nodes in a path, it should be invalid."""
num_layers = 12
model = build_model(num_layers)
n1 = build_node("n1", model, tflops=200.0, x=0.0, y=0.0)
n2 = build_node("n2", model, tflops=200.0, x=1.0, y=0.0)
n1.set_layer_allocation(0, 6)
n2.set_layer_allocation(6, 12)
nodes = [n1, n2]
set_rtt_from_coords(nodes)

# Manually remove the RTT info between n1 and n2
if n2.node_id in n1.rtt_to_nodes:
del n1.rtt_to_nodes[n2.node_id]

router = DynamicProgrammingRouting()
node_ids, latency = router.find_optimal_path(nodes, num_layers)

assert node_ids == []
assert latency == float("inf")


@pytest.mark.parametrize(
"num_layers,segments,expected_path",
[
Expand Down
45 changes: 17 additions & 28 deletions tests/scheduler_tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,18 @@

from __future__ import annotations

from scheduling.model_info import ModelInfo
from scheduling.node import Node, NodeHardwareInfo, RequestSignal
from scheduling.node import RequestSignal
from scheduling.scheduler import Scheduler

from .test_utils import build_model_info


def _build_node(node_id: str, model: ModelInfo, *, tflops: float, mem_gb: float) -> Node:
hw = NodeHardwareInfo(
node_id=node_id,
tflops_fp16=tflops,
gpu_name="",
memory_gb=mem_gb,
memory_bandwidth_gbps=1000.0,
device="cuda",
)
n = Node(node_id=node_id, hardware=hw, model_info=model)
# Ensure latency estimation uses a defined speedup
setattr(n, "quantization_speedup", 1.0)
return n
from .test_utils import build_model_info, build_node, set_rtt_from_coords


def test_scheduler_initialize_and_dispatch():
"""Allocate, then enqueue one request and dispatch it."""
model = build_model_info(12)
n1 = _build_node("a100-0", model, tflops=312.0, mem_gb=80.0)
n2 = _build_node("a100-1", model, tflops=312.0, mem_gb=80.0)
n1 = build_node("a100-0", model, tflops=312.0, mem_gb=80.0, x=0, y=0)
n2 = build_node("a100-1", model, tflops=312.0, mem_gb=80.0, x=1, y=0)
set_rtt_from_coords([n1, n2])

sched = Scheduler(model, [n1, n2], strategy="greedy", min_nodes_bootstrapping=1)
sched.layer_allocator.global_allocation()
Expand All @@ -54,12 +39,13 @@ def test_scheduler_initialize_and_dispatch():
def test_scheduler_join_and_leave():
"""New node can join and be assigned; leave removes it and may rebalance."""
model = build_model_info(12)
n1 = _build_node("a100-0", model, tflops=312.0, mem_gb=80.0)
n2 = _build_node("a100-1", model, tflops=312.0, mem_gb=80.0)
n1 = build_node("a100-0", model, tflops=312.0, mem_gb=80.0, x=0, y=0)
n2 = build_node("a100-1", model, tflops=312.0, mem_gb=80.0, x=1, y=0)
set_rtt_from_coords([n1, n2])
sched = Scheduler(model, [n1, n2], strategy="greedy", min_nodes_bootstrapping=1)

# Join a new node
n3 = _build_node("rtx4090-x", model, tflops=82.6, mem_gb=24.0)
n3 = build_node("rtx4090-x", model, tflops=82.6, mem_gb=24.0, x=0, y=1)
sched.join(n3)
assert n3.start_layer is not None and n3.end_layer is not None

Expand All @@ -72,7 +58,7 @@ def test_scheduler_bootstrap_wait_and_dynamic_events():
"""Scheduler waits for min nodes, bootstraps, then handles join/leave events."""
model = build_model_info(12)
# Start with no nodes assigned yet; bootstrap needs 2
n1 = _build_node("a100-0", model, tflops=312.0, mem_gb=80.0)
n1 = build_node("a100-0", model, tflops=312.0, mem_gb=80.0, x=0, y=0)
sched = Scheduler(model, [], strategy="dp", min_nodes_bootstrapping=2)

# Enqueue one join; should not bootstrap yet (insufficient nodes)
Expand All @@ -83,15 +69,17 @@ def test_scheduler_bootstrap_wait_and_dynamic_events():
assert not sched.layer_allocator.has_full_pipeline()

# Add second node and process join; now bootstrap should succeed
n2 = _build_node("5090-1", model, tflops=165.0, mem_gb=32.0)
n2 = build_node("5090-1", model, tflops=165.0, mem_gb=32.0, x=1, y=0)
sched.enqueue_join(n2)
sched._process_joins() # type: ignore[attr-defined]
# RTTs are needed for DP routing strategy
set_rtt_from_coords(sched.nodes)
ok = sched.bootstrap()
assert ok
assert sched.layer_allocator.has_full_pipeline()

# Dynamic join after bootstrap should assign immediately
n3 = _build_node("rtx4090-x", model, tflops=82.6, mem_gb=24.0)
n3 = build_node("rtx4090-x", model, tflops=82.6, mem_gb=24.0, x=0, y=1)
sched.enqueue_join(n3)
sched._process_joins() # type: ignore[attr-defined]
assert n3.start_layer is not None and n3.end_layer is not None
Expand Down Expand Up @@ -124,7 +112,8 @@ def test_scheduler_single_node_leave_then_rejoin_reassigns_layers():
model = build_model_info(12)

# Start with a single capable node and bootstrap successfully
n1 = _build_node("solo-0", model, tflops=312.0, mem_gb=80.0)
n1 = build_node("solo-0", model, tflops=312.0, mem_gb=80.0, x=0, y=0)
set_rtt_from_coords([n1])
sched = Scheduler(model, [n1], strategy="dp", min_nodes_bootstrapping=1)
ok = sched.bootstrap()
assert ok
Expand All @@ -136,7 +125,7 @@ def test_scheduler_single_node_leave_then_rejoin_reassigns_layers():
assert not sched.layer_allocator.has_full_pipeline()

# Re-join the (same) node id; scheduler should re-assign layers
n1_rejoin = _build_node("solo-0", model, tflops=312.0, mem_gb=80.0)
n1_rejoin = build_node("solo-0", model, tflops=312.0, mem_gb=80.0, x=0, y=0)
sched.enqueue_join(n1_rejoin)
sched._process_joins() # type: ignore[attr-defined]

Expand Down
25 changes: 15 additions & 10 deletions tests/scheduler_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,21 @@ def to_latency(d: float) -> float:


def set_rtt_from_coords(nodes: List[Node]) -> None:
"""Attach an RTT getter to each node based on their coordinates."""
rtts = compute_rtts_from_coords(nodes)

def getter(src: Node, dst: Node) -> float:
if src.node_id == dst.node_id:
return 0.0
return rtts.get((src.node_id, dst.node_id), 200.0)

for n in nodes:
n.rtt_getter = getter
"""Populate `rtt_to_nodes` on each node based on their coordinates."""
all_rtts = compute_rtts_from_coords(nodes)
node_map = {n.node_id: n for n in nodes}
ids = list(node_map.keys())

for aid in ids:
node_a = node_map[aid]
if node_a.rtt_to_nodes is None:
node_a.rtt_to_nodes = {}
for bid in ids:
if aid == bid:
continue
rtt = all_rtts.get((aid, bid))
if rtt is not None:
node_a.rtt_to_nodes[bid] = rtt


def geo_rtt_provider(positions: Dict[str, Tuple[float, float]]):
Expand Down