Skip to content

Conversation

@MrGeva
Copy link
Collaborator

@MrGeva MrGeva commented Jan 15, 2026

When max_batch_size=1, the example sequence used during export had batch_size=1. This caused torch.export to specialize the batch dimension to a static value of 1 instead of keeping it dynamic, resulting in:

  ValueError: Found the following conflicts between user-specified ranges
  and inferred ranges from model tracing:
  - Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but export 0/1 specialized due to hint of 1 for dimension inputs['input_ids'].shape[0].
  - Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but export 0/1 specialized due to hint of 1 for dimension inputs['position_ids'].shape[0].

The fix ensures the example batch size is always >= 2 during export, even when max_batch_size=1, to prevent torch.export from specializing the batch dimension.

Added tests to verify the fix.

Summary by CodeRabbit

  • Improvements

    • Optimized default batch size generation during model export to ensure minimum batch size of 2, improving compatibility with minimal batch configurations.
  • Tests

    • Added test coverage validating batch size behavior during model export operations across different maximum batch size configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

…batch_size=1

When max_batch_size=1, the example sequence used during export had
batch_size=1. This caused torch.export to specialize the batch dimension
to a static value of 1 instead of keeping it dynamic, resulting in:

  ValueError: Found the following conflicts between user-specified ranges
  and inferred ranges from model tracing:
  - Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but
    export 0/1 specialized due to hint of 1 for dimension
    inputs['input_ids'].shape[0].
  - Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but
    export 0/1 specialized due to hint of 1 for dimension
    inputs['position_ids'].shape[0].

The fix ensures the example batch size is always >= 2 during export,
even when max_batch_size=1, to prevent torch.export from specializing
the batch dimension.

Added tests to verify the fix.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
@MrGeva MrGeva requested a review from a team as a code owner January 15, 2026 08:32
@MrGeva MrGeva requested a review from suyoggupta January 15, 2026 08:33
@MrGeva MrGeva changed the title [][fix] AutoDeploy prevent torch.export from specializing batch dimension when max_batch_size=1 [#10696][fix] AutoDeploy prevent torch.export from specializing batch dimension when max_batch_size=1 Jan 15, 2026
@MrGeva
Copy link
Collaborator Author

MrGeva commented Jan 15, 2026

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

📝 Walkthrough

Walkthrough

The change modifies the default batch size calculation in set_example_sequence to ensure a minimum batch size of 2, even when max_batch_size is 1. Corresponding test cases validate this new behavior.

Changes

Cohort / File(s) Summary
Batch size calculation fix
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Modified batch size computation from min(2, max_batch_size) to max(2, min(2, max_batch_size)), guaranteeing a minimum of 2 for default input generation during export
Test suite for batch size validation
tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py
Added TestSequenceInfoExampleBatchSize with two test cases: one validating batch size ≥ 2 when max_batch_size=1, and another validating batch size = 2 for max_batch_size=32

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately describes the main change: preventing torch.export from specializing the batch dimension when max_batch_size=1 in AutoDeploy.
Description check ✅ Passed The description explains the issue, solution, and test coverage. However, the PR description template sections (Description and Test Coverage) are not explicitly filled out as separate sections in the provided description.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

🧹 Recent nitpick comments
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)

728-730: Simplify the batch size expression.

The expression max(2, min(2, self.max_batch_size)) always evaluates to 2 for any positive max_batch_size:

  • When max_batch_size=1: min(2,1)=1, then max(2,1)=2
  • When max_batch_size>=2: min(2,x)=2, then max(2,2)=2

The fix correctly prevents torch.export from specializing the batch dimension. Consider simplifying for clarity:

♻️ Proposed simplification
         if input_ids is None:
             # Use batch_size >= 2 for export to prevent torch.export from specializing
             # the batch dimension when max_batch_size=1 (dimension value 1 triggers static optimization)
-            bs, seq_len = max(2, min(2, self.max_batch_size)), min(4, self.max_seq_len)
+            bs, seq_len = 2, min(4, self.max_seq_len)
             input_ids = torch.ones(bs, seq_len, dtype=torch.int).tolist()

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cd55fb4 and 83a7198.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g., some_file.py)
Python classes should use PascalCase (e.g., class SomeClass)
Python functions and methods should use snake_case (e.g., def my_awesome_function():)
Python local variables should use snake_case, with prefix k for variable names that start with a number (e.g., k_99th_percentile)
Python global variables should use upper snake_case with prefix G (e.g., G_MY_GLOBAL)
Python constants should use upper snake_case (e.g., MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format """<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification

Files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:577-579
Timestamp: 2025-08-20T06:56:02.889Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, maxSequenceLength is now enforced as a non-optional argument in the BlockManager constructor, so concerns about std::nullopt defaulting to 0 are not applicable. When windowSize > maxSequenceLength, a warning should be added instead of handling optional parameter cases.
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (1)
tests/unittest/llmapi/apps/_test_openai_misc.py (1)
  • max_batch_size (30-31)
tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py (2)
  • SequenceInfo (322-1049)
  • set_example_sequence (719-751)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py (1)

235-278: LGTM!

The tests are well-structured and properly validate the fix:

  • Test 1 correctly verifies that example batch size is at least 2 when max_batch_size=1, which is the critical edge case that was causing torch.export to specialize the batch dimension.
  • Test 2 ensures the behavior is consistent with larger max_batch_size values.

The docstrings clearly explain the rationale for each test case.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32103 [ run ] triggered by Bot. Commit: 83a7198

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32103 [ run ] completed with state SUCCESS. Commit: 83a7198
/LLM/main/L0_MergeRequest_PR pipeline #24883 completed with status: 'SUCCESS'

@MrGeva MrGeva requested a review from lucaslie January 15, 2026 15:30
@MrGeva MrGeva merged commit a11f0db into NVIDIA:main Jan 18, 2026
8 of 11 checks passed
greg-kwasniewski1 pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Jan 18, 2026
… batch dimension when max_batch_size=1 (NVIDIA#10697)

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

3 participants