diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 5974e221..263d9e88 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -30,7 +30,7 @@ "Qwen/Qwen3-14B-FP8": "Qwen/Qwen3-14B-FP8", "Qwen/Qwen3-32B": "Qwen/Qwen3-32B", "Qwen/Qwen3-32B-FP8": "Qwen/Qwen3-32B-FP8", - "Qwen/Qwen3-30B-A3B": "Qwen/Qwen3-30B-A3B", + "Qwen/Qwen3-30B-A3B": "Qwen/Qwen3-30B-A3B-MLX-8bit", "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8": "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8", "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8": "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8", "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit", diff --git a/src/parallax/launch.py b/src/parallax/launch.py index f41f5133..ac380e01 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -87,8 +87,8 @@ initial_peers=args.initial_peers, scheduler_addr=args.scheduler_addr, relay_servers=args.relay_servers, - pp_start_layer=None, - pp_end_layer=None, + pp_start_layer=args.start_layer, + pp_end_layer=args.end_layer, hidden_layers=None, tcp_port=args.tcp_port, udp_port=args.udp_port, diff --git a/src/parallax/p2p/server.py b/src/parallax/p2p/server.py index b7be6173..85e5d091 100644 --- a/src/parallax/p2p/server.py +++ b/src/parallax/p2p/server.py @@ -248,6 +248,7 @@ def __init__( self.announcer = None self.connection_handler = None self.stop_event = threading.Event() + logger.debug(f"manual_layer_assignment: {self.manual_layer_assignment}") self._layer_allocation_changed = False def build_lattica(self): @@ -775,7 +776,10 @@ def launch_p2p_server( thread = threading.Thread(target=server.run, daemon=True) thread.start() - while server.block_start_index is None: + # Wait for layer allocation and model_name to be set + while server.block_start_index is None or ( + scheduler_addr is not None and server.model_name is None + ): time.sleep(1) return server diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index b8d354c7..8d036d77 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -30,10 +30,11 @@ import zmq.asyncio from fastapi.responses import ORJSONResponse, StreamingResponse from mlx_lm.tokenizer_utils import StreamingDetokenizer -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import load_config from pydantic import BaseModel from starlette.datastructures import State +from parallax.utils.selective_download import download_metadata_only from parallax.utils.tokenizer_utils import load_detokenizer, load_tokenizer from parallax.utils.utils import get_zmq_socket from parallax_utils.logging_config import get_logger @@ -104,8 +105,16 @@ def __init__( self.send_to_executor = get_zmq_socket(context, zmq.PUSH, executor_input_ipc_name, True) self.recv_from_executor = get_zmq_socket(context, zmq.PULL, executor_output_ipc_name, True) self.processing_requests: Dict[str, HTTPRequestInfo] = {} - # Load tokenizer for separate detokenizers - model_path = get_model_path(model_path_str)[0] + + # Load tokenizer for separate detokenizers. + # Important: avoid triggering full weight downloads here. + # Only download metadata/config/tokenizer files. + from pathlib import Path + + if Path(model_path_str).exists(): + model_path = Path(model_path_str) + else: + model_path = download_metadata_only(model_path_str) config = load_config(model_path) self.model_path_str = model_path_str self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index be03e00f..3ecc6dcf 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -85,7 +85,7 @@ def register_block_class(self): logger.warning(f"Failed to load model from {model_file}: {e}") def load( - self, lazy: bool = False, strict: bool = True + self, lazy: bool = False, strict: bool = True, use_selective_download: bool = True ) -> Tuple[nn.Module, Dict[str, Any], Any]: """ Loads the specified model shard by loading only the necessary weights @@ -96,10 +96,27 @@ def load( into memory. Defaults to False. strict (bool): If True, raises an exception if weights do not match. Defaults to True. + use_selective_download (bool): If True, only download necessary weight files + from Hugging Face. Defaults to True. Returns: A tuple containing the loaded sharded MLX model and its configuration dictionary. """ - model_path = get_model_path(self.model_path_str)[0] + if use_selective_download and self.start_layer is not None and self.end_layer is not None: + from parallax.utils.selective_download import ( + get_model_path_with_selective_download, + ) + + logger.info( + f"Using selective download for layers [{self.start_layer}, {self.end_layer})" + ) + model_path = get_model_path_with_selective_download( + self.model_path_str, + start_layer=self.start_layer, + end_layer=self.end_layer, + ) + else: + model_path = get_model_path(self.model_path_str)[0] + config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) @@ -157,6 +174,24 @@ def load( if not weight_files: weight_files = glob.glob(str(model_path / "weight*.safetensors")) + # Sort weight files by name for consistent loading order + weight_files = sorted(weight_files) + + # Use shared utility to filter weight files + from parallax.utils.weight_filter_utils import ( + filter_weight_files_by_layer_range_for_load, + ) + + weight_files = filter_weight_files_by_layer_range_for_load( + model_path=model_path, + weight_files=weight_files, + start_layer=current_start_layer, + end_layer=current_end_layer, + is_first_shard=model_shard.is_first_shard, + is_last_shard=model_shard.is_last_shard, + config=config, + ) + if not weight_files and strict: raise FileNotFoundError(f"No safetensors found in {model_path}") @@ -165,8 +200,11 @@ def load( shard_weights = {} layer_key_prefix = "model.layers" # Common prefix - for wf in weight_files: - # For bf16 models, we need torch tensors as a bridge + for file_idx, wf in enumerate(weight_files): + logger.debug( + f"Scanning weight file {file_idx + 1}/{len(weight_files)}: {pathlib.Path(wf).name}" + ) + with safetensors.safe_open(wf, framework="pt") as f: for key in f.keys(): is_needed = False @@ -215,7 +253,7 @@ def load( shard_weights[remapped_key] = mx.array(f.get_tensor(key)) if (quantization := config.get("quantization", None)) is not None: - logger.info("Model is quantized. Applying quantization parameters...") + logger.debug("Model is quantized. Applying quantization parameters...") def class_predicate(p, m): # Handle custom per-layer quantizations from the config @@ -232,10 +270,6 @@ def class_predicate(p, m): prefixed = f"model.{p}" if prefixed in qcfg: override = qcfg[prefixed] - if isinstance(override, dict): - logger.debug( - f"[quantize] Using override for '{prefixed}' (mapped to '{p}'): bits={override.get('bits')} group_size={override.get('group_size')}" - ) return override if not hasattr(m, "to_quantized"): return False diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 2bf8a0f5..ff8f9f4a 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -11,7 +11,7 @@ import sglang import sglang.srt.distributed.parallel_state import torch -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import load_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import ( get_tp_group, @@ -35,6 +35,9 @@ ) from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch +from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( + set_layer_range_for_filtering, +) from parallax.utils.tokenizer_utils import load_tokenizer logger = logging.getLogger(__name__) @@ -67,6 +70,9 @@ def __init__( """Add pp_start_layer and pp_end_layer for decentralized model""" self.pp_start_layer = pp_start_layer self.pp_end_layer = pp_end_layer + num_hidden_layers = model_config.hf_config.num_hidden_layers + set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers) + super().__init__( model_config=model_config, mem_fraction_static=mem_fraction_static, @@ -230,7 +236,19 @@ def initialize_sgl_model_runner( - tokenizer: tokenizer driven by mlx-lm """ apply_parallax_sglang_monkey_patch() - model_path = get_model_path(original_model_path)[0] + + # Use selective download for GPU models to save bandwidth and disk space + from parallax.utils.selective_download import get_model_path_with_selective_download + + logger.info( + f"Downloading model with selective weight files for layers [{start_layer}, {end_layer})" + ) + model_path = get_model_path_with_selective_download( + original_model_path, + start_layer=start_layer, + end_layer=end_layer, + ) + config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") @@ -251,7 +269,7 @@ def initialize_sgl_model_runner( kv_block_size = 1 server_args = form_sgl_server_args( - original_model_path, + str(model_path), dtype, attention_backend, kv_block_size, @@ -262,7 +280,7 @@ def initialize_sgl_model_runner( if (quantization_config := config.get("quantization_config", None)) is not None: quant_method = quantization_config.get("quant_method") model_config = ModelConfig( - model_path=original_model_path, + model_path=str(model_path), model_override_args="{}", dtype=dtype, quantization=quant_method, diff --git a/src/parallax/sglang/monkey_patch.py b/src/parallax/sglang/monkey_patch.py index 3bf53067..4b2d45c9 100644 --- a/src/parallax/sglang/monkey_patch.py +++ b/src/parallax/sglang/monkey_patch.py @@ -17,15 +17,19 @@ from parallax.sglang.monkey_patch_utils.triton_backend import ( apply_triton_backend_init_monkey_patch, ) +from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( + apply_weight_loader_filter_patch, +) ## Here is some patch func for sglang ## Hopefully, when sglang support pipeline parallelism natively, we can remove these patches def apply_parallax_sglang_monkey_patch(): + apply_model_parallel_monkey_patch() + apply_triton_backend_init_monkey_patch() + apply_weight_loader_filter_patch() apply_qwen3_next_monkey_patch() apply_qwen3_next_config_monkey_patch() apply_gpt_oss_monkey_patch() apply_minimax_m2_monkey_patch() apply_glm4_moe_monkey_patch() - apply_triton_backend_init_monkey_patch() - apply_model_parallel_monkey_patch() diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py new file mode 100644 index 00000000..5cc529ff --- /dev/null +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -0,0 +1,70 @@ +import logging +from pathlib import Path +from typing import List + +from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range + +logger = logging.getLogger(__name__) + +_layer_range_cache = {} + + +def set_layer_range_for_filtering(pp_start_layer: int, pp_end_layer: int, num_hidden_layers: int): + global _layer_range_cache + _layer_range_cache["pp_start_layer"] = pp_start_layer + _layer_range_cache["pp_end_layer"] = pp_end_layer + _layer_range_cache["num_hidden_layers"] = num_hidden_layers + + +def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: + global _layer_range_cache + + pp_start_layer = _layer_range_cache.get("pp_start_layer") + pp_end_layer = _layer_range_cache.get("pp_end_layer") + num_hidden_layers = _layer_range_cache.get("num_hidden_layers") + + if pp_start_layer is None or pp_end_layer is None: + logger.debug("No layer range set, loading all weight files") + return hf_weights_files + + if not hf_weights_files: + return hf_weights_files + + model_path = Path(hf_weights_files[0]).parent + is_first_shard = pp_start_layer == 0 + is_last_shard = pp_end_layer >= num_hidden_layers + + filtered_files = filter_weight_files_by_layer_range( + model_path=model_path, + weight_files=hf_weights_files, + pp_start_layer=pp_start_layer, + pp_end_layer=pp_end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, + ) + + return filtered_files + + +def apply_weight_loader_filter_patch(): + import glob as glob_module + + original_glob = glob_module.glob + + def patched_glob(pathname, **kwargs): + files = original_glob(pathname, **kwargs) + if ( + isinstance(files, list) + and files + and any(f.endswith((".safetensors", ".bin", ".pt")) for f in files) + ): + + # Filter if we have layer range set + global _layer_range_cache + if _layer_range_cache.get("pp_start_layer") is not None: + filtered = _filter_weight_files_by_cache(files) + return filtered + + return files + + glob_module.glob = patched_glob diff --git a/src/parallax/utils/selective_download.py b/src/parallax/utils/selective_download.py new file mode 100644 index 00000000..8a986177 --- /dev/null +++ b/src/parallax/utils/selective_download.py @@ -0,0 +1,103 @@ +import logging +from pathlib import Path +from typing import Optional + +from huggingface_hub import hf_hub_download, snapshot_download + +logger = logging.getLogger(__name__) +from parallax.utils.weight_filter_utils import ( + determine_needed_weight_files_for_download, +) + +EXCLUDE_WEIGHT_PATTERNS = [ + "*.safetensors", + "*.bin", + "*.pt", + "*.pth", + "pytorch_model*.bin", + "model*.safetensors", + "weight*.safetensors", +] + + +def download_metadata_only( + repo_id: str, + cache_dir: Optional[str] = None, + force_download: bool = False, +) -> Path: + path = snapshot_download( + repo_id=repo_id, + cache_dir=cache_dir, + ignore_patterns=EXCLUDE_WEIGHT_PATTERNS, + force_download=force_download, + ) + return Path(path) + + +def selective_model_download( + repo_id: str, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + cache_dir: Optional[str] = None, + force_download: bool = False, +) -> Path: + logger.debug(f"Downloading model metadata for {repo_id}") + + model_path = download_metadata_only( + repo_id=repo_id, + cache_dir=cache_dir, + force_download=force_download, + ) + logger.debug(f"Downloaded model metadata to {model_path}") + + if start_layer is not None and end_layer is not None: + logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})") + + needed_weight_files = determine_needed_weight_files_for_download( + model_path=model_path, + start_layer=start_layer, + end_layer=end_layer, + ) + + if not needed_weight_files: + logger.debug("Could not determine specific weight files, downloading all") + snapshot_download( + repo_id=repo_id, + cache_dir=cache_dir, + force_download=force_download, + ) + else: + # Step 3: Download only the needed weight files + logger.info(f"Downloading {len(needed_weight_files)} weight files") + + for weight_file in needed_weight_files: + logger.debug(f"Downloading {weight_file}") + hf_hub_download( + repo_id=repo_id, + filename=weight_file, + cache_dir=cache_dir, + force_download=force_download, + ) + + logger.debug(f"Downloaded weight files for layers [{start_layer}, {end_layer})") + else: + logger.debug("No layer range specified, downloading all model files") + snapshot_download( + repo_id=repo_id, + cache_dir=cache_dir, + force_download=force_download, + ) + + return model_path + + +def get_model_path_with_selective_download( + model_path_or_repo: str, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, +) -> Path: + return selective_model_download( + repo_id=model_path_or_repo, + start_layer=start_layer, + end_layer=end_layer, + ) diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py new file mode 100644 index 00000000..f11811b0 --- /dev/null +++ b/src/parallax/utils/weight_filter_utils.py @@ -0,0 +1,161 @@ +import json +import logging +from pathlib import Path +from typing import Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + + +def should_include_weight_key( + key: str, + start_layer: int, + end_layer: int, + is_first_shard: bool, + is_last_shard: bool, + tie_word_embeddings: bool = False, +) -> bool: + if is_first_shard and "embed_tokens" in key and key.startswith("model."): + return True + + if is_last_shard: + if "model.norm" in key or "lm_head" in key: + return True + if tie_word_embeddings and "embed" in key and key.startswith("model.embed_tokens"): + return True + + if "layers." in key: + parts = key.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts): + layer_idx = int(parts[i + 1]) + return start_layer <= layer_idx < end_layer + + return False + + +def filter_weight_files_by_layer_range_for_load( + model_path: Path, + weight_files: List[str], + start_layer: int, + end_layer: int, + is_first_shard: bool, + is_last_shard: bool, + config: Optional[Dict] = None, +) -> List[str]: + index_file = model_path / "model.safetensors.index.json" + + if not index_file.exists(): + logger.debug(f"No index file found at {index_file}, cannot filter weight files") + return weight_files + + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + if not weight_map: + logger.debug("weight_map is empty in index file") + return weight_files + + tie_word_embeddings = False + if config: + tie_word_embeddings = config.get("tie_word_embeddings", False) + else: + config_file = model_path / "config.json" + if config_file.exists(): + with open(config_file, "r") as f: + cfg = json.load(f) + tie_word_embeddings = cfg.get("tie_word_embeddings", False) + + needed_files: Set[str] = set() + + for key, filename in weight_map.items(): + if filename in needed_files: + continue + if should_include_weight_key( + key=key, + start_layer=start_layer, + end_layer=end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, + tie_word_embeddings=tie_word_embeddings, + ): + needed_files.add(filename) + + if not needed_files: + logger.debug( + f"No relevant weight files found in index for layers [{start_layer}, {end_layer})" + ) + return weight_files + + filtered_files = [] + for wf in weight_files: + wf_name = Path(wf).name + if wf_name in needed_files: + filtered_files.append(wf) + + logger.debug( + f"Filtered weight files from {len(weight_files)} to {len(filtered_files)} " + f"for layers [{start_layer}, {end_layer})" + ) + + return filtered_files + + +def determine_needed_weight_files_for_download( + model_path: Path, + start_layer: int, + end_layer: int, + config: Optional[Dict] = None, +) -> List[str]: + is_first_shard = start_layer == 0 + + is_last_shard = False + if config: + num_hidden_layers = config.get("num_hidden_layers", 0) + is_last_shard = end_layer >= num_hidden_layers + else: + config_file = model_path / "config.json" + if config_file.exists(): + with open(config_file, "r") as f: + cfg = json.load(f) + num_hidden_layers = cfg.get("num_hidden_layers", 0) + is_last_shard = end_layer >= num_hidden_layers + + index_file = model_path / "model.safetensors.index.json" + + if not index_file.exists(): + logger.debug(f"Index file not found at {index_file}") + return [] + + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + if not weight_map: + logger.debug("weight_map is empty in index file") + return [] + + tie_word_embeddings = False + if config: + tie_word_embeddings = config.get("tie_word_embeddings", False) + + needed_files: Set[str] = set() + + for key, filename in weight_map.items(): + if filename in needed_files: + continue + if should_include_weight_key( + key=key, + start_layer=start_layer, + end_layer=end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, + tie_word_embeddings=tie_word_embeddings, + ): + needed_files.add(filename) + + result = sorted(list(needed_files)) + logger.debug( + f"Determined {len(result)} weight files needed for layers [{start_layer}, {end_layer})" + ) + return result diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index af43c55e..35faea5b 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -302,14 +302,28 @@ def leave(self, node_id: str) -> None: self.layer_allocator.leave(node_id) if self.layer_allocator.should_global_rebalance(): logger.debug("Global rebalance triggered due to node leave") - # TODO: send a signal to the nodes to stop running requests - # and re-assign start/end layers so nodes can re-shard - self._bootstrapped = False - self._bootstrapped_event.clear() - for n in self.nodes: - if n.start_layer is not None and n.end_layer is not None: - self.layer_allocator.deallocate(n) - self.layer_allocator.global_allocation() + + # Count manual vs automatic nodes + manual_count = sum(1 for n in self.nodes if n.manual_layer_assignment) + total_count = len(self.nodes) + logger.debug( + f"Node count: {manual_count} manual, {total_count - manual_count} automatic" + ) + if manual_count == total_count: + logger.debug("All nodes are manual assignment, skipping global rebalance") + elif manual_count > 0: + logger.error( + f"Mixed assignment detected ({manual_count} manual, {total_count - manual_count} automatic); skipping rebalance" + ) + else: + # All nodes are automatic, proceed with rebalance + self._bootstrapped = False + self._bootstrapped_event.clear() + for n in self.nodes: + if n.start_layer is not None and n.end_layer is not None: + self.layer_allocator.deallocate(n) + self.layer_allocator.global_allocation() + with self._node_count_cv: self._node_count_cv.notify_all()