In [1]:
import sys
import os
import tensorflow as tf
import numpy as np
import argparse
from time import time

module_path = os.path.expanduser("~/martin/wavenet")
if module_path not in sys.path:
    sys.path.append(module_path)
    
from apps.vocoder.model import Vocoder
from apps.vocoder.hparams import hparams
from apps.vocoder.datasets.data_feeder import ensure_divisible
import apps.vocoder.audio as audio

import IPython
from IPython.display import Audio
from tqdm import tqdm

import matplotlib.pyplot as plt
import librosa
import librosa.display
%matplotlib inline 

from glob import glob

speakers = {
    'awb':0,
    'bdl':1,
    'clb':2,
    'jmk':3,
    'ksp':4,
    'rms':5,
    'slt':6
}

In [2]:
def load_data(fpath, trim=False):
    wav = audio.load_wav(fpath)
    if trim:
        wav, _ = librosa.effects.trim(wav, top_db=20)
        
    if hparams.rescaling:
        wav = wav / np.abs(wav).max() * hparams.rescaling_max
    mel = audio.melspectrogram(wav).astype(np.float32).T
    l, r = audio.lws_pad_lr(wav, hparams.fft_size, audio.get_hop_size())
    wav = np.pad(wav, (l, r), mode="constant", constant_values=0.)
    N = mel.shape[0]
    assert len(wav) >= N * audio.get_hop_size()
    wav = wav[:N * audio.get_hop_size()]
    return wav, mel

def show_spectrum(spectrum, is_transpose=False):
    plt.figure(figsize=[16, 4])
    librosa.display.specshow(spectrum) if is_transpose==True else librosa.display.specshow(spectrum.T)
    plt.colorbar()
    plt.show()
    
def get_lc(fpath, start=0, sample_size=None, trim=False):
    data_raw, data_lc = load_data(fpath, trim=trim)
    
    if sample_size is None:
        sample_size = len(data_raw)

    sample_size = ensure_divisible(sample_size, audio.get_hop_size(), True)
    max_frames = sample_size // audio.get_hop_size()
    # s = np.random.randint(0, len(data_lc) - max_frames)
    s = start
    ts = s * audio.get_hop_size()
    wav = data_raw[ts:ts + audio.get_hop_size() * max_frames]
    local_condition = data_lc[s:s + max_frames, :]

    return local_condition, wav    

def get_batch(files, trim=False):
    batch_lc = []
    batch_x = []
    input_len = []
    for fpath in files: 
        local_condition, wav = get_lc(fpath, trim=trim)
        batch_lc.append(local_condition)
        batch_x.append(wav)
        input_len.append(wav.shape[0])

    idx_max = np.argmax(input_len)
    max_x_len = batch_x[idx_max].shape[0]
    max_lc_len = batch_lc[idx_max].shape[0]

    for i in range(len(files)):
        a_x = batch_x[i]
        a_lc = batch_lc[i]
        pad_x = max_x_len - len(a_x)
        pad_lc = max_lc_len - len(a_lc)
        batch_x[i] = np.pad(a_x, (0, pad_x), mode='constant').reshape(1, -1, 1)
        batch_lc[i] = np.pad(a_lc, ((0, pad_lc), (0,0)), mode='constant').reshape(1, -1, hparams.num_mels)
    
    batch_x = np.vstack(batch_x)
    batch_lc = np.vstack(batch_lc)    
    return batch_x, batch_lc, input_len

In [3]:
tf.reset_default_graph()
tf.set_random_seed(123)
vocoder = Vocoder(hparams)

inputs = tf.placeholder(tf.float32)
l = tf.placeholder(tf.float32)
lc = vocoder.create_upsample(l)

In [4]:
# multi_speaker
speaker_name = 'awb'
speaker_id = speakers[speaker_name]
fpath = os.path.expanduser("~/data/CMU_ARCTIC/cmu_us_{}_arctic/wav/arctic_b0385.wav".format(speaker_name))

# single_speaker
# fpath = os.path.expanduser("~/data/ljspeech/test_wavs/LJ044-0079.wav")
# speaker_id = None

files = [fpath]

# Some results for keith's Tacotron (https://keithito.github.io/audio-samples/)
# files = glob("keith/*.wav")

# Some targets included r9y9's tests (https://r9y9.github.io/blog/2018/01/28/wavenet_vocoder/)
# files = glob("r9y9/*.wav")

batch_x, batch_lc, input_lengths = get_batch(files, trim=True)

In [5]:
Audio(batch_x[0].reshape(-1), rate=16000)

In [None]:
vocoder.init_synthesizer(len(batch_lc))
sess_config = tf.ConfigProto(device_count = {'GPU': 0})
with tf.Session(config=sess_config) as sess:
# with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    vocoder.load(sess, "../logs_tmp")
    _lc = sess.run(lc, feed_dict={l:batch_lc})
    results = vocoder.synthesize(sess, input_lengths, _lc, [speaker_id])    

Trying to restore saved checkpoints from ../logs_tmp ...  Checkpoint found: ../logs_tmp/model.ckpt-10000
  Global step was: 10000
  Restoring...INFO:tensorflow:Restoring parameters from ../logs_tmp/model.ckpt-10000
 Done.


 69%|██████▉   | 39861/57600 [11:52<05:16, 55.97it/s]

In [6]:
for i, wav in enumerate(results):
    IPython.display.display(Audio(batch_x[i].reshape(-1), rate=22050))
    IPython.display.display(Audio(wav, rate=22050))