In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import re, os, csv, pathlib, ast
import pandas as pd
from statistics import mean, variance

accent_short_forms = {"hindi":"HIN", "korean":"KOR", "vietnamese":"VTN", "arabic":"ARB", "chinese":"CHN", "spanish":"ESP"}
accent_map = {"ABA":"arabic","SKA":"arabic","YBAA":"arabic","ZHAA":"arabic",
              "BWC":"chinese","LXC":"chinese","NCC":"chinese","TXHC":"chinese",
              "ASI":"hindi","RRBI":"hindi","SVBI":"hindi","TNI":"hindi",
              "HJK":"korean","HKK":"korean","YDCK":"korean","YKWK":"korean",
              "EBVS":"spanish","ERMS":"spanish","MBMPS":"spanish","NJS":"spanish",
              "HQTV":"vietnamese","PNV":"vietnamese","THV":"vietnamese","TLV":"vietnamese"
              }
raw_string="""|ABA|M|Arabic|1129|150|\n|SKA|F|Arabic|974|150|\n|YBAA|M|Arabic|1130|149|\n|ZHAA|F|Arabic|1132|150|\n|BWC|M|Chinese|1130|150|\n|LXC|F|Chinese|1131|150|\n|NCC|F|Chinese|1131|150|\n|TXHC|M|Chinese|1132|150|\n|ASI|M|Hindi|1131|150|\n|RRBI|M|Hindi|1130|150|\n|SVBI|F|Hindi|1132|150|\n|TNI|F|Hindi|1131|150|\n|HJK|F|Korean|1131|150|\n|HKK|M|Korean|1131|150|\n|YDCK|F|Korean|1131|150|\n|YKWK|M|Korean|1131|150|\n|EBVS|M|Spanish|1007|150|\n|ERMS|M|Spanish|1132|150|\n|MBMPS|F|Spanish|1132|150|\n|NJS|F|Spanish|1131|150|\n|HQTV|M|Vietnamese|1132|150|\n|PNV|F|Vietnamese|1132|150|\n|THV|F|Vietnamese|1132|150|\n|TLV|M|Vietnamese|1132|150|"""
raw_strings=raw_string.split('\n')
gender_map={}
for lne in raw_strings:
    attrs=lne.split('|')
    gender_map[attrs[1]]=attrs[2]

composed_accent_map = {k: accent_short_forms.get(v) for k, v in accent_map.items()}

def replace_with_short_forms(s):
    for key, value in accent_short_forms.items():
        s = s.replace(key, value)
    return s

def last_name(pth):
    return pathlib.PurePath(pth).name

def get_dirs(pth):
    return [last_name(f.name) for f in os.scandir(pth) if f.is_dir()]

def get_each_run(lne):
    return list(map(float, re.findall(': (.+?) -> ', lne)[0].split(' ')))

def get_selection_counts(s):
    return list(map(replace_with_short_forms, re.findall('Counter\\((.+?)\\)', s)))

def get_test_file_from_stats_path(run_number, stats_file_opened):
    return stats_file_opened.name[:-9]+"run_{}/output/test_infer_log.txt".format(run_number)

def WER_test_file(test_file):
    txt_file = open(test_file, 'r')
    lines = txt_file.readlines()
    matched = ""
    for line in lines:
        if "==========>>>>>>Evaluation Greedy WER: " in line:
            txt_file.close()
            return float(line.rstrip().split(": ")[1])
    txt_file.close()
    return ""

def get_eta(func, eta):
    return "-n:"+str(float(eta[4:]))

In [3]:
budget = 100
# target = 50
target = 10
# budget = 200
# target = 80
features = '39'
# features = 'TRILL'
csv_name = "report_{}_{}_{}.csv".format(budget, target, features)

In [8]:
# sample_path = 'Error-Driven-ASR-Personalization/CMU_expts/accent/hindi/manifests/TSS_output/all/budget_100/target_50/FL1MI/eta_1.0/euclidean/39/stats.txt'
# CMU_expts/speaker_without/ABA/manifests/TSS_output/all/budget_100/target_50/FL1MI/eta_1.0/euclidean/39/run_1/
# budget = 200
# target = 80

cols = ['accent', 'ground', 'function', 'similarity', 'duration', 'samples', 
        'WER-r1', 'WER-r2', 'WER-r3', 'WER-mean', 'WER-var', 'accents_run1', 'accents_run2', 'accents_run3']
df = pd.DataFrame(columns = cols)

accents = [f.name for f in os.scandir('./') if f.is_dir() and f.name != '.ipynb_checkpoints' and f.name != 'reserved_TSS_output']

# not random
for accent in accents:
    if not(pathlib.Path('./{}/manifests/TSS_output_NEW/'.format(accent)).is_dir()):
        print("no results for accent {}".format(accent))
        continue
    if 'within' not in get_dirs('./{}/manifests/TSS_output_NEW/'.format(accent)):
        print("no within results for {}".format(accent))
        continue
    if not(os.path.isdir('./{}/manifests/TSS_output_NEW/within/budget_{}/target_{}/'.format(accent, budget, target))):
        continue
#     for function in get_dirs('./{}/manifests/TSS_output/within/budget_{}/target_{}/'.format(accent, budget, target)):        
    for function in ['equal_random']:
        if 'deprecated' in function:
            continue
        stats_file_path='./{}/manifests/TSS_output_NEW/within/budget_{}/target_{}/{}/stats.txt'.format(accent, budget, target, function)
        stats_file = open(stats_file_path, 'r')
        lines = stats_file.readlines()
        df_selections = get_selection_counts(lines[5])
        total_selections, total_durations, accented_selections, accented_durations = map(get_each_run, lines[:4])
        sample_frac = mean([x[0]/x[1] for x in zip(accented_selections, total_selections)])
        sample_total = mean(total_selections)
        duration_frac = mean([x[0]/x[1] for x in zip(accented_durations, total_durations)])
        duration_total = mean(total_durations)
        df_duration = "{:.2f}/{:.2f}".format(duration_total*duration_frac, duration_total)
        df_samples = "{:.2f}/{:.2f}".format(sample_total*sample_frac, sample_total)
        try:
            wers = [WER_test_file(get_test_file_from_stats_path(i, stats_file)) for i in range(1,4)]
            wers = [x for x in wers if type(x)==float or type(x)==int]
            df_wer_mean = round(mean(wers), 2)
            df_wer_var = round(variance(wers), 3)
        except:
            print("no WER's in file", get_test_file_from_stats_path(1, stats_file))
            wers = [0,0,0]
            df_wer_mean = 0
            df_wer_var = 999
        df = df.append(dict(zip(cols, [accent, "within", function, "NA", df_duration, df_samples]+
                                       wers+[df_wer_mean, round(df_wer_var**0.5, 3)] + df_selections)), 
                       ignore_index=True)
        stats_file.close()

no WER's in file ./hindi/manifests/TSS_output_NEW/within/budget_100/target_10/equal_random/run_1/output/test_infer_log.txt
no WER's in file ./chinese/manifests/TSS_output_NEW/within/budget_100/target_10/equal_random/run_1/output/test_infer_log.txt
no WER's in file ./spanish/manifests/TSS_output_NEW/within/budget_100/target_10/equal_random/run_1/output/test_infer_log.txt
no WER's in file ./arabic/manifests/TSS_output_NEW/within/budget_100/target_10/equal_random/run_1/output/test_infer_log.txt
no WER's in file ./korean/manifests/TSS_output_NEW/within/budget_100/target_10/equal_random/run_1/output/test_infer_log.txt
no WER's in file ./vietnamese/manifests/TSS_output_NEW/within/budget_100/target_10/equal_random/run_1/output/test_infer_log.txt


In [9]:
df = df.sort_values(by=['accent', 'similarity', 'ground', 'function'], ascending=True, ignore_index=True)
display(df)

Unnamed: 0,accent,ground,function,similarity,duration,samples,WER-r1,WER-r2,WER-r3,WER-mean,WER-var,accents_run1,accents_run2,accents_run3
0,arabic,within,equal_random,,354.85/354.85,99.33/99.33,0,0,0,0,31.607,"{'ZHAA': 742, 'YBAA': 741, 'ABA': 740, 'SKA': ...","{'ZHAA': 742, 'YBAA': 741, 'ABA': 740, 'SKA': ...","{'ZHAA': 742, 'YBAA': 741, 'ABA': 740, 'SKA': ..."
1,chinese,within,equal_random,,351.27/351.27,90.33/90.33,0,0,0,0,31.607,"{'TXHC': 742, 'BWC': 741, 'LXC': 741, 'NCC': 741}","{'TXHC': 742, 'BWC': 741, 'LXC': 741, 'NCC': 741}","{'TXHC': 742, 'BWC': 741, 'LXC': 741, 'NCC': 741}"
2,hindi,within,equal_random,,353.12/353.12,115.67/115.67,0,0,0,0,31.607,"{'SVBI': 742, 'ASI': 741, 'TNI': 741, 'RRBI': ...","{'SVBI': 742, 'ASI': 741, 'TNI': 741, 'RRBI': ...","{'SVBI': 742, 'ASI': 741, 'TNI': 741, 'RRBI': ..."
3,korean,within,equal_random,,350.45/350.45,98.33/98.33,0,0,0,0,31.607,"{'YKWK': 741, 'YDCK': 741, 'HJK': 741, 'HKK': ...","{'YKWK': 741, 'YDCK': 741, 'HJK': 741, 'HKK': ...","{'YKWK': 741, 'YDCK': 741, 'HJK': 741, 'HKK': ..."
4,spanish,within,equal_random,,351.18/351.18,94.67/94.67,0,0,0,0,31.607,"{'ERMS': 742, 'MBMPS': 742, 'NJS': 741, 'EBVS'...","{'ERMS': 742, 'MBMPS': 742, 'NJS': 741, 'EBVS'...","{'ERMS': 742, 'MBMPS': 742, 'NJS': 741, 'EBVS'..."
5,vietnamese,within,equal_random,,347.80/347.80,90.00/90.00,0,0,0,0,31.607,"{'THV': 742, 'TLV': 742, 'HQTV': 742, 'PNV': 742}","{'THV': 742, 'TLV': 742, 'HQTV': 742, 'PNV': 742}","{'THV': 742, 'TLV': 742, 'HQTV': 742, 'PNV': 742}"


In [10]:
# total selection : 100 100 100 -> 100.00
# total selection duration: 357.0149433106577 357.0149433106577 357.0149433106577 -> 357.01
# accented selection: 76 76 76 -> 76.00
# accented duration: 254.74947845804974 254.74947845804974 254.74947845804974 -> 254.75

# all selections: [Counter({'hindi': 76, 'korean': 8, 'spanish': 7, 'arabic': 3, 'chinese': 3, 'vietnamese': 3}), Counter({'hindi': 76, 'korean': 8, 'spanish': 7, 'arabic': 3, 'chinese': 3, 'vietnamese': 3}), Counter({'hindi': 76, 'korean': 8, 'spanish': 7, 'arabic': 3, 'chinese': 3, 'vietnamese': 3})]

#Evaluation Greedy WER: 16.19

df2 = pd.read_csv(csv_name)
df2=df2.append(df)
df2 = df2.sort_values(by=['accent', 'ground', 'similarity', 'function'], ascending=True, ignore_index=True)
display(df2)

Unnamed: 0,accent,ground,function,similarity,duration,samples,WER-r1,WER-r2,WER-r3,WER-mean,WER-std-dev,accents_run1,accents_run2,accents_run3,WER-var
0,arabic,all,FL2MI-n:1.0,euclidean,348.93/353.97,85.00/86.00,0,0,0,0,31.607,"{'ABA': 41, 'ZHAA': 31, 'SKA': 9, 'YBAA': 4, '...","{'ABA': 41, 'ZHAA': 31, 'SKA': 9, 'YBAA': 4, '...","{'ABA': 41, 'ZHAA': 31, 'SKA': 9, 'YBAA': 4, '...",
1,arabic,all,GCMI-n:1.0,euclidean,304.95/357.89,82.00/94.00,0,0,0,0,31.607,"{'SKA': 73, 'YDCK': 11, 'ABA': 5, 'YBAA': 4, '...","{'SKA': 73, 'YDCK': 11, 'ABA': 5, 'YBAA': 4, '...","{'SKA': 73, 'YDCK': 11, 'ABA': 5, 'YBAA': 4, '...",
2,arabic,all,LogDMI-n:1.0,euclidean,340.44/356.38,94.00/99.00,0,0,0,0,31.607,"{'SKA': 33, 'ABA': 21, 'ZHAA': 21, 'YBAA': 19,...","{'SKA': 33, 'ABA': 21, 'ZHAA': 21, 'YBAA': 19,...","{'SKA': 33, 'ABA': 21, 'ZHAA': 21, 'YBAA': 19,...",
3,arabic,all,random,,71.24/358.35,19.39/97.67,0,0,0,0,31.607,"{'TLV': 11, 'NJS': 8, 'HQTV': 7, 'ASI': 6, 'NC...","{'YBAA': 7, 'ABA': 7, 'YDCK': 6, 'RRBI': 6, 'Y...","{'RRBI': 8, 'TXHC': 7, 'NJS': 7, 'ZHAA': 7, 'Y...",
4,arabic,within,equal_random,,354.85/354.85,99.33/99.33,0,0,0,0,,"{'ZHAA': 742, 'YBAA': 741, 'ABA': 740, 'SKA': ...","{'ZHAA': 742, 'YBAA': 741, 'ABA': 740, 'SKA': ...","{'ZHAA': 742, 'YBAA': 741, 'ABA': 740, 'SKA': ...",31.607
5,chinese,all,FL2MI-n:1.0,euclidean,359.75/359.75,85.00/85.00,0,0,0,0,31.607,"{'LXC': 66, 'BWC': 13, 'TXHC': 6}","{'LXC': 66, 'BWC': 13, 'TXHC': 6}","{'LXC': 66, 'BWC': 13, 'TXHC': 6}",
6,chinese,all,GCMI-n:1.0,euclidean,322.62/354.03,70.00/78.00,0,0,0,0,31.607,"{'LXC': 67, 'NJS': 8, 'TXHC': 3}","{'LXC': 67, 'NJS': 8, 'TXHC': 3}","{'LXC': 67, 'NJS': 8, 'TXHC': 3}",
7,chinese,all,LogDMI-n:1.0,euclidean,350.81/357.64,90.00/92.00,0,0,0,0,31.607,"{'TXHC': 39, 'LXC': 29, 'BWC': 21, 'NCC': 1, '...","{'TXHC': 39, 'LXC': 29, 'BWC': 21, 'NCC': 1, '...","{'TXHC': 39, 'LXC': 29, 'BWC': 21, 'NCC': 1, '...",
8,chinese,all,random,,53.65/358.35,14.27/97.67,0,0,0,0,31.607,"{'TLV': 11, 'NJS': 8, 'HQTV': 7, 'ASI': 6, 'NC...","{'YBAA': 7, 'ABA': 7, 'YDCK': 6, 'RRBI': 6, 'Y...","{'RRBI': 8, 'TXHC': 7, 'NJS': 7, 'ZHAA': 7, 'Y...",
9,chinese,within,equal_random,,351.27/351.27,90.33/90.33,0,0,0,0,,"{'TXHC': 742, 'BWC': 741, 'LXC': 741, 'NCC': 741}","{'TXHC': 742, 'BWC': 741, 'LXC': 741, 'NCC': 741}","{'TXHC': 742, 'BWC': 741, 'LXC': 741, 'NCC': 741}",31.607


In [11]:
df2.to_csv(csv_name, index=False)