Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@
from funasr.utils import export_utils
from funasr.utils import misc


def _resolve_ncpu(config, fallback=4):
"""Return a positive integer representing CPU threads from config."""
value = config.get("ncpu", fallback)
try:
value = int(value)
except (TypeError, ValueError):
value = fallback
return max(value, 1)

try:
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
Expand Down Expand Up @@ -132,6 +142,7 @@ def __init__(self, **kwargs):
vad_kwargs["model"] = vad_model
vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
vad_kwargs["device"] = kwargs["device"]
vad_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4))
vad_model, vad_kwargs = self.build_model(**vad_kwargs)

# if punc_model is not None, build punc model else None
Expand All @@ -142,6 +153,7 @@ def __init__(self, **kwargs):
punc_kwargs["model"] = punc_model
punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
punc_kwargs["device"] = kwargs["device"]
punc_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4))
punc_model, punc_kwargs = self.build_model(**punc_kwargs)

# if spk_model is not None, build spk model else None
Expand All @@ -155,6 +167,7 @@ def __init__(self, **kwargs):
spk_kwargs["model"] = spk_model
spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
spk_kwargs["device"] = kwargs["device"]
spk_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4))
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
self.cb_model = ClusterBackend(**cb_kwargs).to(kwargs["device"])
spk_mode = kwargs.get("spk_mode", "punc_segment")
Expand All @@ -171,6 +184,7 @@ def __init__(self, **kwargs):
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
self.model_path = kwargs.get("model_path")
self._store_base_configs()

@staticmethod
def build_model(**kwargs):
Expand All @@ -190,7 +204,10 @@ def build_model(**kwargs):
kwargs["batch_size"] = 1
kwargs["device"] = device

torch.set_num_threads(kwargs.get("ncpu", 4))
ncpu = _resolve_ncpu(kwargs, 4)
kwargs["ncpu"] = ncpu
if torch.get_num_threads() != ncpu:
torch.set_num_threads(ncpu)

# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
Expand Down Expand Up @@ -302,6 +319,7 @@ def __call__(self, *args, **cfg):
return res

def generate(self, input, input_len=None, progress_callback=None, **cfg):
self._reset_runtime_configs()
if self.vad_model is None:
return self.inference(
input, input_len=input_len, progress_callback=progress_callback, **cfg
Expand All @@ -322,6 +340,8 @@ def inference(
progress_callback=None,
**cfg,
):
if kwargs is None:
self._reset_runtime_configs()
kwargs = self.kwargs if kwargs is None else kwargs
if "cache" in kwargs:
kwargs.pop("cache")
Expand Down Expand Up @@ -397,6 +417,7 @@ def inference(
return asr_result_list

def inference_with_vad(self, input, input_len=None, **cfg):
self._reset_runtime_configs()
kwargs = self.kwargs
# step.1: compute the vad model
deep_update(self.vad_kwargs, cfg)
Expand Down Expand Up @@ -691,3 +712,37 @@ def export(self, input=None, **cfg):
export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)

return export_dir

def _store_base_configs(self):
"""Snapshot base kwargs for all submodules to allow reset before inference."""
baseline = {}
for name in dir(self):
if not name.endswith("kwargs"):
continue
value = getattr(self, name, None)
if isinstance(value, dict):
baseline[name] = copy.deepcopy(value)
# include primary kwargs explicitly
baseline["kwargs"] = copy.deepcopy(self.kwargs)
self._base_kwargs_map = baseline

def _reset_runtime_configs(self):
"""Ensure runtime kwargs reset to baseline defaults before inference."""
base_map = getattr(self, "_base_kwargs_map", None)
if not base_map:
return

for name, base in base_map.items():
restored = copy.deepcopy(base)
setattr(self, name, restored)

ncpu = _resolve_ncpu(self.kwargs, 4)
self.kwargs["ncpu"] = ncpu
for name, value in base_map.items():
if name == "kwargs":
continue
config = getattr(self, name, None)
if isinstance(config, dict):
config.setdefault("ncpu", ncpu)
if torch.get_num_threads() != ncpu:
torch.set_num_threads(ncpu)
2 changes: 1 addition & 1 deletion funasr/models/fsmn_vad_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ def inference(
if len(segments_i) > 0:
segments.extend(*segments_i)

cache["prev_samples"] = audio_sample[:-m]
cache["prev_samples"] = audio_sample[-m:] if m > 0 else torch.empty(0)
if _is_final:
self.init_cache(cache)

Expand Down
2 changes: 1 addition & 1 deletion funasr/models/paraformer_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def inference(
result_i = {"key": key[0], "text": text_postprocessed}
result = [result_i]

cache["prev_samples"] = audio_sample[:-m]
cache["prev_samples"] = audio_sample[-m:] if m > 0 else torch.empty(0)
if _is_final:
self.init_cache(cache, **kwargs)

Expand Down
Empty file.
Loading