In [None]:
import numpy as np
import librosa
import torch 

import resemblyzer

from melgan.model.generator import Generator

from hparams import create_hparams
from model import Parrot
from reader.symbols import ph2id

In [None]:
hparam_string = ''
hparams = create_hparams(hparam_string)
checkpoint_path = '../runs/outdir_embeds_by_resemblyzer_23jul2020/checkpoint_51000'

# init melgan
ckpt = torch.load('../runs/melgan_TEMP/librispeech_41cec78_0525.pt')
melgan_vocoder = Generator(80).cuda()
melgan_vocoder.load_state_dict(ckpt['model_g'])
melgan_vocoder.eval()

# init parrot (nonparaseq2seq2020)
parrot_model = Parrot(hparams).cuda()
parrot_model.load_state_dict(torch.load(checkpoint_path)["state_dict"])

# init mean statistics for mel normalization
mel_mean, mel_std = np.load(hparams.mel_mean_std)

# init speaker embeddict
speaker_embedder = resemblyzer.VoiceEncoder()

In [None]:
# count params in parrot model
print(np.sum([torch.numel(param) for param in parrot_model.parameters()]))
print(np.sum([torch.numel(param) for param in parrot_model.text_encoder.parameters()]))
print(np.sum([torch.numel(param) for param in parrot_model.audio_seq2seq.parameters()]))
# print(np.sum([torch.numel(param) for param in parrot_model.speaker_encoder.parameters()]))
print(np.sum([torch.numel(param) for param in parrot_model.speaker_classifier.parameters()]))
print(np.sum([torch.numel(param) for param in parrot_model.decoder.parameters()]))
print(np.sum([torch.numel(param) for param in parrot_model.postnet.parameters()]))

In [None]:
# Parrot helper functions
# audio_in1,  _ = librosa.load('../runs/outdir_21jul2020/p246_018.flac', sr=16000)
# audio_in2,  _ = librosa.load('../runs/outdir_21jul2020/V3-6617-314-nl_sp.flac', sr=16000)
# audio_ref, _ = librosa.load('../runs/outdir_21jul2020/p257_012.flac', sr=16000)

# audio_ref, _ = librosa.load('../runs/outdir_21jul2020/V3-6617-314-nl_sp.flac', sr=16000)
# audio_ref, _ = librosa.load('../runs/outdir_21jul2020/s107u011n.flac', sr=16000)

# path_ref = '../runs/outdir_21jul2020/s107u011n.flac'
# path_ref = '../runs/outdir_21jul2020/V3-6617-314-nl_sp.flac'
# path_ref = '../runs/outdir_21jul2020/p257_012.flac'
# path_inp = '../runs/outdir_21jul2020/p246_018.flac'
path_inp = '../runs/outdir_21jul2020/p257_012.flac'

audio_ref, _ = librosa.load(path_ref, sr=16000)
audio_in,  _ = librosa.load(path_inp, sr=16000)

def audio_to_spect_for_parrot(audio):
    spec = np.abs(librosa.stft(y=audio, n_fft=2048, hop_length=256, win_length=1024, window='hann', center=True, pad_mode='reflect'))
    melspect =librosa.feature.melspectrogram(S=spec, sr=16000, n_mels=80, htk=False)
    logmelspect = np.log(np.clip(melspect, a_min=1e-5, a_max=None)).astype(np.float32)
    normlogmelspect = (logmelspect - mel_mean[:, None]) / mel_std[:, None]
    normlogmelspect = torch.cuda.FloatTensor(normlogmelspect)[None, :]
    return normlogmelspect

def mel_to_wav(mel_input):
    mean = torch.FloatTensor(mel_mean)[:,None].cuda()
    std = torch.FloatTensor(mel_std)[:,None].cuda()
    mel_input = 1.2 * mel_input * std + mean
    mel_input = torch.log(torch.clamp(torch.exp(mel_input), 1e-5)) # TODO: clamp in logspace?

    audio = melgan_vocoder.inference(mel_input).float() / 32768.0

    return audio.data.cpu().numpy()

def text_to_phoneme_to_idc(text):
    # ONLY WORKS FOR US-EN
    from phonemizer.phonemize import phonemize
    from phonemizer.backend import FestivalBackend
    from phonemizer.separator import Separator

    phones = phonemize(
        text,
        language  = 'en-us',
        backend   = 'festival',
        separator = Separator(
            phone    = ' ',
            syllable = '',
            word     = ''
        )
    )
        
    idc = torch.cuda.LongTensor([ph2id[ph] for ph in phones.split()])[None, :]
    return idc

mel_in = audio_to_spect_for_parrot(audio_in)
mel_ref = torch.zeros_like(mel_in) # DUMMY INPUT
spkr_embed = speaker_embedder.embed_utterance(resemblyzer.preprocess_wav(path_ref))
spkr_embed = torch.cuda.FloatTensor(spkr_embed)[None, :]

In [None]:
# VC

text_input_padded = torch.cuda.LongTensor([[0, 0, 0]]) # dummy input
mel_padded = mel_in # as it is a single item, no need to pad
text_lengths = None
mel_lengths = mel_padded.size(-1)

# inp_list = [text_input_padded, mel_padded, text_lengths, mel_lengths]
inp_list = [text_input_padded, mel_padded, text_lengths, mel_lengths, spkr_embed]

y_pred = parrot_model.inference(inp_list, False, mel_ref, hparams.beam_width)

mel_output = y_pred[1]

audio_out = mel_to_wav(mel_output)

import IPython.display as ipd

ipd.display(ipd.Audio(audio_in,  rate=16000))
ipd.display(ipd.Audio(audio_ref, rate=16000))
ipd.display(ipd.Audio(audio_out, rate=16000))

In [None]:
# TTS

text = "Puppies are cute. Cats are fluffy"

text_input_padded = text_to_phoneme_to_idc(text)
mel_padded = torch.zeros_like(mel_padded) # dummy input 
text_lengths = [text_input_padded.size(-1)]
mel_lengths = mel_padded.size(-1)

# inp_list = [text_input_padded, mel_padded, text_lengths, mel_lengths]
inp_list = [text_input_padded, mel_padded, text_lengths, mel_lengths, spkr_embed]

y_pred = parrot_model.inference(inp_list, True, mel_ref, hparams.beam_width)

mel_output = y_pred[1]

audio_out = mel_to_wav(mel_output)

import IPython.display as ipd

ipd.display(ipd.Audio(audio_ref, rate=16000))
ipd.display(ipd.Audio(audio_out, rate=16000))