-
Notifications
You must be signed in to change notification settings - Fork 45
/
synthesize.wave.py
111 lines (92 loc) · 4.29 KB
/
synthesize.wave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
"""Synthetize sentences into speech."""
__author__ = 'Atomicoo'
import argparse
import os
import os.path as osp
import time
from scipy.io.wavfile import write
import torch
from utils.hparams import HParam
from utils.transform import StandardNorm
from helpers.synthesizer import Synthesizer
import vocoder.models
from vocoder.layers import PQMF
from utils.audio import dynamic_range_decompression
from datasets.dataset import TextProcessor
from models import ParallelText2Mel
from utils.utils import select_device, get_last_chkpt_path
try:
from helpers.manager import GPUManager
except ImportError as err:
print(err); gm = None
else:
gm = GPUManager()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--batch_size", default=8, type=int, help="Batch size")
parser.add_argument("--checkpoint", default=None, type=str, help="Checkpoint file path")
parser.add_argument("--melgan_checkpoint", default=None, type=str, help="Checkpoint file path of melgan")
parser.add_argument("--input_texts", default=None, type=str, help="Input text file path")
parser.add_argument("--outputs_dir", default=None, type=str, help="Output wave file directory")
parser.add_argument("--device", default=None, help="cuda device or cpu")
parser.add_argument("--name", default="parallel", type=str, help="Append to logdir name")
parser.add_argument("--config", default=None, type=str, help="Config file path")
args = parser.parse_args()
if torch.cuda.is_available():
index = args.device if args.device else str(0 if gm is None else gm.auto_choice())
else:
index = 'cpu'
device = select_device(index)
hparams = HParam(args.config) \
if args.config else HParam(osp.join(osp.abspath(os.getcwd()), "config", "default.yaml"))
logdir = osp.join(hparams.trainer.logdir, f"%s-%s" % (hparams.data.dataset, args.name))
checkpoint = args.checkpoint or get_last_chkpt_path(logdir)
normalizer = StandardNorm(hparams.audio.spec_mean, hparams.audio.spec_std)
processor = TextProcessor(hparams.text)
text2mel = ParallelText2Mel(hparams.parallel)
text2mel.eval()
synthesizer = Synthesizer(
model=text2mel,
checkpoint=checkpoint,
processor=processor,
normalizer=normalizer,
device=device
)
print('Synthesizing...')
since = time.time()
text_file = args.input_texts or hparams.synthesizer.inputs_file_path
with open(text_file, 'r', encoding='utf-8') as fr:
texts = fr.read().strip().split('\n')
melspecs = synthesizer.inference(texts)
print(f"Inference {len(texts)} spectrograms, total elapsed {time.time()-since:.3f}s. Done.")
vocoder_checkpoint = args.melgan_checkpoint or \
osp.join(hparams.trainer.logdir, f"{hparams.data.dataset}-melgan", hparams.melgan.checkpoint)
ckpt = torch.load(vocoder_checkpoint, map_location=device)
decompressed = dynamic_range_decompression(melspecs)
decompressed_log10 = torch.log10(decompressed)
mu = torch.tensor(ckpt['stats']['mu']).unsqueeze(1)
var = torch.tensor(ckpt['stats']['var']).unsqueeze(1)
sigma = torch.sqrt(var)
decompressed_log10_norm = (decompressed_log10 - mu) / sigma
melspecs = (decompressed_log10 - mu) / sigma
Generator = getattr(vocoder.models, ckpt['gtype'])
vocoder = Generator(**ckpt['config']).to(device)
vocoder.remove_weight_norm()
if ckpt['config']['out_channels'] > 1:
vocoder.pqmf = PQMF()
vocoder.load_state_dict(ckpt['model'])
if ckpt['config']['out_channels'] > 1:
waves = vocoder.pqmf.synthesis(vocoder(melspecs)).squeeze(1)
else:
waves = vocoder(melspecs).squeeze(1)
print(f"Generate {len(texts)} audios, total elapsed {time.time()-since:.3f}s. Done.")
print('Saving audio...')
outputs_dir = args.outputs_dir or hparams.synthesizer.outputs_dir
os.makedirs(outputs_dir, exist_ok=True)
for i, wav in enumerate(waves, start=1):
wav = wav.cpu().detach().numpy()
filename = osp.join(outputs_dir, f"{time.strftime('%Y-%m-%d')}_{i:03d}.wav")
write(filename, hparams.audio.sampling_rate, wav)
print(f"Audios saved to {outputs_dir}. Done.")
print(f'Done. ({time.time()-since:.3f}s)')