[TRTLLM-12527][feat] Parallelize multi-shard visual-gen checkpoint loading#14021
[TRTLLM-12527][feat] Parallelize multi-shard visual-gen checkpoint loading#14021yibinl-nvidia wants to merge 1 commit into
Conversation
📝 WalkthroughWalkthroughWeightLoader's ChangesConcurrent Weight File Loading
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py (2)
125-128: 💤 Low valueConsider catching more specific exceptions.
Per coding guidelines, avoid broad exception handling. While catching
Exceptionhere is a common pattern forfuture.result()(to add context viafrom exc), you could narrow the scope to the likely exceptions from file I/O and tensor loading.♻️ Suggested narrower exception handling
try: loaded = future.result() - except Exception as exc: + except (OSError, RuntimeError, ValueError) as exc: raise RuntimeError(f"Failed to load weight file {wf}") from excNote: If there are other exception types from
safetensors.load_fileortorch.loadthat should be caught, add them to the tuple.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py` around lines 125 - 128, The try/except around future.result() in weight_loader.py is too broad; replace the blanket "except Exception" with a narrowed tuple of likely errors (e.g., OSError, IOError, RuntimeError, ValueError and any specific exceptions thrown by safetensors.load_file or torch.load) so only expected file/I/O or tensor-loading failures are caught, still re-raising with RuntimeError(f"Failed to load weight file {wf}") from exc to preserve context; add any additional safetensors/torch-specific exception types to the tuple if discovered.
21-22: ⚡ Quick winOutdated docstring contradicts new functionality.
The docstring says "no parallel loading optimization for now" but this PR adds parallel loading via
ThreadPoolExecutor. Update the description to reflect the new concurrent loading capability.📝 Suggested docstring update
""" Weight loader for diffusion models. Loads weights from safetensors/bin files, similar to HfWeightLoader - but simpler (no parallel loading optimization for now). + with parallel loading for multi-shard checkpoints. Supports loading multiple components (e.g., transformer and transformer_2):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py` around lines 21 - 22, The module docstring in tensorrt_llm._torch.visual_gen.checkpoints.weight_loader still claims "no parallel loading optimization for now" but the code now uses ThreadPoolExecutor for concurrent loads; update the top-level docstring (and any class docstring for HfWeightLoader or WeightLoader) to reflect that the loader supports parallel/concurrent loading using ThreadPoolExecutor (mention ability to configure worker count if applicable) and remove the outdated "no parallel loading" statement so the description matches the implemented behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py`:
- Around line 91-93: Reformat the long lines flagged by ruff by wrapping
arguments and breaking the calls into multiple indented lines: split the call to
self._load_weight_files(weight_files, component, is_pipeline) across lines, wrap
the logger.info(...) parameters onto separate lines, and reflow the
tqdm.tqdm(as_completed(...)) invocation so its inner as_completed(...) call and
outer tqdm parameters are on separate lines; run ruff format on
tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py to apply consistent
formatting for the functions/methods _load_weight_files, logger.info, and
tqdm.tqdm(as_completed(...)).
- Line 1: Add the required NVIDIA copyright header at the top of the module
before the module-level docstring in the weight_loader.py file (i.e., the
module-level docstring in
tensorrt_llm._torch.visual_gen.checkpoints.weight_loader); insert the
current-year NVIDIA boilerplate header exactly as used in other project files,
then keep the existing docstring and code unchanged below it.
---
Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py`:
- Around line 125-128: The try/except around future.result() in weight_loader.py
is too broad; replace the blanket "except Exception" with a narrowed tuple of
likely errors (e.g., OSError, IOError, RuntimeError, ValueError and any specific
exceptions thrown by safetensors.load_file or torch.load) so only expected
file/I/O or tensor-loading failures are caught, still re-raising with
RuntimeError(f"Failed to load weight file {wf}") from exc to preserve context;
add any additional safetensors/torch-specific exception types to the tuple if
discovered.
- Around line 21-22: The module docstring in
tensorrt_llm._torch.visual_gen.checkpoints.weight_loader still claims "no
parallel loading optimization for now" but the code now uses ThreadPoolExecutor
for concurrent loads; update the top-level docstring (and any class docstring
for HfWeightLoader or WeightLoader) to reflect that the loader supports
parallel/concurrent loading using ThreadPoolExecutor (mention ability to
configure worker count if applicable) and remove the outdated "no parallel
loading" statement so the description matches the implemented behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: d1f9a9ae-cf2c-4d5d-a448-d13f1747813f
📒 Files selected for processing (1)
tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
a050e7d to
efb290f
Compare
|
/bot run |
|
PR_Github #47891 [ run ] triggered by Bot. Commit: |
|
PR_Github #47891 [ run ] completed with state
|
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.