From 63f43fcc4391a267c3b82570200ef0b85c021d43 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 31 Oct 2025 14:13:31 +0800 Subject: [PATCH 01/29] debug log --- src/parallax/server/shard_loader.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index be03e00f..af7bbe54 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -165,7 +165,16 @@ def load( shard_weights = {} layer_key_prefix = "model.layers" # Common prefix - for wf in weight_files: + logger.debug( + f"Loading shard layers [{current_start_layer}, {current_end_layer}) from {len(weight_files)} weight files" + ) + loaded_keys_count = 0 + + for file_idx, wf in enumerate(weight_files): + logger.debug( + f"Scanning weight file {file_idx + 1}/{len(weight_files)}: {pathlib.Path(wf).name}" + ) + file_loaded_count = 0 # For bf16 models, we need torch tensors as a bridge with safetensors.safe_open(wf, framework="pt") as f: for key in f.keys(): @@ -213,6 +222,13 @@ def load( # If the key is needed, load only that tensor from the file if is_needed: shard_weights[remapped_key] = mx.array(f.get_tensor(key)) + loaded_keys_count += 1 + file_loaded_count += 1 + + if file_loaded_count > 0: + logger.debug(f" Loaded {file_loaded_count} tensors from {pathlib.Path(wf).name}") + else: + logger.debug(f" Skipped {pathlib.Path(wf).name} (no relevant layers)") if (quantization := config.get("quantization", None)) is not None: logger.info("Model is quantized. Applying quantization parameters...") @@ -250,6 +266,10 @@ def class_predicate(p, m): class_predicate=class_predicate, ) + logger.debug( + f"Loaded {loaded_keys_count} weight tensors for shard layers [{current_start_layer}, {current_end_layer})" + ) + model_shard.load_weights(list(shard_weights.items()), strict=strict) if not lazy: From e7a8c9226e9f4a69afd100cfc2497ad8962871aa Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 3 Nov 2025 15:02:19 +0800 Subject: [PATCH 02/29] update --- src/backend/server/rpc_connection_handler.py | 1 + src/parallax/p2p/server.py | 28 ++++++++++-- src/scheduling/node.py | 48 ++++++++------------ src/scheduling/scheduler.py | 40 ++++++++++++++-- 4 files changed, 82 insertions(+), 35 deletions(-) diff --git a/src/backend/server/rpc_connection_handler.py b/src/backend/server/rpc_connection_handler.py index 0288f26c..f7ae40eb 100644 --- a/src/backend/server/rpc_connection_handler.py +++ b/src/backend/server/rpc_connection_handler.py @@ -164,6 +164,7 @@ def build_node(self, node_json: dict): max_concurrent_requests=node_json.get("max_concurrent_requests"), max_sequence_length=node_json.get("max_sequence_length"), is_active=node_json.get("is_active", True), + manual_layer_assignment=node_json.get("manual_layer_assignment", False), ) if node_json.get("start_layer", None) is not None: node.start_layer = node_json.get("start_layer") diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index d373820e..6eda755d 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -240,6 +240,7 @@ def __init__( self.rtt_last_update = 0 self.rtt_update_interval = 60 self.status = ServerState.JOINING + self.manual_layer_assignment = block_end_index is not None and block_start_index is not None self.scheduler_stub = None self.scheduler_peer_id = None @@ -327,6 +328,10 @@ def run(self): self.lattica = None time.sleep(10) return self.run() + + if self.manual_layer_assignment: + node_info["manual_layer_assignment"] = True + response = self.scheduler_stub.node_join(node_info) response = response.result(timeout=300) if response == {}: @@ -335,8 +340,9 @@ def run(self): logger.info(f"Join scheduler response: {response}") - self.block_start_index = response.get("start_layer") - self.block_end_index = response.get("end_layer") + if not self.manual_layer_assignment: + self.block_start_index = response.get("start_layer") + self.block_end_index = response.get("end_layer") self.model_name = response.get("model_name") # Publish executor metrics to backend on each update @@ -637,13 +643,27 @@ def get_node_info(self, is_update: bool = False): "is_active": self.status == ServerState.READY, } + # For manual layer assignment, always include start_layer and end_layer + if self.manual_layer_assignment: + info["start_layer"] = self.block_start_index + info["end_layer"] = self.block_end_index + logger.info( + f"Manual assignment: sending start_layer={self.block_start_index}, " + f"end_layer={self.block_end_index}" + ) + if is_update: metrics = get_metrics() info["current_requests"] = metrics.get("current_requests", 0) if metrics.get("layer_latency_ms") is not None: info["layer_latency_ms"] = metrics.get("layer_latency_ms") - info["start_layer"] = self.block_start_index - info["end_layer"] = self.block_end_index + + logger.debug(f"start_index: {self.block_start_index}") + logger.debug(f"end_index: {self.block_end_index}") + # In update mode, always include current allocation + if not self.manual_layer_assignment: + info["start_layer"] = self.block_start_index + info["end_layer"] = self.block_end_index return info diff --git a/src/scheduling/node.py b/src/scheduling/node.py index d33276c7..2094ccb2 100644 --- a/src/scheduling/node.py +++ b/src/scheduling/node.py @@ -72,6 +72,7 @@ def __init__( batch_size: int = 1, target_seq_len: int = 1, source_seq_len: int = 256, + using_mlx: bool = False, ) -> None: self.tflops = hardware.tflops_fp16 self.io_bandwidth = hardware.memory_bandwidth_gbps @@ -80,6 +81,7 @@ def __init__( self.batch_size = batch_size self.target_seq_len = target_seq_len self.source_seq_len = source_seq_len + self.using_mlx = using_mlx def get_compute_roofline_latency_ms(self, flops: int) -> float: """Compute-bound latency in milliseconds for the given floating-point ops.""" @@ -108,14 +110,14 @@ def roofline_layer_latency_ms( self, include_input_embed: bool = False, include_lm_head: bool = False, - num_decoder_layers: int = 1, + num_current_layers: int = 1, ) -> float: """Estimate latency to execute the specified layer set on this node. Args: include_input_embed: Whether to include input embedding I/O include_lm_head: Whether to include LM head compute and I/O - num_decoder_layers: Number of decoder layers included + num_current_layers: Number of decoder layers included Returns: Total latency (ms) combining decoder layers and optional endpoints. @@ -127,14 +129,15 @@ def roofline_layer_latency_ms( source_seq_len=self.source_seq_len, ) ) - decoder_layer_io_latency = self.get_io_roofline_latency_ms( - self.model_info.decoder_layer_io_bytes( - roofline=True, - batch_size=self.batch_size, - target_seq_len=self.target_seq_len, - source_seq_len=self.source_seq_len, - ) + model_btyes = self.model_info.decoder_layer_io_bytes( + roofline=True, + batch_size=self.batch_size, + target_seq_len=self.target_seq_len, + source_seq_len=self.source_seq_len, ) + if self.using_mlx: + model_btyes *= self.model_info.mlx_bit_factor + decoder_layer_io_latency = self.get_io_roofline_latency_ms(model_btyes) # For first / last layers flops, io_bytes = 0, 0 @@ -149,9 +152,9 @@ def roofline_layer_latency_ms( compute_time_ms = self.get_compute_roofline_latency_ms(flops) io_time_ms = self.get_io_roofline_latency_ms(io_bytes) return ( - num_decoder_layers * max(decoder_layer_compute_latency, decoder_layer_io_latency) + num_current_layers * max(decoder_layer_compute_latency, decoder_layer_io_latency) + max(compute_time_ms, io_time_ms) - ) / num_decoder_layers + ) / num_current_layers @dataclass @@ -176,6 +179,7 @@ class Node: max_concurrent_requests: int = 16 max_sequence_length: int = 4096 + manual_layer_assignment: bool = False start_layer: Optional[int] = None # inclusive end_layer: Optional[int] = None # exclusive current_requests: int = 0 @@ -218,7 +222,7 @@ def max_requests(self) -> int: max_sequence_len=self.max_sequence_length, device=None, kv_cache_memory_fraction=self.kv_cache_ratio, - num_shard_layers=self.num_decoder_layers, + num_shard_layers=self.num_current_layers, num_key_value_heads=self.model_info.num_kv_heads, head_dim=self.model_info.head_size, elem_bytes=elem_bytes, @@ -242,19 +246,6 @@ def num_current_layers(self) -> int: return 0 return self.end_layer - self.start_layer - @property - def num_decoder_layers(self) -> int: - """Number of decoder layers.""" - if self.start_layer is None or self.end_layer is None: - return 0 - start_layer = self.start_layer + 1 if self.has_embedding else self.start_layer - end_layer = self.end_layer - 1 if self.has_lm_head else self.end_layer - if start_layer >= end_layer: - raise ValueError( - f"Node {self.node_id} has invalid decoder layer range: start_layer {start_layer} <= end_layer {end_layer}" - ) - return end_layer - start_layer - @property def has_embedding(self) -> bool: """Check if this node hosts the embedding layer (layer 0).""" @@ -312,11 +303,11 @@ def get_decoder_layer_capacity( @property def per_decoder_layer_kv_cache_memory(self) -> Optional[int]: """Return the available memory for kv cache per layer.""" - if self.num_decoder_layers == 0: + if self.num_current_layers == 0: return None return floor( (self.hardware.memory_gb * 1024 * 1024 * 1024 * self.kv_cache_ratio) - / self.num_decoder_layers + / self.num_current_layers ) def set_layer_allocation(self, start_layer: int, end_layer: int) -> None: @@ -349,11 +340,12 @@ def roofline_layer_latency_ms(self) -> float: batch_size=self.current_requests, target_seq_len=1, source_seq_len=self.max_sequence_length, + using_mlx=self.hardware.device == "mlx", ) return perf_model.roofline_layer_latency_ms( include_input_embed=self.has_embedding, include_lm_head=self.has_lm_head, - num_decoder_layers=self.num_decoder_layers, + num_current_layers=self.num_current_layers, ) @property diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index 2112f3c6..ef120a32 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -250,14 +250,43 @@ def checking_node_heartbeat(self) -> None: def join(self, node: Node, bootstrap: bool = False) -> None: """Add a node to allocation and refresh plan and materialized nodes.""" logger.debug( - "Joining node %s (kv_ratio=%.2f, param_ratio=%.2f)", + "Joining node %s (kv_ratio=%.2f, param_ratio=%.2f, manual_assignment=%s)", node.node_id, node.kv_cache_ratio, node.param_hosting_ratio, + node.manual_layer_assignment, ) self.layer_allocator.declare(node) - if not bootstrap: + + # Manual layer assignment bypasses bootstrap waiting + if node.manual_layer_assignment: + # Manual layer assignment: use the layers specified by the node + if node.start_layer is None or node.end_layer is None: + raise ValueError( + f"Node {node.node_id} has manual_layer_assignment=True " + f"but start_layer ({node.start_layer}) or end_layer ({node.end_layer}) is None" + ) + logger.info( + f"Manual layer assignment for node {node.node_id}: " + f"layers [{node.start_layer}, {node.end_layer})" + ) + # Directly allocate the specified layers without automatic assignment + self.layer_allocator.allocate(node, node.start_layer, node.end_layer) + + # Check if manual allocations now cover the full pipeline + if self.layer_allocator.has_full_pipeline(): + if not self._bootstrapped: + logger.info( + "Manual layer assignments have established a full pipeline; " + "marking scheduler as bootstrapped" + ) + self._bootstrapped = True + self._bootstrapped_event.set() + elif not bootstrap: + # Automatic layer assignment (only after bootstrap) self.layer_allocator.join(node) + # If bootstrap=True and not manual, node is only declared (allocation deferred to bootstrap()) + # Notify waiters that node count changed with self._node_count_cv: self._node_count_cv.notify_all() @@ -470,6 +499,7 @@ def _process_node_updates(self) -> None: def _process_joins(self) -> None: """Handle pending join events, honoring bootstrap state for assignment.""" joined_any = False + had_manual_assignment = False while True: try: node = self._pending_joins.get_nowait() @@ -477,14 +507,18 @@ def _process_joins(self) -> None: break # During bootstrap (no full pipeline yet), only declare nodes; no dynamic assignment. # After bootstrap, allow dynamic light-weight joins. + # Exception: manual layer assignments are processed immediately regardless of bootstrap state. self.join(node, bootstrap=not self._bootstrapped_event.is_set()) joined_any = True + if node.manual_layer_assignment: + had_manual_assignment = True # If we are not bootstrapped (e.g., after a leave-triggered rebalance) and # new nodes just joined, attempt a greedy bootstrap immediately when we have # enough nodes. If it doesn't produce a full pipeline, we'll try again on # subsequent joins. - if joined_any and not self._bootstrapped_event.is_set(): + # Skip bootstrap if manual assignments were used (they handle bootstrapping internally). + if joined_any and not self._bootstrapped_event.is_set() and not had_manual_assignment: if len(self.nodes) >= self.min_nodes_bootstrapping: try: ok = self.bootstrap() From 4e09501d2c9984706de070f52d7710836f424917 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 3 Nov 2025 15:14:32 +0800 Subject: [PATCH 03/29] remove debug logger --- src/parallax/p2p/server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index 6eda755d..5d9b43df 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -657,9 +657,6 @@ def get_node_info(self, is_update: bool = False): info["current_requests"] = metrics.get("current_requests", 0) if metrics.get("layer_latency_ms") is not None: info["layer_latency_ms"] = metrics.get("layer_latency_ms") - - logger.debug(f"start_index: {self.block_start_index}") - logger.debug(f"end_index: {self.block_end_index}") # In update mode, always include current allocation if not self.manual_layer_assignment: info["start_layer"] = self.block_start_index From 7ffe14760324a273df235533b0d97c220db32660 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 3 Nov 2025 19:51:09 +0800 Subject: [PATCH 04/29] update --- src/parallax/server/shard_loader.py | 58 +++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index af7bbe54..157f6d1c 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -157,6 +157,56 @@ def load( if not weight_files: weight_files = glob.glob(str(model_path / "weight*.safetensors")) + # Sort weight files by name for consistent loading order + weight_files = sorted(weight_files) + + # Filter weight files based on layer range using index file + index_file = model_path / "model.safetensors.index.json" + if index_file.exists(): + import json + + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + needed_files = set() + + # Check which files contain layers we need + for key, filename in weight_map.items(): + if filename in needed_files: + continue + should_include = False + if model_shard.is_first_shard and "embed_tokens" in key: + should_include = True + elif model_shard.is_last_shard: + if "model.norm" in key or "lm_head" in key: + should_include = True + elif config.get("tie_word_embeddings", False) and "embed_tokens" in key: + should_include = True + + if "model.layers." in key: + parts = key.split(".") + layer_idx = int(parts[2]) + if current_start_layer <= layer_idx < current_end_layer: + should_include = True + + if should_include: + full_path = str(model_path / filename) + needed_files.add(full_path) + + # Filter weight_files to only include needed ones + if needed_files: + weight_files = [wf for wf in weight_files if wf in needed_files] + logger.info( + f"Filtered to {len(weight_files)} weight files (out of {len(glob.glob(str(model_path / 'model*.safetensors')))} total) " + f"for layers [{current_start_layer}, {current_end_layer})" + ) + else: + logger.warning("No relevant weight files found in index, will scan all files") + + else: + logger.debug("No index file found, will scan all weight files") + if not weight_files and strict: raise FileNotFoundError(f"No safetensors found in {model_path}") @@ -248,10 +298,10 @@ def class_predicate(p, m): prefixed = f"model.{p}" if prefixed in qcfg: override = qcfg[prefixed] - if isinstance(override, dict): - logger.debug( - f"[quantize] Using override for '{prefixed}' (mapped to '{p}'): bits={override.get('bits')} group_size={override.get('group_size')}" - ) + # if isinstance(override, dict): + # logger.debug( + # f"[quantize] Using override for '{prefixed}' (mapped to '{p}'): bits={override.get('bits')} group_size={override.get('group_size')}" + # ) return override if not hasattr(m, "to_quantized"): return False From f7a02613b1d8fceeb3f6500dab15de15fecfb2e9 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 3 Nov 2025 20:13:12 +0800 Subject: [PATCH 05/29] update --- src/parallax/server/shard_loader.py | 11 - src/parallax/sglang/monkey_patch.py | 4 + .../weight_loader_filter.py | 195 ++++++++++++++++++ 3 files changed, 199 insertions(+), 11 deletions(-) create mode 100644 src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 157f6d1c..63217748 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -157,10 +157,6 @@ def load( if not weight_files: weight_files = glob.glob(str(model_path / "weight*.safetensors")) - # Sort weight files by name for consistent loading order - weight_files = sorted(weight_files) - - # Filter weight files based on layer range using index file index_file = model_path / "model.safetensors.index.json" if index_file.exists(): import json @@ -171,7 +167,6 @@ def load( weight_map = index_data.get("weight_map", {}) needed_files = set() - # Check which files contain layers we need for key, filename in weight_map.items(): if filename in needed_files: continue @@ -193,14 +188,8 @@ def load( if should_include: full_path = str(model_path / filename) needed_files.add(full_path) - - # Filter weight_files to only include needed ones if needed_files: weight_files = [wf for wf in weight_files if wf in needed_files] - logger.info( - f"Filtered to {len(weight_files)} weight files (out of {len(glob.glob(str(model_path / 'model*.safetensors')))} total) " - f"for layers [{current_start_layer}, {current_end_layer})" - ) else: logger.warning("No relevant weight files found in index, will scan all files") diff --git a/src/parallax/sglang/monkey_patch.py b/src/parallax/sglang/monkey_patch.py index 3bf53067..ca422438 100644 --- a/src/parallax/sglang/monkey_patch.py +++ b/src/parallax/sglang/monkey_patch.py @@ -17,6 +17,9 @@ from parallax.sglang.monkey_patch_utils.triton_backend import ( apply_triton_backend_init_monkey_patch, ) +from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( + apply_weight_loader_filter_patch, +) ## Here is some patch func for sglang @@ -29,3 +32,4 @@ def apply_parallax_sglang_monkey_patch(): apply_glm4_moe_monkey_patch() apply_triton_backend_init_monkey_patch() apply_model_parallel_monkey_patch() + apply_weight_loader_filter_patch() diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py new file mode 100644 index 00000000..c8fefbaf --- /dev/null +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -0,0 +1,195 @@ +""" +Monkey patch for SGLang/vLLM weight loader to filter safetensors files based on layer range. + +This reduces I/O and memory usage by only loading weight files that contain layers +in the [pp_start_layer, pp_end_layer) range. +""" + +import json +import logging +from pathlib import Path +from typing import List, Optional, Set + +logger = logging.getLogger(__name__) + + +def filter_weight_files_by_layer_range( + model_path: Path, + weight_files: List[str], + pp_start_layer: int, + pp_end_layer: int, + is_first_shard: bool, + is_last_shard: bool, +) -> List[str]: + """ + Filter weight files based on layer range using model.safetensors.index.json. + + Args: + model_path: Path to the model directory + weight_files: List of all weight file paths + pp_start_layer: Starting layer index (inclusive) + pp_end_layer: Ending layer index (exclusive) + is_first_shard: Whether this is the first shard (needs embedding) + is_last_shard: Whether this is the last shard (needs lm_head and norm) + + Returns: + Filtered list of weight files containing only relevant layers + """ + index_file = model_path / "model.safetensors.index.json" + + if not index_file.exists(): + logger.debug(f"No index file found at {index_file}, will load all weight files") + return weight_files + + try: + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + if not weight_map: + logger.warning("weight_map is empty in index file, will load all weight files") + return weight_files + + needed_files: Set[str] = set() + + # Check which files contain layers/weights we need + for key, filename in weight_map.items(): + should_include = False + + # Check for embedding layer (first shard) + if is_first_shard and "embed_tokens" in key: + should_include = True + logger.debug(f"Including {filename} for embedding layer: {key}") + + # Check for lm_head and norm (last shard) + if is_last_shard: + if "model.norm" in key or "lm_head" in key: + should_include = True + logger.debug(f"Including {filename} for lm_head/norm: {key}") + + # Check for decoder layers in range + # Common patterns: "model.layers.0.", "layers.0.", "model.decoder.layers.0." + if "layers." in key: + try: + # Try to extract layer number from key + parts = key.split(".") + + # Find the "layers" index and get the next part as layer number + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts): + layer_idx = int(parts[i + 1]) + if pp_start_layer <= layer_idx < pp_end_layer: + should_include = True + logger.debug(f"Including {filename} for layer {layer_idx}: {key}") + break + except (ValueError, IndexError): + # If we can't parse the layer number, include it to be safe + logger.debug(f"Could not parse layer number from {key}, including to be safe") + should_include = True + + if should_include: + # Convert relative filename to full path + full_path = str(model_path / filename) + needed_files.add(full_path) + + # Filter weight_files to only include needed ones + if needed_files: + filtered_files = [wf for wf in weight_files if wf in needed_files] + logger.info( + f"Filtered weight files from {len(weight_files)} to {len(filtered_files)} " + f"for layers [{pp_start_layer}, {pp_end_layer})" + ) + logger.debug(f"Needed files: {[Path(f).name for f in filtered_files]}") + return filtered_files + else: + logger.warning( + f"No relevant weight files found in index for layers [{pp_start_layer}, {pp_end_layer}), " + "will load all files" + ) + return weight_files + + except Exception as e: + logger.warning(f"Failed to filter weight files using index file: {e}, will load all files") + return weight_files + + +def apply_weight_loader_filter_patch(): + """ + Apply monkey patch to filter weight files before loading. + + This patches the get_model_filenames function in vLLM/SGLang to filter + out weight files that don't contain layers in the current shard's range. + """ + try: + from sglang.srt.model_loader import weight_utils + from sglang.srt.distributed import get_pp_group + + original_get_model_filenames = weight_utils.get_model_filenames + + def patched_get_model_filenames(model_name_or_path: str, **kwargs): + """Patched version that filters weight files by layer range.""" + # Get original file list + weight_files = original_get_model_filenames(model_name_or_path, **kwargs) + + # Try to get PP group info + try: + pp_group = get_pp_group() + if pp_group is None: + logger.debug("No PP group found, skipping weight file filtering") + return weight_files + + pp_start_layer = getattr(pp_group, "pp_start_layer", None) + pp_end_layer = getattr(pp_group, "pp_end_layer", None) + + if pp_start_layer is None or pp_end_layer is None: + logger.debug( + f"PP layer range not set (start={pp_start_layer}, end={pp_end_layer}), " + "skipping weight file filtering" + ) + return weight_files + + model_path = Path(model_name_or_path) + is_first_shard = pp_start_layer == 0 + + # We need to know the total number of layers to determine if this is the last shard + # For now, we'll assume if end_layer is very large, it's the last shard + # A more robust solution would read the config file + is_last_shard = False + try: + config_file = model_path / "config.json" + if config_file.exists(): + with open(config_file, "r") as f: + config = json.load(f) + num_hidden_layers = config.get("num_hidden_layers", 0) + is_last_shard = pp_end_layer >= num_hidden_layers + except Exception as e: + logger.debug(f"Could not determine if last shard: {e}") + + logger.info( + f"Filtering weight files for shard: layers [{pp_start_layer}, {pp_end_layer}), " + f"is_first={is_first_shard}, is_last={is_last_shard}" + ) + + filtered_files = filter_weight_files_by_layer_range( + model_path=model_path, + weight_files=weight_files, + pp_start_layer=pp_start_layer, + pp_end_layer=pp_end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, + ) + + return filtered_files + + except Exception as e: + logger.warning(f"Error in weight file filtering: {e}, using all files") + return weight_files + + # Apply the patch + weight_utils.get_model_filenames = patched_get_model_filenames + logger.info("Applied weight loader filter patch") + + except ImportError as e: + logger.warning(f"Could not import SGLang weight_utils, skipping patch: {e}") + except Exception as e: + logger.error(f"Failed to apply weight loader filter patch: {e}") From cbf198b7d61f3e9067a06910c8132c997f099ab1 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 3 Nov 2025 13:19:15 +0000 Subject: [PATCH 06/29] update gpu load --- src/parallax/sglang/model_runner.py | 10 + .../weight_loader_filter.py | 178 ++++++++---------- 2 files changed, 84 insertions(+), 104 deletions(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 2bf8a0f5..d629d01c 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -67,6 +67,16 @@ def __init__( """Add pp_start_layer and pp_end_layer for decentralized model""" self.pp_start_layer = pp_start_layer self.pp_end_layer = pp_end_layer + + # Set layer range for weight file filtering before model loading + from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( + set_layer_range_for_filtering, + ) + + num_hidden_layers = model_config.hf_config.num_hidden_layers + is_last_shard = pp_end_layer >= num_hidden_layers + set_layer_range_for_filtering(pp_start_layer, pp_end_layer, is_last_shard) + super().__init__( model_config=model_config, mem_fraction_static=mem_fraction_static, diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index c8fefbaf..b7d3dbfd 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -3,6 +3,11 @@ This reduces I/O and memory usage by only loading weight files that contain layers in the [pp_start_layer, pp_end_layer) range. + +Usage: + 1. Call apply_weight_loader_filter_patch() during initialization + 2. Call set_layer_range_for_filtering() before loading model weights + 3. Model weights will be automatically filtered based on the layer range """ import json @@ -21,24 +26,15 @@ def filter_weight_files_by_layer_range( is_first_shard: bool, is_last_shard: bool, ) -> List[str]: + """Filter weight files to only those containing layers in the specified range. + + Supports both safetensors (.safetensors) and PyTorch (.bin/.pt) formats. """ - Filter weight files based on layer range using model.safetensors.index.json. - - Args: - model_path: Path to the model directory - weight_files: List of all weight file paths - pp_start_layer: Starting layer index (inclusive) - pp_end_layer: Ending layer index (exclusive) - is_first_shard: Whether this is the first shard (needs embedding) - is_last_shard: Whether this is the last shard (needs lm_head and norm) - - Returns: - Filtered list of weight files containing only relevant layers - """ + # Try safetensors index first index_file = model_path / "model.safetensors.index.json" if not index_file.exists(): - logger.debug(f"No index file found at {index_file}, will load all weight files") + logger.debug(f"No index file found at {model_path}, will load all weight files") return weight_files try: @@ -52,29 +48,21 @@ def filter_weight_files_by_layer_range( needed_files: Set[str] = set() - # Check which files contain layers/weights we need for key, filename in weight_map.items(): should_include = False - # Check for embedding layer (first shard) if is_first_shard and "embed_tokens" in key: should_include = True logger.debug(f"Including {filename} for embedding layer: {key}") - # Check for lm_head and norm (last shard) if is_last_shard: if "model.norm" in key or "lm_head" in key: should_include = True logger.debug(f"Including {filename} for lm_head/norm: {key}") - # Check for decoder layers in range - # Common patterns: "model.layers.0.", "layers.0.", "model.decoder.layers.0." if "layers." in key: try: - # Try to extract layer number from key parts = key.split(".") - - # Find the "layers" index and get the next part as layer number for i, part in enumerate(parts): if part == "layers" and i + 1 < len(parts): layer_idx = int(parts[i + 1]) @@ -83,16 +71,13 @@ def filter_weight_files_by_layer_range( logger.debug(f"Including {filename} for layer {layer_idx}: {key}") break except (ValueError, IndexError): - # If we can't parse the layer number, include it to be safe logger.debug(f"Could not parse layer number from {key}, including to be safe") should_include = True if should_include: - # Convert relative filename to full path full_path = str(model_path / filename) needed_files.add(full_path) - # Filter weight_files to only include needed ones if needed_files: filtered_files = [wf for wf in weight_files if wf in needed_files] logger.info( @@ -113,83 +98,68 @@ def filter_weight_files_by_layer_range( return weight_files -def apply_weight_loader_filter_patch(): - """ - Apply monkey patch to filter weight files before loading. +_layer_range_cache = {} - This patches the get_model_filenames function in vLLM/SGLang to filter - out weight files that don't contain layers in the current shard's range. - """ - try: - from sglang.srt.model_loader import weight_utils - from sglang.srt.distributed import get_pp_group - - original_get_model_filenames = weight_utils.get_model_filenames - - def patched_get_model_filenames(model_name_or_path: str, **kwargs): - """Patched version that filters weight files by layer range.""" - # Get original file list - weight_files = original_get_model_filenames(model_name_or_path, **kwargs) - - # Try to get PP group info - try: - pp_group = get_pp_group() - if pp_group is None: - logger.debug("No PP group found, skipping weight file filtering") - return weight_files - - pp_start_layer = getattr(pp_group, "pp_start_layer", None) - pp_end_layer = getattr(pp_group, "pp_end_layer", None) - - if pp_start_layer is None or pp_end_layer is None: - logger.debug( - f"PP layer range not set (start={pp_start_layer}, end={pp_end_layer}), " - "skipping weight file filtering" - ) - return weight_files - - model_path = Path(model_name_or_path) - is_first_shard = pp_start_layer == 0 - - # We need to know the total number of layers to determine if this is the last shard - # For now, we'll assume if end_layer is very large, it's the last shard - # A more robust solution would read the config file - is_last_shard = False - try: - config_file = model_path / "config.json" - if config_file.exists(): - with open(config_file, "r") as f: - config = json.load(f) - num_hidden_layers = config.get("num_hidden_layers", 0) - is_last_shard = pp_end_layer >= num_hidden_layers - except Exception as e: - logger.debug(f"Could not determine if last shard: {e}") - - logger.info( - f"Filtering weight files for shard: layers [{pp_start_layer}, {pp_end_layer}), " - f"is_first={is_first_shard}, is_last={is_last_shard}" - ) - - filtered_files = filter_weight_files_by_layer_range( - model_path=model_path, - weight_files=weight_files, - pp_start_layer=pp_start_layer, - pp_end_layer=pp_end_layer, - is_first_shard=is_first_shard, - is_last_shard=is_last_shard, - ) - - return filtered_files - - except Exception as e: - logger.warning(f"Error in weight file filtering: {e}, using all files") - return weight_files - - # Apply the patch - weight_utils.get_model_filenames = patched_get_model_filenames - logger.info("Applied weight loader filter patch") - - except ImportError as e: - logger.warning(f"Could not import SGLang weight_utils, skipping patch: {e}") - except Exception as e: - logger.error(f"Failed to apply weight loader filter patch: {e}") + +def set_layer_range_for_filtering(pp_start_layer: int, pp_end_layer: int, is_last_shard: bool): + global _layer_range_cache + _layer_range_cache["pp_start_layer"] = pp_start_layer + _layer_range_cache["pp_end_layer"] = pp_end_layer + _layer_range_cache["is_last_shard"] = is_last_shard + logger.info( + f"Set layer range for weight filtering: [{pp_start_layer}, {pp_end_layer}), " + f"is_last={is_last_shard}" + ) + + +def apply_weight_loader_filter_patch(): + from sglang.srt.model_loader import weight_utils + + original_safetensors_iterator = weight_utils.safetensors_weights_iterator + + def patched_safetensors_weights_iterator( + hf_weights_files: List[str], + is_all_weights_sharded: bool = False, + decryption_key: Optional[str] = None, + disable_mmap: bool = False, + ): + filtered_files = _filter_weight_files_by_cache(hf_weights_files) + return original_safetensors_iterator( + filtered_files, is_all_weights_sharded, decryption_key, disable_mmap + ) + + def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: + global _layer_range_cache + + pp_start_layer = _layer_range_cache.get("pp_start_layer") + pp_end_layer = _layer_range_cache.get("pp_end_layer") + is_last_shard = _layer_range_cache.get("is_last_shard", False) + + if pp_start_layer is None or pp_end_layer is None: + logger.debug("No layer range set, loading all weight files") + return hf_weights_files + + if not hf_weights_files: + return hf_weights_files + + model_path = Path(hf_weights_files[0]).parent + is_first_shard = pp_start_layer == 0 + + logger.info( + f"Filtering weight files for layers [{pp_start_layer}, {pp_end_layer}), " + f"is_first={is_first_shard}, is_last={is_last_shard}" + ) + + filtered_files = filter_weight_files_by_layer_range( + model_path=model_path, + weight_files=hf_weights_files, + pp_start_layer=pp_start_layer, + pp_end_layer=pp_end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, + ) + + return filtered_files + + weight_utils.safetensors_weights_iterator = patched_safetensors_weights_iterator + logger.debug("Applied weight loader filter patch to safetensors and pt weight iterators") From 321fe22841e1d5647b0a631f81ca36e05b7ef3a2 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 3 Nov 2025 13:26:17 +0000 Subject: [PATCH 07/29] update --- src/parallax/sglang/monkey_patch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/parallax/sglang/monkey_patch.py b/src/parallax/sglang/monkey_patch.py index ca422438..c7de6f13 100644 --- a/src/parallax/sglang/monkey_patch.py +++ b/src/parallax/sglang/monkey_patch.py @@ -25,11 +25,11 @@ ## Here is some patch func for sglang ## Hopefully, when sglang support pipeline parallelism natively, we can remove these patches def apply_parallax_sglang_monkey_patch(): + apply_model_parallel_monkey_patch() + apply_weight_loader_filter_patch() + apply_triton_backend_init_monkey_patch() apply_qwen3_next_monkey_patch() apply_qwen3_next_config_monkey_patch() apply_gpt_oss_monkey_patch() apply_minimax_m2_monkey_patch() apply_glm4_moe_monkey_patch() - apply_triton_backend_init_monkey_patch() - apply_model_parallel_monkey_patch() - apply_weight_loader_filter_patch() From 52db26f3927b69d5b87cf80a6cc4d16eb6143ed5 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 3 Nov 2025 13:48:16 +0000 Subject: [PATCH 08/29] update gpu load --- src/parallax/sglang/monkey_patch.py | 2 +- .../weight_loader_filter.py | 129 +++++++----------- 2 files changed, 51 insertions(+), 80 deletions(-) diff --git a/src/parallax/sglang/monkey_patch.py b/src/parallax/sglang/monkey_patch.py index c7de6f13..4b2d45c9 100644 --- a/src/parallax/sglang/monkey_patch.py +++ b/src/parallax/sglang/monkey_patch.py @@ -26,8 +26,8 @@ ## Hopefully, when sglang support pipeline parallelism natively, we can remove these patches def apply_parallax_sglang_monkey_patch(): apply_model_parallel_monkey_patch() - apply_weight_loader_filter_patch() apply_triton_backend_init_monkey_patch() + apply_weight_loader_filter_patch() apply_qwen3_next_monkey_patch() apply_qwen3_next_config_monkey_patch() apply_gpt_oss_monkey_patch() diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index b7d3dbfd..c4d36cfa 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -1,15 +1,3 @@ -""" -Monkey patch for SGLang/vLLM weight loader to filter safetensors files based on layer range. - -This reduces I/O and memory usage by only loading weight files that contain layers -in the [pp_start_layer, pp_end_layer) range. - -Usage: - 1. Call apply_weight_loader_filter_patch() during initialization - 2. Call set_layer_range_for_filtering() before loading model weights - 3. Model weights will be automatically filtered based on the layer range -""" - import json import logging from pathlib import Path @@ -26,11 +14,6 @@ def filter_weight_files_by_layer_range( is_first_shard: bool, is_last_shard: bool, ) -> List[str]: - """Filter weight files to only those containing layers in the specified range. - - Supports both safetensors (.safetensors) and PyTorch (.bin/.pt) formats. - """ - # Try safetensors index first index_file = model_path / "model.safetensors.index.json" if not index_file.exists(): @@ -53,12 +36,10 @@ def filter_weight_files_by_layer_range( if is_first_shard and "embed_tokens" in key: should_include = True - logger.debug(f"Including {filename} for embedding layer: {key}") if is_last_shard: if "model.norm" in key or "lm_head" in key: should_include = True - logger.debug(f"Including {filename} for lm_head/norm: {key}") if "layers." in key: try: @@ -68,10 +49,9 @@ def filter_weight_files_by_layer_range( layer_idx = int(parts[i + 1]) if pp_start_layer <= layer_idx < pp_end_layer: should_include = True - logger.debug(f"Including {filename} for layer {layer_idx}: {key}") break except (ValueError, IndexError): - logger.debug(f"Could not parse layer number from {key}, including to be safe") + # Could not parse layer number, include to be safe should_include = True if should_include: @@ -80,11 +60,6 @@ def filter_weight_files_by_layer_range( if needed_files: filtered_files = [wf for wf in weight_files if wf in needed_files] - logger.info( - f"Filtered weight files from {len(weight_files)} to {len(filtered_files)} " - f"for layers [{pp_start_layer}, {pp_end_layer})" - ) - logger.debug(f"Needed files: {[Path(f).name for f in filtered_files]}") return filtered_files else: logger.warning( @@ -106,60 +81,56 @@ def set_layer_range_for_filtering(pp_start_layer: int, pp_end_layer: int, is_las _layer_range_cache["pp_start_layer"] = pp_start_layer _layer_range_cache["pp_end_layer"] = pp_end_layer _layer_range_cache["is_last_shard"] = is_last_shard - logger.info( - f"Set layer range for weight filtering: [{pp_start_layer}, {pp_end_layer}), " - f"is_last={is_last_shard}" + + +def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: + global _layer_range_cache + + pp_start_layer = _layer_range_cache.get("pp_start_layer") + pp_end_layer = _layer_range_cache.get("pp_end_layer") + is_last_shard = _layer_range_cache.get("is_last_shard", False) + + if pp_start_layer is None or pp_end_layer is None: + logger.debug("No layer range set, loading all weight files") + return hf_weights_files + + if not hf_weights_files: + return hf_weights_files + + model_path = Path(hf_weights_files[0]).parent + is_first_shard = pp_start_layer == 0 + + filtered_files = filter_weight_files_by_layer_range( + model_path=model_path, + weight_files=hf_weights_files, + pp_start_layer=pp_start_layer, + pp_end_layer=pp_end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, ) + return filtered_files + def apply_weight_loader_filter_patch(): - from sglang.srt.model_loader import weight_utils - - original_safetensors_iterator = weight_utils.safetensors_weights_iterator - - def patched_safetensors_weights_iterator( - hf_weights_files: List[str], - is_all_weights_sharded: bool = False, - decryption_key: Optional[str] = None, - disable_mmap: bool = False, - ): - filtered_files = _filter_weight_files_by_cache(hf_weights_files) - return original_safetensors_iterator( - filtered_files, is_all_weights_sharded, decryption_key, disable_mmap - ) - - def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: - global _layer_range_cache - - pp_start_layer = _layer_range_cache.get("pp_start_layer") - pp_end_layer = _layer_range_cache.get("pp_end_layer") - is_last_shard = _layer_range_cache.get("is_last_shard", False) - - if pp_start_layer is None or pp_end_layer is None: - logger.debug("No layer range set, loading all weight files") - return hf_weights_files - - if not hf_weights_files: - return hf_weights_files - - model_path = Path(hf_weights_files[0]).parent - is_first_shard = pp_start_layer == 0 - - logger.info( - f"Filtering weight files for layers [{pp_start_layer}, {pp_end_layer}), " - f"is_first={is_first_shard}, is_last={is_last_shard}" - ) - - filtered_files = filter_weight_files_by_layer_range( - model_path=model_path, - weight_files=hf_weights_files, - pp_start_layer=pp_start_layer, - pp_end_layer=pp_end_layer, - is_first_shard=is_first_shard, - is_last_shard=is_last_shard, - ) - - return filtered_files - - weight_utils.safetensors_weights_iterator = patched_safetensors_weights_iterator - logger.debug("Applied weight loader filter patch to safetensors and pt weight iterators") + import glob as glob_module + + original_glob = glob_module.glob + + def patched_glob(pathname, **kwargs): + files = original_glob(pathname, **kwargs) + if ( + isinstance(files, list) + and files + and any(f.endswith((".safetensors", ".bin", ".pt")) for f in files) + ): + + # Filter if we have layer range set + global _layer_range_cache + if _layer_range_cache.get("pp_start_layer") is not None: + filtered = _filter_weight_files_by_cache(files) + return filtered + + return files + + glob_module.glob = patched_glob From aec66ff0de1144dd947fffa5f501683e62d3e514 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 3 Nov 2025 14:02:56 +0000 Subject: [PATCH 09/29] update --- src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index c4d36cfa..baefb9a9 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -32,6 +32,8 @@ def filter_weight_files_by_layer_range( needed_files: Set[str] = set() for key, filename in weight_map.items(): + if filename in needed_files: + continue should_include = False if is_first_shard and "embed_tokens" in key: From fb4447f33288b8e40d209f9dade061fb7a10fb4d Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 09:47:42 +0800 Subject: [PATCH 10/29] update --- src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index baefb9a9..5cd4e6e2 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -1,7 +1,7 @@ import json import logging from pathlib import Path -from typing import List, Optional, Set +from typing import List, Set logger = logging.getLogger(__name__) From 131f59b71b5c1ab766f7e42a9b8df32c6a83a961 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 09:54:02 +0800 Subject: [PATCH 11/29] update --- src/parallax/server/shard_loader.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 63217748..e0c12fa5 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -204,17 +204,11 @@ def load( shard_weights = {} layer_key_prefix = "model.layers" # Common prefix - logger.debug( - f"Loading shard layers [{current_start_layer}, {current_end_layer}) from {len(weight_files)} weight files" - ) - loaded_keys_count = 0 - for file_idx, wf in enumerate(weight_files): logger.debug( f"Scanning weight file {file_idx + 1}/{len(weight_files)}: {pathlib.Path(wf).name}" ) - file_loaded_count = 0 - # For bf16 models, we need torch tensors as a bridge + with safetensors.safe_open(wf, framework="pt") as f: for key in f.keys(): is_needed = False @@ -261,13 +255,6 @@ def load( # If the key is needed, load only that tensor from the file if is_needed: shard_weights[remapped_key] = mx.array(f.get_tensor(key)) - loaded_keys_count += 1 - file_loaded_count += 1 - - if file_loaded_count > 0: - logger.debug(f" Loaded {file_loaded_count} tensors from {pathlib.Path(wf).name}") - else: - logger.debug(f" Skipped {pathlib.Path(wf).name} (no relevant layers)") if (quantization := config.get("quantization", None)) is not None: logger.info("Model is quantized. Applying quantization parameters...") @@ -287,10 +274,6 @@ def class_predicate(p, m): prefixed = f"model.{p}" if prefixed in qcfg: override = qcfg[prefixed] - # if isinstance(override, dict): - # logger.debug( - # f"[quantize] Using override for '{prefixed}' (mapped to '{p}'): bits={override.get('bits')} group_size={override.get('group_size')}" - # ) return override if not hasattr(m, "to_quantized"): return False @@ -305,10 +288,6 @@ def class_predicate(p, m): class_predicate=class_predicate, ) - logger.debug( - f"Loaded {loaded_keys_count} weight tensors for shard layers [{current_start_layer}, {current_end_layer})" - ) - model_shard.load_weights(list(shard_weights.items()), strict=strict) if not lazy: From 614facbbc96b2059ef246c0816d982263657116c Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 14:14:25 +0800 Subject: [PATCH 12/29] update --- src/parallax/server/shard_loader.py | 72 ++++---- src/parallax/sglang/model_runner.py | 14 +- .../weight_loader_filter.py | 73 ++------ src/parallax/utils/selective_download.py | 103 ++++++++++++ src/parallax/utils/weight_filter_utils.py | 157 ++++++++++++++++++ 5 files changed, 317 insertions(+), 102 deletions(-) create mode 100644 src/parallax/utils/selective_download.py create mode 100644 src/parallax/utils/weight_filter_utils.py diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index e0c12fa5..6edc125b 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -85,7 +85,7 @@ def register_block_class(self): logger.warning(f"Failed to load model from {model_file}: {e}") def load( - self, lazy: bool = False, strict: bool = True + self, lazy: bool = False, strict: bool = True, use_selective_download: bool = True ) -> Tuple[nn.Module, Dict[str, Any], Any]: """ Loads the specified model shard by loading only the necessary weights @@ -96,10 +96,25 @@ def load( into memory. Defaults to False. strict (bool): If True, raises an exception if weights do not match. Defaults to True. + use_selective_download (bool): If True, only download necessary weight files + from Hugging Face. Defaults to True. Returns: A tuple containing the loaded sharded MLX model and its configuration dictionary. """ - model_path = get_model_path(self.model_path_str)[0] + if use_selective_download and self.start_layer is not None and self.end_layer is not None: + from parallax.utils.selective_download import get_model_path_with_selective_download + + logger.info( + f"Using selective download for layers [{self.start_layer}, {self.end_layer})" + ) + model_path = get_model_path_with_selective_download( + self.model_path_str, + start_layer=self.start_layer, + end_layer=self.end_layer, + ) + else: + model_path = get_model_path(self.model_path_str)[0] + config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) @@ -157,44 +172,21 @@ def load( if not weight_files: weight_files = glob.glob(str(model_path / "weight*.safetensors")) - index_file = model_path / "model.safetensors.index.json" - if index_file.exists(): - import json - - with open(index_file, "r") as f: - index_data = json.load(f) - - weight_map = index_data.get("weight_map", {}) - needed_files = set() - - for key, filename in weight_map.items(): - if filename in needed_files: - continue - should_include = False - if model_shard.is_first_shard and "embed_tokens" in key: - should_include = True - elif model_shard.is_last_shard: - if "model.norm" in key or "lm_head" in key: - should_include = True - elif config.get("tie_word_embeddings", False) and "embed_tokens" in key: - should_include = True - - if "model.layers." in key: - parts = key.split(".") - layer_idx = int(parts[2]) - if current_start_layer <= layer_idx < current_end_layer: - should_include = True - - if should_include: - full_path = str(model_path / filename) - needed_files.add(full_path) - if needed_files: - weight_files = [wf for wf in weight_files if wf in needed_files] - else: - logger.warning("No relevant weight files found in index, will scan all files") + # Sort weight files by name for consistent loading order + weight_files = sorted(weight_files) - else: - logger.debug("No index file found, will scan all weight files") + # Use shared utility to filter weight files + from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range + + weight_files = filter_weight_files_by_layer_range( + model_path=model_path, + weight_files=weight_files, + start_layer=current_start_layer, + end_layer=current_end_layer, + is_first_shard=model_shard.is_first_shard, + is_last_shard=model_shard.is_last_shard, + config=config, + ) if not weight_files and strict: raise FileNotFoundError(f"No safetensors found in {model_path}") @@ -257,7 +249,7 @@ def load( shard_weights[remapped_key] = mx.array(f.get_tensor(key)) if (quantization := config.get("quantization", None)) is not None: - logger.info("Model is quantized. Applying quantization parameters...") + logger.debug("Model is quantized. Applying quantization parameters...") def class_predicate(p, m): # Handle custom per-layer quantizations from the config diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index d629d01c..775a9e34 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -240,7 +240,19 @@ def initialize_sgl_model_runner( - tokenizer: tokenizer driven by mlx-lm """ apply_parallax_sglang_monkey_patch() - model_path = get_model_path(original_model_path)[0] + + # Use selective download for GPU models to save bandwidth and disk space + from parallax.utils.selective_download import get_model_path_with_selective_download + + logger.info( + f"Downloading model with selective weight files for layers [{start_layer}, {end_layer})" + ) + model_path = get_model_path_with_selective_download( + original_model_path, + start_layer=start_layer, + end_layer=end_layer, + ) + config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 5cd4e6e2..45982ad8 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -1,7 +1,8 @@ -import json import logging from pathlib import Path -from typing import List, Set +from typing import List + +from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range as shared_filter logger = logging.getLogger(__name__) @@ -14,65 +15,15 @@ def filter_weight_files_by_layer_range( is_first_shard: bool, is_last_shard: bool, ) -> List[str]: - index_file = model_path / "model.safetensors.index.json" - - if not index_file.exists(): - logger.debug(f"No index file found at {model_path}, will load all weight files") - return weight_files - - try: - with open(index_file, "r") as f: - index_data = json.load(f) - - weight_map = index_data.get("weight_map", {}) - if not weight_map: - logger.warning("weight_map is empty in index file, will load all weight files") - return weight_files - - needed_files: Set[str] = set() - - for key, filename in weight_map.items(): - if filename in needed_files: - continue - should_include = False - - if is_first_shard and "embed_tokens" in key: - should_include = True - - if is_last_shard: - if "model.norm" in key or "lm_head" in key: - should_include = True - - if "layers." in key: - try: - parts = key.split(".") - for i, part in enumerate(parts): - if part == "layers" and i + 1 < len(parts): - layer_idx = int(parts[i + 1]) - if pp_start_layer <= layer_idx < pp_end_layer: - should_include = True - break - except (ValueError, IndexError): - # Could not parse layer number, include to be safe - should_include = True - - if should_include: - full_path = str(model_path / filename) - needed_files.add(full_path) - - if needed_files: - filtered_files = [wf for wf in weight_files if wf in needed_files] - return filtered_files - else: - logger.warning( - f"No relevant weight files found in index for layers [{pp_start_layer}, {pp_end_layer}), " - "will load all files" - ) - return weight_files - - except Exception as e: - logger.warning(f"Failed to filter weight files using index file: {e}, will load all files") - return weight_files + return shared_filter( + model_path=model_path, + weight_files=weight_files, + start_layer=pp_start_layer, + end_layer=pp_end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, + config={}, + ) _layer_range_cache = {} diff --git a/src/parallax/utils/selective_download.py b/src/parallax/utils/selective_download.py new file mode 100644 index 00000000..724f95e9 --- /dev/null +++ b/src/parallax/utils/selective_download.py @@ -0,0 +1,103 @@ +import logging +from pathlib import Path +from typing import Optional + +from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub.utils import HfHubHTTPError + +logger = logging.getLogger(__name__) + + +def determine_needed_weight_files(model_path: Path, start_layer: int, end_layer: int): + from parallax.utils.weight_filter_utils import ( + determine_needed_weight_files as determine_files, + ) + + return determine_files(model_path, start_layer, end_layer) + + +def selective_model_download( + repo_id: str, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + cache_dir: Optional[str] = None, + force_download: bool = False, +) -> Path: + logger.debug(f"Downloading model metadata for {repo_id}") + + ignore_patterns = [ + "*.safetensors", + "*.bin", + "*.pt", + "*.pth", + "pytorch_model*.bin", + "model*.safetensors", + ] + + model_path = snapshot_download( + repo_id=repo_id, + cache_dir=cache_dir, + ignore_patterns=ignore_patterns, + force_download=force_download, + ) + model_path = Path(model_path) + logger.debug(f"Downloaded model metadata to {model_path}") + + if start_layer is not None and end_layer is not None: + logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})") + + needed_weight_files = determine_needed_weight_files( + model_path=model_path, + start_layer=start_layer, + end_layer=end_layer, + ) + + if not needed_weight_files: + logger.debug("Could not determine specific weight files, downloading all") + snapshot_download( + repo_id=repo_id, + cache_dir=cache_dir, + force_download=force_download, + ) + else: + # Step 3: Download only the needed weight files + logger.info(f"Downloading {len(needed_weight_files)} weight files") + + for weight_file in needed_weight_files: + logger.debug(f"Downloading {weight_file}") + hf_hub_download( + repo_id=repo_id, + filename=weight_file, + cache_dir=cache_dir, + force_download=force_download, + ) + + logger.debug(f"Downloaded weight files for layers [{start_layer}, {end_layer})") + else: + logger.debug("No layer range specified, downloading all model files") + snapshot_download( + repo_id=repo_id, + cache_dir=cache_dir, + force_download=force_download, + ) + + return model_path + + +def get_model_path_with_selective_download( + model_path_or_repo: str, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, +) -> Path: + path = Path(model_path_or_repo) + + if path.exists(): + logger.debug(f"Using local model path: {path}") + return path + + logger.debug(f"Treating as Hugging Face repo: {model_path_or_repo}") + return selective_model_download( + repo_id=model_path_or_repo, + start_layer=start_layer, + end_layer=end_layer, + ) diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py new file mode 100644 index 00000000..5a5a0f37 --- /dev/null +++ b/src/parallax/utils/weight_filter_utils.py @@ -0,0 +1,157 @@ +import json +import logging +from pathlib import Path +from typing import Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + + +def should_include_weight_key( + key: str, + start_layer: int, + end_layer: int, + is_first_shard: bool, + is_last_shard: bool, + tie_word_embeddings: bool = False, +) -> bool: + if is_first_shard and "embed_tokens" in key and key.startswith("model."): + return True + + if is_last_shard: + if "model.norm" in key or "lm_head" in key: + return True + if tie_word_embeddings and "embed" in key and key.startswith("model.embed_tokens"): + return True + + if "layers." in key: + parts = key.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts): + layer_idx = int(parts[i + 1]) + return start_layer <= layer_idx < end_layer + + return False + + +def filter_weight_files_by_layer_range( + model_path: Path, + weight_files: List[str], + start_layer: int, + end_layer: int, + is_first_shard: bool, + is_last_shard: bool, + config: Optional[Dict] = None, +) -> List[str]: + index_file = model_path / "model.safetensors.index.json" + + if not index_file.exists(): + logger.debug(f"No index file found at {index_file}, cannot filter weight files") + return weight_files + + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + if not weight_map: + logger.debug("weight_map is empty in index file") + return weight_files + + tie_word_embeddings = False + if config: + tie_word_embeddings = config.get("tie_word_embeddings", False) + else: + config_file = model_path / "config.json" + if config_file.exists(): + with open(config_file, "r") as f: + cfg = json.load(f) + tie_word_embeddings = cfg.get("tie_word_embeddings", False) + + needed_files: Set[str] = set() + + for key, filename in weight_map.items(): + if should_include_weight_key( + key=key, + start_layer=start_layer, + end_layer=end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, + tie_word_embeddings=tie_word_embeddings, + ): + needed_files.add(filename) + + if not needed_files: + logger.debug( + f"No relevant weight files found in index for layers [{start_layer}, {end_layer})" + ) + return weight_files + + filtered_files = [] + for wf in weight_files: + wf_name = Path(wf).name + if wf_name in needed_files: + filtered_files.append(wf) + + logger.debug( + f"Filtered weight files from {len(weight_files)} to {len(filtered_files)} " + f"for layers [{start_layer}, {end_layer})" + ) + + return filtered_files + + +def determine_needed_weight_files( + model_path: Path, + start_layer: int, + end_layer: int, + config: Optional[Dict] = None, +) -> List[str]: + is_first_shard = start_layer == 0 + + is_last_shard = False + if config: + num_hidden_layers = config.get("num_hidden_layers", 0) + is_last_shard = end_layer >= num_hidden_layers + else: + config_file = model_path / "config.json" + if config_file.exists(): + with open(config_file, "r") as f: + cfg = json.load(f) + num_hidden_layers = cfg.get("num_hidden_layers", 0) + is_last_shard = end_layer >= num_hidden_layers + + index_file = model_path / "model.safetensors.index.json" + + if not index_file.exists(): + logger.debug(f"Index file not found at {index_file}") + return [] + + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + if not weight_map: + logger.debug("weight_map is empty in index file") + return [] + + tie_word_embeddings = False + if config: + tie_word_embeddings = config.get("tie_word_embeddings", False) + + needed_files: Set[str] = set() + + for key, filename in weight_map.items(): + if should_include_weight_key( + key=key, + start_layer=start_layer, + end_layer=end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, + tie_word_embeddings=tie_word_embeddings, + ): + needed_files.add(filename) + + result = sorted(list(needed_files)) + logger.debug( + f"Determined {len(result)} weight files needed for layers [{start_layer}, {end_layer})" + ) + return result From 2521a39b064f8935e98e90d86a9cca3afc79ccc1 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 14:48:13 +0800 Subject: [PATCH 13/29] update --- src/parallax/launch.py | 6 +++--- src/parallax/p2p/server.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index d12c8948..611ecf04 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -84,9 +84,9 @@ initial_peers=args.initial_peers, scheduler_addr=args.scheduler_addr, relay_servers=args.relay_servers, - pp_start_layer=None, - pp_end_layer=None, - hidden_layers=None, + pp_start_layer=args.start_layer, + pp_end_layer=args.end_layer, + hidden_layers=executor.config.get("num_hidden_layers"), tcp_port=args.tcp_port, udp_port=args.udp_port, dht_prefix=args.dht_prefix, diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index 073d7c6d..dc5e5ade 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -248,6 +248,7 @@ def __init__( self.announcer = None self.connection_handler = None self.stop_event = threading.Event() + logger.debug(f"manual_layer_assignment: {self.manual_layer_assignment}") def build_lattica(self): self.lattica = Lattica.builder().with_listen_addrs(self.host_maddrs) From 9472cc1db9f140c3ab40c1deecb1553c640941ab Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 14:50:05 +0800 Subject: [PATCH 14/29] update --- src/parallax/launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 611ecf04..e351549b 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -86,7 +86,7 @@ relay_servers=args.relay_servers, pp_start_layer=args.start_layer, pp_end_layer=args.end_layer, - hidden_layers=executor.config.get("num_hidden_layers"), + hidden_layers=None, tcp_port=args.tcp_port, udp_port=args.udp_port, dht_prefix=args.dht_prefix, From 3c36dc188e14383688b19b7131304d94ba919180 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 15:14:35 +0800 Subject: [PATCH 15/29] update --- src/parallax/p2p/server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index dc5e5ade..34e8ded7 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -751,7 +751,10 @@ def launch_p2p_server( thread = threading.Thread(target=server.run, daemon=True) thread.start() - while server.block_start_index is None: + # Wait for layer allocation and model_name to be set + while server.block_start_index is None or ( + scheduler_addr is not None and server.model_name is None + ): time.sleep(1) return server From be8828872227f50b7192351bd0ce6b5984ae1d83 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 15:41:25 +0800 Subject: [PATCH 16/29] pre-commit --- src/parallax/server/shard_loader.py | 8 ++++++-- src/parallax/sglang/model_runner.py | 2 +- .../sglang/monkey_patch_utils/weight_loader_filter.py | 4 +++- src/parallax/utils/selective_download.py | 1 - 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 6edc125b..e53a308f 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -102,7 +102,9 @@ def load( A tuple containing the loaded sharded MLX model and its configuration dictionary. """ if use_selective_download and self.start_layer is not None and self.end_layer is not None: - from parallax.utils.selective_download import get_model_path_with_selective_download + from parallax.utils.selective_download import ( + get_model_path_with_selective_download, + ) logger.info( f"Using selective download for layers [{self.start_layer}, {self.end_layer})" @@ -176,7 +178,9 @@ def load( weight_files = sorted(weight_files) # Use shared utility to filter weight files - from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range + from parallax.utils.weight_filter_utils import ( + filter_weight_files_by_layer_range, + ) weight_files = filter_weight_files_by_layer_range( model_path=model_path, diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 775a9e34..8011e0f6 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -11,7 +11,7 @@ import sglang import sglang.srt.distributed.parallel_state import torch -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import load_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import ( get_tp_group, diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 45982ad8..cc0cb2e8 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -2,7 +2,9 @@ from pathlib import Path from typing import List -from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range as shared_filter +from parallax.utils.weight_filter_utils import ( + filter_weight_files_by_layer_range as shared_filter, +) logger = logging.getLogger(__name__) diff --git a/src/parallax/utils/selective_download.py b/src/parallax/utils/selective_download.py index 724f95e9..1686646d 100644 --- a/src/parallax/utils/selective_download.py +++ b/src/parallax/utils/selective_download.py @@ -3,7 +3,6 @@ from typing import Optional from huggingface_hub import hf_hub_download, snapshot_download -from huggingface_hub.utils import HfHubHTTPError logger = logging.getLogger(__name__) From cc9880aac5108ae11b52e1c0f6f76cd1afd98902 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 20:08:50 +0800 Subject: [PATCH 17/29] update --- src/scheduling/scheduler.py | 46 ++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index ef120a32..30c70763 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -302,14 +302,44 @@ def leave(self, node_id: str) -> None: self.layer_allocator.leave(node_id) if self.layer_allocator.should_global_rebalance(): logger.debug("Global rebalance triggered due to node leave") - # TODO: send a signal to the nodes to stop running requests - # and re-assign start/end layers so nodes can re-shard - self._bootstrapped = False - self._bootstrapped_event.clear() - for n in self.nodes: - if n.start_layer is not None and n.end_layer is not None: - self.layer_allocator.deallocate(n) - self.layer_allocator.global_allocation() + + # Check if all remaining nodes are manual + all_manual = all(n.manual_layer_assignment for n in self.nodes) + if all_manual: + logger.debug("All nodes are manual assignment, skipping global rebalance") + else: + # TODO: send a signal to the nodes to stop running requests + # and re-assign start/end layers so nodes can re-shard + self._bootstrapped = False + self._bootstrapped_event.clear() + + # Separate manual and automatic nodes + manual_nodes = [] + for n in self.nodes: + if n.manual_layer_assignment: + logger.debug( + f"Preserving manual node {n.node_id}: [{n.start_layer}, {n.end_layer})" + ) + manual_nodes.append(n) + elif n.start_layer is not None and n.end_layer is not None: + self.layer_allocator.deallocate(n) + + # Temporarily remove manual nodes from allocator + for n in manual_nodes: + if n in self.layer_allocator.nodes: + self.layer_allocator.nodes.remove(n) + + # Reallocate only automatic nodes + self.layer_allocator.global_allocation() + + # Add manual nodes back to allocator + for n in manual_nodes: + if n not in self.layer_allocator.nodes: + self.layer_allocator.nodes.append(n) + # Re-sort after adding back manual nodes + self.layer_allocator.nodes.sort( + key=lambda node: node.get_decoder_layer_capacity(), reverse=True + ) with self._node_count_cv: self._node_count_cv.notify_all() From ea0a3a9c6067cc494f54c38edf004acfb3cc2d75 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 20:33:58 +0800 Subject: [PATCH 18/29] update --- src/parallax/utils/selective_download.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/parallax/utils/selective_download.py b/src/parallax/utils/selective_download.py index 1686646d..0ff5578b 100644 --- a/src/parallax/utils/selective_download.py +++ b/src/parallax/utils/selective_download.py @@ -88,13 +88,6 @@ def get_model_path_with_selective_download( start_layer: Optional[int] = None, end_layer: Optional[int] = None, ) -> Path: - path = Path(model_path_or_repo) - - if path.exists(): - logger.debug(f"Using local model path: {path}") - return path - - logger.debug(f"Treating as Hugging Face repo: {model_path_or_repo}") return selective_model_download( repo_id=model_path_or_repo, start_layer=start_layer, From 1f332870fa944448c69313e00623265d9a57cac7 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 20:40:03 +0800 Subject: [PATCH 19/29] update --- src/scheduling/scheduler.py | 42 ++++++++++++------------------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index 30c70763..14e9f9da 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -303,43 +303,27 @@ def leave(self, node_id: str) -> None: if self.layer_allocator.should_global_rebalance(): logger.debug("Global rebalance triggered due to node leave") - # Check if all remaining nodes are manual - all_manual = all(n.manual_layer_assignment for n in self.nodes) - if all_manual: + # Count manual vs automatic nodes + manual_count = sum(1 for n in self.nodes if n.manual_layer_assignment) + total_count = len(self.nodes) + logger.debug( + f"Node count: {manual_count} manual, {total_count - manual_count} automatic" + ) + if manual_count == total_count: logger.debug("All nodes are manual assignment, skipping global rebalance") + elif manual_count > 0: + logger.error( + f"Mixed assignment detected ({manual_count} manual, {total_count - manual_count} automatic); skipping rebalance" + ) else: - # TODO: send a signal to the nodes to stop running requests - # and re-assign start/end layers so nodes can re-shard + # All nodes are automatic, proceed with rebalance self._bootstrapped = False self._bootstrapped_event.clear() - - # Separate manual and automatic nodes - manual_nodes = [] for n in self.nodes: - if n.manual_layer_assignment: - logger.debug( - f"Preserving manual node {n.node_id}: [{n.start_layer}, {n.end_layer})" - ) - manual_nodes.append(n) - elif n.start_layer is not None and n.end_layer is not None: + if n.start_layer is not None and n.end_layer is not None: self.layer_allocator.deallocate(n) - - # Temporarily remove manual nodes from allocator - for n in manual_nodes: - if n in self.layer_allocator.nodes: - self.layer_allocator.nodes.remove(n) - - # Reallocate only automatic nodes self.layer_allocator.global_allocation() - # Add manual nodes back to allocator - for n in manual_nodes: - if n not in self.layer_allocator.nodes: - self.layer_allocator.nodes.append(n) - # Re-sort after adding back manual nodes - self.layer_allocator.nodes.sort( - key=lambda node: node.get_decoder_layer_capacity(), reverse=True - ) with self._node_count_cv: self._node_count_cv.notify_all() From 22edf792640c4eb828535d52721650ac98c105cc Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 20:53:31 +0800 Subject: [PATCH 20/29] update --- src/parallax/sglang/model_runner.py | 12 ++++-------- .../monkey_patch_utils/weight_loader_filter.py | 7 ++++--- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 8011e0f6..f8e3c316 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -36,6 +36,9 @@ from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch from parallax.utils.tokenizer_utils import load_tokenizer +from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( + set_layer_range_for_filtering, +) logger = logging.getLogger(__name__) @@ -67,15 +70,8 @@ def __init__( """Add pp_start_layer and pp_end_layer for decentralized model""" self.pp_start_layer = pp_start_layer self.pp_end_layer = pp_end_layer - - # Set layer range for weight file filtering before model loading - from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( - set_layer_range_for_filtering, - ) - num_hidden_layers = model_config.hf_config.num_hidden_layers - is_last_shard = pp_end_layer >= num_hidden_layers - set_layer_range_for_filtering(pp_start_layer, pp_end_layer, is_last_shard) + set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers) super().__init__( model_config=model_config, diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index cc0cb2e8..f2333736 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -31,11 +31,11 @@ def filter_weight_files_by_layer_range( _layer_range_cache = {} -def set_layer_range_for_filtering(pp_start_layer: int, pp_end_layer: int, is_last_shard: bool): +def set_layer_range_for_filtering(pp_start_layer: int, pp_end_layer: int, num_hidden_layers: int): global _layer_range_cache _layer_range_cache["pp_start_layer"] = pp_start_layer _layer_range_cache["pp_end_layer"] = pp_end_layer - _layer_range_cache["is_last_shard"] = is_last_shard + _layer_range_cache["num_hidden_layers"] = num_hidden_layers def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: @@ -43,7 +43,7 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: pp_start_layer = _layer_range_cache.get("pp_start_layer") pp_end_layer = _layer_range_cache.get("pp_end_layer") - is_last_shard = _layer_range_cache.get("is_last_shard", False) + num_hidden_layers = _layer_range_cache.get("num_hidden_layers") if pp_start_layer is None or pp_end_layer is None: logger.debug("No layer range set, loading all weight files") @@ -54,6 +54,7 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: model_path = Path(hf_weights_files[0]).parent is_first_shard = pp_start_layer == 0 + is_last_shard = pp_end_layer >= num_hidden_layers filtered_files = filter_weight_files_by_layer_range( model_path=model_path, From ced07a425b66e7f9a0e19a4eb91132d26320b595 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 21:06:45 +0800 Subject: [PATCH 21/29] update --- .../weight_loader_filter.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index f2333736..5775e691 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -54,7 +54,13 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: model_path = Path(hf_weights_files[0]).parent is_first_shard = pp_start_layer == 0 - is_last_shard = pp_end_layer >= num_hidden_layers + is_last_shard = num_hidden_layers is not None and pp_end_layer >= num_hidden_layers + + logger.debug( + f"Filtering weight files: start_layer={pp_start_layer}, end_layer={pp_end_layer}, " + f"is_first_shard={is_first_shard}, is_last_shard={is_last_shard}, " + f"input_files={len(hf_weights_files)}" + ) filtered_files = filter_weight_files_by_layer_range( model_path=model_path, @@ -65,6 +71,9 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: is_last_shard=is_last_shard, ) + logger.debug( + f"Filtered to {len(filtered_files)} files: {[Path(f).name for f in filtered_files]}" + ) return filtered_files @@ -75,17 +84,27 @@ def apply_weight_loader_filter_patch(): def patched_glob(pathname, **kwargs): files = original_glob(pathname, **kwargs) + logger.debug( + f"patched_glob called: pathname={pathname}, num_files={len(files) if isinstance(files, list) else 'N/A'}" + ) + if ( isinstance(files, list) and files and any(f.endswith((".safetensors", ".bin", ".pt")) for f in files) ): - + logger.debug(f"Found weight files, checking layer range cache...") # Filter if we have layer range set global _layer_range_cache if _layer_range_cache.get("pp_start_layer") is not None: + logger.debug( + f"Layer range set: start={_layer_range_cache.get('pp_start_layer')}, end={_layer_range_cache.get('pp_end_layer')}" + ) filtered = _filter_weight_files_by_cache(files) + logger.debug(f"Filtered from {len(files)} to {len(filtered)} weight files") return filtered + else: + logger.debug("Layer range not set, loading all weight files") return files From 08c6643bdecc91611f8c2ee3463b36398ea3fe01 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 21:11:16 +0800 Subject: [PATCH 22/29] update --- .../weight_loader_filter.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 5775e691..edc3b688 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -109,3 +109,94 @@ def patched_glob(pathname, **kwargs): return files glob_module.glob = patched_glob + + # Patch os.listdir + import os + + original_listdir = os.listdir + + def patched_listdir(path): + files = original_listdir(path) + logger.debug(f"patched_listdir called: path={path}, num_files={len(files)}") + + # Convert to full paths for filtering + if any(f.endswith((".safetensors", ".bin", ".pt")) for f in files): + logger.debug(f"Found weight files in listdir") + global _layer_range_cache + if _layer_range_cache.get("pp_start_layer") is not None: + full_paths = [ + os.path.join(path, f) + for f in files + if f.endswith((".safetensors", ".bin", ".pt")) + ] + if full_paths: + filtered_paths = _filter_weight_files_by_cache(full_paths) + filtered_names = [os.path.basename(f) for f in filtered_paths] + # Keep non-weight files + result = [ + f for f in files if not f.endswith((".safetensors", ".bin", ".pt")) + ] + filtered_names + logger.debug(f"Filtered listdir from {len(files)} to {len(result)} files") + return result + + return files + + os.listdir = patched_listdir + + # Patch Path.glob + from pathlib import Path as PathlibPath + + original_path_glob = PathlibPath.glob + + def patched_path_glob(self, pattern): + files = list(original_path_glob(self, pattern)) + logger.debug(f"patched_path_glob called: pattern={pattern}, num_files={len(files)}") + + if files and any(str(f).endswith((".safetensors", ".bin", ".pt")) for f in files): + logger.debug(f"Found weight files in Path.glob") + global _layer_range_cache + if _layer_range_cache.get("pp_start_layer") is not None: + str_files = [ + str(f) for f in files if str(f).endswith((".safetensors", ".bin", ".pt")) + ] + if str_files: + filtered_strs = _filter_weight_files_by_cache(str_files) + filtered_paths = [PathlibPath(f) for f in filtered_strs] + # Keep non-weight files + result = [ + f for f in files if not str(f).endswith((".safetensors", ".bin", ".pt")) + ] + filtered_paths + logger.debug(f"Filtered Path.glob from {len(files)} to {len(result)} files") + return iter(result) + + return iter(files) + + PathlibPath.glob = patched_path_glob + + # Patch safetensors.torch.load_file to intercept actual file loading + try: + import safetensors.torch + + original_load_file = safetensors.torch.load_file + + def patched_load_file(filename, *args, **kwargs): + logger.debug(f"patched_load_file called: filename={filename}") + return original_load_file(filename, *args, **kwargs) + + safetensors.torch.load_file = patched_load_file + except ImportError: + logger.debug("safetensors module not available for patching") + + # Patch json.load to intercept index file reading + import json + import builtins + + original_open = builtins.open + + def patched_open(file, mode="r", *args, **kwargs): + result = original_open(file, mode, *args, **kwargs) + if isinstance(file, str) and file.endswith(".index.json"): + logger.debug(f"patched_open called for index file: {file}") + return result + + builtins.open = patched_open From 441d1032712f4f0d2d98aa41e948a567c1a8da01 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 21:15:59 +0800 Subject: [PATCH 23/29] update --- .../weight_loader_filter.py | 56 +++++++++++++++++-- 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index edc3b688..8819b83c 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -187,16 +187,60 @@ def patched_load_file(filename, *args, **kwargs): except ImportError: logger.debug("safetensors module not available for patching") - # Patch json.load to intercept index file reading + # Patch json.load to intercept and modify index file content import json import builtins original_open = builtins.open + original_json_load = json.load + + def patched_json_load(fp, *args, **kwargs): + result = original_json_load(fp, *args, **kwargs) + + # Check if this is a safetensors index file + if isinstance(result, dict) and "weight_map" in result: + logger.debug( + f"Intercepted safetensors index file with {len(result.get('weight_map', {}))} weight mappings" + ) + + global _layer_range_cache + if _layer_range_cache.get("pp_start_layer") is not None: + # Get all weight files from the index + weight_map = result.get("weight_map", {}) + all_files = list(set(weight_map.values())) + + logger.debug(f"Index contains {len(all_files)} unique weight files") + + # We need to get the model path from somewhere + # Try to infer it from the file pointer + try: + file_path = fp.name + model_path = Path(file_path).parent + + # Build full paths + full_paths = [str(model_path / f) for f in all_files] + + # Filter files + filtered_paths = _filter_weight_files_by_cache(full_paths) + filtered_files = [Path(f).name for f in filtered_paths] + + logger.debug( + f"Filtered index from {len(all_files)} to {len(filtered_files)} files: {filtered_files}" + ) + + # Rebuild weight_map with only filtered files + new_weight_map = { + key: value for key, value in weight_map.items() if value in filtered_files + } + + result["weight_map"] = new_weight_map + logger.debug( + f"Modified weight_map from {len(weight_map)} to {len(new_weight_map)} entries" + ) + + except Exception as e: + logger.warning(f"Failed to filter index file: {e}") - def patched_open(file, mode="r", *args, **kwargs): - result = original_open(file, mode, *args, **kwargs) - if isinstance(file, str) and file.endswith(".index.json"): - logger.debug(f"patched_open called for index file: {file}") return result - builtins.open = patched_open + json.load = patched_json_load From 52327867e35a983fa26aa0a13ebe50d4d3da2cd4 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 21:19:44 +0800 Subject: [PATCH 24/29] update --- .../weight_loader_filter.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 8819b83c..71f16949 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -143,17 +143,20 @@ def patched_listdir(path): os.listdir = patched_listdir - # Patch Path.glob + # Patch Path.glob and Path.rglob from pathlib import Path as PathlibPath original_path_glob = PathlibPath.glob + original_path_rglob = PathlibPath.rglob def patched_path_glob(self, pattern): files = list(original_path_glob(self, pattern)) - logger.debug(f"patched_path_glob called: pattern={pattern}, num_files={len(files)}") + logger.debug( + f"patched_path_glob called: self={self}, pattern={pattern}, num_files={len(files)}" + ) if files and any(str(f).endswith((".safetensors", ".bin", ".pt")) for f in files): - logger.debug(f"Found weight files in Path.glob") + logger.debug(f"Found {len(files)} weight files in Path.glob, filtering...") global _layer_range_cache if _layer_range_cache.get("pp_start_layer") is not None: str_files = [ @@ -171,7 +174,33 @@ def patched_path_glob(self, pattern): return iter(files) + def patched_path_rglob(self, pattern): + files = list(original_path_rglob(self, pattern)) + logger.debug( + f"patched_path_rglob called: self={self}, pattern={pattern}, num_files={len(files)}" + ) + + if files and any(str(f).endswith((".safetensors", ".bin", ".pt")) for f in files): + logger.debug(f"Found {len(files)} weight files in Path.rglob, filtering...") + global _layer_range_cache + if _layer_range_cache.get("pp_start_layer") is not None: + str_files = [ + str(f) for f in files if str(f).endswith((".safetensors", ".bin", ".pt")) + ] + if str_files: + filtered_strs = _filter_weight_files_by_cache(str_files) + filtered_paths = [PathlibPath(f) for f in filtered_strs] + # Keep non-weight files + result = [ + f for f in files if not str(f).endswith((".safetensors", ".bin", ".pt")) + ] + filtered_paths + logger.debug(f"Filtered Path.rglob from {len(files)} to {len(result)} files") + return iter(result) + + return iter(files) + PathlibPath.glob = patched_path_glob + PathlibPath.rglob = patched_path_rglob # Patch safetensors.torch.load_file to intercept actual file loading try: From b19eeb289e029ef5eda1c1cc45d806a149a8b9cf Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 21:24:25 +0800 Subject: [PATCH 25/29] log gpu --- .../weight_loader_filter.py | 64 +++++++++++++++---- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 71f16949..5d6eb962 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -117,27 +117,35 @@ def patched_glob(pathname, **kwargs): def patched_listdir(path): files = original_listdir(path) - logger.debug(f"patched_listdir called: path={path}, num_files={len(files)}") - # Convert to full paths for filtering - if any(f.endswith((".safetensors", ".bin", ".pt")) for f in files): - logger.debug(f"Found weight files in listdir") + # Check if this directory contains weight files + weight_files = [f for f in files if f.endswith((".safetensors", ".bin", ".pt"))] + + if weight_files: + logger.debug( + f"patched_listdir: path={path}, total_files={len(files)}, weight_files={len(weight_files)}" + ) + global _layer_range_cache if _layer_range_cache.get("pp_start_layer") is not None: - full_paths = [ - os.path.join(path, f) - for f in files - if f.endswith((".safetensors", ".bin", ".pt")) - ] - if full_paths: + # Build full paths for filtering + full_paths = [os.path.join(path, f) for f in weight_files] + + try: filtered_paths = _filter_weight_files_by_cache(full_paths) filtered_names = [os.path.basename(f) for f in filtered_paths] - # Keep non-weight files + + # Keep non-weight files + filtered weight files result = [ f for f in files if not f.endswith((".safetensors", ".bin", ".pt")) ] + filtered_names - logger.debug(f"Filtered listdir from {len(files)} to {len(result)} files") + + logger.debug( + f"Filtered listdir: {len(weight_files)} → {len(filtered_names)} weight files" + ) return result + except Exception as e: + logger.warning(f"Error filtering listdir: {e}, returning all files") return files @@ -273,3 +281,35 @@ def patched_json_load(fp, *args, **kwargs): return result json.load = patched_json_load + + # Patch huggingface_hub.hf_hub_download to prevent downloading unwanted files + try: + from huggingface_hub import hf_hub_download as original_hf_hub_download + + def patched_hf_hub_download( + repo_id, filename, *args, subfolder=None, repo_type=None, **kwargs + ): + # Check if this is a weight file download + if filename and filename.endswith((".safetensors", ".bin", ".pt")): + logger.debug( + f"patched_hf_hub_download called: repo_id={repo_id}, filename={filename}" + ) + + global _layer_range_cache + if _layer_range_cache.get("pp_start_layer") is not None: + # Get model path to check if file should be filtered + # We need to check if this file is needed for our layer range + # For now, just log and let it through + logger.warning( + f"Weight file download requested: {filename} - this should have been filtered!" + ) + + return original_hf_hub_download( + repo_id, filename, *args, subfolder=subfolder, repo_type=repo_type, **kwargs + ) + + import huggingface_hub + + huggingface_hub.hf_hub_download = patched_hf_hub_download + except ImportError: + logger.debug("huggingface_hub not available for patching") From be440845e6bf8d37cd7385db859bff7993ccb9ea Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 4 Nov 2025 21:28:59 +0800 Subject: [PATCH 26/29] add log --- .../sglang/monkey_patch_utils/weight_loader_filter.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 5d6eb962..6098dd78 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -122,8 +122,10 @@ def patched_listdir(path): weight_files = [f for f in files if f.endswith((".safetensors", ".bin", ".pt"))] if weight_files: - logger.debug( - f"patched_listdir: path={path}, total_files={len(files)}, weight_files={len(weight_files)}" + logger.info( + f"patched_listdir found weight files: path={path}, " + f"total={len(files)}, weight_files={len(weight_files)}, " + f"first_files={weight_files[:5]}" ) global _layer_range_cache @@ -140,8 +142,8 @@ def patched_listdir(path): f for f in files if not f.endswith((".safetensors", ".bin", ".pt")) ] + filtered_names - logger.debug( - f"Filtered listdir: {len(weight_files)} → {len(filtered_names)} weight files" + logger.info( + f"✂️ Filtered listdir: {len(weight_files)} → {len(filtered_names)} weight files: {filtered_names}" ) return result except Exception as e: From 3bc4e940349b134cdf9a736f7139ad64e0bf31fe Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 5 Nov 2025 10:16:16 +0800 Subject: [PATCH 27/29] update model Qwen3-30B-A3B --- src/backend/server/static_config.py | 2 +- src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 5974e221..263d9e88 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -30,7 +30,7 @@ "Qwen/Qwen3-14B-FP8": "Qwen/Qwen3-14B-FP8", "Qwen/Qwen3-32B": "Qwen/Qwen3-32B", "Qwen/Qwen3-32B-FP8": "Qwen/Qwen3-32B-FP8", - "Qwen/Qwen3-30B-A3B": "Qwen/Qwen3-30B-A3B", + "Qwen/Qwen3-30B-A3B": "Qwen/Qwen3-30B-A3B-MLX-8bit", "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8": "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8", "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8": "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8", "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit", diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 6098dd78..e6b7e4c4 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -143,7 +143,7 @@ def patched_listdir(path): ] + filtered_names logger.info( - f"✂️ Filtered listdir: {len(weight_files)} → {len(filtered_names)} weight files: {filtered_names}" + f"Filtered listdir: {len(weight_files)} → {len(filtered_names)} weight files: {filtered_names}" ) return result except Exception as e: From a92ac6fc8168741abf2002d7530dea1cc6080835 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 5 Nov 2025 12:22:07 +0800 Subject: [PATCH 28/29] fix gpu bug and modify files --- src/parallax/sglang/model_runner.py | 6 +- .../weight_loader_filter.py | 229 +----------------- 2 files changed, 5 insertions(+), 230 deletions(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index f8e3c316..ff8f9f4a 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -35,10 +35,10 @@ ) from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch -from parallax.utils.tokenizer_utils import load_tokenizer from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( set_layer_range_for_filtering, ) +from parallax.utils.tokenizer_utils import load_tokenizer logger = logging.getLogger(__name__) @@ -269,7 +269,7 @@ def initialize_sgl_model_runner( kv_block_size = 1 server_args = form_sgl_server_args( - original_model_path, + str(model_path), dtype, attention_backend, kv_block_size, @@ -280,7 +280,7 @@ def initialize_sgl_model_runner( if (quantization_config := config.get("quantization_config", None)) is not None: quant_method = quantization_config.get("quant_method") model_config = ModelConfig( - model_path=original_model_path, + model_path=str(model_path), model_override_args="{}", dtype=dtype, quantization=quant_method, diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index e6b7e4c4..f2333736 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -54,13 +54,7 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: model_path = Path(hf_weights_files[0]).parent is_first_shard = pp_start_layer == 0 - is_last_shard = num_hidden_layers is not None and pp_end_layer >= num_hidden_layers - - logger.debug( - f"Filtering weight files: start_layer={pp_start_layer}, end_layer={pp_end_layer}, " - f"is_first_shard={is_first_shard}, is_last_shard={is_last_shard}, " - f"input_files={len(hf_weights_files)}" - ) + is_last_shard = pp_end_layer >= num_hidden_layers filtered_files = filter_weight_files_by_layer_range( model_path=model_path, @@ -71,9 +65,6 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: is_last_shard=is_last_shard, ) - logger.debug( - f"Filtered to {len(filtered_files)} files: {[Path(f).name for f in filtered_files]}" - ) return filtered_files @@ -84,234 +75,18 @@ def apply_weight_loader_filter_patch(): def patched_glob(pathname, **kwargs): files = original_glob(pathname, **kwargs) - logger.debug( - f"patched_glob called: pathname={pathname}, num_files={len(files) if isinstance(files, list) else 'N/A'}" - ) - if ( isinstance(files, list) and files and any(f.endswith((".safetensors", ".bin", ".pt")) for f in files) ): - logger.debug(f"Found weight files, checking layer range cache...") + # Filter if we have layer range set global _layer_range_cache if _layer_range_cache.get("pp_start_layer") is not None: - logger.debug( - f"Layer range set: start={_layer_range_cache.get('pp_start_layer')}, end={_layer_range_cache.get('pp_end_layer')}" - ) filtered = _filter_weight_files_by_cache(files) - logger.debug(f"Filtered from {len(files)} to {len(filtered)} weight files") return filtered - else: - logger.debug("Layer range not set, loading all weight files") return files glob_module.glob = patched_glob - - # Patch os.listdir - import os - - original_listdir = os.listdir - - def patched_listdir(path): - files = original_listdir(path) - - # Check if this directory contains weight files - weight_files = [f for f in files if f.endswith((".safetensors", ".bin", ".pt"))] - - if weight_files: - logger.info( - f"patched_listdir found weight files: path={path}, " - f"total={len(files)}, weight_files={len(weight_files)}, " - f"first_files={weight_files[:5]}" - ) - - global _layer_range_cache - if _layer_range_cache.get("pp_start_layer") is not None: - # Build full paths for filtering - full_paths = [os.path.join(path, f) for f in weight_files] - - try: - filtered_paths = _filter_weight_files_by_cache(full_paths) - filtered_names = [os.path.basename(f) for f in filtered_paths] - - # Keep non-weight files + filtered weight files - result = [ - f for f in files if not f.endswith((".safetensors", ".bin", ".pt")) - ] + filtered_names - - logger.info( - f"Filtered listdir: {len(weight_files)} → {len(filtered_names)} weight files: {filtered_names}" - ) - return result - except Exception as e: - logger.warning(f"Error filtering listdir: {e}, returning all files") - - return files - - os.listdir = patched_listdir - - # Patch Path.glob and Path.rglob - from pathlib import Path as PathlibPath - - original_path_glob = PathlibPath.glob - original_path_rglob = PathlibPath.rglob - - def patched_path_glob(self, pattern): - files = list(original_path_glob(self, pattern)) - logger.debug( - f"patched_path_glob called: self={self}, pattern={pattern}, num_files={len(files)}" - ) - - if files and any(str(f).endswith((".safetensors", ".bin", ".pt")) for f in files): - logger.debug(f"Found {len(files)} weight files in Path.glob, filtering...") - global _layer_range_cache - if _layer_range_cache.get("pp_start_layer") is not None: - str_files = [ - str(f) for f in files if str(f).endswith((".safetensors", ".bin", ".pt")) - ] - if str_files: - filtered_strs = _filter_weight_files_by_cache(str_files) - filtered_paths = [PathlibPath(f) for f in filtered_strs] - # Keep non-weight files - result = [ - f for f in files if not str(f).endswith((".safetensors", ".bin", ".pt")) - ] + filtered_paths - logger.debug(f"Filtered Path.glob from {len(files)} to {len(result)} files") - return iter(result) - - return iter(files) - - def patched_path_rglob(self, pattern): - files = list(original_path_rglob(self, pattern)) - logger.debug( - f"patched_path_rglob called: self={self}, pattern={pattern}, num_files={len(files)}" - ) - - if files and any(str(f).endswith((".safetensors", ".bin", ".pt")) for f in files): - logger.debug(f"Found {len(files)} weight files in Path.rglob, filtering...") - global _layer_range_cache - if _layer_range_cache.get("pp_start_layer") is not None: - str_files = [ - str(f) for f in files if str(f).endswith((".safetensors", ".bin", ".pt")) - ] - if str_files: - filtered_strs = _filter_weight_files_by_cache(str_files) - filtered_paths = [PathlibPath(f) for f in filtered_strs] - # Keep non-weight files - result = [ - f for f in files if not str(f).endswith((".safetensors", ".bin", ".pt")) - ] + filtered_paths - logger.debug(f"Filtered Path.rglob from {len(files)} to {len(result)} files") - return iter(result) - - return iter(files) - - PathlibPath.glob = patched_path_glob - PathlibPath.rglob = patched_path_rglob - - # Patch safetensors.torch.load_file to intercept actual file loading - try: - import safetensors.torch - - original_load_file = safetensors.torch.load_file - - def patched_load_file(filename, *args, **kwargs): - logger.debug(f"patched_load_file called: filename={filename}") - return original_load_file(filename, *args, **kwargs) - - safetensors.torch.load_file = patched_load_file - except ImportError: - logger.debug("safetensors module not available for patching") - - # Patch json.load to intercept and modify index file content - import json - import builtins - - original_open = builtins.open - original_json_load = json.load - - def patched_json_load(fp, *args, **kwargs): - result = original_json_load(fp, *args, **kwargs) - - # Check if this is a safetensors index file - if isinstance(result, dict) and "weight_map" in result: - logger.debug( - f"Intercepted safetensors index file with {len(result.get('weight_map', {}))} weight mappings" - ) - - global _layer_range_cache - if _layer_range_cache.get("pp_start_layer") is not None: - # Get all weight files from the index - weight_map = result.get("weight_map", {}) - all_files = list(set(weight_map.values())) - - logger.debug(f"Index contains {len(all_files)} unique weight files") - - # We need to get the model path from somewhere - # Try to infer it from the file pointer - try: - file_path = fp.name - model_path = Path(file_path).parent - - # Build full paths - full_paths = [str(model_path / f) for f in all_files] - - # Filter files - filtered_paths = _filter_weight_files_by_cache(full_paths) - filtered_files = [Path(f).name for f in filtered_paths] - - logger.debug( - f"Filtered index from {len(all_files)} to {len(filtered_files)} files: {filtered_files}" - ) - - # Rebuild weight_map with only filtered files - new_weight_map = { - key: value for key, value in weight_map.items() if value in filtered_files - } - - result["weight_map"] = new_weight_map - logger.debug( - f"Modified weight_map from {len(weight_map)} to {len(new_weight_map)} entries" - ) - - except Exception as e: - logger.warning(f"Failed to filter index file: {e}") - - return result - - json.load = patched_json_load - - # Patch huggingface_hub.hf_hub_download to prevent downloading unwanted files - try: - from huggingface_hub import hf_hub_download as original_hf_hub_download - - def patched_hf_hub_download( - repo_id, filename, *args, subfolder=None, repo_type=None, **kwargs - ): - # Check if this is a weight file download - if filename and filename.endswith((".safetensors", ".bin", ".pt")): - logger.debug( - f"patched_hf_hub_download called: repo_id={repo_id}, filename={filename}" - ) - - global _layer_range_cache - if _layer_range_cache.get("pp_start_layer") is not None: - # Get model path to check if file should be filtered - # We need to check if this file is needed for our layer range - # For now, just log and let it through - logger.warning( - f"Weight file download requested: {filename} - this should have been filtered!" - ) - - return original_hf_hub_download( - repo_id, filename, *args, subfolder=subfolder, repo_type=repo_type, **kwargs - ) - - import huggingface_hub - - huggingface_hub.hf_hub_download = patched_hf_hub_download - except ImportError: - logger.debug("huggingface_hub not available for patching") From 66742ef2c69b95b3a0e0543d15c0f3e5227a7d51 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 5 Nov 2025 20:28:01 +0800 Subject: [PATCH 29/29] update --- src/parallax/server/http_server.py | 15 ++++-- src/parallax/server/shard_loader.py | 4 +- .../weight_loader_filter.py | 24 +--------- src/parallax/utils/selective_download.py | 48 +++++++++++-------- src/parallax/utils/weight_filter_utils.py | 8 +++- 5 files changed, 49 insertions(+), 50 deletions(-) diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index b8d354c7..8d036d77 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -30,10 +30,11 @@ import zmq.asyncio from fastapi.responses import ORJSONResponse, StreamingResponse from mlx_lm.tokenizer_utils import StreamingDetokenizer -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import load_config from pydantic import BaseModel from starlette.datastructures import State +from parallax.utils.selective_download import download_metadata_only from parallax.utils.tokenizer_utils import load_detokenizer, load_tokenizer from parallax.utils.utils import get_zmq_socket from parallax_utils.logging_config import get_logger @@ -104,8 +105,16 @@ def __init__( self.send_to_executor = get_zmq_socket(context, zmq.PUSH, executor_input_ipc_name, True) self.recv_from_executor = get_zmq_socket(context, zmq.PULL, executor_output_ipc_name, True) self.processing_requests: Dict[str, HTTPRequestInfo] = {} - # Load tokenizer for separate detokenizers - model_path = get_model_path(model_path_str)[0] + + # Load tokenizer for separate detokenizers. + # Important: avoid triggering full weight downloads here. + # Only download metadata/config/tokenizer files. + from pathlib import Path + + if Path(model_path_str).exists(): + model_path = Path(model_path_str) + else: + model_path = download_metadata_only(model_path_str) config = load_config(model_path) self.model_path_str = model_path_str self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index e53a308f..3ecc6dcf 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -179,10 +179,10 @@ def load( # Use shared utility to filter weight files from parallax.utils.weight_filter_utils import ( - filter_weight_files_by_layer_range, + filter_weight_files_by_layer_range_for_load, ) - weight_files = filter_weight_files_by_layer_range( + weight_files = filter_weight_files_by_layer_range_for_load( model_path=model_path, weight_files=weight_files, start_layer=current_start_layer, diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index f2333736..5cc529ff 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -2,32 +2,10 @@ from pathlib import Path from typing import List -from parallax.utils.weight_filter_utils import ( - filter_weight_files_by_layer_range as shared_filter, -) +from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range logger = logging.getLogger(__name__) - -def filter_weight_files_by_layer_range( - model_path: Path, - weight_files: List[str], - pp_start_layer: int, - pp_end_layer: int, - is_first_shard: bool, - is_last_shard: bool, -) -> List[str]: - return shared_filter( - model_path=model_path, - weight_files=weight_files, - start_layer=pp_start_layer, - end_layer=pp_end_layer, - is_first_shard=is_first_shard, - is_last_shard=is_last_shard, - config={}, - ) - - _layer_range_cache = {} diff --git a/src/parallax/utils/selective_download.py b/src/parallax/utils/selective_download.py index 0ff5578b..8a986177 100644 --- a/src/parallax/utils/selective_download.py +++ b/src/parallax/utils/selective_download.py @@ -5,14 +5,33 @@ from huggingface_hub import hf_hub_download, snapshot_download logger = logging.getLogger(__name__) - - -def determine_needed_weight_files(model_path: Path, start_layer: int, end_layer: int): - from parallax.utils.weight_filter_utils import ( - determine_needed_weight_files as determine_files, +from parallax.utils.weight_filter_utils import ( + determine_needed_weight_files_for_download, +) + +EXCLUDE_WEIGHT_PATTERNS = [ + "*.safetensors", + "*.bin", + "*.pt", + "*.pth", + "pytorch_model*.bin", + "model*.safetensors", + "weight*.safetensors", +] + + +def download_metadata_only( + repo_id: str, + cache_dir: Optional[str] = None, + force_download: bool = False, +) -> Path: + path = snapshot_download( + repo_id=repo_id, + cache_dir=cache_dir, + ignore_patterns=EXCLUDE_WEIGHT_PATTERNS, + force_download=force_download, ) - - return determine_files(model_path, start_layer, end_layer) + return Path(path) def selective_model_download( @@ -24,28 +43,17 @@ def selective_model_download( ) -> Path: logger.debug(f"Downloading model metadata for {repo_id}") - ignore_patterns = [ - "*.safetensors", - "*.bin", - "*.pt", - "*.pth", - "pytorch_model*.bin", - "model*.safetensors", - ] - - model_path = snapshot_download( + model_path = download_metadata_only( repo_id=repo_id, cache_dir=cache_dir, - ignore_patterns=ignore_patterns, force_download=force_download, ) - model_path = Path(model_path) logger.debug(f"Downloaded model metadata to {model_path}") if start_layer is not None and end_layer is not None: logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})") - needed_weight_files = determine_needed_weight_files( + needed_weight_files = determine_needed_weight_files_for_download( model_path=model_path, start_layer=start_layer, end_layer=end_layer, diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py index 5a5a0f37..f11811b0 100644 --- a/src/parallax/utils/weight_filter_utils.py +++ b/src/parallax/utils/weight_filter_utils.py @@ -33,7 +33,7 @@ def should_include_weight_key( return False -def filter_weight_files_by_layer_range( +def filter_weight_files_by_layer_range_for_load( model_path: Path, weight_files: List[str], start_layer: int, @@ -69,6 +69,8 @@ def filter_weight_files_by_layer_range( needed_files: Set[str] = set() for key, filename in weight_map.items(): + if filename in needed_files: + continue if should_include_weight_key( key=key, start_layer=start_layer, @@ -99,7 +101,7 @@ def filter_weight_files_by_layer_range( return filtered_files -def determine_needed_weight_files( +def determine_needed_weight_files_for_download( model_path: Path, start_layer: int, end_layer: int, @@ -140,6 +142,8 @@ def determine_needed_weight_files( needed_files: Set[str] = set() for key, filename in weight_map.items(): + if filename in needed_files: + continue if should_include_weight_key( key=key, start_layer=start_layer,