[TRTLLM-7081][refactor] Add MultimodalModelMixin#13866
Conversation
moraxu
left a comment
There was a problem hiding this comment.
I'd add two clarifications to Flowchart for new mixin:
Encoder work needed-->No- when is that the case?Any active MM rows?-->No- when is that the case?
@moraxu Done (updated PR description). Also color coded the flowchart to better illustrate what the mixin provides, what the model has to implement, and what it can optionally implement. |
📝 WalkthroughWalkthroughThis PR introduces a standardized multimodal input preparation pipeline via ChangesMultimodal Input Pipeline Standardization
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/models/modeling_multimodal_mixin.py (1)
307-312: 💤 Low valueConsider guarding against unexpected return lengths from
fuse_input_embeds.The code assumes
fuse_input_embedsreturns either 2 or 3 elements. If the upstream function's signature changes, this could raise an unpacking error with a confusing traceback. Consider adding anelsebranch that raises a clear error.♻️ Optional: Add explicit else branch
if len(result) == 3: fused_input_ids, inputs_embeds, fused_extra_embeds = result return fused_input_ids, inputs_embeds, fused_extra_embeds or () - - fused_input_ids, inputs_embeds = result - return fused_input_ids, inputs_embeds, () + elif len(result) == 2: + fused_input_ids, inputs_embeds = result + return fused_input_ids, inputs_embeds, () + else: + raise ValueError( + f"fuse_input_embeds returned unexpected tuple length {len(result)}" + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/models/modeling_multimodal_mixin.py` around lines 307 - 312, The current code in the section handling the output of the fuse_input_embeds function assumes the return length is either 2 or 3, which may cause an unpacking error if this changes. To fix this, add an explicit else branch after the existing if and fallback conditions that raises a clear and descriptive error indicating an unexpected return length was encountered. This change should be made within the method or function where the fuse_input_embeds result is unpacked to ensure safer future-proofing and easier debugging.tests/unittest/_torch/multimodal/test_multimodal_mixin.py (1)
55-86: ⚡ Quick winTest coverage is limited to the cached-embeddings happy path.
The test validates precomputed indices forwarding but does not cover:
- The encode path (when embeddings are not cached)
- Multiple context requests / batched scenarios
- Error conditions (e.g., mismatched embedding dimensions, missing indices)
- The
after_full_multimodal_embeddingsandafter_active_multimodal_embeddingshooksSince this is a new mixin that will be adopted by multiple models, consider adding at least one test exercising the encode path and one for the
num_context_requests=0(no multimodal work) scenario.QA test list: This is a unit test under
tests/unittest/, so no QA list updates are necessary.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/multimodal/test_multimodal_mixin.py` around lines 55 - 86, The test test_prepare_multimodal_inputs_forwards_precomputed_indices only covers the cached-embeddings path; add at least two new unit tests: one that exercises the encode path by passing multimodal_params that are not cached (use a maker helper like make_multimodal_param or modify make_cached_multimodal_param to simulate uncached entries) so prepare_multimodal_inputs on DummyMultimodalModel calls the internal encoder branch and returns inputs_embeds computed from emb(input_ids) for text indices, and a second test for num_context_requests=0 that verifies no multimodal embedding work is done (inputs_embeds remains based solely on text embeddings and hooks after_full_multimodal_embeddings/after_active_multimodal_embeddings on DummyMultimodalModel are invoked or left untouched); ensure you also add cases for mismatched dimensions and missing indices to assert appropriate errors from prepare_multimodal_inputs.tensorrt_llm/_torch/models/modeling_mistral.py (1)
708-728: 💤 Low valueThe
mm_inputsparameter is unused.
get_language_model_forward_kwargsacceptsmm_inputs: PreparedLlmInputsbut does not use it in the returned dict. If this is intentional (e.g., for future hooks or subclass overrides), consider documenting it. Otherwise, remove the unused parameter.♻️ Option 1: Remove unused parameter
def get_language_model_forward_kwargs( self, *, attn_metadata: AttentionMetadata, input_ids: torch.Tensor | None, position_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None, - mm_inputs: PreparedLlmInputs, return_context_logits: bool, spec_metadata: SpecMetadata | None, resource_manager: Any | None, ) -> dict[str, Any]:♻️ Option 2: Add docstring noting future extensibility
def get_language_model_forward_kwargs( self, *, attn_metadata: AttentionMetadata, input_ids: torch.Tensor | None, position_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None, mm_inputs: PreparedLlmInputs, return_context_logits: bool, spec_metadata: SpecMetadata | None, resource_manager: Any | None, ) -> dict[str, Any]: + """Build kwargs for the language model forward. + + Args: + mm_inputs: Prepared multimodal inputs; available for subclass + overrides that need access to extra_embeds or other fields. + """🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/models/modeling_mistral.py` around lines 708 - 728, The parameter mm_inputs in get_language_model_forward_kwargs is unused; either remove it from the function signature and from all call sites (update type hints and any overrides/implementations that expect PreparedLlmInputs) or, if it was intentionally kept for extensibility, add mm_inputs to the returned dict and add a short docstring on get_language_model_forward_kwargs describing that mm_inputs is preserved for future hooks/overrides; locate the symbol get_language_model_forward_kwargs and adjust its signature/return value and all callers/overrides accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tensorrt_llm/_torch/models/modeling_mistral.py`:
- Around line 708-728: The parameter mm_inputs in
get_language_model_forward_kwargs is unused; either remove it from the function
signature and from all call sites (update type hints and any
overrides/implementations that expect PreparedLlmInputs) or, if it was
intentionally kept for extensibility, add mm_inputs to the returned dict and add
a short docstring on get_language_model_forward_kwargs describing that mm_inputs
is preserved for future hooks/overrides; locate the symbol
get_language_model_forward_kwargs and adjust its signature/return value and all
callers/overrides accordingly.
In `@tensorrt_llm/_torch/models/modeling_multimodal_mixin.py`:
- Around line 307-312: The current code in the section handling the output of
the fuse_input_embeds function assumes the return length is either 2 or 3, which
may cause an unpacking error if this changes. To fix this, add an explicit else
branch after the existing if and fallback conditions that raises a clear and
descriptive error indicating an unexpected return length was encountered. This
change should be made within the method or function where the fuse_input_embeds
result is unpacked to ensure safer future-proofing and easier debugging.
In `@tests/unittest/_torch/multimodal/test_multimodal_mixin.py`:
- Around line 55-86: The test
test_prepare_multimodal_inputs_forwards_precomputed_indices only covers the
cached-embeddings path; add at least two new unit tests: one that exercises the
encode path by passing multimodal_params that are not cached (use a maker helper
like make_multimodal_param or modify make_cached_multimodal_param to simulate
uncached entries) so prepare_multimodal_inputs on DummyMultimodalModel calls the
internal encoder branch and returns inputs_embeds computed from emb(input_ids)
for text indices, and a second test for num_context_requests=0 that verifies no
multimodal embedding work is done (inputs_embeds remains based solely on text
embeddings and hooks
after_full_multimodal_embeddings/after_active_multimodal_embeddings on
DummyMultimodalModel are invoked or left untouched); ensure you also add cases
for mismatched dimensions and missing indices to assert appropriate errors from
prepare_multimodal_inputs.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 8728bbee-40c3-438c-a3b3-7816f50371ae
📒 Files selected for processing (3)
tensorrt_llm/_torch/models/modeling_mistral.pytensorrt_llm/_torch/models/modeling_multimodal_mixin.pytests/unittest/_torch/multimodal/test_multimodal_mixin.py
* Why?
All of our existing multimodal models follow a similar pattern:
1. select context requests from the current batch
2. select those that have multimodal parameters
3. run `get_multimodal_embeddings`, which hides encoder
caching and chunked-prefill reuse behavior with an implicit
API contract.
4. run `find_input_mm_embeds`, which slices embeddings
for the current forward chunk.
5. run `fuse_input_embeds`, which places multimodal embeddings
into text embeddings.
6. Call the underlying language model's `forward` method with
the fused embeddings.
This means these features are all opt-in, and the code pattern is
duplicated, with developers needing to be aware of several implicit
contracts.
* What?
This commit adds a new `MultimodalModelMixin` that specifies
an interface that hides most of this, allowing model developers to
just implement the forward pass for multimodal embeddings.
As an illustration, it refactors the mistral small model to use this
new mixin.
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
|
/bot run |
|
PR_Github #47786 [ run ] triggered by Bot. Commit: |
|
PR_Github #47786 [ run ] completed with state
|
|
/bot run |
|
PR_Github #47995 [ run ] triggered by Bot. Commit: |
|
PR_Github #47995 [ run ] completed with state |
* Why?
All of our existing multimodal models follow a similar pattern:
1. select context requests from the current batch
2. select those that have multimodal parameters
3. run `get_multimodal_embeddings`, which hides encoder
caching and chunked-prefill reuse behavior with an implicit
API contract.
4. run `find_input_mm_embeds`, which slices embeddings
for the current forward chunk.
5. run `fuse_input_embeds`, which places multimodal embeddings
into text embeddings.
6. Call the underlying language model's `forward` method with
the fused embeddings.
This means these features are all opt-in, and the code pattern is
duplicated, with developers needing to be aware of several implicit
contracts.
* What?
This commit adds a new `MultimodalModelMixin` that specifies
an interface that hides most of this, allowing model developers to
just implement the forward pass for multimodal embeddings.
As an illustration, it refactors the mistral small model to use this
new mixin.
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Mirror the new MultimodalModelMixin (NVIDIA#13866) at the encoder layer. The two are parallel mixins: MultimodalModelMixin orchestrates the LM-side embedding flow, MultimodalEncoderMixin provides the engine-driven AttentionMetadata setup for the encoder module. Also renames _set_up_mm_encoder_attn_metadata to match. No behavior change. Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
…ltimodal Introduces PyTorchModelEngine.is_multimodal so call sites query the engine directly instead of poking at the input processor's class hierarchy. The property combines two signals: * Primary: isinstance(model, MultimodalModelMixin) — the canonical marker introduced in PR NVIDIA#13866. Multimodal LM classes declare themselves by inheriting from it. * Fallback: isinstance(input_processor, BaseMultimodalInputProcessor) — every multimodal model necessarily ships a multimodal input processor, so this catches models that haven't migrated to MultimodalModelMixin yet (only Mistral has so far). A TODO calls out that the fallback should be dropped once the MultimodalModelMixin migration of the remaining multimodal models (Qwen-VL, Nemotron Nano VL2, Gemma, Phi-4-MM, ...) lands. Call-site cleanup: _create_dummy_context_requests now reads self._model_engine.is_multimodal and no longer imports BaseMultimodalInputProcessor directly. Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Mirror the new MultimodalModelMixin (NVIDIA#13866) at the encoder layer. The two are parallel mixins: MultimodalModelMixin orchestrates the LM-side embedding flow, MultimodalEncoderMixin provides the engine-driven AttentionMetadata setup for the encoder module. Also renames _set_up_mm_encoder_attn_metadata to match. No behavior change. Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
…ltimodal Introduces PyTorchModelEngine.is_multimodal so call sites query the engine directly instead of poking at the input processor's class hierarchy. The property combines two signals: * Primary: isinstance(model, MultimodalModelMixin) — the canonical marker introduced in PR NVIDIA#13866. Multimodal LM classes declare themselves by inheriting from it. * Fallback: isinstance(input_processor, BaseMultimodalInputProcessor) — every multimodal model necessarily ships a multimodal input processor, so this catches models that haven't migrated to MultimodalModelMixin yet (only Mistral has so far). A TODO calls out that the fallback should be dropped once the MultimodalModelMixin migration of the remaining multimodal models (Qwen-VL, Nemotron Nano VL2, Gemma, Phi-4-MM, ...) lands. Call-site cleanup: _create_dummy_context_requests now reads self._model_engine.is_multimodal and no longer imports BaseMultimodalInputProcessor directly. Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Summary by CodeRabbit
New Features
Refactor
Tests
Description
All of our existing multimodal models follow a similar pattern:
get_multimodal_embeddings, which hides encodercaching and chunked-prefill reuse behavior with an implicit
API contract.
find_input_mm_embeds, which slices embeddingsfor the current forward chunk.
fuse_input_embeds, which places multimodal embeddingsinto text embeddings.
forwardmethod withthe fused embeddings.
This means these features are all opt-in, and the code pattern is
duplicated, with developers needing to be aware of several implicit
contracts.
This commit adds a new
MultimodalModelMixinthat specifiesan interface that hides most of this, allowing model developers to
just implement the forward pass for multimodal embeddings.
As an illustration, it refactors the mistral small model to use this
new mixin.
Flowchart for new mixin
flowchart TD A[Mixin forward] --> B[Count context requests] B --> C[Mixin prepare_multimodal_inputs] C --> D[Select context multimodal params] D --> E{Any params?} E -- No --> Z[Return input_ids with no inputs_embeds] E -- Yes --> F[Build encoder kwargs via model hook] F --> G[Find uncached params] G --> H{Encoder work needed?} H -- No: all selected params already have cached embeddings --> J[Gather cached embeddings] H -- Yes --> I[Call model encode_multimodal_inputs] I --> K[Validate primary tensor contract] K --> L[Cache per-request embedding slices] L --> J J --> M[after_full_multimodal_embeddings hook] M --> N[Find active chunk embeddings] N --> O{Any active MM rows?} O -- No: current chunk has no MM rows to fuse --> P[Return text-only chunk inputs] O -- Yes --> Q[after_active_multimodal_embeddings hook] Q --> R[fuse_input_embeds] R --> S[Return PreparedLlmInputs] S --> U[Build LLM forward kwargs via model hook] U --> T[Call language model forward] classDef base fill:#d8f5d0,stroke:#2f8f46,color:#143d1f; classDef model fill:#d7e9ff,stroke:#2f6fb3,color:#12365d; classDef optional fill:#fff2b8,stroke:#b38b00,color:#4d3b00; classDef decision fill:#f5f5f5,stroke:#707070,color:#202020; class A,B,C,D,G,J,K,L,N,P,R,S,T,Z base; class I model; class F,M,Q,U optional; class E,H,O decision;Legend:
Clarifications:
Encoder work needed? -> Nomeans every selected context multimodal request already has cached embeddings, so the mixin gathers cached rows without calling the encoder.Any active MM rows? -> Nomeansfind_input_mm_embedsreturned no rows for this forward chunk. This can happen when all MM rows are covered by KV-cache reuse or when the current chunk does not overlap the request's multimodal placeholder rows.Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.