In [3]:
%matplotlib inline
import IPython.display as ipd

import torch
from torch.utils.data import DataLoader

import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence

from scipy.io.wavfile import write as write_wav


def get_text(text, hps):
    text = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text = commons.intersperse(text, 0)
    text = torch.LongTensor(text)
    return text

# hyperparameter
hps = utils.get_hparams_from_file("local/12-21_12-30.json")

## Single-person speech

#### Load checkpoint

In [4]:
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model
).cuda()
_ = net_g.eval()

_ = utils.load_checkpoint("logs/12-21_12-30/G_18000.pth", net_g, None)

INFO:root:Loaded checkpoint 'logs/12-21_12-30/G_18000.pth' (iteration 368)


#### Run inference

In [13]:
stn_tst = get_text("嗨。嗨。嗨。小。黑。子。露。出。鸡。脚。了。吧。", hps)
with torch.no_grad():
    x_tst = stn_tst.cuda().unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
    audio = (
        net_g.infer(
            x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
        )[0][0, 0]
        .data.cpu()
        .float()
        .numpy()
    )
ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))
write_wav("local/12-21_12-30-output.wav", hps.data.sampling_rate, audio)