In [None]:
import sounddevice as sd
print(sd.query_devices())

In [None]:
sd.default.device = 11

In [None]:
import sys
import os
import time
import numpy as np
import collections
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import cm as cm
from IPython.display import Audio, display, clear_output, Markdown, Image
import librosa
import librosa.display
import ipywidgets as widgets
# 
# import tacotron2 preprocessing utilities
from utils.tacotron2.symbols import symbols
from utils.tacotron2 import text_to_sequence as text_to_sequence_internal
# import bert pre- and postprocessing utilities
from utils.bert.preprocessing import convert_example_to_feature, read_squad_example, get_predictions
from utils.bert.tokenization import BertTokenizer
# import jasper pre- and postprocessing utilities
from utils.jasper.speech_utils import AudioSegment, SpeechClient
# import trtis api
from tensorrtserver.api import *


defaults = {
    # settings
    'sigma_infer': 0.6,                         # don't touch this
    'sampling_rate': 22050,                     # don't touch this
    'stft_hop_length': 256,                     # don't touch this
    'url': 'localhost:8000',                    # don't touch this
    'protocol': 0,                              # 0: http, 1: grpc 
    'autoplay': True,                           # autoplay
    'character_limit_min': 4,                   # don't touch this
    'character_limit_max': 124,                 # don't touch this
    'vocab_file': "./utils/bert/vocab.txt",     # don't touch this
    'do_lower_case': True,                      # don't touch this
    'version_2_with_negative': False,           # if true, the model may give 'i don't know' as an answer. the model has to be trained for it. 
    'max_seq_length': 384,                      # the maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded. 
    'doc_stride': 128,                          # when splitting up a long document into chunks, how much stride to take between chunks
    'max_query_length': 64,                     # the maximum number of tokens in the question. Questions longer than this will be truncated to this length
    'n_best_size': 10,                          # don't touch this
    'max_answer_length': 30,                    # don't touch this
    'do_lower_case': True,                      # don't touch this
    'null_score_diff_threshold': 0.0,           # don't touch this
    'jasper_batch_size': 1,                     # don't touch this
    'jasper_sampling_rate': 44100,              # don't touch this
    'record_maximum_seconds': 4.0               # maximum number of seconds to record
}


# create args object
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)


args = Struct(**defaults)


# create the inference context for the models
infer_ctx_bert = InferContext(args.url, args.protocol, 'bertQA-ts-script', -1)
infer_ctx_tacotron2 = InferContext(args.url, args.protocol, 'tacotron2', -1)
infer_ctx_waveglow = InferContext(args.url, args.protocol, 'waveglow-trt', -1)
infer_jasper = SpeechClient(args.url, args.protocol, 'jasper-trt-ensemble', -1, 
                            args.jasper_batch_size, 'pyt', verbose=False, 
                            mode='asynchronous', from_features=False)


def display_sequences(sequences, labels, colors):
    ''' displays sequences on a dotted plot '''
    plt.figure(figsize=(10, 2.5))
    plt.tick_params(
        axis='both',
        which='both',
        bottom=False,
        top=False,
        left=False,
        right=False,
        labelbottom=False,
        labelleft=False)
    for sequence,color,label in zip(sequences,colors,labels):
        plt.plot(sequence, color, label=label)
    plt.legend(loc='upper right')
    plt.show()


def display_heatmap(sequence, title='preprocessed text'):
    ''' displays sequence as a heatmap '''
    clear_output(wait=True)
    sequence = sequence[None, :]
    plt.figure(figsize=(10, 2.5))
    plt.title(title)
    plt.tick_params(
        axis='both',
        which='both',
        bottom=False,
        top=False,
        left=False,
        right=False,
        labelbottom=False,
        labelleft=False)
    plt.imshow(sequence, cmap='BrBG_r', interpolation='nearest')
    plt.show()


def display_sound(signal, title, color):
    ''' displays signal '''
    clear_output(wait=True)
    plt.figure(figsize=(10, 2.5))
    plt.title(title)
    plt.tick_params(
        axis='both',
        which='both',
        bottom=True,
        top=False,
        left=False,
        right=False,
        labelbottom=True,
        labelleft=False)
    librosa.display.waveplot(signal, color=color)
    plt.show()


def display_spectrogram(mel, title):
    ''' displays mel spectrogram '''
    clear_output(wait=True)
    fig = plt.figure(figsize=(10, 2.5))
    ax = fig.add_subplot(111)
#    plt.title(title)
    plt.tick_params(
        axis='both',
        which='both',
        bottom=True,
        top=False,
        left=False,
        right=False,
        labelbottom=True,
        labelleft=False)
    plt.xlabel('Time')
    cmap = cm.get_cmap('jet', 30)
    cax = ax.imshow(mel.astype(np.float32), interpolation="nearest", cmap=cmap)
    ax.grid(True)
    plt.show()


def text_to_sequence(text):
    ''' preprocessor of tacotron2
        ::text:: the input str
        ::returns:: sequence, the preprocessed text
    '''
    sequence = text_to_sequence_internal(text, ['english_cleaners'])
    sequence = np.array(sequence, dtype=np.int64)
    return sequence


def sequence_to_mel(sequence):
    ''' calls tacotron2
        ::sequence:: int64 numpy array, contains the preprocessed text
        ::returns:: (mel, mel_lengths) pair
                     mel is the mel-spectrogram, np.array
                     mel_lengths contains the length of the unpadded mel, np.array
    '''
    input_lengths = [len(sequence)]
    input_lengths = np.array(input_lengths, dtype=np.int64)
    # prepare input/output
    input_dict = {}
    input_dict['sequence__0'] = (sequence,)
    input_dict['input_lengths__1'] = (input_lengths,)
    output_dict = {}
    output_dict['mel_outputs_postnet__0'] = InferContext.ResultFormat.RAW
    output_dict['mel_lengths__1'] = InferContext.ResultFormat.RAW
    batch_size = 1
    # call tacotron2
    result = infer_ctx_tacotron2.run(input_dict, output_dict, batch_size)
    # get results
    mel = result['mel_outputs_postnet__0'][0] # take only the first instance in the output batch
    mel_lengths = result['mel_lengths__1'][0] # take only the first instance in the output batch
    return mel, mel_lengths


def mel_to_signal(mel, mel_lengths):
    ''' calls waveglow
        ::mel:: mel spectrogram
        ::mel_lengths:: original length of mel spectrogram
        ::returns:: waveform
    '''
    # prepare input/output
    mel = np.expand_dims(mel, axis=0)
    input_dict = {}
    input_dict['mel'] = (mel,)
    stride = 256
    n_group = 8
    z_size = mel.shape[2]*stride//n_group
    shape = (1,n_group,z_size)
    input_dict['z'] = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)
    input_dict['z'] = (input_dict['z'],)
    output_dict = {}
    output_dict['audio'] = InferContext.ResultFormat.RAW
    # call waveglow
    result = infer_ctx_waveglow.run(input_dict, output_dict)
    # get the results
    signal = result['audio'][0] # take only the first instance in the output batch
    # postprocessing of waveglow: trimming signal to its actual size
    trimmed_length = mel_lengths[0] * args.stft_hop_length
    signal = signal[:trimmed_length] # trim
    signal = signal.astype(np.float32)
    return signal


def question_and_context_to_feature(question_text, context):
    tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large
    example = read_squad_example(question_text, 
                                 context, 
                                 version_2_with_negative=args.version_2_with_negative)
    feature = convert_example_to_feature(
        example=example, 
        tokenizer=tokenizer, 
        max_seq_length=args.max_seq_length, 
        doc_stride=args.doc_stride, 
        max_query_length=args.max_query_length)
    return example, feature


def button_rec_clicked(change):
    if record_seconds.value > 0.0:
        with plot_jasper_audio:
            clear_output(wait=True)
            recording = sd.rec(int(record_seconds.value*args.jasper_sampling_rate), samplerate=args.jasper_sampling_rate, channels=1)
            while record_seconds.value > 0:
                time.sleep(0.01)
                record_seconds.value -= 0.01
            sd.wait()
            recording = recording.squeeze()
            display_sound(recording,'recorded audio','orange')
            audio = AudioSegment(recording, args.jasper_sampling_rate).samples
        hypotheses = infer_jasper.recognize([audio], ['audio recording'])
        question_text.value = str(hypotheses[0]) + '? '


button_rec = widgets.Button(description="RECORD")
button_rec.on_click(button_rec_clicked)
record_seconds = widgets.FloatSlider(min=0.0, max=args.record_maximum_seconds, value=args.record_maximum_seconds, 
                                     step=0.1, continuous_update=True, description = "seconds")
buttons = widgets.HBox([button_rec, record_seconds])


question_text = widgets.Textarea(
    value='jasper output / bert input question',
    placeholder='',
    description='',
    disabled=False,
    continuous_update=True,
    layout=widgets.Layout(width='550px', height='40px')
)


context = widgets.Textarea(
    value='bert input context',
    placeholder='',
    description='',
    disabled=False,
    continuous_update=True,
    layout=widgets.Layout(width='550px', height='80px')
)

question_context = widgets.HBox([question_text, context])

response_text = widgets.Textarea(
    value='',
    placeholder='',
    description='',
    disabled=False,
    continuous_update=True,
    layout=widgets.Layout(width='550px', height='40px')
)


def text_to_logits(input_ids_data, segment_ids_data, input_mask_data):
    # call bert
    input_dict = {}
    input_dict['input__0']   = (input_ids_data.astype(np.int64),)
    input_dict['input__1'] = (segment_ids_data.astype(np.int64),)
    input_dict['input__2']  = (input_mask_data.astype(np.int64),)
    batch_size = 1
    output_dict = {}
    output_dict['output__0'] = InferContext.ResultFormat.RAW
    output_dict['output__1']   = InferContext.ResultFormat.RAW
    # 
    result = infer_ctx_bert.run(input_dict, output_dict, batch_size)
    # 
    start_logits = [float(x) for x in result["output__0"][0].flat]
    end_logits = [float(x) for x in result["output__1"][0].flat]
    return start_logits, end_logits


def question_text_change(change):
    text = change['new']
    text = text.strip(' ')
    length = len(text)
    if length < args.character_limit_min: # too short text
        return
    if text[-1] != '?':
        return
    # preprocess bert
    example, feature = question_and_context_to_feature(text, context.value)
    input_ids_data = np.array(feature.input_ids, dtype=np.int64)
    input_mask_data = np.array(feature.input_mask, dtype=np.int64)
    segment_ids_data = np.array(feature.segment_ids, dtype=np.int64)
    L = segment_ids_data.shape[0] - 1
    while L > 20 and segment_ids_data[L-20] == 0:
        L -= 20
    with plot_tensor:
        clear_output(wait=True)
        C = input_ids_data.max()
        sequences = (input_ids_data[:L],C//2*input_mask_data[:L],C*segment_ids_data[:L])
        display_sequences(sequences, ('input','mask','segment'), ('r.','b.','g.'))
        
    # call bert
    start_logits, end_logits = text_to_logits(input_ids_data, segment_ids_data, input_mask_data)
    with plot_logits:
        clear_output(wait=True)
        start = np.array(start_logits, dtype=np.float32)
        end = np.array(end_logits, dtype=np.float32)
        sequences = (start[:L], end[:L])
        display_sequences(sequences, ('start_logits', 'end_logits'), ('black', 'violet'))
    # postprocess bert
    prediction = get_predictions(example, feature, start_logits, end_logits, 
                                 args.n_best_size, args.max_answer_length, args.do_lower_case, 
                                 args.version_2_with_negative, args.null_score_diff_threshold)
    response_text.value = prediction[0]["text"] + '. \n'


def context_change(change):
    text = change['new']
    length = len(text)
    if length < args.character_limit_min: # too short text
        return
    # inference
    question_text.value += ' '

def response_text_change(change):
    ''' this gets called each time text_area.value changes '''
    text = change['new']
    text = text.strip(' ')
    length = len(text)
    if length < args.character_limit_min: # too short text
        return
    if length > args.character_limit_max: # too long text
        text_area.value = text[:args.character_limit_max]
        return
    # preprocess tacotron2
    sequence = text_to_sequence(text)
    with plot_response_text_preprocessed:
        display_heatmap(sequence)
    # run tacotron2
    mel, mel_lengths = sequence_to_mel(sequence)
    with plot_spectrogram:
        display_spectrogram(mel, change['new'])
    # run waveglow
    signal = mel_to_signal(mel, mel_lengths)
    with plot_signal:
        display_sound(signal, change['new'], 'green')
    with plot_play:
        clear_output(wait=True)
        display(Audio(signal, rate=args.sampling_rate, autoplay=args.autoplay))

def get_output_widget(width, height, object_fit='fill'):
    ''' creates an output widget with default values and returns it '''
    layout = widgets.Layout(width=width,
                            height=height,
                            object_fit=object_fit,
                            object_position = '{center} {center}')
    ret = widgets.Output(layout=layout)
    return ret


plot_tensor = get_output_widget(width='5in',height='1.75in')
plot_logits = get_output_widget(width='5in',height='1.75in')
plot_response_text_preprocessed = get_output_widget(width='10in',height='1in')
plot_spectrogram = get_output_widget(width='10in',height='2.0in', object_fit='scale-down')
plot_jasper_audio = get_output_widget(width='10in',height='2.0in')
plot_signal = get_output_widget(width='10in',height='2.0in')
plot_play = get_output_widget(width='4in',height='1in')

empty = widgets.VBox([], layout=widgets.Layout(height='1in'))
markdown_z0 = Markdown('**Jasper input**')
markdown_m0 = Markdown('**Jasper output / BERT input**')
markdown_bert = Markdown('**BERT**')
markdown_tacotron2 = Markdown('**Tacotron 2**')
markdown_3 = Markdown('**WaveGlow**')

bert_widgets = widgets.HBox([plot_tensor, plot_logits])
tacotron2_widgets = widgets.HBox([response_text, plot_spectrogram])

display(
    empty, 
    markdown_z0, 
    buttons, 
    markdown_m0, question_context,
    markdown_bert,
    bert_widgets,
    markdown_tacotron2,
    tacotron2_widgets,
    markdown_3, 
    plot_play, 
    empty
)


def fill_initial_values():
    with plot_jasper_audio:
        display_sound(np.zeros(100),"input audio",'orange')
    # 
    context.value = "The man holding the telescope went into a shop to purchase some flowers on the occasion of all saints day. "
    # context.value = "William Shakespeare was an English poet, playwright and actor, widely regarded as the greatest writer in the English language and the world's greatest dramatist. He is often called England's national poet and the \"Bard of Avon\"."
    question_text.value = ""
    
fill_initial_values()

response_text.observe(response_text_change, names='value')
question_text.observe(question_text_change, names='value')
context.observe(context_change, names='value')