Skip to content

to_hf.py + PromptFormatter: produce vLLM-ready SpeechLM checkpoints with backbone-native chat_template#15623

Merged
pzelasko merged 12 commits into
NVIDIA-NeMo:mainfrom
DongjiGao:pr/to-hf-vllm
Apr 24, 2026
Merged

to_hf.py + PromptFormatter: produce vLLM-ready SpeechLM checkpoints with backbone-native chat_template#15623
pzelasko merged 12 commits into
NVIDIA-NeMo:mainfrom
DongjiGao:pr/to-hf-vllm

Conversation

@DongjiGao
Copy link
Copy Markdown
Contributor

@DongjiGao DongjiGao commented Apr 18, 2026

Background

Reopens #15617 after its base branch (speechlm2-with-nemo-automodel-merge) was deleted; now rebased on main. Incorporates @pzelasko's review comments from the closed thread.

Summary

Produces a single checkpoint artifact that serves both HF / NeMo SALM inference and vLLM (SpeechLM plugin) without an opt-in flag, and eliminates the drift between our hand-rolled jinja and the backbone LLM's canonical chat_template.

Commit 1 - PromptFormatter uses backbone chat_template as source of truth

Replaces the hardcoded TEMPLATE dict + to_jinja() reconstruction in QwenPromptFormatter and NemotronNanoV3PromptFormatter with a shared base helper that delegates to tokenizer.apply_chat_template on the backbone's native jinja. Training and exported-ckpt inference now share one jinja string shipped by the backbone's HF repo - no more train/infer drift from reconstruction mismatches.

  • formatter.py: add AUDIO_TOKEN, _turn_to_openai_message, _is_reasoning_divergence, _encode_dialog_from_chat_template. Drop the abstract to_jinja stub.
  • qwen.py: QwenPromptFormatter.encode_dialog delegates; hardcoded TEMPLATE, INFERENCE_PREFIX and to_jinja removed. Qwen3PromptFormatter (reasoning, out of refactor scope) unchanged.
  • nemotron_nano_v3.py: encode_dialog delegates and forwards enable_thinking to the backbone jinja. Default is False; reasoning training raises a clear 'not supported' error because the prefix-based loss mask can't align with the template's fresh <think> generation prompt.

Commit 2 - to_hf.py unconditionally produces vLLM-ready checkpoints

  • Removes the opt-in vllm: bool = False flag; every export is vLLM-ready (addresses @pzelasko's review comment on to_hf.py: add --vllm flag for vLLM-ready checkpoint export #15617).
  • _detect_vllm_architecture now raises ValueError on missing architectures instead of silently defaulting (addresses the 'broad except' review comment).
  • _try_prepare_for_vllm wrapper downgrades a ValueError to a warning so legacy SALM callers loading the same dir still get a valid HF-only checkpoint.
  • Tokenizer saved from the backbone (brings its native chat_template) + augmented with <|audio|>. extra_special_tokens normalized to dict (HF loaders crash on the list form transformers writes for additional_special_tokens).
  • For reasoning backbones (nemotron-nano-v3), the exported chat_template's enable_thinking default is flipped to False so vLLM's request-time render matches training-time render; otherwise vLLM silently prepends <think> and WER regresses from 1.57 to 5.92.
  • hf_hub.py: setdefault('model_type', ...) and setdefault('architectures', [...]) in HFHubMixin.save_pretrained so NeMo-saved checkpoints carry the metadata vLLM / transformers need.

Architecture mapping

LLM backbone Architecture name
Hybrid (NemotronH) NeMoSpeechLMHybridForConditionalGeneration
Standard (Qwen3, etc.) NeMoSpeechLMForConditionalGeneration

Usage

torchrun --nproc-per-node=8 --nnodes=4 \
    examples/speechlm2/to_hf.py \
    class_path=nemo.collections.speechlm2.models.salm_automodel.SALMAutomodel \
    ckpt_path=/path/to/checkpoint.ckpt \
    ckpt_config=/path/to/exp_config.yaml \
    output_dir=/path/to/hf_output

No vllm=True flag any more - output is always vLLM-ready.

Test plan

  • Parity test against backbone jinja for Qwen2.5-7B-Instruct and Nemotron-3-Nano-30B-A3B-BF16; rendered text byte-identical to the previous Python + to_jinja pipeline when enable_thinking=False.
  • Mask-construction invariants on training turns: len(context) + len(answer) == len(input) and sum(mask) == len(answer) across Qwen / Nemotron, plain / multi-turn, with / without system, multimodal list content.
  • End-to-end vLLM librispeech-pc eval on re-exported nemotron-nano-v3 ckpt: WER 1.57 (== baseline). Without the enable_thinking default flip, WER regressed to 5.92 - regression confirmed, fix verified.
  • End-to-end NeMo SALM librispeech-pc eval on same ckpt: WER 1.62, 0 empty hypotheses (previous pipeline had 75 empties and 5.50 WER-all).
  • Backward compat: legacy SALM caller exports fine -- _try_prepare_for_vllm downgrades missing-prereq ValueError to a warning.
  • CI: add 'Run CICD' label.

Made with Cursor

Comment thread nemo/collections/common/prompts/formatter.py Outdated
Comment thread examples/speechlm2/to_hf.py Outdated
Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your thinking about using chat template directly but for the reasons below I don't think it is feasible in the current shape.

Let's approach this from two angles.

  1. In this PR don't add chat template. Revert the changes from nemotron and qwen prompt formatter. Then fix Nemotron...PromptFormatter so that it's logic is equivalent to that of jinja, but using PromptFormatter.
  2. We can add ChatTemplatePromptFormatter in a separate PR which would be auto-selected when prompt_format: "tokenizer_chat_template" and make it generic for HF models.

Comment thread nemo/collections/common/prompts/formatter.py Outdated
Comment thread nemo/collections/common/prompts/formatter.py Outdated
@pzelasko
Copy link
Copy Markdown
Collaborator

This PR also needs unit test coverage of the changes.

@DongjiGao
Copy link
Copy Markdown
Contributor Author

I like your thinking about using chat template directly but for the reasons below I don't think it is feasible in the current shape.

Let's approach this from two angles.

  1. In this PR don't add chat template. Revert the changes from nemotron and qwen prompt formatter. Then fix Nemotron...PromptFormatter so that it's logic is equivalent to that of jinja, but using PromptFormatter.
  2. We can add ChatTemplatePromptFormatter in a separate PR which would be auto-selected when prompt_format: "tokenizer_chat_template" and make it generic for HF models.

Could you please explain a bit why it is not feasible in current shape? Thanks!

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Comment thread nemo/collections/common/prompts/qwen.py Outdated
Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

`to_hf.py` now always emits a checkpoint that can be served by vLLM's
SpeechLM plugin: architecture/model_type fields in config.json, the
backbone's canonical chat_template in tokenizer_config.json, audio
placeholder registered on the tokenizer, and a minimal generation
config. The previous `vllm: bool` opt-in flag is gone -- HF and vLLM
loaders now share the same on-disk artifact (pzelasko's ask on the
closed PR NVIDIA-NeMo#15617 thread).

Key bits:
- `_detect_vllm_architecture` inspects the backbone's HF config to
  pick the right vLLM plugin class. Fail-fast ValueError on missing
  `architectures` rather than silently defaulting to 'Std' (also
  addresses pzelasko's review comment about the broad except).
- `prepare_for_vllm` is invoked unconditionally after save, wrapped in
  `_try_prepare_for_vllm` which downgrades a `ValueError` to a warning
  so legacy callers that never needed vLLM (e.g., NeMo SALM loading
  the same dir) still get a clean HF-only checkpoint.
- Tokenizer is re-saved from the backbone (brings its native
  chat_template along) + augmented with `<|audio|>`; extra_special_tokens
  is normalized to a dict so vLLM's AutoTokenizer can load it.
- For reasoning backbones (nemotron-nano-v3), the exported
  chat_template's `enable_thinking` default is flipped to False so
  vLLM's request-time render matches training-time render; otherwise
  vLLM silently prepends `<think>\n` to every assistant turn and WER
  regresses. Verified librispeech-pc WER 1.57 (== baseline) after this
  fix; without it WER regressed to 5.92.

hf_hub.py: setdefault `model_type` and `architectures` in
`HFHubMixin.save_pretrained` so NeMo-saved SpeechLM checkpoints carry
the metadata vLLM / transformers need to identify them.

Made-with: Cursor
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Five top-level helpers (load_checkpoint, setup_distributed_from_config,
consolidate_state_dict, save_hf_checkpoint, main) lacked return types.
Uses Any for setup_distributed_from_config's AutomodelParallelStrategy
return to avoid adding an import just for typing; concrete types
everywhere else. Same fix previously applied on pr/vllm-plugin.

Made-with: Cursor
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Modern HuggingFace transformers (~4.42+) moves long chat_template
strings out of tokenizer_config.json into a separate
chat_template.jinja file to keep the JSON readable. Qwen3-1.7B's
4168-char template triggers this split; Nemotron-Nano's shorter
template stays inline.

The old code deleted chat_template.jinja before reading
tokenizer_config.json, assuming the inline copy was always complete.
For Qwen3 that meant the exported checkpoint shipped with an empty
chat_template -- vLLM's apply_chat_template returned a prompt
without the <|audio|> placeholder, which broke multimodal prompt
replacement (Failed to apply prompt replacement for mm_items['audio'][0]).

Now read chat_template.jinja, inline it into tokenizer_config.json
when non-empty, and only then delete the file. Nemotron's inline-only
path is unchanged because .jinja doesn't get written for small
templates.

Made-with: Cursor
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Newer NeMo containers (e.g. nemo-25.11-pytorch2.9-automodel-03apr26)
wrap AutoTokenizer.from_pretrained(trust_remote_code=True) in a
NeMo-internal TokenizersBackend class. save_pretrained then writes
'tokenizer_class: TokenizersBackend' to tokenizer_config.json --
not in HF transformers' registry, so vLLM's
AutoTokenizer.from_pretrained crashes at server load:

  ValueError: Tokenizer class TokenizersBackend does not exist or
  is not currently imported.

The underlying tokenizer.json is a valid HF fast tokenizer regardless
of which wrapper produced it; force the class name back to
PreTrainedTokenizerFast so downstream HF-based loaders (including
vLLM) can round-trip the config.

Made-with: Cursor
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Qwen3's chat_template injects '<think>\n\n</think>\n\n' before assistant
content when enable_thinking=False (the 'no reasoning' mode, which is
what SpeechLM ASR wants). The old QwenPromptFormatter didn't include
this prefix, so SpeechLM fine-tunes trained through it showed the model
a turn shape that's different from Qwen3's pre-training convention.

Bake NO_THINK_PREFIX into both INFERENCE_PREFIX and (transitively via
INFERENCE_PREFIX) the assistant template, so future fine-tunes produce
training data byte-identical to 'apply_chat_template(enable_thinking=
False)'. Existing checkpoints are unaffected -- the change only kicks
in the next time you retrain with prompt_format=qwen.

Test: updated hardcoded expected strings to match Qwen3 jinja output
for single-turn training and inference.

Made-with: Cursor
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Removes the hardcoded _AUDIO_TOKEN constant and reads the audio placeholder
from model_cfg["audio_locator_tag"], raising ValueError if missing. This
ensures the exported config.json, added tokenizer symbols, and
extra_special_tokens dict all reference the same source of truth as training,
avoiding silent drift between train-time and inference-time audio tokens.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Made-with: Cursor
Before this fix, the formatter only normalized past assistant turns that
already contained <think>...</think> tags, so a past assistant turn without
any think tags would emit as "<|im_start|>assistant\nTEST<|im_end|>\n" while
the HF jinja template emits "<|im_start|>assistant\n<think></think>TEST<|im_end|>\n"
(jinja unconditionally injects an empty think block for content lacking both
tags). This caused a silent train/inference-template divergence for
multi-turn dialogs without reasoning history.

Fix step 3 in encode_dialog to handle all three cases symmetrically:
  - both tags present -> truncate to "<think></think>" + post-</think> content
  - neither tag present -> prepend "<think></think>"
  - only one tag present -> leave as-is (matches jinja)

Also adds three tests to fill previously-missing coverage:
  - training multi-turn with past assistant missing think tags
    (regression test for the fix above)
  - inference multi-turn with enable_thinking=False
  - inference multi-turn with enable_thinking=True

All 12 non-HF tests pass in the NeMo 25.11 container.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Made-with: Cursor
The previous in-place string-replace flipped Nemotron's
``enable_thinking`` default from True to False so that default vLLM
inference (with no ``chat_template_kwargs``) would match SpeechLM
training rendering. This approach is fragile (silently no-ops if
upstream changes the template) and surprising for downstream consumers
of the exported checkpoint.

Serving callers should instead pass ``chat_template_kwargs={"enable_thinking": False}``
(or the OpenAI-API equivalent) at inference time to opt out of thinking.
This keeps the exported chat_template byte-identical to the backbone's
canonical HF template.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Made-with: Cursor
Covers the behavior introduced / changed in this PR:
  * Error paths for missing pretrained_llm and missing audio_locator_tag
  * config.json patching (model_type, architectures, audio_locator_tag SoT)
  * Audio token registration (add_special_tokens called only when missing
    from the backbone vocab)
  * tokenizer_config.json normalization (dict-form extra_special_tokens,
    forced tokenizer_class=PreTrainedTokenizerFast)
  * chat_template.jinja rescue (inlined back into tokenizer_config.json
    and the separate .jinja file removed)
  * chat_template is byte-identical after prep (regression guard for the
    removal of the enable_thinking default-flip)
  * generation_config.json carries the tokenizer's eos_token_id

The script lives under examples/ and is loaded via importlib; AutoTokenizer
and _detect_vllm_architecture are patched so the tests run fully offline.
9 tests pass in 0.35s in the NeMo 25.11 container.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Made-with: Cursor
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
…re-training"

This reverts commit 0245881.

Piotr raised a concern that baking ``NO_THINK_PREFIX`` into
``QwenPromptFormatter`` changes the turn shape seen by in-flight fine-tunes
like canary-qwen-2.5b, which were trained on the prior (no-think-prefix)
formatter output. Re-rendering the same data through the updated formatter
would silently shift the prompt distribution and break those checkpoints.

Reverting the bake keeps the ``qwen`` prompt format byte-identical to the
version those checkpoints saw during training. Any future fine-tune that
actually wants the ``<think></think>`` empty-reasoning prefix should use
``Qwen3PromptFormatter`` (already handles it via ``enable_thinking=False``)
or explicitly include the prefix in training data, rather than flipping the
default for all ``qwen`` consumers.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Made-with: Cursor
@pzelasko pzelasko merged commit 87ccac8 into NVIDIA-NeMo:main Apr 24, 2026
163 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants