In [None]:
import os
import time
import numpy as np
import collections
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import cm as cm
from IPython.display import Audio, display, clear_output, Markdown, Image
import librosa
import librosa.display
import ipywidgets as widgets
# 
from tacotron2.text import text_to_sequence as text_to_sequence_internal
from tacotron2.text.symbols import symbols
# 
from tensorrtserver.api import *


defaults = {
    # settings
    'sigma_infer': 0.6,        # don't touch this
    'sampling_rate': 22050,    # don't touch this
    'stft_hop_length': 256,    # don't touch this
    'url': 'localhost:8000',   # don't touch this
    'protocol': 0,             # 0: http, 1: grpc 
    'autoplay': True,          # autoplay
    'character_limit_min': 4,  # don't touch this
    'character_limit_max': 340 # don't touch this
}


# create args object
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)


args = Struct(**defaults)


# create the inference context for the models
infer_ctx_tacotron2 = InferContext(args.url, args.protocol, 'tacotron2', -1)
infer_ctx_waveglow = InferContext(args.url, args.protocol, 'waveglow', -1)


def display_heatmap(sequence, title='preprocessed text'):
    ''' displays sequence as a heatmap '''
    clear_output(wait=True)
    sequence = sequence[None, :]
    plt.figure(figsize=(10, 2.5))
    plt.title(title)
    plt.tick_params(
        axis='both',
        which='both',
        bottom=False,
        top=False,
        left=False,
        right=False,
        labelbottom=False,
        labelleft=False)
    plt.imshow(sequence, cmap='BrBG_r', interpolation='nearest')
    plt.show()


def display_sound(signal, title, color):
    ''' displays signal '''
    clear_output(wait=True)
    plt.figure(figsize=(10, 2.5))
    plt.title(title)
    plt.tick_params(
        axis='both',
        which='both',
        bottom=True,
        top=False,
        left=False,
        right=False,
        labelbottom=True,
        labelleft=False)
    librosa.display.waveplot(signal, color=color)
    plt.show()


def display_spectrogram(mel, title):
    ''' displays mel spectrogram '''
    clear_output(wait=True)
    fig = plt.figure(figsize=(10, 2.5))
    ax = fig.add_subplot(111)
    plt.title(title)
    plt.tick_params(
        axis='both',
        which='both',
        bottom=True,
        top=False,
        left=False,
        right=False,
        labelbottom=True,
        labelleft=False)
    plt.xlabel('Time')
    cmap = cm.get_cmap('jet', 30)
    cax = ax.imshow(mel.astype(np.float32), interpolation="nearest", cmap=cmap)
    ax.grid(True)
    plt.show()


def text_to_sequence(text):
    ''' preprocessor of tacotron2
        ::text:: the input str
        ::returns:: sequence, the preprocessed text
    '''
    sequence = text_to_sequence_internal(text, ['english_cleaners'])
    sequence = np.array(sequence, dtype=np.int64)
    return sequence


def sequence_to_mel(sequence):
    ''' calls tacotron2
        ::sequence:: int64 numpy array, contains the preprocessed text
        ::returns:: (mel, mel_lengths, alignments) tuple
                     mel is the mel-spectrogram, np.array
                     mel_lengths contains the length of the unpadded mel, np.array
                     alignments contains attention weigths, np.array
    '''
    input_lengths = [len(sequence)]
    input_lengths = np.array(input_lengths, dtype=np.int64)
    # prepare input/output
    input_dict = {}
    input_dict['sequence__0'] = (sequence,)
    input_dict['input_lengths__1'] = (input_lengths,)
    output_dict = {}
    output_dict['mel_outputs_postnet__0'] = InferContext.ResultFormat.RAW
    output_dict['mel_lengths__1'] = InferContext.ResultFormat.RAW
    output_dict['alignments__2'] = InferContext.ResultFormat.RAW
    batch_size = 1
    # call tacotron2
    result = infer_ctx_tacotron2.run(input_dict, output_dict, batch_size)
    # get results
    mel = result['mel_outputs_postnet__0'][0] # take only the first instance in the output batch
    mel_lengths = result['mel_lengths__1'][0] # take only the first instance in the output batch
    alignments = result['alignments__2'][0] # take only the first instance in the output batch
    return mel, mel_lengths, alignments


def force_to_shape(mel, length):
    ''' preprocessor of waveglow
        :: mel :: numpy array 
        :: length :: int 
        :: return :: m padded (or trimmed) to length in dimension 1
    '''
    diff = length - mel.shape[1]
    if 0 < diff:
        # pad it
        min_value = mel.min()
        shape = ((0,0),(0,diff))
        ret = np.pad(mel, shape, mode='constant', constant_values=min_value)
    else:
        # trim it
        ret = mel[:,:length]
    ret = ret[:,:,None]
    return ret


def mel_to_signal(mel, mel_lengths):
    ''' calls waveglow
        ::mel:: mel spectrogram
        ::mel_lengths:: original length of mel spectrogram
        ::returns:: waveform
    '''
    # padding/trimming mel to dimension 620
    mel = mel[:,:,None]
    # prepare input/output
    input_dict = {}
    input_dict['mel'] = (mel,)
    stride = 256
    kernel_size = 1024
    n_group = 8
    z_size = (mel.shape[1]-1)*stride + (kernel_size-1) + 1 - (kernel_size-stride)
    z_size = z_size//n_group
    shape = (n_group,z_size,1)
    input_dict['z'] = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)
    input_dict['z'] = (input_dict['z'],)
    output_dict = {}
    output_dict['audio'] = InferContext.ResultFormat.RAW
    batch_size = 1
    # call waveglow
    result = infer_ctx_waveglow.run(input_dict, output_dict, batch_size)
    # get the results
    signal = result['audio'][0] # take only the first instance in the output batch
    # postprocessing of waveglow: trimming signal to its actual size
    trimmed_length = mel_lengths[0] * args.stft_hop_length
    signal = signal[:trimmed_length] # trim
    signal = signal.astype(np.float32)
    return signal


# widgets
def get_output_widget(width, height):
    ''' creates an output widget with default values and returns it '''
    layout = widgets.Layout(width=width,
                            height=height,
                            object_fit='fill',
                            object_position = '{center} {center}')
    ret = widgets.Output(layout=layout)
    return ret


text_area = widgets.Textarea(
    value='type here',
    placeholder='',
    description='',
    disabled=False,
    continuous_update=True,
    layout=widgets.Layout(width='550px', height='80px')
)


plot_text_area_preprocessed = get_output_widget(width='10in',height='1in')
plot_spectrogram = get_output_widget(width='10in',height='2.1in')
plot_signal = get_output_widget(width='10in',height='2.1in')
plot_play = get_output_widget(width='10in',height='1in')


def text_area_change(change):
    ''' this gets called each time text_area.value changes '''
    text = change['new']
    text = text.strip(' ')
    length = len(text)
    if length < args.character_limit_min: # too short text
        return
    if length > args.character_limit_max: # too long text
        text_area.value = text[:args.character_limit_max]
        return
    # preprocess tacotron2
    sequence = text_to_sequence(text)
    with plot_text_area_preprocessed:
        display_heatmap(sequence)
    # run tacotron2
    mel, mel_lengths, alignments = sequence_to_mel(sequence)
    with plot_spectrogram:
        display_spectrogram(mel, change['new'])
    # run waveglow
    signal = mel_to_signal(mel, mel_lengths)
    with plot_signal:
        display_sound(signal, change['new'], 'green')
    with plot_play:
        clear_output(wait=True)
        display(Audio(signal, rate=args.sampling_rate, autoplay=args.autoplay))
        # related issue: https://github.com/ipython/ipython/issues/11316


# setup callback
text_area.observe(text_area_change, names='value')

# decorative widgets
empty = widgets.VBox([], layout=widgets.Layout(height='1in'))
markdown_4 = Markdown('**tacotron2 input**')
markdown_5 = Markdown('**tacotron2 preprocessing**')
markdown_6 = Markdown('**tacotron2 output / waveglow input**')
markdown_7 = Markdown('**waveglow output**')
markdown_8 = Markdown('**play**')

# display widgets
display(
    empty, 
    markdown_4, text_area, 
#    markdown_5, plot_text_area_preprocessed, 
    markdown_6, plot_spectrogram, 
    markdown_7, plot_signal, 
    markdown_8, plot_play, 
    empty
)

# default text
text_area.value = "The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves."