Skip to content

[TRTLLM-12527][feat] Parallelize multi-shard visual-gen checkpoint loading#14021

Open
yibinl-nvidia wants to merge 1 commit into
NVIDIA:mainfrom
yibinl-nvidia:dev-yibinl-TRT-12527-part-2
Open

[TRTLLM-12527][feat] Parallelize multi-shard visual-gen checkpoint loading#14021
yibinl-nvidia wants to merge 1 commit into
NVIDIA:mainfrom
yibinl-nvidia:dev-yibinl-TRT-12527-part-2

Conversation

@yibinl-nvidia
Copy link
Copy Markdown
Collaborator

@yibinl-nvidia yibinl-nvidia commented May 12, 2026

Summary by CodeRabbit

  • Refactor
    • Optimized weight loading for visual generation models to process sharded files concurrently, improving performance when multiple weight files are present.
    • Enhanced error reporting to clearly identify which weight file shard failed during loading.

Review Change Stack

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@yibinl-nvidia yibinl-nvidia requested a review from a team as a code owner May 12, 2026 02:42
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 12, 2026

📝 Walkthrough

Walkthrough

WeightLoader's load_weights method now delegates per-component weight file loading to a new _load_weight_files helper. The helper uses ThreadPoolExecutor to load multiple weight shards concurrently when present, falling back to sequential loading for single-file checkpoints. Shard-load failures are caught and re-raised as RuntimeError with the failing file name.

Changes

Concurrent Weight File Loading

Layer / File(s) Summary
Concurrent weight file loading with thread pool
tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py
Added ThreadPoolExecutor and as_completed imports; refactored load_weights to delegate per-component loading to new _load_weight_files helper that selects between sequential loading (0–1 files) and concurrent loading via ThreadPoolExecutor (2+ files), aggregates tensor dicts, and wraps per-shard exceptions into RuntimeError with file identifiers.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description consists entirely of the template with empty/unfilled sections and no substantive content provided by the author. Fill in the Description section explaining the issue and solution, provide Test Coverage details listing relevant tests, and ensure the PR Checklist items are reviewed and checked as appropriate.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly describes the main change: parallelizing multi-shard visual-gen checkpoint loading, which matches the implementation details in the raw summary.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py (2)

125-128: 💤 Low value

Consider catching more specific exceptions.

Per coding guidelines, avoid broad exception handling. While catching Exception here is a common pattern for future.result() (to add context via from exc), you could narrow the scope to the likely exceptions from file I/O and tensor loading.

♻️ Suggested narrower exception handling
                 try:
                     loaded = future.result()
-                except Exception as exc:
+                except (OSError, RuntimeError, ValueError) as exc:
                     raise RuntimeError(f"Failed to load weight file {wf}") from exc

Note: If there are other exception types from safetensors.load_file or torch.load that should be caught, add them to the tuple.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py` around lines 125
- 128, The try/except around future.result() in weight_loader.py is too broad;
replace the blanket "except Exception" with a narrowed tuple of likely errors
(e.g., OSError, IOError, RuntimeError, ValueError and any specific exceptions
thrown by safetensors.load_file or torch.load) so only expected file/I/O or
tensor-loading failures are caught, still re-raising with RuntimeError(f"Failed
to load weight file {wf}") from exc to preserve context; add any additional
safetensors/torch-specific exception types to the tuple if discovered.

21-22: ⚡ Quick win

Outdated docstring contradicts new functionality.

The docstring says "no parallel loading optimization for now" but this PR adds parallel loading via ThreadPoolExecutor. Update the description to reflect the new concurrent loading capability.

📝 Suggested docstring update
     """
     Weight loader for diffusion models.

     Loads weights from safetensors/bin files, similar to HfWeightLoader
-    but simpler (no parallel loading optimization for now).
+    with parallel loading for multi-shard checkpoints.

     Supports loading multiple components (e.g., transformer and transformer_2):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py` around lines 21
- 22, The module docstring in
tensorrt_llm._torch.visual_gen.checkpoints.weight_loader still claims "no
parallel loading optimization for now" but the code now uses ThreadPoolExecutor
for concurrent loads; update the top-level docstring (and any class docstring
for HfWeightLoader or WeightLoader) to reflect that the loader supports
parallel/concurrent loading using ThreadPoolExecutor (mention ability to
configure worker count if applicable) and remove the outdated "no parallel
loading" statement so the description matches the implemented behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py`:
- Around line 91-93: Reformat the long lines flagged by ruff by wrapping
arguments and breaking the calls into multiple indented lines: split the call to
self._load_weight_files(weight_files, component, is_pipeline) across lines, wrap
the logger.info(...) parameters onto separate lines, and reflow the
tqdm.tqdm(as_completed(...)) invocation so its inner as_completed(...) call and
outer tqdm parameters are on separate lines; run ruff format on
tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py to apply consistent
formatting for the functions/methods _load_weight_files, logger.info, and
tqdm.tqdm(as_completed(...)).
- Line 1: Add the required NVIDIA copyright header at the top of the module
before the module-level docstring in the weight_loader.py file (i.e., the
module-level docstring in
tensorrt_llm._torch.visual_gen.checkpoints.weight_loader); insert the
current-year NVIDIA boilerplate header exactly as used in other project files,
then keep the existing docstring and code unchanged below it.

---

Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py`:
- Around line 125-128: The try/except around future.result() in weight_loader.py
is too broad; replace the blanket "except Exception" with a narrowed tuple of
likely errors (e.g., OSError, IOError, RuntimeError, ValueError and any specific
exceptions thrown by safetensors.load_file or torch.load) so only expected
file/I/O or tensor-loading failures are caught, still re-raising with
RuntimeError(f"Failed to load weight file {wf}") from exc to preserve context;
add any additional safetensors/torch-specific exception types to the tuple if
discovered.
- Around line 21-22: The module docstring in
tensorrt_llm._torch.visual_gen.checkpoints.weight_loader still claims "no
parallel loading optimization for now" but the code now uses ThreadPoolExecutor
for concurrent loads; update the top-level docstring (and any class docstring
for HfWeightLoader or WeightLoader) to reflect that the loader supports
parallel/concurrent loading using ThreadPoolExecutor (mention ability to
configure worker count if applicable) and remove the outdated "no parallel
loading" statement so the description matches the implemented behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d1f9a9ae-cf2c-4d5d-a448-d13f1747813f

📥 Commits

Reviewing files that changed from the base of the PR and between 64260ba and a050e7d.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py

Comment thread tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py
Comment thread tensorrt_llm/_torch/visual_gen/checkpoints/weight_loader.py Outdated
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
@yibinl-nvidia yibinl-nvidia force-pushed the dev-yibinl-TRT-12527-part-2 branch from a050e7d to efb290f Compare May 12, 2026 05:53
@yibinl-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47891 [ run ] triggered by Bot. Commit: efb290f Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47891 [ run ] completed with state SUCCESS. Commit: efb290f
/LLM/main/L0_MergeRequest_PR pipeline #37742 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants