Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
63f43fc
debug log
yuhao-zh Oct 31, 2025
e7a8c92
update
yuhao-zh Nov 3, 2025
4e09501
remove debug logger
yuhao-zh Nov 3, 2025
4c0b820
Merge remote-tracking branch 'origin/main' into feat/read_less_weights
yuhao-zh Nov 3, 2025
7ffe147
update
yuhao-zh Nov 3, 2025
1798c1b
Merge branch 'main' into feat/read_less_weights
yuhao-zh Nov 3, 2025
f7a0261
update
yuhao-zh Nov 3, 2025
cbf198b
update gpu load
Nov 3, 2025
321fe22
update
Nov 3, 2025
52db26f
update gpu load
Nov 3, 2025
aec66ff
update
Nov 3, 2025
fb4447f
update
yuhao-zh Nov 4, 2025
131f59b
update
yuhao-zh Nov 4, 2025
b77a902
Merge branch 'feat/support_specified_layers' into feat/read_less_weights
yuhao-zh Nov 4, 2025
614facb
update
yuhao-zh Nov 4, 2025
23c1d62
Merge branch 'main' into feat/read_less_weights
yuhao-zh Nov 4, 2025
2521a39
update
yuhao-zh Nov 4, 2025
9472cc1
update
yuhao-zh Nov 4, 2025
3c36dc1
update
yuhao-zh Nov 4, 2025
be88288
pre-commit
yuhao-zh Nov 4, 2025
cc9880a
update
yuhao-zh Nov 4, 2025
ea0a3a9
update
yuhao-zh Nov 4, 2025
1f33287
update
yuhao-zh Nov 4, 2025
22edf79
update
yuhao-zh Nov 4, 2025
ced07a4
update
yuhao-zh Nov 4, 2025
08c6643
update
yuhao-zh Nov 4, 2025
441d103
update
yuhao-zh Nov 4, 2025
5232786
update
yuhao-zh Nov 4, 2025
b19eeb2
log gpu
yuhao-zh Nov 4, 2025
be44084
add log
yuhao-zh Nov 4, 2025
3bc4e94
update model Qwen3-30B-A3B
yuhao-zh Nov 5, 2025
a92ac6f
fix gpu bug and modify files
yuhao-zh Nov 5, 2025
137db69
Merge branch 'main' into feat/read_less_weights
yuhao-zh Nov 5, 2025
66742ef
update
yuhao-zh Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/backend/server/static_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"Qwen/Qwen3-14B-FP8": "Qwen/Qwen3-14B-FP8",
"Qwen/Qwen3-32B": "Qwen/Qwen3-32B",
"Qwen/Qwen3-32B-FP8": "Qwen/Qwen3-32B-FP8",
"Qwen/Qwen3-30B-A3B": "Qwen/Qwen3-30B-A3B",
"Qwen/Qwen3-30B-A3B": "Qwen/Qwen3-30B-A3B-MLX-8bit",
"Qwen/Qwen3-30B-A3B-Instruct-2507-FP8": "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8",
"Qwen/Qwen3-30B-A3B-Thinking-2507-FP8": "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8",
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit",
Expand Down
4 changes: 2 additions & 2 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@
initial_peers=args.initial_peers,
scheduler_addr=args.scheduler_addr,
relay_servers=args.relay_servers,
pp_start_layer=None,
pp_end_layer=None,
pp_start_layer=args.start_layer,
pp_end_layer=args.end_layer,
hidden_layers=None,
tcp_port=args.tcp_port,
udp_port=args.udp_port,
Expand Down
6 changes: 5 additions & 1 deletion src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def __init__(
self.announcer = None
self.connection_handler = None
self.stop_event = threading.Event()
logger.debug(f"manual_layer_assignment: {self.manual_layer_assignment}")
self._layer_allocation_changed = False

def build_lattica(self):
Expand Down Expand Up @@ -775,7 +776,10 @@ def launch_p2p_server(
thread = threading.Thread(target=server.run, daemon=True)
thread.start()

while server.block_start_index is None:
# Wait for layer allocation and model_name to be set
while server.block_start_index is None or (
scheduler_addr is not None and server.model_name is None
):
time.sleep(1)

return server
15 changes: 12 additions & 3 deletions src/parallax/server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
import zmq.asyncio
from fastapi.responses import ORJSONResponse, StreamingResponse
from mlx_lm.tokenizer_utils import StreamingDetokenizer
from mlx_lm.utils import get_model_path, load_config
from mlx_lm.utils import load_config
from pydantic import BaseModel
from starlette.datastructures import State

from parallax.utils.selective_download import download_metadata_only
from parallax.utils.tokenizer_utils import load_detokenizer, load_tokenizer
from parallax.utils.utils import get_zmq_socket
from parallax_utils.logging_config import get_logger
Expand Down Expand Up @@ -104,8 +105,16 @@ def __init__(
self.send_to_executor = get_zmq_socket(context, zmq.PUSH, executor_input_ipc_name, True)
self.recv_from_executor = get_zmq_socket(context, zmq.PULL, executor_output_ipc_name, True)
self.processing_requests: Dict[str, HTTPRequestInfo] = {}
# Load tokenizer for separate detokenizers
model_path = get_model_path(model_path_str)[0]

# Load tokenizer for separate detokenizers.
# Important: avoid triggering full weight downloads here.
# Only download metadata/config/tokenizer files.
from pathlib import Path

if Path(model_path_str).exists():
model_path = Path(model_path_str)
else:
model_path = download_metadata_only(model_path_str)
config = load_config(model_path)
self.model_path_str = model_path_str
self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
Expand Down
52 changes: 43 additions & 9 deletions src/parallax/server/shard_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def register_block_class(self):
logger.warning(f"Failed to load model from {model_file}: {e}")

def load(
self, lazy: bool = False, strict: bool = True
self, lazy: bool = False, strict: bool = True, use_selective_download: bool = True
) -> Tuple[nn.Module, Dict[str, Any], Any]:
"""
Loads the specified model shard by loading only the necessary weights
Expand All @@ -96,10 +96,27 @@ def load(
into memory. Defaults to False.
strict (bool): If True, raises an exception if weights do not match.
Defaults to True.
use_selective_download (bool): If True, only download necessary weight files
from Hugging Face. Defaults to True.
Returns:
A tuple containing the loaded sharded MLX model and its configuration dictionary.
"""
model_path = get_model_path(self.model_path_str)[0]
if use_selective_download and self.start_layer is not None and self.end_layer is not None:
from parallax.utils.selective_download import (
get_model_path_with_selective_download,
)

logger.info(
f"Using selective download for layers [{self.start_layer}, {self.end_layer})"
)
model_path = get_model_path_with_selective_download(
self.model_path_str,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
else:
model_path = get_model_path(self.model_path_str)[0]

config = load_config(model_path)
tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))

Expand Down Expand Up @@ -157,6 +174,24 @@ def load(
if not weight_files:
weight_files = glob.glob(str(model_path / "weight*.safetensors"))

# Sort weight files by name for consistent loading order
weight_files = sorted(weight_files)

# Use shared utility to filter weight files
from parallax.utils.weight_filter_utils import (
filter_weight_files_by_layer_range_for_load,
)

weight_files = filter_weight_files_by_layer_range_for_load(
model_path=model_path,
weight_files=weight_files,
start_layer=current_start_layer,
end_layer=current_end_layer,
is_first_shard=model_shard.is_first_shard,
is_last_shard=model_shard.is_last_shard,
config=config,
)

if not weight_files and strict:
raise FileNotFoundError(f"No safetensors found in {model_path}")

Expand All @@ -165,8 +200,11 @@ def load(
shard_weights = {}
layer_key_prefix = "model.layers" # Common prefix

for wf in weight_files:
# For bf16 models, we need torch tensors as a bridge
for file_idx, wf in enumerate(weight_files):
logger.debug(
f"Scanning weight file {file_idx + 1}/{len(weight_files)}: {pathlib.Path(wf).name}"
)

with safetensors.safe_open(wf, framework="pt") as f:
for key in f.keys():
is_needed = False
Expand Down Expand Up @@ -215,7 +253,7 @@ def load(
shard_weights[remapped_key] = mx.array(f.get_tensor(key))

if (quantization := config.get("quantization", None)) is not None:
logger.info("Model is quantized. Applying quantization parameters...")
logger.debug("Model is quantized. Applying quantization parameters...")

def class_predicate(p, m):
# Handle custom per-layer quantizations from the config
Expand All @@ -232,10 +270,6 @@ def class_predicate(p, m):
prefixed = f"model.{p}"
if prefixed in qcfg:
override = qcfg[prefixed]
if isinstance(override, dict):
logger.debug(
f"[quantize] Using override for '{prefixed}' (mapped to '{p}'): bits={override.get('bits')} group_size={override.get('group_size')}"
)
return override
if not hasattr(m, "to_quantized"):
return False
Expand Down
26 changes: 22 additions & 4 deletions src/parallax/sglang/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sglang
import sglang.srt.distributed.parallel_state
import torch
from mlx_lm.utils import get_model_path, load_config
from mlx_lm.utils import load_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import (
get_tp_group,
Expand All @@ -35,6 +35,9 @@
)

from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch
from parallax.sglang.monkey_patch_utils.weight_loader_filter import (
set_layer_range_for_filtering,
)
from parallax.utils.tokenizer_utils import load_tokenizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,6 +70,9 @@ def __init__(
"""Add pp_start_layer and pp_end_layer for decentralized model"""
self.pp_start_layer = pp_start_layer
self.pp_end_layer = pp_end_layer
num_hidden_layers = model_config.hf_config.num_hidden_layers
set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers)

super().__init__(
model_config=model_config,
mem_fraction_static=mem_fraction_static,
Expand Down Expand Up @@ -230,7 +236,19 @@ def initialize_sgl_model_runner(
- tokenizer: tokenizer driven by mlx-lm
"""
apply_parallax_sglang_monkey_patch()
model_path = get_model_path(original_model_path)[0]

# Use selective download for GPU models to save bandwidth and disk space
from parallax.utils.selective_download import get_model_path_with_selective_download

logger.info(
f"Downloading model with selective weight files for layers [{start_layer}, {end_layer})"
)
model_path = get_model_path_with_selective_download(
original_model_path,
start_layer=start_layer,
end_layer=end_layer,
)

config = load_config(model_path)
tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
dtype = config.get("torch_dtype", "bfloat16")
Expand All @@ -251,7 +269,7 @@ def initialize_sgl_model_runner(
kv_block_size = 1

server_args = form_sgl_server_args(
original_model_path,
str(model_path),
dtype,
attention_backend,
kv_block_size,
Expand All @@ -262,7 +280,7 @@ def initialize_sgl_model_runner(
if (quantization_config := config.get("quantization_config", None)) is not None:
quant_method = quantization_config.get("quant_method")
model_config = ModelConfig(
model_path=original_model_path,
model_path=str(model_path),
model_override_args="{}",
dtype=dtype,
quantization=quant_method,
Expand Down
8 changes: 6 additions & 2 deletions src/parallax/sglang/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@
from parallax.sglang.monkey_patch_utils.triton_backend import (
apply_triton_backend_init_monkey_patch,
)
from parallax.sglang.monkey_patch_utils.weight_loader_filter import (
apply_weight_loader_filter_patch,
)


## Here is some patch func for sglang
## Hopefully, when sglang support pipeline parallelism natively, we can remove these patches
def apply_parallax_sglang_monkey_patch():
apply_model_parallel_monkey_patch()
apply_triton_backend_init_monkey_patch()
apply_weight_loader_filter_patch()
apply_qwen3_next_monkey_patch()
apply_qwen3_next_config_monkey_patch()
apply_gpt_oss_monkey_patch()
apply_minimax_m2_monkey_patch()
apply_glm4_moe_monkey_patch()
apply_triton_backend_init_monkey_patch()
apply_model_parallel_monkey_patch()
70 changes: 70 additions & 0 deletions src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from pathlib import Path
from typing import List

from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range

logger = logging.getLogger(__name__)

_layer_range_cache = {}


def set_layer_range_for_filtering(pp_start_layer: int, pp_end_layer: int, num_hidden_layers: int):
global _layer_range_cache
_layer_range_cache["pp_start_layer"] = pp_start_layer
_layer_range_cache["pp_end_layer"] = pp_end_layer
_layer_range_cache["num_hidden_layers"] = num_hidden_layers


def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]:
global _layer_range_cache

pp_start_layer = _layer_range_cache.get("pp_start_layer")
pp_end_layer = _layer_range_cache.get("pp_end_layer")
num_hidden_layers = _layer_range_cache.get("num_hidden_layers")

if pp_start_layer is None or pp_end_layer is None:
logger.debug("No layer range set, loading all weight files")
return hf_weights_files

if not hf_weights_files:
return hf_weights_files

model_path = Path(hf_weights_files[0]).parent
is_first_shard = pp_start_layer == 0
is_last_shard = pp_end_layer >= num_hidden_layers

filtered_files = filter_weight_files_by_layer_range(
model_path=model_path,
weight_files=hf_weights_files,
pp_start_layer=pp_start_layer,
pp_end_layer=pp_end_layer,
is_first_shard=is_first_shard,
is_last_shard=is_last_shard,
)

return filtered_files


def apply_weight_loader_filter_patch():
import glob as glob_module

original_glob = glob_module.glob

def patched_glob(pathname, **kwargs):
files = original_glob(pathname, **kwargs)
if (
isinstance(files, list)
and files
and any(f.endswith((".safetensors", ".bin", ".pt")) for f in files)
):

# Filter if we have layer range set
global _layer_range_cache
if _layer_range_cache.get("pp_start_layer") is not None:
filtered = _filter_weight_files_by_cache(files)
return filtered

return files

glob_module.glob = patched_glob
Loading