In [None]:
import json
import yaml

import matplotlib.pyplot as plt
import torch

from LightGrad import LightGrad


import IPython.display as ipd

In [None]:
def convert_phn_to_id(phonemes, phn2id):
    """
    phonemes: phonemes separated by ' '
    phn2id: phn2id dict
    """
    return [phn2id[x] for x in ['<bos>'] + phonemes.split(' ') + ['<eos>']]


def text2phnid(text, phn2id, language='zh', add_blank=True):
    if language == 'zh':
        from text import G2pZh
        character2phn = G2pZh()
        pinyin, phonemes = character2phn.character2phoneme(text)
        if add_blank:
            phonemes = ' <blank> '.join(phonemes.split(' '))
        return pinyin, phonemes, convert_phn_to_id(phonemes, phn2id)
    elif language == 'en':
        from text import G2pEn
        word2phn = G2pEn()
        phonemes = word2phn(text)
        if add_blank:
            phonemes = ' <blank> '.join(phonemes)
        return phonemes, convert_phn_to_id(phonemes, phn2id)
    else:
        raise ValueError(
            'Language should be zh (for Chinese) or en (for English)!')


def plot_mel(tensors, titles):
    xlim = max([t.shape[1] for t in tensors])
    fig, axs = plt.subplots(nrows=len(tensors),
                            ncols=1,
                            figsize=(12, 9),
                            constrained_layout=True)
    for i in range(len(tensors)):
        im = axs[i].imshow(tensors[i],
                           aspect="auto",
                           origin="lower",
                           interpolation='none')
        plt.colorbar(im, ax=axs[i])
        axs[i].set_title(titles[i])
        axs[i].set_xlim([0, xlim])
    fig.canvas.draw()
    return plt

In [None]:
# Setup HiFi-GAN

from hifi_gan import models, env

HiFiGAN_CONFIG = ''
HiFiGAN_ckpt = ''
with open(HiFiGAN_CONFIG) as f:
    hifigan_hparams = env.AttrDict(json.load(f))

generator = models.Generator(hifigan_hparams)

generator.load_state_dict(torch.load(
    HiFiGAN_ckpt, map_location='cpu')['generator'])
generator = generator.eval()
generator.remove_weight_norm()


def convert_mel_to_audio(mel):
    # only support batch size of 1
    assert mel.shape[0] == 1
    with torch.no_grad():
        audio = generator(mel).squeeze(1)  # (b,t)
    return audio

In [None]:
# inference for bznsyp

N_STEP = 4
TEMP = 1.5
STREAMING_CLIP_SIZE = 0.5  # in seconds

config_path = 'config/bznsyp_config.yaml'
ckpt_path = ''

print('loading ', ckpt_path)
_, _, state_dict = torch.load(ckpt_path, map_location='cpu')
with open(config_path) as f: config = yaml.load(f, yaml.SafeLoader)

with open(config['phn2id_path']) as f: phn2id = json.load(f)
vocab_size = len(phn2id) + 1

model = LightGrad.build_model(config, vocab_size)
model.load_state_dict(state_dict)

In [None]:
text = "做一个测试"

pinyin, phonemes, phnid = text2phnid(text, phn2id, 'zh')
print(f'pinyin seq: {pinyin}')
print(f'phoneme seq: {phonemes}')
phnid_len = torch.tensor(len(phnid), dtype=torch.long).unsqueeze(0)
phnid = torch.tensor(phnid).unsqueeze(0)

mel_clips = []

streaming_clip_frames = STREAMING_CLIP_SIZE * config['sample_rate'] // config[
    'hop_size']

for _, mel_clip, _ in model.forward_streaming(phnid,
                                              phnid_len,
                                              n_timesteps=N_STEP,
                                              temperature=TEMP,
                                              out_size=streaming_clip_frames,
                                              solver='dpm'):
    mel_clips.append(mel_clip)

mel_prediction_streaming = torch.cat(mel_clips, dim=2)

_, mel_prediction, _ = model.forward(phnid,
                                     phnid_len,
                                     n_timesteps=N_STEP,
                                     temperature=TEMP,
                                     solver='dpm')

plot_mel([mel_prediction_streaming[0], mel_prediction[0]],
         ['streaming inference', 'non-streaming inference'])

ipd.display(
    ipd.Audio(convert_mel_to_audio(mel_prediction_streaming), rate=22050))
ipd.display(ipd.Audio(convert_mel_to_audio(mel_prediction), rate=22050))

In [None]:
# inference for ljspeech

N_STEP = 4
TEMP = 1.5
STREAMING_CLIP_SIZE = 0.5  # in seconds

config_path = 'config/ljspeech_config.yaml'
ckpt_path = ''

print('loading ', ckpt_path)
_, _, state_dict = torch.load(ckpt_path,
                              map_location='cpu')


with open(config_path) as f:
    config = yaml.load(f, yaml.SafeLoader)

with open(config['phn2id_path']) as f:
    phn2id = json.load(f)
vocab_size = len(phn2id) + 1

model = LightGrad.build_model(config, vocab_size)
model.load_state_dict(state_dict)

In [None]:
text = "This is a test"

phonemes, phnid = text2phnid(text, phn2id, 'en')
print(f'phoneme seq: {phonemes}', type(phonemes))
phnid_len = torch.tensor(len(phnid), dtype=torch.long).unsqueeze(0)
phnid = torch.tensor(phnid).unsqueeze(0)

mel_clips = []

streaming_clip_frames = STREAMING_CLIP_SIZE * config['sample_rate'] // config[
    'hop_size']

for _, mel_clip, _ in model.forward_streaming(phnid,
                                              phnid_len,
                                              n_timesteps=N_STEP,
                                              temperature=TEMP,
                                              out_size=streaming_clip_frames,
                                              solver='dpm'):
    mel_clips.append(mel_clip)

mel_prediction_streaming = torch.cat(mel_clips, dim=2)

_, mel_prediction, _ = model.forward(phnid,
                                     phnid_len,
                                     n_timesteps=N_STEP,
                                     temperature=TEMP,
                                     solver='dpm')

plot_mel([mel_prediction_streaming[0], mel_prediction[0]],
         ['streaming inference', 'non-streaming inference'])

ipd.display(ipd.Audio(convert_mel_to_audio(
    mel_prediction_streaming), rate=22050))
ipd.display(ipd.Audio(convert_mel_to_audio(mel_prediction), rate=22050))