Skip to content

Add LLaVA audio (sound) model support#4402

Open
cuichenx wants to merge 24 commits intoNVIDIA:mainfrom
cuichenx:llava-model-audio
Open

Add LLaVA audio (sound) model support#4402
cuichenx wants to merge 24 commits intoNVIDIA:mainfrom
cuichenx:llava-model-audio

Conversation

@cuichenx
Copy link
Copy Markdown
Contributor

@cuichenx cuichenx commented Apr 21, 2026

Adds end-to-end support for an audio/sound modality in LLaVAModel, plus the helpers needed to use it at scale.

Summary of changes

  • megatron/core/models/huggingface/fastconformer_model.py (new): ParakeetHuggingFaceModel wrapper that loads either a NeMo Parakeet checkpoint (nemo://...) or an upstream Hugging Face FastConformer model (hf://...) via transformers.AutoModel / AutoFeatureExtractor, with optional gradient checkpointing. Returns (hidden_states, lengths) to match LLaVAModel's call site.
  • megatron/core/models/huggingface/module.py: route nemo://...parakeet... and hf://...parakeet... paths to ParakeetHuggingFaceModel from get_hf_model_type / build_hf_model. Non-parakeet nemo:// schemes raise a clear NotImplementedError.
  • megatron/core/models/multimodal/llava_model.py:
    • New sound_model / sound_projection / sound_token_index plumbing, DEFAULT_SOUND_TOKEN_INDEX, SOUND_TOKEN, freeze_sound_model / freeze_sound_projection, and corresponding wiring through forward / token combination in _preprocess_data.
    • dynamic_resolution plus RADIO knobs (force_eval_mode, force_cpe_eval_mode, interpolate_only_cpe, cpe_aspect_ratio_select, disable_cpe, temporal_patch_dim, separate_video_embedder, temporal_ckpt_compat).
    • Extend the SP/CP gating and language-model dispatch to also recognize nemotron6-moe; keep using the upstream HybridModel / hybrid_stack_spec.
    • Extend pixel_shuffle with optional h, w kwargs for non-square dynamic-resolution patch grids.
  • megatron/core/models/multimodal/context_parallel.py: add CP utilities for dynamic-resolution vision inputs (split / gather across CP ranks, tubelet-aware split points, padding for unequal media counts, FP8 recipe-aware tail padding).
  • megatron/core/models/vision/radio.py: dynamic-resolution support; temporal patch dim with optional separate video embedder; CPE controls (force-eval, interpolate-only, aspect-ratio select, disable); parameterized interpolate_align_corners / grid_sample_align_corners; state-dict pre-hooks to upgrade 2D embedder weights to 3D for temporal compression / from a non-temporal checkpoint.

Tests

Four new unit-test files (~1.7 k LOC), validated 84 / 84 passing on a 2-rank cw-dfw run:

  • tests/unit_tests/models/test_fastconformer_model.py: HF backend dtype propagation, sampling-rate plumbing, gradient_checkpointing, error paths; NeMo singleton cache keyed by model id; parakeet dispatch in module.py (positive + negative cases).
  • tests/unit_tests/models/test_llava_sound.py: SOUND_TOKEN / DEFAULT_SOUND_TOKEN_INDEX constants; freeze honouring the new flags; has_sounds sentinel decision logic; sound replacement in _preprocess_data for both sound_pad_to_clip_duration branches.
  • tests/unit_tests/models/test_multimodal_context_parallel.py: _compute_tubelet_aware_split_points invariants; _split_num_frames clipping; split_to_context_parallel_ranks_dynamic_res (non-temporal, tubelet-aware, padded, hidden-dim assertion); gather_from_context_parallel_ranks with global_pad; GatherFromContextParallelRanks autograd backward (CP gradient correctness regression guard).
  • tests/unit_tests/models/test_radio_model.py: constructor / forward / state-dict round-trip; _state_dict_pre_hook_init_embedder + _state_dict_pre_hook_init_video_embedder; train() override with force_eval_mode; pixel_shuffle square + non-square; constructor coverage for the new RADIO knobs; _apply_temporal_grouping direct unit tests.

What does this PR do ?

End-to-end audio modality for LLaVAModel (Parakeet/FastConformer encoder + projection + sound-token replacement) plus dynamic-resolution / temporal-patch support for RADIO and the CP utilities needed for both.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

cuichenx and others added 2 commits April 20, 2026 17:34
…aware vision plumbing for RADIO

Adds end-to-end support for an audio/sound modality in LLaVAModel, plus the
helpers needed to use it at scale:

- megatron/core/models/huggingface/fastconformer: HuggingFace-style
  FastConformer config, modeling, feature extractor, and a NeMo->HF
  conversion script (covers Parakeet/Canary FastConformer variants).
- megatron/core/models/huggingface/fastconformer_model.py: ParakeetHuggingFaceModel
  wrapper that loads either a NeMo Parakeet checkpoint (nemo://...) or an HF
  FastConformer model (hf://...), with optional gradient checkpointing.
- megatron/core/models/multimodal/llava_model.py:
  - New sound_model / sound_projection / sound_token_index plumbing,
    DEFAULT_SOUND_TOKEN_INDEX, SOUND_TOKEN, freeze_sound_model/projection,
    and corresponding wiring through forward / token combination.
  - dynamic_resolution + RADIO knobs (force_eval_mode, force_cpe_eval_mode,
    interpolate_only_cpe, cpe_aspect_ratio_select, disable_cpe,
    temporal_patch_dim, separate_video_embedder, temporal_ckpt_compat).
  - Extend SP/CP gating and language-model dispatch to also recognize
    nemotron6-moe; keep using the upstream HybridModel/hybrid_stack_spec.
- megatron/core/models/multimodal/context_parallel.py: add CP utilities
  for dynamic-resolution vision inputs (split/gather across CP ranks,
  tubelet-aware split points, padding for unequal media counts).
- megatron/core/models/vision/radio.py: dynamic-resolution support,
  temporal patch dim with optional separate video embedder, CPE controls
  (force-eval, interpolate-only, aspect-ratio select, disable), and
  state-dict pre-hooks to upgrade 2D embedder weights to 3D for temporal
  compression / from a non-temporal checkpoint.

Co-authored-by: Tuomas Rintamaki <trintamaki@nvidia.com>
Co-authored-by: Tyler Poon <tylerpoon@gmail.com>
Co-authored-by: Collin McCarthy <cmccarthy@nvidia.com>
Co-authored-by: Matthieu Le <matthieul@nvidia.com>
Co-authored-by: Piotr Zelasko <pzelasko@nvidia.com>
Co-authored-by: Ehsan Hosseini Asl <ehosseiniasl@nvidia.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
…ecipe

- megatron/core/models/vision/radio.py: add interpolate_align_corners
  (default False) and grid_sample_align_corners (default True) on
  RADIOViTModel and use them in place of the hardcoded align_corners
  values in _get_pos_embeddings (4x F.interpolate, 1x F.grid_sample).
  Defaults preserve current behavior.
- megatron/core/models/multimodal/context_parallel.py: drop the
  redundant `patch_dim = 16` reassignment inside
  split_to_context_parallel_ranks_dynamic_res, and forward a new
  fp8_recipe arg through to get_padding so the FP8 padding multiple
  matches the active recipe (32 for "mxfp8", 16 otherwise). Also add
  short docstrings on the GatherFromContextParallelRanks autograd
  helpers and gather_from_context_parallel_ranks (lint cleanup).
- megatron/core/models/multimodal/llava_model.py: pass fp8_recipe
  (read from the language transformer config) at the
  split_to_context_parallel_ranks_dynamic_res call site.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 21, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 22, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx cuichenx force-pushed the llava-model-audio branch from b5d35ae to eb8b092 Compare April 22, 2026 18:42
Comment thread megatron/core/models/vision/radio.py Outdated
@cuichenx
Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread megatron/core/models/vision/radio.py Outdated
Comment thread megatron/core/models/vision/radio.py Outdated
Comment thread megatron/core/models/multimodal/llava_model.py Outdated
Comment thread megatron/core/models/multimodal/llava_model.py Outdated
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

This is a large PR (+1081/-116 lines) adding sound/audio modality support, dynamic resolution, temporal compression, and CP utilities. A few issues:

Bug: fp8_pad_hook regression (radio.py)

The padding multiple is now hardcoded to 16, but the model constructor sets class_token_len = 32 when fp8_recipe == "mxfp8". This will cause a shape mismatch when loading checkpoints under mxfp8. See inline comment for fix.

Missing test coverage

The PR adds ~1100 lines of new functionality across three files with no accompanying tests. The PR checklist also has unchecked boxes for unit and functional tests. Key areas that would benefit from test coverage:

  • context_parallel.py: The new CP split/gather utilities (split_to_context_parallel_ranks, split_to_context_parallel_ranks_dynamic_res, _compute_tubelet_aware_split_points) have non-trivial logic around padding, tubelet-aware splitting, and edge cases (fewer images than CP ranks, empty splits). These are good candidates for unit tests that can run without GPU.
  • radio.py: The _apply_temporal_grouping method and the state-dict pre-hooks (_state_dict_pre_hook_init_embedder, _state_dict_pre_hook_init_video_embedder) have branching logic for 2D→3D weight conversion that should be tested.
  • llava_model.py: The sound embedding integration in _preprocess_data and the new SP/CP padding logic in _process_embedding_token_parallel.

Existing test files (tests/unit_tests/models/test_llava_model.py, tests/unit_tests/models/test_radio_model.py) could be extended.

Signed-off-by: Chen Cui <chcui@nvidia.com>
The PR changed the TE import block from the main-branch pattern
(import HAVE_TE from extensions.transformer_engine + if/else) to a
local try/except that reimplements the same detection. The change is
unrelated to LLaVA audio support and a regression:

- duplicates HAVE_TE that extensions.transformer_engine already exports
- drops the `else: TEDotProductAttention = None; tex = None` fallbacks,
  leaving those names undefined on the TE-absent path (latent NameError
  at the `== TEDotProductAttention` check)
- uses bare `except:` which swallows KeyboardInterrupt/SystemExit
- needlessly gates `is_te_min_version` (which lives in core.utils and
  already handles missing TE internally) behind the TE import

Signed-off-by: Chen Cui <chcui@nvidia.com>
New tests covering logic added in this PR:

- tests/unit_tests/models/test_multimodal_context_parallel.py
  - TestComputeTubeletAwareSplitPoints: invariants (length, endpoints,
    monotonicity, tubelet-aligned boundaries) + concrete expected
    values for even split, T=1 degenerate case, num_tubelets<=1
    snapping, and pathological monotonicity clamping.
  - TestSplitNumFrames: full / partial / boundary-spanning / empty /
    out-of-range / empty-input cases for the [lb, ub) clipper.
  - TestDynamicResCPDistributed (skipif WORLD_SIZE<2): variable-size
    gather round-trip, gather with num_padded_imgs dropping trailing
    ranks, and split with num_imgs==cp_size==2 asserting each rank
    owns exactly one image and no padding path is triggered.

- tests/unit_tests/models/test_radio_model.py
  - TestRADIOStateDictPreHooks: exercises
    _state_dict_pre_hook_init_embedder and
    _state_dict_pre_hook_init_video_embedder via a SimpleNamespace
    stub (only patch_dim / temporal_patch_dim are read). Covers 2D->3D
    expansion with /T rescale, no-op on already-3D or missing key,
    prefix handling, creating video_embedder from 2D image weights,
    splitting existing combined 3D into 2D-averaged image + 3D-copied
    video, bias cloning (independence verified), preserving an
    already-present video_embedder, and no-op when the image key is
    absent.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx cuichenx marked this pull request as ready for review April 22, 2026 22:08
@cuichenx cuichenx requested review from a team as code owners April 22, 2026 22:08
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team April 22, 2026 22:08
Comment thread megatron/core/models/vision/radio.py
Comment thread megatron/core/models/vision/radio.py
@cuichenx cuichenx marked this pull request as draft April 28, 2026 22:06
cuichenx added 2 commits May 1, 2026 12:20
- Restore self.pg_collection / self.vp_stage assignments and pass them
  through to TransformerBlock so the inner block keeps its VP/PP awareness.
- Restore the HAVE_EINOPS check at the top of forward() so a missing einops
  install raises a clear ImportError instead of NameError on the
  rearrange call.

Signed-off-by: Chen Cui <chcui@nvidia.com>
- Re-add fastconformer_model.py with the vlm2-style wrapper that returns
  (hidden_states, lengths) so it matches LLaVAModel's call site
  `sound_embeddings, sound_embeddings_len = self.sound_model(...)`.
- For the hf:// backend, load via transformers.AutoModel /
  AutoFeatureExtractor since the FastConformer model is now upstreamed
  to HF (no longer need the local megatron.core.models.huggingface.fastconformer
  package).
- Wire 'parakeet' into huggingface/module.py: special-case it in
  get_hf_model_type() before AutoConfig (nemo:// paths can't be
  resolved by AutoConfig) and dispatch to ParakeetHuggingFaceModel
  in build_hf_model().

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx
Copy link
Copy Markdown
Contributor Author

cuichenx commented May 1, 2026

/ok to test b373b26

CI runs isort 5.13.2 which complains about the double blank line
between the last import and the first comment block in two of the new
test files. Local pre-commit's isort version is more permissive and
let it through. Matches the formatter that NVIDIA/Megatron-LM CI
enforces via tools/autoformat.sh.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx
Copy link
Copy Markdown
Contributor Author

cuichenx commented May 1, 2026

/ok to test b6b1d03

Two regressions introduced when PR NVIDIA#4402 simplified the else-branch of
the non-temporal image encoding path:

1. Packed dynamic-resolution path was dropped.  When imgs_sizes and
   vision_packed_seq_params are both present, RADIO returns a packed
   [1, sum(patches_i+ct_len), h_vision] tensor.  The previous code
   split it per image, stripped class tokens, and called pixel_shuffle
   with the per-image (ps_h, ps_w).  The simplified code called
   pixel_shuffle on the full tensor without h/w, producing a wrong
   reshape for non-square patch grids.  Restore the full per-image
   split + per-image pixel-shuffle, and set num_image_tiles from the
   actual patch counts.  Also add is_packed_dynamic_res kwarg to
   _preprocess_data so img_seq_len is set to 1 (not self.img_seq_len)
   on this path, matching the token count in the image embeddings.

2. is_packed_dynamic_res was not initialised before the
   use_inference_kv_cache / has_images guard, causing an UnboundLocalError
   on audio-only and cached-inference forward passes.  Initialise to
   False at the top of the forward body so all code paths are covered.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx
Copy link
Copy Markdown
Contributor Author

cuichenx commented May 4, 2026

/claude review

@cuichenx
Copy link
Copy Markdown
Contributor Author

cuichenx commented May 4, 2026

/ok to test 6091e22

Comment thread megatron/core/models/multimodal/context_parallel.py Outdated
Comment thread tests/unit_tests/models/test_multimodal_context_parallel.py
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review

Solid PR — comprehensive tests (~1.7k LOC across 4 new test files), clean separation of the sound/audio modality from existing vision plumbing, and the new RADIO dynamic-resolution + temporal-patch support is well structured.

One bug found: coordinate-system mismatch in _split_num_frames when called from split_to_context_parallel_ranks_dynamic_res with temporal_patch_size > 1. The function operates in tubelet-index space internally, but lb/ub from _compute_tubelet_aware_split_points are in frame-index space (since total_frames = len(global_imgs_sizes) = sum(num_frames)). This causes rank 0 to claim all frames and rank 1 to receive none. See the inline comment on context_parallel.py:278-300 for a worked example and suggested fix. The corresponding unit test (test_split_dynamic_res_temporal_aware_tubelets) masks this because it sets up global_imgs_sizes with per-tubelet entries rather than per-frame entries (which is what the real forward path provides).

… splits

_split_num_frames had been rewritten to operate on tubelet indices, but
_compute_tubelet_aware_split_points returns frame-space split points and the
surrounding code slices per-frame seqlens/imgs_sizes with the same lb/ub.
With temporal_patch_size>1, rank 0 ended up reporting all frames in
local_num_frames while rank 1 got an empty list, and the downstream
_apply_temporal_grouping assert (sum(num_frames)==imgs_sizes.shape[0])
would fail. Revert _split_num_frames to frame space so all consumers
agree on coordinates.

Also let _compute_tubelet_aware_split_points short-circuit when the
target_split lands exactly on a media boundary, instead of forcing
tubelet_idx>=1 into the next media. Both old and new outputs are
tubelet-aligned, but the new one keeps the per-rank load balanced for
e.g. num_frames=[4,4], T=2, cp=2 (now [0,4,8] instead of [0,6,8]).

Update test_split_dynamic_res_temporal_aware_tubelets to feed per-frame
imgs_sizes / cu_seqlens to match the real LLaVAModel.forward call site
(_apply_temporal_grouping enforces sum(num_frames)==imgs_sizes.shape[0]).

Re-flow llava_model.py to satisfy black after the change.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx
Copy link
Copy Markdown
Contributor Author

cuichenx commented May 4, 2026

/ok to test 03d8a57

@cuichenx
Copy link
Copy Markdown
Contributor Author

cuichenx commented May 4, 2026

/claude review

Comment thread megatron/core/models/multimodal/llava_model.py
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thorough test coverage (~1.7k LOC across 4 test files) and clean implementation. Left one minor comment on the pixel_shuffle non-square path — it does a 1D sequential grouping rather than a 2D spatial shuffle, which differs from the square path semantics. Not a bug if intentional, but worth a clarifying comment.

…uping

Add TestApplyTemporalGrouping.test_mixed_video_lengths_first_clean_second_with_last_frame_dup:
two videos (4 frames + 3 frames, T=2) verify that the first video compresses
cleanly while the second video's final tubelet contains the last frame repeated.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx
Copy link
Copy Markdown
Contributor Author

cuichenx commented May 5, 2026

/ok to test f75d38c

@cuichenx cuichenx linked an issue May 5, 2026 that may be closed by this pull request
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Final Review PR is in the "final review" stage label May 7, 2026
@ericharper ericharper enabled auto-merge May 7, 2026 18:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: high

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Nemotron 3 Nano Omni Training support

5 participants