Skip to content

[https://nvbugs/6015329][fix] Use model-level warmup cache key for visual gen pipelines#12516

Merged
karljang merged 2 commits intoNVIDIA:mainfrom
karljang:feat/warmup-cache-key
Mar 26, 2026
Merged

[https://nvbugs/6015329][fix] Use model-level warmup cache key for visual gen pipelines#12516
karljang merged 2 commits intoNVIDIA:mainfrom
karljang:feat/warmup-cache-key

Conversation

@karljang
Copy link
Collaborator

@karljang karljang commented Mar 24, 2026

Summary

  • Add warmup_cache_key() method to BasePipeline so each model controls which dimensions matter for warmup cache matching
  • FLUX (image-only) overrides to (height, width), ignoring num_frames
  • Video models (WAN, LTX2) use the default (height, width, num_frames)
  • Executor uses pipeline.warmup_cache_key() instead of hardcoded (h, w, num_frames) tuple

Motivation

FLUX is image-only — num_frames has no effect on its computation. But the warmup cache key included num_frames, so requests with a non-default num_frames value at the same resolution would trigger a false "not warmed up" warning and potentially unnecessary recompilation.

Test plan

  • Verify FLUX warmup with different num_frames values no longer triggers warning
  • Verify WAN/LTX2 warmup still differentiates by num_frames

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Refactor
    • Optimized the caching mechanism for visual generation models by introducing a unified cache key computation system. Different model implementations can now customize their caching strategy based on spatial resolution, improving cache efficiency across pipelines.

@karljang karljang requested a review from a team as a code owner March 24, 2026 21:14
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 24, 2026

📝 Walkthrough

Walkthrough

This change introduces a warmup_cache_key() abstraction method in the base pipeline class to compute cache keys for warmup shape tracking. Specific pipeline implementations override this method to customize which dimensions are included in the key. The executor was updated to use the derived cache key instead of direct tuples.

Changes

Cohort / File(s) Summary
Base Pipeline Infrastructure
tensorrt_llm/_torch/visual_gen/pipeline.py
Added warmup_cache_key() method to centralize cache key computation. Updated _warmed_up_shapes type annotation from Set[Tuple[int, int, int]] to Set[tuple] to support flexible key formats. Warmup tracking now stores computed cache keys instead of raw tuples.
FLUX Model Implementations
tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py, tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py
Added warmup_cache_key() method override to both FLUX pipelines, returning (height, width) only and ignoring num_frames dimension in the cache key.
Executor Integration
tensorrt_llm/_torch/visual_gen/executor.py
Updated process_request() to compute cache key via self.pipeline.warmup_cache_key() and check membership in _warmed_up_shapes. Warning log message now reports the computed cache key instead of individual request dimensions.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description clearly explains the issue, solution, and test plan, but is missing several required template sections including PR title format, Test Coverage details, and PR Checklist. Add a properly formatted PR title following [type] Summary format, expand Test Coverage section with specific test cases, and complete the PR Checklist to confirm adherence to coding guidelines and test requirements.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: introducing model-level warmup cache keys for visual generation pipelines, which directly reflects the core modification across all changed files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/visual_gen/pipeline.py (1)

443-448: Consider deduplicating warmup runs by warmup_cache_key to avoid redundant work.

When multiple warmup shapes map to the same key (e.g., FLUX with different num_frames), startup may run duplicate warmups.

♻️ Suggested refactor
-        for height, width, num_frames in shapes:
-            logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps")
-            self._run_warmup(height, width, num_frames, steps)
-            torch.cuda.synchronize()
-
-        self._warmed_up_shapes = set(self.warmup_cache_key(h, w, f) for h, w, f in shapes)
+        warmed_keys = set()
+        for height, width, num_frames in shapes:
+            cache_key = self.warmup_cache_key(height, width, num_frames)
+            if cache_key in warmed_keys:
+                continue
+            logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps")
+            self._run_warmup(height, width, num_frames, steps)
+            torch.cuda.synchronize()
+            warmed_keys.add(cache_key)
+
+        self._warmed_up_shapes = warmed_keys
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/pipeline.py` around lines 443 - 448, Multiple
warmup shapes can map to the same warmup_cache_key causing duplicate warmup
runs; deduplicate by computing warmup_cache_key(h,w,f) for each shape first,
keep one representative (e.g., first) per key, then iterate only over those
unique representatives to call self._run_warmup(height, width, num_frames,
steps) and torch.cuda.synchronize(), and finally set self._warmed_up_shapes to
the set of keys (using warmup_cache_key) that were actually run; refer to
warmup_cache_key, _run_warmup, and _warmed_up_shapes to locate where to perform
the deduplication.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/pipeline.py`:
- Around line 443-448: Multiple warmup shapes can map to the same
warmup_cache_key causing duplicate warmup runs; deduplicate by computing
warmup_cache_key(h,w,f) for each shape first, keep one representative (e.g.,
first) per key, then iterate only over those unique representatives to call
self._run_warmup(height, width, num_frames, steps) and torch.cuda.synchronize(),
and finally set self._warmed_up_shapes to the set of keys (using
warmup_cache_key) that were actually run; refer to warmup_cache_key,
_run_warmup, and _warmed_up_shapes to locate where to perform the deduplication.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f27d8d6a-4409-48c8-afc8-b6192e4c5e88

📥 Commits

Reviewing files that changed from the base of the PR and between 94175a8 and eb5c714.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/visual_gen/executor.py
  • tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py
  • tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py
  • tensorrt_llm/_torch/visual_gen/pipeline.py

@karljang karljang changed the title [None][fix] Use model-level warmup cache key for visual gen pipelines [ http://nvbugs/6015329][fix] Use model-level warmup cache key for visual gen pipelines Mar 24, 2026
@karljang karljang marked this pull request as draft March 24, 2026 22:10
Copy link
Collaborator

@chang-l chang-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@karljang karljang force-pushed the feat/warmup-cache-key branch from eb5c714 to e1e0b9e Compare March 24, 2026 22:32
@karljang
Copy link
Collaborator Author

/bot run

@karljang karljang marked this pull request as ready for review March 24, 2026 22:36
@karljang karljang changed the title [ http://nvbugs/6015329][fix] Use model-level warmup cache key for visual gen pipelines [ https://nvbugs/6015329][fix] Use model-level warmup cache key for visual gen pipelines Mar 24, 2026
@karljang karljang changed the title [ https://nvbugs/6015329][fix] Use model-level warmup cache key for visual gen pipelines [https://nvbugs/6015329][fix] Use model-level warmup cache key for visual gen pipelines Mar 24, 2026
@tensorrt-cicd
Copy link
Collaborator

PR_Github #40177 [ run ] triggered by Bot. Commit: e1e0b9e Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #40177 [ run ] completed with state SUCCESS. Commit: e1e0b9e
/LLM/main/L0_MergeRequest_PR pipeline #31320 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@karljang karljang force-pushed the feat/warmup-cache-key branch from e1e0b9e to 7f2e2ea Compare March 25, 2026 03:26
FLUX is image-only but the warmup cache checked (height, width,
num_frames), causing false cache misses when num_frames differed from
the default. Add a warmup_cache_key() method that each pipeline can
override — FLUX returns (height, width), video models keep the default
(height, width, num_frames).

Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
FLUX overrides warmup_cache_key with **kwargs to ignore num_frames.
Callers must pass num_frames as a keyword arg so **kwargs can absorb
it — positional args fail with "takes 3 positional arguments but 4
were given".

Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
@karljang karljang force-pushed the feat/warmup-cache-key branch from 7f2e2ea to 0e917ab Compare March 25, 2026 03:39
@karljang
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #40232 [ run ] triggered by Bot. Commit: 0e917ab Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #40232 [ run ] completed with state FAILURE. Commit: 0e917ab
/LLM/main/L0_MergeRequest_PR pipeline #31368 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@karljang
Copy link
Collaborator Author

/bot help

@github-actions
Copy link

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental) --high-priority]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

--high-priority (OPTIONAL) : Run the pipeline with high priority. This option is restricted to authorized users only and will route the job to a high-priority queue.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@karljang
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #40348 [ run ] triggered by Bot. Commit: 0e917ab Link to invocation

@karljang karljang enabled auto-merge (squash) March 25, 2026 23:03
@tensorrt-cicd
Copy link
Collaborator

PR_Github #40348 [ run ] completed with state FAILURE. Commit: 0e917ab
/LLM/main/L0_MergeRequest_PR pipeline #31452 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@karljang
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #40417 [ run ] triggered by Bot. Commit: 0e917ab Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #40417 [ run ] completed with state SUCCESS. Commit: 0e917ab
/LLM/main/L0_MergeRequest_PR pipeline #31510 completed with status: 'SUCCESS'

CI Report

Link to invocation

@karljang karljang merged commit 9cc7584 into NVIDIA:main Mar 26, 2026
5 checks passed
@karljang karljang deleted the feat/warmup-cache-key branch March 26, 2026 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants