In [2]:
%matplotlib inline

In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt

from sklearn.manifold import TSNE
from ipywidgets import interact
import matplotlib.cm as cm
from tqdm import tqdm
from nltk.corpus import reuters
from nltk import bigrams, trigrams
from collections import Counter, defaultdict

if 'cd' not in globals():
    cd = True
    os.chdir('..')


In [4]:
def get_cer(plot_dir):
    data = []
    for folder in os.listdir(plot_dir):
            if folder.startswith('snapshot.ep.'):
                ep = int(folder.strip('snapshot.ep.'))
                with open(f'{plot_dir}/{folder}/result.txt', 'r', encoding="utf-8") as f:
                    for l in f:
                        if 'Sum/Avg' in l:
                            pter = float(l.strip().strip('|').strip().split()[-2])
                            data.append((ep, pter))
                            break

    data = sorted(data, key=lambda x: x[0])
    eps, pters = zip(*data)
    return np.array(eps), np.array(pters)

def get_cer_with_cache(plot_dir, cache_data):
    if plot_dir in cache_data:
        eps, pters = cache_data[plot_dir]
    else:
        eps, pters = get_cer(plot_dir)
        cache_data[plot_dir] = eps, pters
    return eps, pters


In [5]:
cache_data = {}
langs = ['Croatian', 'Polish', 'Spanish', 'Portuguese', 'Turkish', 'German', 'Bulgarian', 'Thai', 'Mandarin', 'French', 'Czech', '203', '101', '505', '404', '402', '307', '206', '107', '103']
lang_labels = ['CR', 'PL', 'SP', 'PO', 'TU', 'GE', 'BG', 'TH', 'CH', 'FR', 'CZ', '203', '101', 'N', '404', '402', '307', '206', '107', '103']
train_langs = ['Portuguese', 'Turkish', 'German', 'Bulgarian', 'Thai', 'Mandarin', 'French', 'Czech', '505', '404', '402', '307', '206', '107', '103']
lang2label = {lang: label for lang, label in zip(langs, lang_labels)}

In [21]:
max_epoch=30
test_langs = ['Croatian', 'Polish', 'Spanish', '203', '101']
chosen_ep = 29
colors = cm.rainbow(np.linspace(0, 1, len(test_langs)))
# plt.figure(figsize=(20,15))
cross_pters = []
for i, lang in enumerate(test_langs):
    label = lang2label[lang]
    plot_dir = f'exp/train_pytorch_wav2vecfexlemb/plot_lm0.7_mask_eval_{lang}_{label}_decode'
    base_plot_dir = f'exp/train_pytorch_wav2vecfex/plot_eval_{lang}_{label}_decode'
    if not os.path.exists(base_plot_dir): 
        continue
    eps, pters = get_cer_with_cache(plot_dir, cache_data)
#     plt.plot(eps[:max_epoch], pters[:max_epoch], label=f'{lang}_'+label, color=colors[i])
    min_ep = eps[np.argmin(pters)]
    min_pter = min(pters)
#     print(f'lgcn Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
    print(f'lgcn  Lang: {lang} ep: {min_ep} min_pter {min_pter}')

#     print(f'base Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
print(','.join(test_langs))
print(','.join(cross_pters))
# plt.legend()
# plt.title(f'LGCN MASK PTER {eps[chosen_ep]}')
# plt.show()



lgcn  Lang: Croatian ep: 30 min_pter 28.2
lgcn  Lang: Polish ep: 30 min_pter 53.1
lgcn  Lang: Spanish ep: 30 min_pter 29.6
lgcn  Lang: 203 ep: 30 min_pter 65.6
lgcn  Lang: 101 ep: 30 min_pter 65.3
Croatian,Polish,Spanish,203,101



In [11]:
max_epoch=30
test_langs = ['Croatian', 'Polish', 'Spanish', '203', '101'] + train_langs
chosen_ep = 29
colors = cm.rainbow(np.linspace(0, 1, len(test_langs)))
# plt.figure(figsize=(20,15))
cross_pters = []
for i, lang in enumerate(test_langs):
    label = lang2label[lang]
    plot_dir = f'exp/train_pytorch_wav2vecfexlembadv/plot_mask_eval_{lang}_{label}_decode'
#     base_plot_dir = f'exp/train_pytorch_wav2vecfexlembadv/plot_eval_{lang}_{label}_decode'
#     if not os.path.exists(base_plot_dir): 
#         continue
    eps, pters = get_cer_with_cache(plot_dir, cache_data)
#     plt.plot(eps[:max_epoch], pters[:max_epoch], label=f'{lang}_'+label, color=colors[i])
    min_ep = eps[np.argmin(pters)]
    min_pter = min(pters)
#     print(f'lgcn Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
    print(f'lemb  Lang: {lang} ep: {min_ep} min_pter {min_pter}')

#     print(f'base Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
print(','.join(test_langs))
print(','.join(cross_pters))
# plt.legend()
# plt.title(f'LGCN MASK PTER {eps[chosen_ep]}')
# plt.show()



lemb  Lang: Croatian ep: 30 min_pter 35.2
lemb  Lang: Polish ep: 30 min_pter 54.1
lemb  Lang: Spanish ep: 30 min_pter 34.3
lemb  Lang: 203 ep: 30 min_pter 72.3
lemb  Lang: 101 ep: 30 min_pter 74.8
lemb  Lang: Portuguese ep: 30 min_pter 17.4
lemb  Lang: Turkish ep: 30 min_pter 20.4
lemb  Lang: German ep: 30 min_pter 24.4
lemb  Lang: Bulgarian ep: 30 min_pter 28.9
lemb  Lang: Thai ep: 30 min_pter 20.9
lemb  Lang: Mandarin ep: 30 min_pter 16.4
lemb  Lang: French ep: 30 min_pter 13.7
lemb  Lang: Czech ep: 30 min_pter 9.7
lemb  Lang: 505 ep: 30 min_pter 14.9
lemb  Lang: 404 ep: 30 min_pter 40.3
lemb  Lang: 402 ep: 30 min_pter 46.6
lemb  Lang: 307 ep: 30 min_pter 41.4
lemb  Lang: 206 ep: 30 min_pter 37.1
lemb  Lang: 107 ep: 30 min_pter 33.3
lemb  Lang: 103 ep: 30 min_pter 39.7
Croatian,Polish,Spanish,203,101,Portuguese,Turkish,German,Bulgarian,Thai,Mandarin,French,Czech,505,404,402,307,206,107,103



In [5]:
max_epoch=30
test_langs = ['Croatian', 'Polish', 'Spanish', '203', '101'] + train_langs
chosen_ep = 29
colors = cm.rainbow(np.linspace(0, 1, len(test_langs)))
# plt.figure(figsize=(20,15))
cross_pters = []
for i, lang in enumerate(test_langs):
    label = lang2label[lang]
    plot_dir = f'exp/train_pytorch_wav2vecfexlembft/plot_mask_eval_{lang}_{label}_decode'
#     base_plot_dir = f'exp/train_pytorch_wav2vecfexlembadv/plot_eval_{lang}_{label}_decode'
#     if not os.path.exists(base_plot_dir): 
#         continue
    eps, pters = get_cer_with_cache(plot_dir, cache_data)
#     plt.plot(eps[:max_epoch], pters[:max_epoch], label=f'{lang}_'+label, color=colors[i])
    min_ep = eps[np.argmin(pters)]
    min_pter = min(pters)
#     print(f'lgcn Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
    print(f'lemb  Lang: {lang} ep: {min_ep} min_pter {min_pter}')

#     print(f'base Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
print(','.join(test_langs))
print(','.join(cross_pters))
# plt.legend()
# plt.title(f'LGCN MASK PTER {eps[chosen_ep]}')
# plt.show()



lemb  Lang: Croatian ep: 11 min_pter 35.3
lemb  Lang: Polish ep: 11 min_pter 54.2
lemb  Lang: Spanish ep: 11 min_pter 34.4
lemb  Lang: 203 ep: 11 min_pter 72.8
lemb  Lang: 101 ep: 11 min_pter 73.1
lemb  Lang: Portuguese ep: 11 min_pter 16.4
lemb  Lang: Turkish ep: 11 min_pter 19.3
lemb  Lang: German ep: 11 min_pter 23.1
lemb  Lang: Bulgarian ep: 11 min_pter 28.1
lemb  Lang: Thai ep: 11 min_pter 19.9
lemb  Lang: Mandarin ep: 11 min_pter 15.9
lemb  Lang: French ep: 11 min_pter 12.8
lemb  Lang: Czech ep: 11 min_pter 9.1
lemb  Lang: 505 ep: 11 min_pter 14.1
lemb  Lang: 404 ep: 11 min_pter 39.3
lemb  Lang: 402 ep: 11 min_pter 45.2
lemb  Lang: 307 ep: 11 min_pter 39.2
lemb  Lang: 206 ep: 11 min_pter 35.8
lemb  Lang: 107 ep: 11 min_pter 32.6
lemb  Lang: 103 ep: 11 min_pter 39.0
Croatian,Polish,Spanish,203,101,Portuguese,Turkish,German,Bulgarian,Thai,Mandarin,French,Czech,505,404,402,307,206,107,103



In [7]:
max_epoch=30
test_langs = ['Croatian', 'Polish', 'Spanish', '203', '101'] + train_langs
chosen_ep = 29
colors = cm.rainbow(np.linspace(0, 1, len(test_langs)))
# plt.figure(figsize=(20,15))
cross_pters = []
for i, lang in enumerate(test_langs):
    label = lang2label[lang]
    plot_dir = f'exp/train_pytorch_wav2vecfexlembadv2selfatt0.001c/plot_mask_eval_{lang}_{label}_decode'
#     base_plot_dir = f'exp/train_pytorch_wav2vecfexlembadv/plot_eval_{lang}_{label}_decode'
#     if not os.path.exists(base_plot_dir): 
#         continue
    eps, pters = get_cer_with_cache(plot_dir, cache_data)
#     plt.plot(eps[:max_epoch], pters[:max_epoch], label=f'{lang}_'+label, color=colors[i])
    min_ep = eps[np.argmin(pters)]
    min_pter = min(pters)
#     print(f'lgcn Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
    print(f'lemb  Lang: {lang} ep: {min_ep} min_pter {min_pter}')

#     print(f'base Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
print(','.join(test_langs))
print(','.join(cross_pters))
# plt.legend()
# plt.title(f'LGCN MASK PTER {eps[chosen_ep]}')
# plt.show()



lemb  Lang: Croatian ep: 44 min_pter 40.3
lemb  Lang: Polish ep: 44 min_pter 57.2
lemb  Lang: Spanish ep: 44 min_pter 35.6
lemb  Lang: 203 ep: 44 min_pter 69.3
lemb  Lang: 101 ep: 44 min_pter 72.6
lemb  Lang: Portuguese ep: 44 min_pter 18.3
lemb  Lang: Turkish ep: 44 min_pter 21.2
lemb  Lang: German ep: 44 min_pter 25.5
lemb  Lang: Bulgarian ep: 44 min_pter 29.7
lemb  Lang: Thai ep: 44 min_pter 21.9
lemb  Lang: Mandarin ep: 44 min_pter 17.4
lemb  Lang: French ep: 44 min_pter 14.4
lemb  Lang: Czech ep: 44 min_pter 10.6
lemb  Lang: 505 ep: 44 min_pter 15.9
lemb  Lang: 404 ep: 44 min_pter 42.4
lemb  Lang: 402 ep: 44 min_pter 48.6
lemb  Lang: 307 ep: 44 min_pter 43.0
lemb  Lang: 206 ep: 44 min_pter 39.2
lemb  Lang: 107 ep: 44 min_pter 36.4
lemb  Lang: 103 ep: 44 min_pter 42.1
Croatian,Polish,Spanish,203,101,Portuguese,Turkish,German,Bulgarian,Thai,Mandarin,French,Czech,505,404,402,307,206,107,103



In [8]:
max_epoch=30
test_langs = ['Croatian', 'Polish', 'Spanish', '203', '101'] + train_langs
chosen_ep = 29
colors = cm.rainbow(np.linspace(0, 1, len(test_langs)))
# plt.figure(figsize=(20,15))
cross_pters = []
for i, lang in enumerate(test_langs):
    label = lang2label[lang]
    plot_dir = f'exp/train_pytorch_wav2vecfexlembadv2selfatt0.001c/plot_eval_{lang}_{label}_decode'
#     base_plot_dir = f'exp/train_pytorch_wav2vecfexlembadv/plot_eval_{lang}_{label}_decode'
#     if not os.path.exists(base_plot_dir): 
#         continue
    eps, pters = get_cer_with_cache(plot_dir, cache_data)
#     plt.plot(eps[:max_epoch], pters[:max_epoch], label=f'{lang}_'+label, color=colors[i])
    min_ep = eps[np.argmin(pters)]
    min_pter = min(pters)
#     print(f'lgcn Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
    print(f'lemb  Lang: {lang} ep: {min_ep} min_pter {min_pter}')

#     print(f'base Lang: {lang} min_ep: {min_ep} min_pter {min_pter}')
print(','.join(test_langs))
print(','.join(cross_pters))
# plt.legend()
# plt.title(f'LGCN MASK PTER {eps[chosen_ep]}')
# plt.show()



lemb  Lang: Croatian ep: 44 min_pter 51.2
lemb  Lang: Polish ep: 44 min_pter 62.3
lemb  Lang: Spanish ep: 44 min_pter 37.6
lemb  Lang: 203 ep: 44 min_pter 73.9
lemb  Lang: 101 ep: 44 min_pter 74.4
lemb  Lang: Portuguese ep: 44 min_pter 18.3
lemb  Lang: Turkish ep: 44 min_pter 21.2
lemb  Lang: German ep: 44 min_pter 25.5
lemb  Lang: Bulgarian ep: 44 min_pter 29.7
lemb  Lang: Thai ep: 44 min_pter 21.9
lemb  Lang: Mandarin ep: 44 min_pter 17.4
lemb  Lang: French ep: 44 min_pter 14.4
lemb  Lang: Czech ep: 44 min_pter 10.6
lemb  Lang: 505 ep: 44 min_pter 15.9
lemb  Lang: 404 ep: 44 min_pter 42.4
lemb  Lang: 402 ep: 44 min_pter 48.6
lemb  Lang: 307 ep: 44 min_pter 43.0
lemb  Lang: 206 ep: 44 min_pter 39.2
lemb  Lang: 107 ep: 44 min_pter 36.4
lemb  Lang: 103 ep: 44 min_pter 42.1
Croatian,Polish,Spanish,203,101,Portuguese,Turkish,German,Bulgarian,Thai,Mandarin,French,Czech,505,404,402,307,206,107,103

