From 2b33690f946070c9ec1c21bcc88daf3a967b2cd7 Mon Sep 17 00:00:00 2001 From: hbybyyang <2451759073@qq.com> Date: Mon, 20 Jan 2025 01:33:05 +0800 Subject: [PATCH] feat: Add Support for v2 Model in Web UI - Added support for the v2 model in the web UI. - Implemented logic to handle v2-specific features, including the handling of prompt audio and disabling streaming inference for v2 models. - Updated UI instructions to ensure users are properly guided when selecting the v2 model. --- webui.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/webui.py b/webui.py index 3552cd92d..60684e042 100644 --- a/webui.py +++ b/webui.py @@ -30,9 +30,10 @@ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮', '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮', '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮', - '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'} + '自然语言控制': '1. 选择预训练音色(v2模型需要选择或录入prompt音频)\n2. 输入instruct文本\n3. 点击生成音频按钮'} stream_mode_list = [('否', False), ('是', True)] max_val = 0.8 +model_versions = None def generate_seed(): @@ -61,6 +62,10 @@ def change_instruction(mode_checkbox_group): def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed): + if model_versions == 'v2': + if stream: + stream = False + gr.Warning('您正在使用v2版本模型, 不支持流式推理, 将使用非流式模式.') if prompt_wav_upload is not None: prompt_wav = prompt_wav_upload elif prompt_wav_record is not None: @@ -69,13 +74,13 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro prompt_wav = None # if instruct mode, please make sure that model is iic/CosyVoice-300M-Instruct and not cross_lingual mode if mode_checkbox_group in ['自然语言控制']: - if cosyvoice.instruct is False: + if cosyvoice.instruct is False and model_versions == 'v1': gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir)) yield (cosyvoice.sample_rate, default_data) if instruct_text == '': gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本') yield (cosyvoice.sample_rate, default_data) - if prompt_wav is not None or prompt_text != '': + if (prompt_wav is not None or prompt_text != '') and model_versions == 'v1': gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略') # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language if mode_checkbox_group in ['跨语种复刻']: @@ -128,11 +133,20 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro set_all_random_seed(seed) for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed): yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) - else: + elif mode_checkbox_group == '自然语言控制': logging.info('get instruct inference request') set_all_random_seed(seed) - for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed): - yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) + if model_versions == 'v1': + for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed): + yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) + elif model_versions == 'v2': + prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr)) + for i in cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k, stream=stream): + yield (cosyvoice.sample_rate, i['tts_speech'].numpy().flatten()) + else: + gr.Warning('非预期的模型版本!') + else: + gr.Warning('非预期的选项!') def main(): @@ -186,9 +200,11 @@ def main(): args = parser.parse_args() try: cosyvoice = CosyVoice(args.model_dir) + model_versions = 'v1' except Exception: try: cosyvoice = CosyVoice2(args.model_dir) + model_versions = 'v2' except Exception: raise TypeError('no valid model_type!')