In [None]:
import os
import time
import numpy as np
import collections
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import cm as cm
from IPython.display import Audio, display, clear_output, Markdown, Image
#import librosa
#import librosa.display
import ipywidgets as widgets
# 
from tacotron2.text import text_to_sequence as text_to_sequence_internal
from tacotron2.text.symbols import symbols
# 
import tritonhttpclient as thc

defaults = {
    # settings
    'sigma_infer': 0.6,        # don't touch this
    'sampling_rate': 22050,    # don't touch this
    'stft_hop_length': 256,    # don't touch this
    'url': 'localhost:8000',   # don't touch this
    'autoplay': True,          # autoplay
    'character_limit_min': 4,  # don't touch this
    'character_limit_max': 340 # don't touch this
}


# create args object
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

args = Struct(**defaults)

triton_client = thc.InferenceServerClient(args.url)

def display_sound(signal, title, color):
    ''' displays signal '''
    clear_output(wait=True)
    plt.figure(figsize=(10, 2.5))
    plt.title(title)
    plt.tick_params(
        axis='both',
        which='both',
        bottom=True,
        top=False,
        left=False,
        right=False,
        labelbottom=True,
        labelleft=False)
    # librosa.display.waveplot(signal, color=color)
    sig = signal[0]
    hop = args.stft_hop_length
    smoothed = []
    for i in range(0, len(sig), hop):
        smoothed.append(np.average(sig[i:i+hop]))
    plt.plot(smoothed, color=color)
    plt.show()


def display_spectrogram(mel, title):
    ''' displays mel spectrogram '''
    clear_output(wait=True)
    fig = plt.figure(figsize=(10, 2.5))
    ax = fig.add_subplot(111)
    plt.title(title)
    plt.tick_params(
        axis='both',
        which='both',
        bottom=True,
        top=False,
        left=False,
        right=False,
        labelbottom=True,
        labelleft=False)
    plt.xlabel('Time')
    cmap = cm.get_cmap('jet', 30)
    cax = ax.imshow(mel[0].astype(np.float32), interpolation="nearest", cmap=cmap)
    ax.grid(True)
    plt.show()


def text_to_sequence(text):
    ''' preprocessor of tacotron2
        ::text:: the input str
        ::returns:: sequence, the preprocessed text
    '''
    sequence = text_to_sequence_internal(text, ['english_cleaners'])
    sequence = np.array(sequence, dtype=np.int64)
    return sequence


def sequence_to_mel(sequence):
    ''' calls tacotron2
        ::sequence:: int64 numpy array, contains the preprocessed text
        ::returns:: (mel, mel_lengths, alignments) tuple
                     mel is the mel-spectrogram, np.array
                     mel_lengths contains the length of the unpadded mel, np.array
                     alignments contains attention weigths, np.array
    '''
    sequence = np.reshape(sequence, (1, -1))
    input_lengths = np.array([[len(sequence[0])]], dtype=np.int64)
    # prepare input/output
    inputs = []
    inputs.append(thc.InferInput('input__0', sequence.shape, 'INT64'))
    inputs.append(thc.InferInput('input__1', input_lengths.shape, 'INT64'))
    inputs[0].set_data_from_numpy(sequence, binary_data=True)
    inputs[1].set_data_from_numpy(input_lengths, binary_data=True)
    outputs = []
    outputs.append(thc.InferRequestedOutput('output__0', binary_data=True))
    outputs.append(thc.InferRequestedOutput('output__1', binary_data=True))
    outputs.append(thc.InferRequestedOutput('output__2', binary_data=True))
    # call tacotron2
    result = triton_client.infer(model_name="tacotron2-ts-script", inputs=inputs, outputs=outputs)
    # get results
    mel = result.as_numpy('output__0')
    mel_lengths = result.as_numpy('output__1')
    alignments = result.as_numpy('output__2')
    return mel, mel_lengths, alignments


def mel_to_signal(mel, mel_lengths):
    ''' calls waveglow
        ::mel:: mel spectrogram
        ::mel_lengths:: original length of mel spectrogram
        ::returns:: waveform
    '''
    # prepare input/output
    mel = mel[:,:,:,None]
    stride = 256
    n_group = 8
    z_size =  mel.shape[2]*stride//n_group
    shape = (1, n_group, z_size, 1)
    z = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)
    
    inputs = []
    inputs.append(thc.InferInput('mel', mel.shape, 'FP16'))
    inputs.append(thc.InferInput('z', z.shape, 'FP16'))
    inputs[0].set_data_from_numpy(mel, binary_data=True)
    inputs[1].set_data_from_numpy(z, binary_data=True)
    outputs = []
    outputs.append(thc.InferRequestedOutput('audio', binary_data=True))
    # call waveglow
    result = triton_client.infer(model_name="waveglow-tensorrt", inputs=inputs, outputs=outputs)
    # get the results
    signal = result.as_numpy('audio')
    # postprocessing of waveglow: trimming signal to its actual size
    trimmed_length = mel.shape[2]*args.stft_hop_length
    signal = signal[:trimmed_length] # trim
    signal = signal.astype(np.float32)
    return signal


# widgets
def get_output_widget(width, height):
    ''' creates an output widget with default values and returns it '''
    layout = widgets.Layout(width=width,
                            height=height,
                            object_fit='fill',
                            object_position = '{center} {center}')
    ret = widgets.Output(layout=layout)
    return ret


text_area = widgets.Textarea(
    value='type here',
    placeholder='',
    description='',
    disabled=False,
    continuous_update=True,
    layout=widgets.Layout(width='550px', height='80px')
)


plot_spectrogram = get_output_widget(width='10in',height='2.1in')
plot_signal = get_output_widget(width='10in',height='2.1in')
plot_play = get_output_widget(width='10in',height='1in')


def text_area_change(change):
    ''' this gets called each time text_area.value changes '''
    text = change['new']
    text = text.strip(' ')
    length = len(text)
    if length < args.character_limit_min: # too short text
        return
    if length > args.character_limit_max: # too long text
        text_area.value = text[:args.character_limit_max]
        return
    # preprocess tacotron2
    sequence = text_to_sequence(text)
    # run tacotron2
    mel, mel_lengths, alignments = sequence_to_mel(sequence)
    with plot_spectrogram:
        display_spectrogram(mel, change['new'])
    # run waveglow
    signal = mel_to_signal(mel, mel_lengths)
    with plot_signal:
        display_sound(signal, change['new'], 'green')
    with plot_play:
        clear_output(wait=True)
        display(Audio(signal, rate=args.sampling_rate, autoplay=args.autoplay))
        # related issue: https://github.com/ipython/ipython/issues/11316


# setup callback
text_area.observe(text_area_change, names='value')

# decorative widgets
empty = widgets.VBox([], layout=widgets.Layout(height='1in'))
markdown_4 = Markdown('**tacotron2 input**')
markdown_6 = Markdown('**tacotron2 output / waveglow input**')
markdown_7 = Markdown('**waveglow output**')
markdown_8 = Markdown('**play**')

# display widgets
display(
    empty, 
    markdown_4, text_area, 
    markdown_6, plot_spectrogram, 
    markdown_7, plot_signal, 
    markdown_8, plot_play, 
    empty
)

# default text
text_area.value = "The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."