Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from pathlib import Path
from typing import List

from parallax.utils.weight_filter_utils import filter_weight_files_by_layer_range
from parallax.utils.weight_filter_utils import (
filter_weight_files_by_layer_range_for_load,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -34,7 +36,7 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]:
is_first_shard = pp_start_layer == 0
is_last_shard = pp_end_layer >= num_hidden_layers

filtered_files = filter_weight_files_by_layer_range(
filtered_files = filter_weight_files_by_layer_range_for_load(
model_path=model_path,
weight_files=hf_weights_files,
pp_start_layer=pp_start_layer,
Expand Down
80 changes: 50 additions & 30 deletions src/parallax/utils/selective_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def download_metadata_only(
cache_dir: Optional[str] = None,
force_download: bool = False,
) -> Path:
# If a local path is provided, return it directly without contacting HF Hub
local_path = Path(repo_id)
if local_path.exists():
return local_path

path = snapshot_download(
repo_id=repo_id,
cache_dir=cache_dir,
Expand All @@ -41,14 +46,21 @@ def selective_model_download(
cache_dir: Optional[str] = None,
force_download: bool = False,
) -> Path:
logger.debug(f"Downloading model metadata for {repo_id}")

model_path = download_metadata_only(
repo_id=repo_id,
cache_dir=cache_dir,
force_download=force_download,
)
logger.debug(f"Downloaded model metadata to {model_path}")
# Handle local model directory
local_path = Path(repo_id)
if local_path.exists():
model_path = local_path
logger.debug(f"Using local model path: {model_path}")
is_remote = False
else:
logger.debug(f"Downloading model metadata for {repo_id}")
model_path = download_metadata_only(
repo_id=repo_id,
cache_dir=cache_dir,
force_download=force_download,
)
logger.debug(f"Downloaded model metadata to {model_path}")
is_remote = True

if start_layer is not None and end_layer is not None:
logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})")
Expand All @@ -59,34 +71,42 @@ def selective_model_download(
end_layer=end_layer,
)

if not needed_weight_files:
logger.debug("Could not determine specific weight files, downloading all")
if is_remote:
if not needed_weight_files:
logger.debug("Could not determine specific weight files, downloading all")
snapshot_download(
repo_id=repo_id,
cache_dir=cache_dir,
force_download=force_download,
)
else:
# Step 3: Download only the needed weight files
logger.info(f"Downloading {len(needed_weight_files)} weight files")

for weight_file in needed_weight_files:
logger.debug(f"Downloading {weight_file}")
hf_hub_download(
repo_id=repo_id,
filename=weight_file,
cache_dir=cache_dir,
force_download=force_download,
)

logger.debug(f"Downloaded weight files for layers [{start_layer}, {end_layer})")
else:
# Local path: skip any downloads
logger.debug("Local model path detected; skipping remote weight downloads")
else:
# No layer range specified
if is_remote:
logger.debug("No layer range specified, downloading all model files")
snapshot_download(
repo_id=repo_id,
cache_dir=cache_dir,
force_download=force_download,
)
else:
# Step 3: Download only the needed weight files
logger.info(f"Downloading {len(needed_weight_files)} weight files")

for weight_file in needed_weight_files:
logger.debug(f"Downloading {weight_file}")
hf_hub_download(
repo_id=repo_id,
filename=weight_file,
cache_dir=cache_dir,
force_download=force_download,
)

logger.debug(f"Downloaded weight files for layers [{start_layer}, {end_layer})")
else:
logger.debug("No layer range specified, downloading all model files")
snapshot_download(
repo_id=repo_id,
cache_dir=cache_dir,
force_download=force_download,
)
logger.debug("No layer range specified and using local path; nothing to download")

return model_path

Expand Down