In [None]:
import numpy as np
from tqdm import tqdm
import sys
import os
from collections import Counter
import math
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import time
import importlib
from collections import namedtuple
import random
import json

In [None]:
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]    
# Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.    
for i in range(len(tableau20)):    
    r, g, b = tableau20[i]    
    tableau20[i] = (r / 255., g / 255., b / 255.)

In [None]:
def play_utt(utt, m_dict):
    sr, y = scipy.io.wavfile.read(os.path.join(wavs_path, utt.rsplit("-",1)[0]+'.wav'))
    start_t = min(seg['start'] for seg in m_dict[utt]['seg'])
    end_t = max(seg['end'] for seg in m_dict[utt]['seg'])
    print(start_t, end_t)
    start_t_samples, end_t_samples = int(start_t*sr), int(end_t*sr)
    display(Audio(y[start_t_samples:end_t_samples], rate=sr))

In [None]:
def get_google_data():
    google_s2t_refs_path = os.path.join("../chainer2/speech2text/both_fbank_out/", "google_s2t_refs.dict")
    google_s2t_hyps_path = os.path.join("../chainer2/speech2text/both_fbank_out/", "google_s2t_hyps.dict")
    google_s2t_refs_for_eval_path = os.path.join("../chainer2/speech2text/both_fbank_out/", "google_s2t_refs_for_eval.dict")

    google_s2t_hyps = pickle.load(open(google_s2t_hyps_path, "rb"))
    google_hyp_r0 = google_s2t_hyps['fisher_dev_r0']

    google_s2t_refs = pickle.load(open(google_s2t_refs_path, "rb"))
    google_dev_ref_0 = google_s2t_refs['fisher_dev_ref_0']

    if os.path.exists(google_s2t_refs_for_eval_path):
        print("eval refs found, loading")
        google_s2t_refs_for_eval = pickle.load(open(google_s2t_refs_for_eval_path, "rb"))
    else:
        print("eval refs not found, creating")
        google_s2t_refs_for_eval = {}
        for u in google_dev_ref_0:
            google_s2t_refs_for_eval[u] = []
            for ref in google_s2t_refs:
                google_s2t_refs_for_eval[u].append(google_s2t_refs[ref][u])

        google_s2t_refs_for_eval = pickle.dump(google_s2t_refs_for_eval, open(google_s2t_refs_for_eval_path, "wb"))
    # end else
    
    return google_s2t_hyps, google_s2t_refs, google_s2t_refs_for_eval

In [None]:
def get_words_in_bow_vocab(words, bow_dict):
    if len(words) > 0:
        if type(words[0]) == bytes:
            list_decoded_words = [w.decode() for w in words]
            list_encoded_words = words
        else:
            list_decoded_words = words
            list_encoded_words = [w.encode() for w in words]

        common_words = list(set([list_decoded_words[i] for i in range(len(list_decoded_words)) if list_encoded_words[i] in bow_dict['w2i']]))
        return common_words
    else:
        return words

In [None]:
def display_bow_words(refs, hyps, bow_dict, m_dict, display_num=100, play_audio=False):
    total_utts = 0
    es_ref = []
    en_ref = []
    en_ref2 = []
    en_ref3 = []
    en_ref4 = []
    utts = []
    en_pred = []
    join_str = ' --- '

    for u in set(refs.keys()) & set(hyps.keys()):
        if u in m_dict:
            utts.append(u)
            total_utts += 1
            #es_ref.append(" ".join([w.decode() if u in m_dict else " " for w in m_dict[u]['es_w']]))
            #es_ref.append(" ")

            if type(refs[u][0]) == str:
                en_ref.append(join_str.join(get_words_in_bow_vocab(refs[u], bow_dict)))
            else:
                en_ref.append(join_str.join(get_words_in_bow_vocab(refs[u][0], bow_dict)))
                en_ref2.append(join_str.join(get_words_in_bow_vocab(refs[u][1], bow_dict)))
                en_ref3.append(join_str.join(get_words_in_bow_vocab(refs[u][2], bow_dict)))
                en_ref4.append(join_str.join(get_words_in_bow_vocab(refs[u][3], bow_dict)))
            
            if type(hyps[u]) == list:
                t_str = join_str.join(get_words_in_bow_vocab(hyps[u], bow_dict))
                en_pred.append(t_str)
            else:
                en_pred.append("")
            
    total_utts_with_bag_words = 0
    for u, en, en2, en3, en4, p in sorted(list(zip(utts, en_ref, en_ref2, en_ref3, en_ref4, en_pred))):
        # for reference, 1st word is GO_ID, no need to display
        if len(en) > 0 or len(en2) > 0 or len(en3) > 0 or len(en4) > 0 or len(p) > 0:
        #if len(p) > 0:
            total_utts_with_bag_words += 1
            print("Utterance: {0:s}".format(u))
            display_pp = PrettyTable(["cat","sent"], hrules=True)
            display_pp.align = "l"
            display_pp.header = False
            #display_pp.add_row(["es ref", textwrap.fill(es,50)])
            display_pp.add_row(["en ref", textwrap.fill(en,50)])
            display_pp.add_row(["en ref2", textwrap.fill(en2,50)])
            display_pp.add_row(["en ref3", textwrap.fill(en3,50)])
            display_pp.add_row(["en ref4", textwrap.fill(en4,50)])


            display_pp.add_row(["en pred", textwrap.fill(p,50)])
            print(display_pp)
            if play_audio:
                play_utt(u, m_dict)
            if total_utts_with_bag_words > display_num:
                break
        # end if
    # end for
    print("total utts={0:d}, utts with bag words={1:d}".format(total_utts, total_utts_with_bag_words))
    return total_utts, total_utts_with_bag_words
