to_hf.py + PromptFormatter: produce vLLM-ready SpeechLM checkpoints with backbone-native chat_template#15623
Merged
Merged
Conversation
DongjiGao
commented
Apr 18, 2026
pzelasko
reviewed
Apr 20, 2026
pzelasko
requested changes
Apr 20, 2026
Collaborator
There was a problem hiding this comment.
I like your thinking about using chat template directly but for the reasons below I don't think it is feasible in the current shape.
Let's approach this from two angles.
- In this PR don't add chat template. Revert the changes from nemotron and qwen prompt formatter. Then fix
Nemotron...PromptFormatterso that it's logic is equivalent to that of jinja, but using PromptFormatter. - We can add
ChatTemplatePromptFormatterin a separate PR which would be auto-selected whenprompt_format: "tokenizer_chat_template"and make it generic for HF models.
Collaborator
|
This PR also needs unit test coverage of the changes. |
Contributor
Author
Could you please explain a bit why it is not feasible in current shape? Thanks! |
b046e3d to
9641fe0
Compare
d073a98 to
849564a
Compare
pzelasko
reviewed
Apr 23, 2026
ebfa643 to
91b9af2
Compare
`to_hf.py` now always emits a checkpoint that can be served by vLLM's SpeechLM plugin: architecture/model_type fields in config.json, the backbone's canonical chat_template in tokenizer_config.json, audio placeholder registered on the tokenizer, and a minimal generation config. The previous `vllm: bool` opt-in flag is gone -- HF and vLLM loaders now share the same on-disk artifact (pzelasko's ask on the closed PR NVIDIA-NeMo#15617 thread). Key bits: - `_detect_vllm_architecture` inspects the backbone's HF config to pick the right vLLM plugin class. Fail-fast ValueError on missing `architectures` rather than silently defaulting to 'Std' (also addresses pzelasko's review comment about the broad except). - `prepare_for_vllm` is invoked unconditionally after save, wrapped in `_try_prepare_for_vllm` which downgrades a `ValueError` to a warning so legacy callers that never needed vLLM (e.g., NeMo SALM loading the same dir) still get a clean HF-only checkpoint. - Tokenizer is re-saved from the backbone (brings its native chat_template along) + augmented with `<|audio|>`; extra_special_tokens is normalized to a dict so vLLM's AutoTokenizer can load it. - For reasoning backbones (nemotron-nano-v3), the exported chat_template's `enable_thinking` default is flipped to False so vLLM's request-time render matches training-time render; otherwise vLLM silently prepends `<think>\n` to every assistant turn and WER regresses. Verified librispeech-pc WER 1.57 (== baseline) after this fix; without it WER regressed to 5.92. hf_hub.py: setdefault `model_type` and `architectures` in `HFHubMixin.save_pretrained` so NeMo-saved SpeechLM checkpoints carry the metadata vLLM / transformers need to identify them. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Five top-level helpers (load_checkpoint, setup_distributed_from_config, consolidate_state_dict, save_hf_checkpoint, main) lacked return types. Uses Any for setup_distributed_from_config's AutomodelParallelStrategy return to avoid adding an import just for typing; concrete types everywhere else. Same fix previously applied on pr/vllm-plugin. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Modern HuggingFace transformers (~4.42+) moves long chat_template strings out of tokenizer_config.json into a separate chat_template.jinja file to keep the JSON readable. Qwen3-1.7B's 4168-char template triggers this split; Nemotron-Nano's shorter template stays inline. The old code deleted chat_template.jinja before reading tokenizer_config.json, assuming the inline copy was always complete. For Qwen3 that meant the exported checkpoint shipped with an empty chat_template -- vLLM's apply_chat_template returned a prompt without the <|audio|> placeholder, which broke multimodal prompt replacement (Failed to apply prompt replacement for mm_items['audio'][0]). Now read chat_template.jinja, inline it into tokenizer_config.json when non-empty, and only then delete the file. Nemotron's inline-only path is unchanged because .jinja doesn't get written for small templates. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Newer NeMo containers (e.g. nemo-25.11-pytorch2.9-automodel-03apr26) wrap AutoTokenizer.from_pretrained(trust_remote_code=True) in a NeMo-internal TokenizersBackend class. save_pretrained then writes 'tokenizer_class: TokenizersBackend' to tokenizer_config.json -- not in HF transformers' registry, so vLLM's AutoTokenizer.from_pretrained crashes at server load: ValueError: Tokenizer class TokenizersBackend does not exist or is not currently imported. The underlying tokenizer.json is a valid HF fast tokenizer regardless of which wrapper produced it; force the class name back to PreTrainedTokenizerFast so downstream HF-based loaders (including vLLM) can round-trip the config. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Qwen3's chat_template injects '<think>\n\n</think>\n\n' before assistant content when enable_thinking=False (the 'no reasoning' mode, which is what SpeechLM ASR wants). The old QwenPromptFormatter didn't include this prefix, so SpeechLM fine-tunes trained through it showed the model a turn shape that's different from Qwen3's pre-training convention. Bake NO_THINK_PREFIX into both INFERENCE_PREFIX and (transitively via INFERENCE_PREFIX) the assistant template, so future fine-tunes produce training data byte-identical to 'apply_chat_template(enable_thinking= False)'. Existing checkpoints are unaffected -- the change only kicks in the next time you retrain with prompt_format=qwen. Test: updated hardcoded expected strings to match Qwen3 jinja output for single-turn training and inference. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Removes the hardcoded _AUDIO_TOKEN constant and reads the audio placeholder from model_cfg["audio_locator_tag"], raising ValueError if missing. This ensures the exported config.json, added tokenizer symbols, and extra_special_tokens dict all reference the same source of truth as training, avoiding silent drift between train-time and inference-time audio tokens. Signed-off-by: Dongji Gao <dongjig@nvidia.com> Made-with: Cursor
Before this fix, the formatter only normalized past assistant turns that
already contained <think>...</think> tags, so a past assistant turn without
any think tags would emit as "<|im_start|>assistant\nTEST<|im_end|>\n" while
the HF jinja template emits "<|im_start|>assistant\n<think></think>TEST<|im_end|>\n"
(jinja unconditionally injects an empty think block for content lacking both
tags). This caused a silent train/inference-template divergence for
multi-turn dialogs without reasoning history.
Fix step 3 in encode_dialog to handle all three cases symmetrically:
- both tags present -> truncate to "<think></think>" + post-</think> content
- neither tag present -> prepend "<think></think>"
- only one tag present -> leave as-is (matches jinja)
Also adds three tests to fill previously-missing coverage:
- training multi-turn with past assistant missing think tags
(regression test for the fix above)
- inference multi-turn with enable_thinking=False
- inference multi-turn with enable_thinking=True
All 12 non-HF tests pass in the NeMo 25.11 container.
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Made-with: Cursor
The previous in-place string-replace flipped Nemotron's
``enable_thinking`` default from True to False so that default vLLM
inference (with no ``chat_template_kwargs``) would match SpeechLM
training rendering. This approach is fragile (silently no-ops if
upstream changes the template) and surprising for downstream consumers
of the exported checkpoint.
Serving callers should instead pass ``chat_template_kwargs={"enable_thinking": False}``
(or the OpenAI-API equivalent) at inference time to opt out of thinking.
This keeps the exported chat_template byte-identical to the backbone's
canonical HF template.
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Made-with: Cursor
Covers the behavior introduced / changed in this PR:
* Error paths for missing pretrained_llm and missing audio_locator_tag
* config.json patching (model_type, architectures, audio_locator_tag SoT)
* Audio token registration (add_special_tokens called only when missing
from the backbone vocab)
* tokenizer_config.json normalization (dict-form extra_special_tokens,
forced tokenizer_class=PreTrainedTokenizerFast)
* chat_template.jinja rescue (inlined back into tokenizer_config.json
and the separate .jinja file removed)
* chat_template is byte-identical after prep (regression guard for the
removal of the enable_thinking default-flip)
* generation_config.json carries the tokenizer's eos_token_id
The script lives under examples/ and is loaded via importlib; AutoTokenizer
and _detect_vllm_architecture are patched so the tests run fully offline.
9 tests pass in 0.35s in the NeMo 25.11 container.
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Made-with: Cursor
Signed-off-by: Dongji Gao <dongjig@nvidia.com>
…re-training" This reverts commit 0245881. Piotr raised a concern that baking ``NO_THINK_PREFIX`` into ``QwenPromptFormatter`` changes the turn shape seen by in-flight fine-tunes like canary-qwen-2.5b, which were trained on the prior (no-think-prefix) formatter output. Re-rendering the same data through the updated formatter would silently shift the prompt distribution and break those checkpoints. Reverting the bake keeps the ``qwen`` prompt format byte-identical to the version those checkpoints saw during training. Any future fine-tune that actually wants the ``<think></think>`` empty-reasoning prefix should use ``Qwen3PromptFormatter`` (already handles it via ``enable_thinking=False``) or explicitly include the prefix in training data, rather than flipping the default for all ``qwen`` consumers. Signed-off-by: Dongji Gao <dongjig@nvidia.com> Made-with: Cursor
91b9af2 to
8d8dab5
Compare
pzelasko
approved these changes
Apr 23, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Background
Reopens #15617 after its base branch (
speechlm2-with-nemo-automodel-merge) was deleted; now rebased onmain. Incorporates @pzelasko's review comments from the closed thread.Summary
Produces a single checkpoint artifact that serves both HF / NeMo SALM inference and vLLM (SpeechLM plugin) without an opt-in flag, and eliminates the drift between our hand-rolled jinja and the backbone LLM's canonical chat_template.
Commit 1 - PromptFormatter uses backbone
chat_templateas source of truthReplaces the hardcoded
TEMPLATEdict +to_jinja()reconstruction inQwenPromptFormatterandNemotronNanoV3PromptFormatterwith a shared base helper that delegates totokenizer.apply_chat_templateon the backbone's native jinja. Training and exported-ckpt inference now share one jinja string shipped by the backbone's HF repo - no more train/infer drift from reconstruction mismatches.formatter.py: addAUDIO_TOKEN,_turn_to_openai_message,_is_reasoning_divergence,_encode_dialog_from_chat_template. Drop the abstractto_jinjastub.qwen.py:QwenPromptFormatter.encode_dialogdelegates; hardcodedTEMPLATE,INFERENCE_PREFIXandto_jinjaremoved.Qwen3PromptFormatter(reasoning, out of refactor scope) unchanged.nemotron_nano_v3.py:encode_dialogdelegates and forwardsenable_thinkingto the backbone jinja. Default isFalse; reasoning training raises a clear 'not supported' error because the prefix-based loss mask can't align with the template's fresh<think>generation prompt.Commit 2 -
to_hf.pyunconditionally produces vLLM-ready checkpointsvllm: bool = Falseflag; every export is vLLM-ready (addresses @pzelasko's review comment on to_hf.py: add --vllm flag for vLLM-ready checkpoint export #15617)._detect_vllm_architecturenow raisesValueErroron missingarchitecturesinstead of silently defaulting (addresses the 'broad except' review comment)._try_prepare_for_vllmwrapper downgrades aValueErrorto a warning so legacy SALM callers loading the same dir still get a valid HF-only checkpoint.chat_template) + augmented with<|audio|>.extra_special_tokensnormalized to dict (HF loaders crash on the list formtransformerswrites foradditional_special_tokens).chat_template'senable_thinkingdefault is flipped toFalseso vLLM's request-time render matches training-time render; otherwise vLLM silently prepends<think>and WER regresses from 1.57 to 5.92.hf_hub.py:setdefault('model_type', ...)andsetdefault('architectures', [...])inHFHubMixin.save_pretrainedso NeMo-saved checkpoints carry the metadata vLLM / transformers need.Architecture mapping
NeMoSpeechLMHybridForConditionalGenerationNeMoSpeechLMForConditionalGenerationUsage
torchrun --nproc-per-node=8 --nnodes=4 \ examples/speechlm2/to_hf.py \ class_path=nemo.collections.speechlm2.models.salm_automodel.SALMAutomodel \ ckpt_path=/path/to/checkpoint.ckpt \ ckpt_config=/path/to/exp_config.yaml \ output_dir=/path/to/hf_outputNo
vllm=Trueflag any more - output is always vLLM-ready.Test plan
enable_thinking=False.len(context) + len(answer) == len(input)andsum(mask) == len(answer)across Qwen / Nemotron, plain / multi-turn, with / without system, multimodal list content.enable_thinkingdefault flip, WER regressed to 5.92 - regression confirmed, fix verified._try_prepare_for_vllmdowngrades missing-prereqValueErrorto a warning.Made with Cursor