From 9499853b10f76a245098ca86f9907f6eb63ef665 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 5 Nov 2025 21:11:42 +0800 Subject: [PATCH 1/2] fix offline bug --- .../weight_loader_filter.py | 4 +- src/parallax/utils/selective_download.py | 80 ++++++++++++------- 2 files changed, 52 insertions(+), 32 deletions(-) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 5cc529f..48e1823 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import List -from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range +from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range_for_load logger = logging.getLogger(__name__) @@ -34,7 +34,7 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: is_first_shard = pp_start_layer == 0 is_last_shard = pp_end_layer >= num_hidden_layers - filtered_files = filter_weight_files_by_layer_range( + filtered_files = filter_weight_files_by_layer_range_for_load( model_path=model_path, weight_files=hf_weights_files, pp_start_layer=pp_start_layer, diff --git a/src/parallax/utils/selective_download.py b/src/parallax/utils/selective_download.py index 8a98617..7e8f22a 100644 --- a/src/parallax/utils/selective_download.py +++ b/src/parallax/utils/selective_download.py @@ -25,6 +25,11 @@ def download_metadata_only( cache_dir: Optional[str] = None, force_download: bool = False, ) -> Path: + # If a local path is provided, return it directly without contacting HF Hub + local_path = Path(repo_id) + if local_path.exists(): + return local_path + path = snapshot_download( repo_id=repo_id, cache_dir=cache_dir, @@ -41,14 +46,21 @@ def selective_model_download( cache_dir: Optional[str] = None, force_download: bool = False, ) -> Path: - logger.debug(f"Downloading model metadata for {repo_id}") - - model_path = download_metadata_only( - repo_id=repo_id, - cache_dir=cache_dir, - force_download=force_download, - ) - logger.debug(f"Downloaded model metadata to {model_path}") + # Handle local model directory + local_path = Path(repo_id) + if local_path.exists(): + model_path = local_path + logger.debug(f"Using local model path: {model_path}") + is_remote = False + else: + logger.debug(f"Downloading model metadata for {repo_id}") + model_path = download_metadata_only( + repo_id=repo_id, + cache_dir=cache_dir, + force_download=force_download, + ) + logger.debug(f"Downloaded model metadata to {model_path}") + is_remote = True if start_layer is not None and end_layer is not None: logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})") @@ -59,34 +71,42 @@ def selective_model_download( end_layer=end_layer, ) - if not needed_weight_files: - logger.debug("Could not determine specific weight files, downloading all") + if is_remote: + if not needed_weight_files: + logger.debug("Could not determine specific weight files, downloading all") + snapshot_download( + repo_id=repo_id, + cache_dir=cache_dir, + force_download=force_download, + ) + else: + # Step 3: Download only the needed weight files + logger.info(f"Downloading {len(needed_weight_files)} weight files") + + for weight_file in needed_weight_files: + logger.debug(f"Downloading {weight_file}") + hf_hub_download( + repo_id=repo_id, + filename=weight_file, + cache_dir=cache_dir, + force_download=force_download, + ) + + logger.debug(f"Downloaded weight files for layers [{start_layer}, {end_layer})") + else: + # Local path: skip any downloads + logger.debug("Local model path detected; skipping remote weight downloads") + else: + # No layer range specified + if is_remote: + logger.debug("No layer range specified, downloading all model files") snapshot_download( repo_id=repo_id, cache_dir=cache_dir, force_download=force_download, ) else: - # Step 3: Download only the needed weight files - logger.info(f"Downloading {len(needed_weight_files)} weight files") - - for weight_file in needed_weight_files: - logger.debug(f"Downloading {weight_file}") - hf_hub_download( - repo_id=repo_id, - filename=weight_file, - cache_dir=cache_dir, - force_download=force_download, - ) - - logger.debug(f"Downloaded weight files for layers [{start_layer}, {end_layer})") - else: - logger.debug("No layer range specified, downloading all model files") - snapshot_download( - repo_id=repo_id, - cache_dir=cache_dir, - force_download=force_download, - ) + logger.debug("No layer range specified and using local path; nothing to download") return model_path From a7492b61eb3db51a5b4dc8a97f5aef6361ee1dd9 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Wed, 5 Nov 2025 21:13:16 +0800 Subject: [PATCH 2/2] pre-commit --- .../sglang/monkey_patch_utils/weight_loader_filter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 48e1823..029db28 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -2,7 +2,9 @@ from pathlib import Path from typing import List -from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range_for_load +from parallax.utils.weight_filter_utils import ( + filter_weight_files_by_layer_range_for_load, +) logger = logging.getLogger(__name__)