From dafe7c365733d10c36143e8ce70f7dc9ed204617 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Mon, 18 Apr 2022 04:09:55 +0000 Subject: [PATCH 1/2] add fastspeech2 cnndecoder onnx model, test=tts --- examples/csmsc/tts2/local/inference.sh | 15 +- examples/csmsc/tts3/README.md | 5 + examples/csmsc/tts3/local/inference.sh | 16 +- .../csmsc/tts3/local/inference_streaming.sh | 47 ++++ .../csmsc/tts3/local/ort_predict_streaming.sh | 19 ++ .../csmsc/tts3/local/synthesize_streaming.sh | 3 +- examples/csmsc/tts3/run_cnndecoder.sh | 59 ++++- paddlespeech/t2s/exps/inference.py | 85 +------ paddlespeech/t2s/exps/inference_streaming.py | 224 +++++++++++++++++ paddlespeech/t2s/exps/ort_predict.py | 29 +-- paddlespeech/t2s/exps/ort_predict_e2e.py | 29 +-- .../t2s/exps/ort_predict_streaming.py | 233 +++++++++++++++++ paddlespeech/t2s/exps/syn_utils.py | 237 ++++++++++++++++++ paddlespeech/t2s/exps/synthesize_streaming.py | 90 ++++--- .../t2s/models/fastspeech2/fastspeech2.py | 15 +- .../t2s/modules/transformer/encoder.py | 2 +- setup.py | 1 + 17 files changed, 908 insertions(+), 201 deletions(-) create mode 100755 examples/csmsc/tts3/local/inference_streaming.sh create mode 100755 examples/csmsc/tts3/local/ort_predict_streaming.sh create mode 100644 paddlespeech/t2s/exps/inference_streaming.py create mode 100644 paddlespeech/t2s/exps/ort_predict_streaming.py diff --git a/examples/csmsc/tts2/local/inference.sh b/examples/csmsc/tts2/local/inference.sh index d78c3eb3273..ed92136cde1 100755 --- a/examples/csmsc/tts2/local/inference.sh +++ b/examples/csmsc/tts2/local/inference.sh @@ -30,21 +30,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --tones_dict=dump/tone_id_map.txt fi -# style melgan -# style melgan's Dygraph to Static Graph is not ready now -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - python3 ${BIN_DIR}/../inference.py \ - --inference_dir=${train_output_path}/inference \ - --am=speedyspeech_csmsc \ - --voc=style_melgan_csmsc \ - --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/pd_infer_out \ - --phones_dict=dump/phone_id_map.txt \ - --tones_dict=dump/tone_id_map.txt -fi - # hifigan -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ --am=speedyspeech_csmsc \ diff --git a/examples/csmsc/tts3/README.md b/examples/csmsc/tts3/README.md index bc672f66f1e..c734199b46d 100644 --- a/examples/csmsc/tts3/README.md +++ b/examples/csmsc/tts3/README.md @@ -231,14 +231,19 @@ Pretrained FastSpeech2 model with no silence in the edge of audios: The static model can be downloaded here: - [fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip) - [fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip) +- [fastspeech2_cnndecoder_csmsc_static_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_static_1.0.0.zip) +- [fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_static_1.0.0.zip) The ONNX model can be downloaded here: - [fastspeech2_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip) +- [fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_onnx_1.0.0.zip) +- [fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip) Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss :-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------: default| 2(gpu) x 76000|1.0991|0.59132|0.035815|0.31915|0.15287| conformer| 2(gpu) x 76000|1.0675|0.56103|0.035869|0.31553|0.15509| +cnndecoder| 1(gpu) x 153000|1.1153|0.61475|0.03380|0.30414|0.14707| FastSpeech2 checkpoint contains files listed below. ```text diff --git a/examples/csmsc/tts3/local/inference.sh b/examples/csmsc/tts3/local/inference.sh index 9322cfd6979..7052b347dd0 100755 --- a/examples/csmsc/tts3/local/inference.sh +++ b/examples/csmsc/tts3/local/inference.sh @@ -5,6 +5,7 @@ train_output_path=$1 stage=0 stop_stage=0 +# pwgan if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ @@ -27,20 +28,9 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --phones_dict=dump/phone_id_map.txt fi -# style melgan -# style melgan's Dygraph to Static Graph is not ready now -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - python3 ${BIN_DIR}/../inference.py \ - --inference_dir=${train_output_path}/inference \ - --am=fastspeech2_csmsc \ - --voc=style_melgan_csmsc \ - --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/pd_infer_out \ - --phones_dict=dump/phone_id_map.txt -fi # hifigan -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ --am=fastspeech2_csmsc \ @@ -51,7 +41,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi # wavernn -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ --am=fastspeech2_csmsc \ diff --git a/examples/csmsc/tts3/local/inference_streaming.sh b/examples/csmsc/tts3/local/inference_streaming.sh new file mode 100755 index 00000000000..70bd489dff6 --- /dev/null +++ b/examples/csmsc/tts3/local/inference_streaming.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +train_output_path=$1 + +stage=2 +stop_stage=2 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../inference_streaming.py \ + --inference_dir=${train_output_path}/inference_streaming \ + --am=fastspeech2_csmsc \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out_streaming \ + --phones_dict=dump/phone_id_map.txt \ + --am_streaming=True +fi + +# for more GAN Vocoders +# multi band melgan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + python3 ${BIN_DIR}/../inference_streaming.py \ + --inference_dir=${train_output_path}/inference_streaming \ + --am=fastspeech2_csmsc \ + --am_stat=dump/train/speech_stats.npy \ + --voc=mb_melgan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out_streaming \ + --phones_dict=dump/phone_id_map.txt \ + --am_streaming=True +fi + +# hifigan +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + python3 ${BIN_DIR}/../inference_streaming.py \ + --inference_dir=${train_output_path}/inference_streaming \ + --am=fastspeech2_csmsc \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out_streaming \ + --phones_dict=dump/phone_id_map.txt \ + --am_streaming=True +fi + diff --git a/examples/csmsc/tts3/local/ort_predict_streaming.sh b/examples/csmsc/tts3/local/ort_predict_streaming.sh new file mode 100755 index 00000000000..502ec912a23 --- /dev/null +++ b/examples/csmsc/tts3/local/ort_predict_streaming.sh @@ -0,0 +1,19 @@ +train_output_path=$1 + +stage=0 +stop_stage=0 + +# e2e, synthesize from text +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../ort_predict_streaming.py \ + --inference_dir=${train_output_path}/inference_onnx_streaming \ + --am=fastspeech2_csmsc \ + --am_stat=dump/train/speech_stats.npy \ + --voc=hifigan_csmsc \ + --output_dir=${train_output_path}/onnx_infer_out_streaming \ + --text=${BIN_DIR}/../csmsc_test.txt \ + --phones_dict=dump/phone_id_map.txt \ + --device=cpu \ + --cpu_threads=2 \ + --am_streaming=True +fi diff --git a/examples/csmsc/tts3/local/synthesize_streaming.sh b/examples/csmsc/tts3/local/synthesize_streaming.sh index 7606c23857f..b135db76d42 100755 --- a/examples/csmsc/tts3/local/synthesize_streaming.sh +++ b/examples/csmsc/tts3/local/synthesize_streaming.sh @@ -88,5 +88,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --text=${BIN_DIR}/../sentences.txt \ --output_dir=${train_output_path}/test_e2e_streaming \ --phones_dict=dump/phone_id_map.txt \ - --am_streaming=True + --am_streaming=True \ + --inference_dir=${train_output_path}/inference_streaming fi diff --git a/examples/csmsc/tts3/run_cnndecoder.sh b/examples/csmsc/tts3/run_cnndecoder.sh index 5cccef01610..20c0fef0005 100755 --- a/examples/csmsc/tts3/run_cnndecoder.sh +++ b/examples/csmsc/tts3/run_cnndecoder.sh @@ -9,7 +9,7 @@ stop_stage=100 conf_path=conf/cnndecoder.yaml train_output_path=exp/cnndecoder -ckpt_name=snapshot_iter_153.pdz +ckpt_name=snapshot_iter_153000.pdz # with the following command, you can choose the stage range you want to run # such as `./run.sh --stage 0 --stop-stage 0` @@ -31,18 +31,75 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi +# synthesize_e2e non-streaming if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # synthesize_e2e, vocoder is pwgan CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi +# inference non-streaming if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # inference with static model CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1 fi +# synthesize_e2e streaming if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # synthesize_e2e, vocoder is pwgan CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_streaming.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 fi +# inference streaming +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # inference with static model + CUDA_VISIBLE_DEVICES=${gpus} ./local/inference_streaming.sh ${train_output_path} || exit -1 +fi + +# paddle2onnx non streaming +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + # install paddle2onnx + version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '0.9.4' ]]; then + pip install paddle2onnx==0.9.4 + fi + ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx fastspeech2_csmsc + ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_csmsc +fi + + +# onnxruntime non streaming +# inference with onnxruntime, use fastspeech2 + hifigan by default +if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + # install onnxruntime + version=$(echo `pip list |grep "onnxruntime"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '1.10.0' ]]; then + pip install onnxruntime==1.10.0 + fi + ./local/ort_predict.sh ${train_output_path} +fi + +# paddle2onnx streaming +if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then + # install paddle2onnx + version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '0.9.4' ]]; then + pip install paddle2onnx==0.9.4 + fi + # streaming acoustic model + ./local/paddle2onnx.sh ${train_output_path} inference_streaming inference_onnx_streaming fastspeech2_csmsc_am_encoder_infer + ./local/paddle2onnx.sh ${train_output_path} inference_streaming inference_onnx_streaming fastspeech2_csmsc_am_decoder + ./local/paddle2onnx.sh ${train_output_path} inference_streaming inference_onnx_streaming fastspeech2_csmsc_am_postnet + # vocoder + ./local/paddle2onnx.sh ${train_output_path} inference_streaming inference_onnx_streaming hifigan_csmsc +fi + +# onnxruntime streaming +if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then + # install onnxruntime + version=$(echo `pip list |grep "onnxruntime"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '1.10.0' ]]; then + pip install onnxruntime==1.10.0 + fi + ./local/ort_predict_streaming.sh ${train_output_path} +fi + diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index c5b64ac726f..3e7c11f2209 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -14,92 +14,17 @@ import argparse from pathlib import Path -import numpy import soundfile as sf -from paddle import inference from timer import timer +from paddlespeech.t2s.exps.syn_utils import get_am_output from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_predictor from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_voc_output from paddlespeech.t2s.utils import str2bool -def get_predictor(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc - model_name = full_name[:full_name.rindex('_')] - config = inference.Config( - str(Path(args.inference_dir) / (full_name + ".pdmodel")), - str(Path(args.inference_dir) / (full_name + ".pdiparams"))) - if args.device == "gpu": - config.enable_use_gpu(100, 0) - elif args.device == "cpu": - config.disable_gpu() - config.enable_memory_optim() - predictor = inference.create_predictor(config) - return predictor - - -def get_am_output(args, am_predictor, frontend, merge_sentences, input): - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] - am_input_names = am_predictor.get_input_names() - get_tone_ids = False - get_spk_id = False - if am_name == 'speedyspeech': - get_tone_ids = True - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: - get_spk_id = True - spk_id = numpy.array([args.spk_id]) - if args.lang == 'zh': - input_ids = frontend.get_input_ids( - input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - elif args.lang == 'en': - input_ids = frontend.get_input_ids( - input, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in {'zh', 'en'}!") - - if get_tone_ids: - tone_ids = input_ids["tone_ids"] - tones = tone_ids[0].numpy() - tones_handle = am_predictor.get_input_handle(am_input_names[1]) - tones_handle.reshape(tones.shape) - tones_handle.copy_from_cpu(tones) - if get_spk_id: - spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) - spk_id_handle.reshape(spk_id.shape) - spk_id_handle.copy_from_cpu(spk_id) - phones = phone_ids[0].numpy() - phones_handle = am_predictor.get_input_handle(am_input_names[0]) - phones_handle.reshape(phones.shape) - phones_handle.copy_from_cpu(phones) - - am_predictor.run() - am_output_names = am_predictor.get_output_names() - am_output_handle = am_predictor.get_output_handle(am_output_names[0]) - am_output_data = am_output_handle.copy_to_cpu() - return am_output_data - - -def get_voc_output(args, voc_predictor, input): - voc_input_names = voc_predictor.get_input_names() - mel_handle = voc_predictor.get_input_handle(voc_input_names[0]) - mel_handle.reshape(input.shape) - mel_handle.copy_from_cpu(input) - - voc_predictor.run() - voc_output_names = voc_predictor.get_output_names() - voc_output_handle = voc_predictor.get_output_handle(voc_output_names[0]) - wav = voc_output_handle.copy_to_cpu() - return wav - - def parse_args(): parser = argparse.ArgumentParser( description="Paddle Infernce with acoustic model & vocoder.") @@ -204,7 +129,7 @@ def main(): merge_sentences=merge_sentences, input=sentence) wav = get_voc_output( - args, voc_predictor=voc_predictor, input=am_output_data) + voc_predictor=voc_predictor, input=am_output_data) speed = wav.size / t.elapse rtf = fs / speed print( @@ -224,7 +149,7 @@ def main(): merge_sentences=merge_sentences, input=sentence) wav = get_voc_output( - args, voc_predictor=voc_predictor, input=am_output_data) + voc_predictor=voc_predictor, input=am_output_data) N += wav.size T += t.elapse diff --git a/paddlespeech/t2s/exps/inference_streaming.py b/paddlespeech/t2s/exps/inference_streaming.py new file mode 100644 index 00000000000..0e58056c27a --- /dev/null +++ b/paddlespeech/t2s/exps/inference_streaming.py @@ -0,0 +1,224 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import numpy as np +import soundfile as sf +from timer import timer + +from paddlespeech.t2s.exps.syn_utils import denorm +from paddlespeech.t2s.exps.syn_utils import get_am_sublayer_output +from paddlespeech.t2s.exps.syn_utils import get_chunks +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_predictor +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output +from paddlespeech.t2s.exps.syn_utils import get_streaming_am_predictor +from paddlespeech.t2s.exps.syn_utils import get_voc_output +from paddlespeech.t2s.utils import str2bool + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Paddle Infernce with acoustic model & vocoder.") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=['fastspeech2_csmsc'], + help='Choose acoustic model type of tts task.') + parser.add_argument( + "--am_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') + # voc + parser.add_argument( + '--voc', + type=str, + default='pwgan_csmsc', + choices=['pwgan_csmsc', 'mb_melgan_csmsc', 'hifigan_csmsc'], + help='Choose vocoder type of tts task.') + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line") + parser.add_argument( + "--inference_dir", type=str, help="dir to save inference models") + parser.add_argument("--output_dir", type=str, help="output dir") + # inference + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu"], + help="Device selected for inference.", ) + # streaming related + parser.add_argument( + "--am_streaming", + type=str2bool, + default=False, + help="whether use streaming acoustic model") + parser.add_argument( + "--chunk_size", type=int, default=42, help="chunk size of am streaming") + parser.add_argument( + "--pad_size", type=int, default=12, help="pad size of am streaming") + + args, _ = parser.parse_known_args() + return args + + +# only inference for models trained with csmsc now +def main(): + args = parse_args() + # frontend + frontend = get_frontend(args) + + # am_predictor + am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor = get_streaming_am_predictor( + args) + am_mu, am_std = np.load(args.am_stat) + # model: {model_name}_{dataset} + am_dataset = args.am[args.am.rindex('_') + 1:] + + # voc_predictor + voc_predictor = get_predictor(args, filed='voc') + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + sentences = get_sentences(args) + + merge_sentences = True + + fs = 24000 if am_dataset != 'ljspeech' else 22050 + # warmup + for utt_id, sentence in sentences[:3]: + with timer() as t: + normalized_mel = get_streaming_am_output( + args, + am_encoder_infer_predictor=am_encoder_infer_predictor, + am_decoder_predictor=am_decoder_predictor, + am_postnet_predictor=am_postnet_predictor, + frontend=frontend, + merge_sentences=merge_sentences, + input=sentence) + mel = denorm(normalized_mel, am_mu, am_std) + wav = get_voc_output(voc_predictor=voc_predictor, input=mel) + speed = wav.size / t.elapse + rtf = fs / speed + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + + print("warm up done!") + + N = 0 + T = 0 + chunk_size = args.chunk_size + pad_size = args.pad_size + get_tone_ids = False + for utt_id, sentence in sentences: + with timer() as t: + # frontend + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + else: + print("lang should be 'zh' here!") + phones = phone_ids[0].numpy() + # acoustic model + orig_hs = get_am_sublayer_output( + am_encoder_infer_predictor, input=phones) + + if args.am_streaming: + hss = get_chunks(orig_hs, chunk_size, pad_size) + chunk_num = len(hss) + mel_list = [] + for i, hs in enumerate(hss): + am_decoder_output = get_am_sublayer_output( + am_decoder_predictor, input=hs) + am_postnet_output = get_am_sublayer_output( + am_postnet_predictor, + input=np.transpose(am_decoder_output, (0, 2, 1))) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output, (0, 2, 1)) + normalized_mel = am_output_data[0] + + sub_mel = denorm(normalized_mel, am_mu, am_std) + # clip output part of pad + if i == 0: + sub_mel = sub_mel[:-pad_size] + elif i == chunk_num - 1: + # 最后一块的右侧一定没有 pad 够 + sub_mel = sub_mel[pad_size:] + else: + # 倒数几块的右侧也可能没有 pad 够 + sub_mel = sub_mel[pad_size:(chunk_size + pad_size) - + sub_mel.shape[0]] + mel_list.append(sub_mel) + mel = np.concatenate(mel_list, axis=0) + + else: + am_decoder_output = get_am_sublayer_output( + am_decoder_predictor, input=orig_hs) + + am_postnet_output = get_am_sublayer_output( + am_postnet_predictor, + input=np.transpose(am_decoder_output, (0, 2, 1))) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output, (0, 2, 1)) + normalized_mel = am_output_data[0] + mel = denorm(normalized_mel, am_mu, am_std) + # vocoder + wav = get_voc_output(voc_predictor=voc_predictor, input=mel) + + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = fs / speed + + sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000) + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py index 2ca9b5be7cf..d1f03710b69 100644 --- a/paddlespeech/t2s/exps/ort_predict.py +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -16,39 +16,14 @@ import jsonlines import numpy as np -import onnxruntime as ort import soundfile as sf from timer import timer +from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.exps.syn_utils import get_test_dataset from paddlespeech.t2s.utils import str2bool -def get_sess(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc - model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - - if args.device == "gpu": - # fastspeech2/mb_melgan can't use trt now! - if args.use_trt: - providers = ['TensorrtExecutionProvider'] - else: - providers = ['CUDAExecutionProvider'] - elif args.device == "cpu": - providers = ['CPUExecutionProvider'] - sess_options.intra_op_num_threads = args.cpu_threads - sess = ort.InferenceSession( - model_dir, providers=providers, sess_options=sess_options) - return sess - - def ort_predict(args): # construct dataset for evaluation with jsonlines.open(args.test_metadata, 'r') as reader: @@ -131,7 +106,7 @@ def parse_args(): '--voc', type=str, default='hifigan_csmsc', - choices=['hifigan_csmsc', 'mb_melgan_csmsc'], + choices=['hifigan_csmsc', 'mb_melgan_csmsc', 'pwgan_csmsc'], help='Choose vocoder type of tts task.') # other parser.add_argument( diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index c62b7ecd87c..366a2902702 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -15,40 +15,15 @@ from pathlib import Path import numpy as np -import onnxruntime as ort import soundfile as sf from timer import timer from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.utils import str2bool -def get_sess(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc - model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - - if args.device == "gpu": - # fastspeech2/mb_melgan can't use trt now! - if args.use_trt: - providers = ['TensorrtExecutionProvider'] - else: - providers = ['CUDAExecutionProvider'] - elif args.device == "cpu": - providers = ['CPUExecutionProvider'] - sess_options.intra_op_num_threads = args.cpu_threads - sess = ort.InferenceSession( - model_dir, providers=providers, sess_options=sess_options) - return sess - - def ort_predict(args): # frontend @@ -156,7 +131,7 @@ def parse_args(): '--voc', type=str, default='hifigan_csmsc', - choices=['hifigan_csmsc', 'mb_melgan_csmsc'], + choices=['hifigan_csmsc', 'mb_melgan_csmsc', 'pwgan_csmsc'], help='Choose vocoder type of tts task.') # other parser.add_argument( diff --git a/paddlespeech/t2s/exps/ort_predict_streaming.py b/paddlespeech/t2s/exps/ort_predict_streaming.py new file mode 100644 index 00000000000..1b486d19df9 --- /dev/null +++ b/paddlespeech/t2s/exps/ort_predict_streaming.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import numpy as np +import soundfile as sf +from timer import timer + +from paddlespeech.t2s.exps.syn_utils import denorm +from paddlespeech.t2s.exps.syn_utils import get_chunks +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_sess +from paddlespeech.t2s.exps.syn_utils import get_streaming_am_sess +from paddlespeech.t2s.utils import str2bool + + +def ort_predict(args): + + # frontend + frontend = get_frontend(args) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + sentences = get_sentences(args) + + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + fs = 24000 if am_dataset != 'ljspeech' else 22050 + + # am + am_encoder_infer_sess, am_decoder_sess, am_postnet_sess = get_streaming_am_sess( + args) + am_mu, am_std = np.load(args.am_stat) + + # vocoder + voc_sess = get_sess(args, filed='voc') + + # frontend warmup + # Loading model cost 0.5+ seconds + if args.lang == 'zh': + frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True) + else: + print("lang should in be 'zh' here!") + + # am warmup + for T in [27, 38, 54]: + phone_ids = np.random.randint(1, 266, size=(T, )) + am_encoder_infer_sess.run(None, input_feed={'text': phone_ids}) + + am_decoder_input = np.random.rand(1, T * 15, 384).astype('float32') + am_decoder_sess.run(None, input_feed={'xs': am_decoder_input}) + + am_postnet_input = np.random.rand(1, 80, T * 15).astype('float32') + am_postnet_sess.run(None, input_feed={'xs': am_postnet_input}) + + # voc warmup + for T in [227, 308, 544]: + data = np.random.rand(T, 80).astype("float32") + voc_sess.run(None, input_feed={"logmel": data}) + print("warm up done!") + + N = 0 + T = 0 + merge_sentences = True + get_tone_ids = False + chunk_size = args.chunk_size + pad_size = args.pad_size + + for utt_id, sentence in sentences: + with timer() as t: + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in be 'zh' here!") + # merge_sentences=True here, so we only use the first item of phone_ids + phone_ids = phone_ids[0].numpy() + orig_hs = am_encoder_infer_sess.run( + None, input_feed={'text': phone_ids}) + if args.am_streaming: + hss = get_chunks(orig_hs[0], chunk_size, pad_size) + chunk_num = len(hss) + mel_list = [] + for i, hs in enumerate(hss): + am_decoder_output = am_decoder_sess.run( + None, input_feed={'xs': hs}) + am_postnet_output = am_postnet_sess.run( + None, + input_feed={ + 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) + }) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output[0], (0, 2, 1)) + normalized_mel = am_output_data[0][0] + + sub_mel = denorm(normalized_mel, am_mu, am_std) + # clip output part of pad + if i == 0: + sub_mel = sub_mel[:-pad_size] + elif i == chunk_num - 1: + # 最后一块的右侧一定没有 pad 够 + sub_mel = sub_mel[pad_size:] + else: + # 倒数几块的右侧也可能没有 pad 够 + sub_mel = sub_mel[pad_size:(chunk_size + pad_size) - + sub_mel.shape[0]] + mel_list.append(sub_mel) + mel = np.concatenate(mel_list, axis=0) + else: + am_decoder_output = am_decoder_sess.run( + None, input_feed={'xs': orig_hs[0]}) + am_postnet_output = am_postnet_sess.run( + None, + input_feed={ + 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) + }) + am_output_data = am_decoder_output + np.transpose( + am_postnet_output[0], (0, 2, 1)) + normalized_mel = am_output_data[0] + mel = denorm(normalized_mel, am_mu, am_std) + mel = mel[0] + # vocoder + + wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) + + N += len(wav[0]) + T += t.elapse + speed = len(wav[0]) / t.elapse + rtf = fs / speed + sf.write( + str(output_dir / (utt_id + ".wav")), + np.array(wav)[0], + samplerate=fs) + print( + f"{utt_id}, mel: {mel.shape}, wave: {len(wav[0])}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Infernce with onnxruntime.") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=['fastspeech2_csmsc'], + help='Choose acoustic model type of tts task.') + parser.add_argument( + "--am_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + + # voc + parser.add_argument( + '--voc', + type=str, + default='hifigan_csmsc', + choices=['hifigan_csmsc', 'mb_melgan_csmsc', 'pwgan_csmsc'], + help='Choose vocoder type of tts task.') + # other + parser.add_argument( + "--inference_dir", type=str, help="dir to save inference models") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line") + parser.add_argument("--output_dir", type=str, help="output dir") + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + + # inference + parser.add_argument( + "--use_trt", + type=str2bool, + default=False, + help="Whether to use inference engin TensorRT.", ) + + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu"], + help="Device selected for inference.", ) + parser.add_argument('--cpu_threads', type=int, default=1) + + # streaming related + parser.add_argument( + "--am_streaming", + type=str2bool, + default=False, + help="whether use streaming acoustic model") + parser.add_argument( + "--chunk_size", type=int, default=42, help="chunk size of am streaming") + parser.add_argument( + "--pad_size", type=int, default=12, help="pad size of am streaming") + + args, _ = parser.parse_known_args() + return args + + +def main(): + args = parse_args() + + ort_predict(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index c52cb372710..21aa5bf8cbd 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import os +from pathlib import Path import numpy as np +import onnxruntime as ort import paddle +from paddle import inference from paddle import jit from paddle.static import InputSpec @@ -62,6 +66,21 @@ } +def denorm(data, mean, std): + return data * std + mean + + +def get_chunks(data, chunk_size, pad_size): + data_len = data.shape[1] + chunks = [] + n = math.ceil(data_len / chunk_size) + for i in range(n): + start = max(0, i * chunk_size - pad_size) + end = min((i + 1) * chunk_size + pad_size, data_len) + chunks.append(data[:, start:end, :]) + return chunks + + # input def get_sentences(args): # construct dataset for evaluation @@ -241,3 +260,221 @@ def voc_to_static(args, voc_inference): paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc)) voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc)) return voc_inference + + +# inference +def get_predictor(args, filed='am'): + full_name = '' + if filed == 'am': + full_name = args.am + elif filed == 'voc': + full_name = args.voc + config = inference.Config( + str(Path(args.inference_dir) / (full_name + ".pdmodel")), + str(Path(args.inference_dir) / (full_name + ".pdiparams"))) + if args.device == "gpu": + config.enable_use_gpu(100, 0) + elif args.device == "cpu": + config.disable_gpu() + config.enable_memory_optim() + predictor = inference.create_predictor(config) + return predictor + + +def get_am_output(args, am_predictor, frontend, merge_sentences, input): + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + am_input_names = am_predictor.get_input_names() + get_tone_ids = False + get_spk_id = False + if am_name == 'speedyspeech': + get_tone_ids = True + if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + get_spk_id = True + spk_id = np.array([args.spk_id]) + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + tones = tone_ids[0].numpy() + tones_handle = am_predictor.get_input_handle(am_input_names[1]) + tones_handle.reshape(tones.shape) + tones_handle.copy_from_cpu(tones) + if get_spk_id: + spk_id_handle = am_predictor.get_input_handle(am_input_names[1]) + spk_id_handle.reshape(spk_id.shape) + spk_id_handle.copy_from_cpu(spk_id) + phones = phone_ids[0].numpy() + phones_handle = am_predictor.get_input_handle(am_input_names[0]) + phones_handle.reshape(phones.shape) + phones_handle.copy_from_cpu(phones) + + am_predictor.run() + am_output_names = am_predictor.get_output_names() + am_output_handle = am_predictor.get_output_handle(am_output_names[0]) + am_output_data = am_output_handle.copy_to_cpu() + return am_output_data + + +def get_voc_output(voc_predictor, input): + voc_input_names = voc_predictor.get_input_names() + mel_handle = voc_predictor.get_input_handle(voc_input_names[0]) + mel_handle.reshape(input.shape) + mel_handle.copy_from_cpu(input) + + voc_predictor.run() + voc_output_names = voc_predictor.get_output_names() + voc_output_handle = voc_predictor.get_output_handle(voc_output_names[0]) + wav = voc_output_handle.copy_to_cpu() + return wav + + +# streaming am +def get_streaming_am_predictor(args): + full_name = args.am + am_encoder_infer_config = inference.Config( + str( + Path(args.inference_dir) / + (full_name + "_am_encoder_infer" + ".pdmodel")), + str( + Path(args.inference_dir) / + (full_name + "_am_encoder_infer" + ".pdiparams"))) + am_decoder_config = inference.Config( + str( + Path(args.inference_dir) / + (full_name + "_am_decoder" + ".pdmodel")), + str( + Path(args.inference_dir) / + (full_name + "_am_decoder" + ".pdiparams"))) + am_postnet_config = inference.Config( + str( + Path(args.inference_dir) / + (full_name + "_am_postnet" + ".pdmodel")), + str( + Path(args.inference_dir) / + (full_name + "_am_postnet" + ".pdiparams"))) + if args.device == "gpu": + am_encoder_infer_config.enable_use_gpu(100, 0) + am_decoder_config.enable_use_gpu(100, 0) + am_postnet_config.enable_use_gpu(100, 0) + elif args.device == "cpu": + am_encoder_infer_config.disable_gpu() + am_decoder_config.disable_gpu() + am_postnet_config.disable_gpu() + + am_encoder_infer_config.enable_memory_optim() + am_decoder_config.enable_memory_optim() + am_postnet_config.enable_memory_optim() + + am_encoder_infer_predictor = inference.create_predictor( + am_encoder_infer_config) + am_decoder_predictor = inference.create_predictor(am_decoder_config) + am_postnet_predictor = inference.create_predictor(am_postnet_config) + return am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor + + +def get_am_sublayer_output(am_sublayer_predictor, input): + am_sublayer_input_names = am_sublayer_predictor.get_input_names() + input_handle = am_sublayer_predictor.get_input_handle( + am_sublayer_input_names[0]) + input_handle.reshape(input.shape) + input_handle.copy_from_cpu(input) + + am_sublayer_predictor.run() + am_sublayer_names = am_sublayer_predictor.get_output_names() + am_sublayer_handle = am_sublayer_predictor.get_output_handle( + am_sublayer_names[0]) + am_sublayer_output = am_sublayer_handle.copy_to_cpu() + return am_sublayer_output + + +def get_streaming_am_output(args, am_encoder_infer_predictor, + am_decoder_predictor, am_postnet_predictor, + frontend, merge_sentences, input): + get_tone_ids = False + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + else: + print("lang should be 'zh' here!") + + phones = phone_ids[0].numpy() + am_encoder_infer_output = get_am_sublayer_output( + am_encoder_infer_predictor, input=phones) + + am_decoder_output = get_am_sublayer_output( + am_decoder_predictor, input=am_encoder_infer_output) + + am_postnet_output = get_am_sublayer_output( + am_postnet_predictor, input=np.transpose(am_decoder_output, (0, 2, 1))) + am_output_data = am_decoder_output + np.transpose(am_postnet_output, + (0, 2, 1)) + normalized_mel = am_output_data[0] + return normalized_mel + + +def get_sess(args, filed='am'): + full_name = '' + if filed == 'am': + full_name = args.am + elif filed == 'voc': + full_name = args.voc + model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + + if args.device == "gpu": + # fastspeech2/mb_melgan can't use trt now! + if args.use_trt: + providers = ['TensorrtExecutionProvider'] + else: + providers = ['CUDAExecutionProvider'] + elif args.device == "cpu": + providers = ['CPUExecutionProvider'] + sess_options.intra_op_num_threads = args.cpu_threads + sess = ort.InferenceSession( + model_dir, providers=providers, sess_options=sess_options) + return sess + + +# streaming am +def get_streaming_am_sess(args): + full_name = args.am + am_encoder_infer_model_dir = str( + Path(args.inference_dir) / (full_name + "_am_encoder_infer" + ".onnx")) + am_decoder_model_dir = str( + Path(args.inference_dir) / (full_name + "_am_decoder" + ".onnx")) + am_postnet_model_dir = str( + Path(args.inference_dir) / (full_name + "_am_postnet" + ".onnx")) + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + if args.device == "gpu": + # fastspeech2/mb_melgan can't use trt now! + if args.use_trt: + providers = ['TensorrtExecutionProvider'] + else: + providers = ['CUDAExecutionProvider'] + elif args.device == "cpu": + providers = ['CPUExecutionProvider'] + sess_options.intra_op_num_threads = args.cpu_threads + am_encoder_infer_sess = ort.InferenceSession( + am_encoder_infer_model_dir, + providers=providers, + sess_options=sess_options) + am_decoder_sess = ort.InferenceSession( + am_decoder_model_dir, providers=providers, sess_options=sess_options) + am_postnet_sess = ort.InferenceSession( + am_postnet_model_dir, providers=providers, sess_options=sess_options) + return am_encoder_infer_sess, am_decoder_sess, am_postnet_sess diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py index 7b9906c1076..4f7a84e91f0 100644 --- a/paddlespeech/t2s/exps/synthesize_streaming.py +++ b/paddlespeech/t2s/exps/synthesize_streaming.py @@ -12,39 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -import math +import os from pathlib import Path import numpy as np import paddle import soundfile as sf import yaml +from paddle import jit +from paddle.static import InputSpec from timer import timer from yacs.config import CfgNode from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.t2s.exps.syn_utils import denorm +from paddlespeech.t2s.exps.syn_utils import get_chunks from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import model_alias +from paddlespeech.t2s.exps.syn_utils import voc_to_static from paddlespeech.t2s.utils import str2bool -def denorm(data, mean, std): - return data * std + mean - - -def get_chunks(data, chunk_size, pad_size): - data_len = data.shape[1] - chunks = [] - n = math.ceil(data_len / chunk_size) - for i in range(n): - start = max(0, i * chunk_size - pad_size) - end = min((i + 1) * chunk_size + pad_size, data_len) - chunks.append(data[:, start:end, :]) - return chunks - - def evaluate(args): # Init body. @@ -84,9 +74,49 @@ def evaluate(args): am_mu = paddle.to_tensor(am_mu) am_std = paddle.to_tensor(am_std) + # am sub layers + am_encoder_infer = am.encoder_infer + am_decoder = am.decoder + am_postnet = am.postnet + # vocoder voc_inference = get_voc_inference(args, voc_config) + # whether dygraph to static + if args.inference_dir: + # fastspeech2 cnndecoder to static + # am.encoder_infer + am_encoder_infer = jit.to_static( + am_encoder_infer, input_spec=[InputSpec([-1], dtype=paddle.int64)]) + paddle.jit.save(am_encoder_infer, + os.path.join(args.inference_dir, + args.am + "_am_encoder_infer")) + am_encoder_infer = paddle.jit.load( + os.path.join(args.inference_dir, args.am + "_am_encoder_infer")) + + # am.decoder + am_decoder = jit.to_static( + am_decoder, + input_spec=[InputSpec([1, -1, 384], dtype=paddle.float32)]) + paddle.jit.save(am_decoder, + os.path.join(args.inference_dir, + args.am + "_am_decoder")) + am_decoder = paddle.jit.load( + os.path.join(args.inference_dir, args.am + "_am_decoder")) + + # am.postnet + am_postnet = jit.to_static( + am_postnet, + input_spec=[InputSpec([1, 80, -1], dtype=paddle.float32)]) + paddle.jit.save(am_postnet, + os.path.join(args.inference_dir, + args.am + "_am_postnet")) + am_postnet = paddle.jit.load( + os.path.join(args.inference_dir, args.am + "_am_postnet")) + + # vocoder + voc_inference = voc_to_static(args, voc_inference) + output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) merge_sentences = True @@ -107,20 +137,19 @@ def evaluate(args): phone_ids = input_ids["phone_ids"] else: - print("lang should in be 'zh' here!") + print("lang should be 'zh' here!") # merge_sentences=True here, so we only use the first item of phone_ids phone_ids = phone_ids[0] with paddle.no_grad(): # acoustic model - orig_hs, h_masks = am.encoder_infer(phone_ids) - + orig_hs = am_encoder_infer(phone_ids) if args.am_streaming: hss = get_chunks(orig_hs, chunk_size, pad_size) chunk_num = len(hss) mel_list = [] for i, hs in enumerate(hss): - before_outs, _ = am.decoder(hs) - after_outs = before_outs + am.postnet( + before_outs = am_decoder(hs) + after_outs = before_outs + am_postnet( before_outs.transpose((0, 2, 1))).transpose( (0, 2, 1)) normalized_mel = after_outs[0] @@ -139,8 +168,8 @@ def evaluate(args): mel = paddle.concat(mel_list, axis=0) else: - before_outs, _ = am.decoder(orig_hs) - after_outs = before_outs + am.postnet( + before_outs = am_decoder(orig_hs) + after_outs = before_outs + am_postnet( before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) normalized_mel = after_outs[0] mel = denorm(normalized_mel, am_mu, am_std) @@ -201,16 +230,9 @@ def parse_args(): default='pwgan_csmsc', choices=[ 'pwgan_csmsc', - 'pwgan_ljspeech', - 'pwgan_aishell3', - 'pwgan_vctk', 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc', - 'hifigan_ljspeech', - 'hifigan_aishell3', - 'hifigan_vctk', - 'wavernn_csmsc', ], help='Choose vocoder type of tts task.') parser.add_argument( @@ -233,13 +255,19 @@ def parse_args(): default='zh', help='Choose model language. zh or en') + parser.add_argument( + "--inference_dir", + type=str, + default=None, + help="dir to save inference models") + parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") parser.add_argument( "--text", type=str, help="text to synthesize, a 'utt_id sentence' pair per line.") - + # streaming related parser.add_argument( "--am_streaming", type=str2bool, diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 8e52f91625e..48595bb25ca 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -590,15 +590,17 @@ def _forward(self, h_masks = self._source_mask(olens_in) else: h_masks = None - if return_after_enc: return hs, h_masks - # (B, Lmax, adim) - zs, _ = self.decoder(hs, h_masks) - # (B, Lmax, odim) + if self.decoder_type == 'cnndecoder': + # remove output masks for dygraph to static graph + zs = self.decoder(hs, h_masks) before_outs = zs else: + # (B, Lmax, adim) + zs, _ = self.decoder(hs, h_masks) + # (B, Lmax, odim) before_outs = self.feat_out(zs).reshape( (paddle.shape(zs)[0], -1, self.odim)) @@ -633,7 +635,8 @@ def encoder_infer( tone_id = tone_id.unsqueeze(0) # (1, L, odim) - hs, h_masks = self._forward( + # use *_ to avoid bug in dygraph to static graph + hs, *_ = self._forward( xs, ilens, is_inference=True, @@ -642,7 +645,7 @@ def encoder_infer( spk_emb=spk_emb, spk_id=spk_id, tone_id=tone_id) - return hs, h_masks + return hs def inference( self, diff --git a/paddlespeech/t2s/modules/transformer/encoder.py b/paddlespeech/t2s/modules/transformer/encoder.py index d05516c2280..11986360a30 100644 --- a/paddlespeech/t2s/modules/transformer/encoder.py +++ b/paddlespeech/t2s/modules/transformer/encoder.py @@ -602,7 +602,7 @@ def forward(self, xs, masks=None): if masks is not None: outputs = outputs * masks outputs = outputs.transpose([0, 2, 1]) - return outputs, masks + return outputs class CNNPostnet(nn.Layer): diff --git a/setup.py b/setup.py index 82ff6341265..1bdf1e6bafe 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "loguru", "matplotlib", "nara_wpe", + "onnxruntime", "pandas", "paddleaudio", "paddlenlp", From da93f944e6fd73e5aede534774f7da900796a0d6 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Mon, 18 Apr 2022 06:58:24 +0000 Subject: [PATCH 2/2] update, test=doc --- examples/csmsc/tts3/local/inference_streaming.sh | 4 ++-- examples/csmsc/tts3/run_cnndecoder.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/csmsc/tts3/local/inference_streaming.sh b/examples/csmsc/tts3/local/inference_streaming.sh index 70bd489dff6..719f46c620a 100755 --- a/examples/csmsc/tts3/local/inference_streaming.sh +++ b/examples/csmsc/tts3/local/inference_streaming.sh @@ -2,8 +2,8 @@ train_output_path=$1 -stage=2 -stop_stage=2 +stage=0 +stop_stage=0 # pwgan if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then diff --git a/examples/csmsc/tts3/run_cnndecoder.sh b/examples/csmsc/tts3/run_cnndecoder.sh index 20c0fef0005..61cd02a9320 100755 --- a/examples/csmsc/tts3/run_cnndecoder.sh +++ b/examples/csmsc/tts3/run_cnndecoder.sh @@ -9,7 +9,7 @@ stop_stage=100 conf_path=conf/cnndecoder.yaml train_output_path=exp/cnndecoder -ckpt_name=snapshot_iter_153000.pdz +ckpt_name=snapshot_iter_153.pdz # with the following command, you can choose the stage range you want to run # such as `./run.sh --stage 0 --stop-stage 0`