In [3]:
%matplotlib inline

In [2]:
import os
import numpy as np
import json
from ipywidgets import interact
import matplotlib.pyplot as plt
import matplotlib.cm as cm
if 'cd' not in globals():
    cd = True
    os.chdir('..')

In [5]:
def get_cer(plot_dir):
    data = []
    for folder in os.listdir(plot_dir):
            if folder.startswith('snapshot.ep.'):
                ep = int(folder.strip('snapshot.ep.'))
                with open(f'{plot_dir}/{folder}/result.txt', 'r', encoding="utf-8") as f:
                    for l in f:
                        if 'Sum/Avg' in l:
                            pter = float(l.strip().strip('|').strip().split()[-2])
                            data.append((ep, pter))
                            break

    data = sorted(data, key=lambda x: x[0])
    eps, pters = zip(*data)
    return np.array(eps), np.array(pters)

def get_cer_with_cache(plot_dir, cache_data):
    if plot_dir in cache_data:
        eps, pters = cache_data[plot_dir]
    else:
        eps, pters = get_cer(plot_dir)
        cache_data[plot_dir] = eps, pters
    return eps, pters


In [4]:
recog_langs = ['Spanish', 'Polish', 'Croatian', '203', '101']
fake_lang_labels = ['CR', 'PL', 'SP', 'PO', 'TU', 'GE', 'BG', 'TH', 'CH', 'FR', 'CZ', '203', '101', 'N', '404', '402', '307', '206', '107', '103']
langs = ['Croatian', 'Polish', 'Spanish', 'Portuguese', 'Turkish', 'German', 'Bulgarian', 'Thai', 'Mandarin', 'French', 'Czech', '203', '101', '505', '404', '402', '307', '206', '107', '103']
lang_labels = ['CR', 'PL', 'SP', 'PO', 'TU', 'GE', 'BG', 'TH', 'CH', 'FR', 'CZ', '203', '101', 'N', '404', '402', '307', '206', '107', '103']
train_langs = ['Portuguese', 'Turkish', 'German', 'Bulgarian', 'Thai', 'Mandarin', 'French', 'Czech', '505', '404', '402', '307', '206', '107', '103']
lang2label = {lang: label for lang, label in zip(langs, lang_labels)}

In [26]:
cache_data = {}

In [28]:
@interact
def update(i_recog_lang=(0,len(recog_langs)-1)):
    recog_lang = recog_langs[i_recog_lang]
    ll = lang2label[recog_lang]
    plt.figure(figsize=(10,6))
    for tag in ['wav2vecfexlemb', 'wav2vecfexlembglottoonly', 'wav2vecfexlembphonly']:
        plot_dir = f'exp/train_pytorch_{tag}/plot_mask_eval_{recog_lang}_{ll}_decode'

        eps, pters = get_cer_with_cache(plot_dir, cache_data)

        plt.plot(eps, pters, label=f'{recog_lang}_{ll}_{tag}')
        plt.legend()


interactive(children=(IntSlider(value=2, description='i_recog_lang', max=4), Output()), _dom_classes=('widget-…

In [35]:
ep = 29
for tag in ['wav2vecfexlemb', 'wav2vecfexlembglottoonly', 'wav2vecfexlembphonly']:
    test_pters = []
    for recog_lang in recog_langs:
        ll = lang2label[recog_lang]
        plot_dir = f'exp/train_pytorch_{tag}/plot_mask_eval_{recog_lang}_{ll}_decode'

        eps, pters = get_cer_with_cache(plot_dir, cache_data)
        print(f'{tag} \t{recog_lang} {ll} {pters[ep]}')
        test_pters.append(pters[ep])
    print(f'mean {np.mean(test_pters)}')

wav2vecfexlemb 	Spanish SP 34.4
wav2vecfexlemb 	Polish PL 54.0
wav2vecfexlemb 	Croatian CR 35.2
wav2vecfexlemb 	203 203 72.8
wav2vecfexlemb 	101 101 73.1
mean 53.9
wav2vecfexlembglottoonly 	Spanish SP 34.8
wav2vecfexlembglottoonly 	Polish PL 55.7
wav2vecfexlembglottoonly 	Croatian CR 35.1
wav2vecfexlembglottoonly 	203 203 69.5
wav2vecfexlembglottoonly 	101 101 73.4
mean 53.7
wav2vecfexlembphonly 	Spanish SP 38.8
wav2vecfexlembphonly 	Polish PL 53.4
wav2vecfexlembphonly 	Croatian CR 36.6
wav2vecfexlembphonly 	203 203 76.0
wav2vecfexlembphonly 	101 101 71.9
mean 55.339999999999996


In [37]:
ep = 29
for tag in ['wav2vecfexlemb', 'wav2vecfexlembglottoonly', 'wav2vecfexlembphonly']:
    test_pters = []
    for recog_lang in recog_langs:
        ll = lang2label[recog_lang]
        plot_dir = f'exp/train_pytorch_{tag}/plot_eval_{recog_lang}_{ll}_decode'

        eps, pters = get_cer_with_cache(plot_dir, cache_data)
        print(f'{tag} \t{recog_lang} {ll} {pters[ep]}')
        test_pters.append(pters[ep])
    print(f'mean {np.mean(test_pters)}')

wav2vecfexlemb 	Spanish SP 37.3
wav2vecfexlemb 	Polish PL 59.8
wav2vecfexlemb 	Croatian CR 41.3
wav2vecfexlemb 	203 203 76.3
wav2vecfexlemb 	101 101 74.6
mean 57.85999999999999
wav2vecfexlembglottoonly 	Spanish SP 37.1
wav2vecfexlembglottoonly 	Polish PL 59.7
wav2vecfexlembglottoonly 	Croatian CR 42.4
wav2vecfexlembglottoonly 	203 203 73.0
wav2vecfexlembglottoonly 	101 101 76.2
mean 57.68000000000001
wav2vecfexlembphonly 	Spanish SP 46.7
wav2vecfexlembphonly 	Polish PL 57.8
wav2vecfexlembphonly 	Croatian CR 46.0
wav2vecfexlembphonly 	203 203 77.7
wav2vecfexlembphonly 	101 101 73.5
mean 60.339999999999996


In [None]:
@interact
def update(i_recog_lang=(0,len(recog_langs)-1), others=True):
    

In [None]:
max_epoch = 30
# fake_lang_labels=['CR']
@interact
def update(i_recog_lang=(0,len(recog_langs)-1), others=True):
    recog_lang = recog_langs[i_recog_lang]
    print(recog_lang)
    plt.figure(figsize=(20,12))
#     if others:
    for ii, fll in enumerate(fake_lang_labels):
        plot_dir = f'exp/train_pytorch_wav2vecfexlemb/plot_eval_{recog_lang}_{fll}_decode'
        if not os.path.exists(plot_dir):
            continue
        eps, pters = get_cer_with_cache(plot_dir, cache_data)
        alpha = 0 if not others and fll != lang2label[recog_lang] else 1
        plt.plot(eps, pters, label=f'{recog_lang}_'+fll, alpha=alpha)
        min_ep = eps[np.argmin(pters)]
        if others:
            plt.text(eps[ii%max_epoch], pters[ii%max_epoch], f'{fll}')

    # baseline
    ll = lang2label[recog_lang]
    plot_dir = f'exp/train_pytorch_wav2vecfex/plot_eval_{recog_lang}_{ll}_decode'
#     if os.path.exists(plot_dir):
        
    eps, pters = get_cer_with_cache(plot_dir, cache_data)
    plt.plot(eps, pters, label=f'{recog_lang}_base', linestyle='--')
    
    plot_dir = f'exp/train_pytorch_wav2vecfex/plot_mask_eval_{recog_lang}_{ll}_decode'
       
    eps, pters = get_cer_with_cache(plot_dir, cache_data)
    plt.plot(eps, pters, label=f'{recog_lang}_basemask', linestyle='--')

    plt.legend()