diff --git a/examples/speechlm2/to_hf.py b/examples/speechlm2/to_hf.py index f2e091854623..71a4c74a79f1 100644 --- a/examples/speechlm2/to_hf.py +++ b/examples/speechlm2/to_hf.py @@ -13,6 +13,7 @@ # limitations under the License. import json import os +from copy import deepcopy from dataclasses import dataclass from pathlib import Path from typing import Any @@ -22,7 +23,9 @@ from omegaconf import DictConfig, OmegaConf from safetensors.torch import save_file +from nemo.collections.speechlm2.parts.hf_hub import LLM_BACKBONE_DIR from nemo.core.config import hydra_runner +from nemo.utils.dtype import str_to_dtype from nemo.utils.model_utils import import_class_by_path @@ -92,19 +95,45 @@ def consolidate_state_dict(model: torch.nn.Module) -> dict[str, torch.Tensor]: return consolidated +def _canonical_torch_dtype_name(dtype: str | torch.dtype) -> str: + """Return the PyTorch dtype name accepted by Transformers configs.""" + return str(str_to_dtype(dtype)).replace("torch.", "") + + +def _hf_export_config(model: torch.nn.Module, dtype: str | torch.dtype) -> dict[str, Any]: + """Build the exported root config without mutating the training config.""" + config = OmegaConf.to_container(model.cfg) if isinstance(model.cfg, DictConfig) else deepcopy(model.cfg) + dtype_name = _canonical_torch_dtype_name(dtype) + config["dtype"] = dtype_name + config["torch_dtype"] = dtype_name + return config + + def save_hf_checkpoint(model: torch.nn.Module, state_dict: dict, cfg: HfExportConfig) -> None: """Save a consolidated state dict and model config in HuggingFace Hub format.""" output_dir = Path(cfg.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - target_dtype = getattr(torch, cfg.dtype) + target_dtype = str_to_dtype(cfg.dtype) state_dict = {k: v.to(target_dtype) for k, v in state_dict.items()} save_file(state_dict, output_dir / "model.safetensors") - config = OmegaConf.to_container(model.cfg) if isinstance(model.cfg, DictConfig) else model.cfg + config = _hf_export_config(model, cfg.dtype) with open(output_dir / "config.json", "w") as f: json.dump(config, f, indent=2) + save_llm_backbone_config(model, output_dir) + + +def save_llm_backbone_config(model: torch.nn.Module, output_dir: str | Path) -> None: + """Save the original LLM config separately from the NeMo wrapper config.""" + llm_config = getattr(getattr(model, "llm", None), "config", None) + if llm_config is None: + return + + llm_backbone_dir = Path(output_dir) / LLM_BACKBONE_DIR + llm_backbone_dir.mkdir(parents=True, exist_ok=True) + llm_config.save_pretrained(str(llm_backbone_dir)) def _detect_vllm_architecture(model_cfg: dict) -> str: @@ -165,7 +194,11 @@ def prepare_for_vllm(output_dir: str, model_cfg: dict) -> None: raise ValueError("model config has no 'audio_locator_tag' (set it in the training YAML).") # 1. Patch config.json (arch, model_type, audio_locator_tag for vLLM plugin). - arch = _detect_vllm_architecture(model_cfg) + arch_model_cfg = dict(model_cfg) + llm_backbone_dir = output_dir / LLM_BACKBONE_DIR + if (llm_backbone_dir / "config.json").exists(): + arch_model_cfg["pretrained_llm"] = str(llm_backbone_dir) + arch = _detect_vllm_architecture(arch_model_cfg) config_path = output_dir / "config.json" config = json.loads(config_path.read_text()) config["model_type"] = "nemo_speechlm" @@ -274,7 +307,7 @@ def main(cfg: HfExportConfig) -> None: full_cfg = OmegaConf.to_container(OmegaConf.load(cfg.ckpt_config), resolve=True) model_cfg = full_cfg["model"] - model_cfg["torch_dtype"] = cfg.dtype + model_cfg["torch_dtype"] = _canonical_torch_dtype_name(cfg.dtype) cls = import_class_by_path(cfg.class_path) strategy_cfg = full_cfg.get("trainer", {}).get("strategy", {}) @@ -317,9 +350,10 @@ def main(cfg: HfExportConfig) -> None: model_cfg["init_configure_model"] = True model = cls(model_cfg) load_checkpoint(model, cfg.ckpt_path) - model = model.to(getattr(torch, cfg.dtype)) + model = model.to(str_to_dtype(cfg.dtype)) model_cfg["pretrained_weights"] = False - model.save_pretrained(cfg.output_dir) + model.save_pretrained(cfg.output_dir, config=_hf_export_config(model, cfg.dtype)) + save_llm_backbone_config(model, cfg.output_dir) _try_prepare_for_vllm(cfg.output_dir, model_cfg) diff --git a/nemo/collections/speechlm2/models/duplex_ear_tts.py b/nemo/collections/speechlm2/models/duplex_ear_tts.py index bab8bbf8810f..a11e2fce6bdf 100644 --- a/nemo/collections/speechlm2/models/duplex_ear_tts.py +++ b/nemo/collections/speechlm2/models/duplex_ear_tts.py @@ -92,8 +92,9 @@ def __init__(self, cfg: dict) -> None: self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "float32"), torch.float32) # Load tokenizer + tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_lm_name self.tokenizer = AutoTokenizer( - self.cfg.pretrained_lm_name, + tokenizer_src, use_fast=True, trust_remote_code=True, bos_token=self.cfg.get("bos_token", None), diff --git a/nemo/collections/speechlm2/models/duplex_s2s_model.py b/nemo/collections/speechlm2/models/duplex_s2s_model.py index f3e8c730e837..71385f4331f4 100644 --- a/nemo/collections/speechlm2/models/duplex_s2s_model.py +++ b/nemo/collections/speechlm2/models/duplex_s2s_model.py @@ -63,7 +63,8 @@ def __init__(self, cfg: dict) -> None: # pretrained LM head weights. # However, for S2S we need to access the activations before LM head directly # to feed them to the audio codec head. - self.tokenizer = AutoTokenizer(self.cfg.pretrained_llm, use_fast=True) + tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm + self.tokenizer = AutoTokenizer(tokenizer_src, use_fast=True) llm = load_pretrained_hf(self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights).train() self.llm = llm.model # fetch PretrainedBaseModel from model "ForCausalLM" self.lm_head = llm.lm_head diff --git a/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py b/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py index 66dac2bcb2f1..5138cb3038e8 100644 --- a/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py +++ b/nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py @@ -64,7 +64,8 @@ def __init__(self, cfg: dict) -> None: # pretrained LM head weights. # However, for S2S we need to access the activations before LM head directly # to feed them to the audio codec head. - self.tokenizer = AutoTokenizer(self.cfg.pretrained_llm, use_fast=True) + tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm + self.tokenizer = AutoTokenizer(tokenizer_src, use_fast=True) llm = load_pretrained_hf(self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights).train() self.llm = llm.model # fetch PretrainedBaseModel from model "ForCausalLM" self.lm_head = llm.lm_head diff --git a/nemo/collections/speechlm2/models/duplex_stt_model.py b/nemo/collections/speechlm2/models/duplex_stt_model.py index d15becb2cbc8..1cd7492e39aa 100644 --- a/nemo/collections/speechlm2/models/duplex_stt_model.py +++ b/nemo/collections/speechlm2/models/duplex_stt_model.py @@ -87,8 +87,9 @@ def __init__(self, cfg: dict) -> None: ).train() # Initialize tokenizer with optional special tokens from config + tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm self.tokenizer = AutoTokenizer( - self.cfg.pretrained_llm, + tokenizer_src, use_fast=True, bos_token=self.cfg.get("bos_token", None), eos_token=self.cfg.get("eos_token", None), diff --git a/nemo/collections/speechlm2/models/salm.py b/nemo/collections/speechlm2/models/salm.py index 5e1533993ce0..c47e9c114e5d 100644 --- a/nemo/collections/speechlm2/models/salm.py +++ b/nemo/collections/speechlm2/models/salm.py @@ -63,8 +63,9 @@ def __init__(self, cfg) -> None: self.cfg = DictConfig(cfg) self.audio_locator_tag = self.cfg.audio_locator_tag + tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm self.tokenizer = AutoTokenizer( - self.cfg.pretrained_llm, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False) + tokenizer_src, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False) ) self.tokenizer.add_special_tokens({"additional_special_tokens": [self.audio_locator_tag]}) self.llm = load_pretrained_hf( diff --git a/nemo/collections/speechlm2/models/salm_asr_decoder.py b/nemo/collections/speechlm2/models/salm_asr_decoder.py index 02665c00480a..6d7d8450dfd7 100644 --- a/nemo/collections/speechlm2/models/salm_asr_decoder.py +++ b/nemo/collections/speechlm2/models/salm_asr_decoder.py @@ -63,7 +63,8 @@ def __init__(self, cfg) -> None: self.cfg = DictConfig(cfg) self.audio_locator_tag = self.cfg.audio_locator_tag - self.tokenizer = AutoTokenizer(self.cfg.pretrained_llm, use_fast=True) + tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm + self.tokenizer = AutoTokenizer(tokenizer_src, use_fast=True) self.tokenizer.add_special_tokens({"additional_special_tokens": [self.audio_locator_tag]}) self.llm = load_pretrained_hf(self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights) if not hasattr(self.llm, "model") and hasattr(self.llm, "backbone"): diff --git a/nemo/collections/speechlm2/models/salm_automodel.py b/nemo/collections/speechlm2/models/salm_automodel.py index f759ac01bcc7..513fe61f93a1 100644 --- a/nemo/collections/speechlm2/models/salm_automodel.py +++ b/nemo/collections/speechlm2/models/salm_automodel.py @@ -53,8 +53,9 @@ def __init__(self, cfg) -> None: self.cfg = DictConfig(cfg) self.audio_locator_tag = self.cfg.audio_locator_tag + tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm self.tokenizer = AutoTokenizer( - self.cfg.pretrained_llm, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False) + tokenizer_src, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False) ) self.tokenizer.add_special_tokens({"additional_special_tokens": [self.audio_locator_tag]}) self.llm = None # populated by configure_model diff --git a/nemo/collections/speechlm2/parts/hf_hub.py b/nemo/collections/speechlm2/parts/hf_hub.py index 29cb80ddacb4..d8e6fafc0bd8 100644 --- a/nemo/collections/speechlm2/parts/hf_hub.py +++ b/nemo/collections/speechlm2/parts/hf_hub.py @@ -20,6 +20,7 @@ from transformers.utils import cached_file SAFETENSORS_SINGLE_FILE = "model.safetensors" +LLM_BACKBONE_DIR = "llm_backbone" class HFHubMixin( @@ -80,6 +81,7 @@ def _from_pretrained( if resolved_config_file is None: raise RuntimeError(f"Missing {CONFIG_NAME} file for {model_id=}") model_kwargs['cfg'] = OmegaConf.to_container(OmegaConf.load(resolved_config_file)) + _inject_local_artifact_paths(model_kwargs['cfg'], model_id, _cached_file_kwargs) # The setting below tells the model's __init__ not to load the original pretrained weights # for individual children modules. # To illustrate: if you trained a new model M using a pretrained ASR and a pretrained LLM, @@ -252,3 +254,28 @@ def _load_state_dict_with_dtensors(model, weight_dir): # the planner narrows each tensor to the local DTensor shard, # and copies directly into model parameter storage. dcp.load(state_dict, storage_reader=reader) + + +def _inject_local_artifact_paths(cfg: dict, model_id: str, cached_file_kwargs: dict) -> None: + """ + Redirect a loaded SpeechLM2 checkpoint config to artifacts saved beside it. + + The root checkpoint directory keeps NeMo's wrapper ``config.json``. When it + also contains a root tokenizer and ``llm_backbone/config.json``, point + tokenizer construction to the root directory and LLM config construction to + ``llm_backbone`` by mutating ``tokenizer_path`` plus ``pretrained_llm`` or + ``pretrained_lm_name`` in-place. + """ + resolved_tokenizer_file = cached_file(model_id, "tokenizer_config.json", **cached_file_kwargs) + if resolved_tokenizer_file is not None and ("pretrained_llm" in cfg or "pretrained_lm_name" in cfg): + cfg["tokenizer_path"] = str(Path(resolved_tokenizer_file).parent) + + resolved_llm_config_file = cached_file(model_id, f"{LLM_BACKBONE_DIR}/{CONFIG_NAME}", **cached_file_kwargs) + if resolved_llm_config_file is None: + return + + llm_backbone_path = str(Path(resolved_llm_config_file).parent) + if "pretrained_llm" in cfg: + cfg["pretrained_llm"] = llm_backbone_path + if "pretrained_lm_name" in cfg: + cfg["pretrained_lm_name"] = llm_backbone_path diff --git a/tests/collections/speechlm2/test_hf_hub.py b/tests/collections/speechlm2/test_hf_hub.py new file mode 100644 index 000000000000..85f512292daf --- /dev/null +++ b/tests/collections/speechlm2/test_hf_hub.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.speechlm2.parts.hf_hub import _inject_local_artifact_paths + + +def _cached_file_kwargs(): + return { + "cache_dir": None, + "force_download": False, + "local_files_only": True, + "token": None, + "revision": None, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_raise_exceptions_for_connection_errors": False, + } + + +def _write_local_export_artifacts(tmp_path): + (tmp_path / "tokenizer_config.json").write_text("{}") + (tmp_path / "llm_backbone").mkdir() + (tmp_path / "llm_backbone" / "config.json").write_text("{}") + + +def test_inject_local_artifact_paths_salm_config(tmp_path): + _write_local_export_artifacts(tmp_path) + cfg = { + "pretrained_llm": "remote-llm", + "pretrained_asr": "remote-asr", + } + + _inject_local_artifact_paths(cfg, str(tmp_path), _cached_file_kwargs()) + + assert cfg["pretrained_llm"] == str(tmp_path / "llm_backbone") + assert cfg["pretrained_asr"] == "remote-asr" + assert cfg["tokenizer_path"] == str(tmp_path) + + +def test_inject_local_artifact_paths_duplex_eartts_config(tmp_path): + _write_local_export_artifacts(tmp_path) + cfg = { + "pretrained_lm_name": "remote-llm", + "tts_config": {}, + } + + _inject_local_artifact_paths(cfg, str(tmp_path), _cached_file_kwargs()) + + assert cfg["pretrained_lm_name"] == str(tmp_path / "llm_backbone") + assert cfg["tokenizer_path"] == str(tmp_path) + + +def test_inject_local_artifact_paths_no_artifacts_keeps_old_config(tmp_path): + cfg = { + "pretrained_llm": "remote-llm", + "pretrained_weights": True, + } + + _inject_local_artifact_paths(cfg, str(tmp_path), _cached_file_kwargs()) + + assert cfg == { + "pretrained_llm": "remote-llm", + "pretrained_weights": True, + } diff --git a/tests/collections/speechlm2/test_to_hf.py b/tests/collections/speechlm2/test_to_hf.py index 985f48d6bb38..8cd3b08bd7cb 100644 --- a/tests/collections/speechlm2/test_to_hf.py +++ b/tests/collections/speechlm2/test_to_hf.py @@ -23,6 +23,8 @@ from unittest.mock import patch import pytest +import torch +from safetensors.torch import load_file _TO_HF_PATH = Path(__file__).parents[3] / "examples" / "speechlm2" / "to_hf.py" _spec = importlib.util.spec_from_file_location("to_hf_for_test", _TO_HF_PATH) @@ -101,6 +103,73 @@ def _seed_output_dir(tmp_path, llm_arch="Qwen2ForCausalLM"): return tmp_path +class _FakeLLMConfig: + def save_pretrained(self, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "config.json").write_text( + json.dumps( + { + "model_type": "qwen2", + "architectures": ["Qwen2ForCausalLM"], + "hidden_size": 2048, + } + ) + ) + + +class _FakeExportModel: + cfg = { + "pretrained_llm": "fake-model", + "pretrained_asr": "fake-asr", + "pretrained_weights": False, + "dtype": "bf16", + "torch_dtype": "bf16", + "audio_locator_tag": AUDIO_TOKEN, + } + llm = type("_FakeLLM", (), {"config": _FakeLLMConfig()})() + + +def test_save_hf_checkpoint_writes_llm_backbone_config(tmp_path): + cfg = to_hf.HfExportConfig( + class_path="fake.Class", + ckpt_path="fake.ckpt", + ckpt_config="fake.yaml", + output_dir=str(tmp_path), + dtype="bfloat16", + ) + to_hf.save_hf_checkpoint(_FakeExportModel(), {"weight": torch.zeros(1)}, cfg) + + root_cfg = json.loads((tmp_path / "config.json").read_text()) + llm_cfg = json.loads((tmp_path / "llm_backbone" / "config.json").read_text()) + + assert "llm_config" not in root_cfg + assert root_cfg["pretrained_llm"] == "fake-model" + assert root_cfg["dtype"] == "bfloat16" + assert root_cfg["torch_dtype"] == "bfloat16" + assert llm_cfg["model_type"] == "qwen2" + assert llm_cfg["architectures"] == ["Qwen2ForCausalLM"] + assert _FakeExportModel.cfg["dtype"] == "bf16" + + +def test_save_hf_checkpoint_accepts_bf16_export_dtype(tmp_path): + cfg = to_hf.HfExportConfig( + class_path="fake.Class", + ckpt_path="fake.ckpt", + ckpt_config="fake.yaml", + output_dir=str(tmp_path), + dtype="bf16", + ) + to_hf.save_hf_checkpoint(_FakeExportModel(), {"weight": torch.zeros(1)}, cfg) + + root_cfg = json.loads((tmp_path / "config.json").read_text()) + state_dict = load_file(tmp_path / "model.safetensors") + + assert root_cfg["dtype"] == "bfloat16" + assert root_cfg["torch_dtype"] == "bfloat16" + assert state_dict["weight"].dtype == torch.bfloat16 + + # ────────────────────────────────────────────────────────────────────── # Error paths (no mocking required — checks run before any HF calls) # ──────────────────────────────────────────────────────────────────────