In [1]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import collections

In [2]:
def collect_losses(root_dir):

    depth = 2

    seeds = [123, 456, 789]

    dict_res = dict() # {"dataset_name" : [means stats, var stats]}

    # collect ts losses

    for seed_ in seeds:
        seed_rootdir = f"{root_dir}/{seed_}"

        for subdir, dirs, files in os.walk(seed_rootdir):
            if subdir[len(seed_rootdir):].count(os.sep) < depth:
                for file_ in files:
                    if file_ == "finetuning_results.json":
                        with open(subdir + "/" + file_) as f:
                            d = json.load(f)
                            losses_ts = list(d["losses_ts"].values())

                        dataset_name = subdir.split(os.sep)[-1]
                        if dataset_name in dict_res.keys():
                            dict_res[dataset_name].append(losses_ts)
                        else:
                            dict_res[dataset_name] = [losses_ts]

    # compute mean and std for each losses

    for dataset, losses_list in dict_res.items():
        means = []
        stds = []

        losses_len = len(losses_list[0])
        for i in range(losses_len):
            losses_i = []
            for losses in losses_list:
                losses_i.append(losses[i])
            means.append(np.mean(losses_i))
            stds.append(np.std(losses_i))

        dict_res[dataset] = [means, stds]

    return dict_res


def avg_losses_over_datasets(datasets_losses):
    avg_losses_dts = None

    num_dts = 0

    for dataset, mean_std in datasets_losses.items():
        means = mean_std[0]

        if avg_losses_dts is None:
            avg_losses_dts = np.array(means)
        else:
            avg_losses_dts += np.array(means)

        num_dts += 1

    return avg_losses_dts / num_dts


In [3]:
def collect_mean_baseline(root_dir):
    dict_res_mae = dict()
    dict_res_mse = dict()

    for subdir, dirs, files in os.walk(root_dir):
        for file_ in files:
            with open(subdir + "/" + file_) as f:
                d = json.load(f)
                dict_res_mae[file_.split(".")[0]] = dict()
                dict_res_mse[file_.split(".")[0]] = dict()
                
                for k, v in d.items():
                    if "mae" in k:
                        dict_res_mae[file_.split(".")[0]][k] = v
                    else:
                        dict_res_mse[file_.split(".")[0]][k] = v
                

    return dict_res_mae, dict_res_mse

In [4]:
res_mae, res_mse = collect_mean_baseline("../mean_baseline")

In [5]:
pd.DataFrame(res_mae).T

Unnamed: 0,mae_loss0,mae_loss1,mae_loss2,mae_loss3,mae_loss4,mae_loss5,mae_loss6,mae_loss7,mae_loss_all
mean_baseline_results_ge_2,23.631326,14.618062,7.006784,8.147557,7.051513,7.868829,18.58737,17.797943,13.088672
mean_baseline_results_en_1,38.517199,14.621716,10.738502,11.030411,13.813458,11.51782,18.290697,17.886902,17.052091
mean_baseline_results_sp_1,38.517199,14.621716,10.738502,11.030411,13.813458,11.51782,18.290697,17.886902,17.052091
mean_baseline_results_sp_0,24.549716,14.705048,11.912802,11.594023,12.410025,13.139339,16.958553,22.415786,15.960663
mean_baseline_results_en_0,24.549716,14.705048,11.912802,11.594023,12.410025,13.139339,16.958553,22.415786,15.960663
mean_baseline_results_en_all,21.930317,17.84561,8.207685,6.764501,9.042601,7.670029,12.825503,13.182144,12.18355
mean_baseline_results_it_2,26.582837,16.653791,6.200814,4.721693,6.028809,6.263821,22.267081,17.07749,13.224541
mean_baseline_results_it_1,48.835523,14.575395,9.585082,7.415826,13.107302,6.873666,10.576857,26.213612,17.147907
mean_baseline_results_sp_all,24.260191,16.306235,11.068098,11.078702,11.807236,13.443232,16.17763,18.831468,15.371598
mean_baseline_results_en_2,29.188814,12.345532,9.174236,8.970618,8.856351,9.950365,28.189191,28.300166,16.871909


In [6]:
pd.DataFrame(res_mse).T

Unnamed: 0,mse_loss0,mse_loss1,mse_loss2,mse_loss3,mse_loss4,mse_loss5,mse_loss6,mse_loss7,mse_loss_all
mean_baseline_results_ge_2,806.65155,343.989716,83.463081,132.983429,84.457062,121.285454,651.978821,530.145752,344.369354
mean_baseline_results_en_1,1810.848999,274.370972,149.066116,168.991058,254.935638,201.284912,563.093567,598.750977,502.667816
mean_baseline_results_sp_1,1810.848999,274.370972,149.066116,168.991058,254.935638,201.284912,563.093567,598.750977,502.667816
mean_baseline_results_sp_0,811.969116,304.091217,217.723511,218.394775,229.068481,264.102142,449.343048,713.752563,401.055634
mean_baseline_results_en_0,811.969116,304.091217,217.723511,218.394775,229.068481,264.102142,449.343048,713.752563,401.055634
mean_baseline_results_en_all,640.336487,448.258118,100.318092,71.553055,116.193436,89.375259,273.781036,289.469513,253.660629
mean_baseline_results_it_2,984.338318,389.863342,61.180115,36.295246,59.142437,63.611099,746.830872,433.831909,346.886658
mean_baseline_results_it_1,2473.706543,266.036621,119.365463,87.648254,209.708557,83.170052,452.522308,1166.272705,607.303833
mean_baseline_results_sp_all,757.955078,360.768005,173.150589,188.085358,193.072662,268.179321,382.582245,512.253784,354.50592
mean_baseline_results_en_2,1056.15271,238.791077,142.295715,128.062286,154.608597,166.107605,1062.52063,1383.977783,541.564575


## Finetuning Results

### Not Pretraining

In [7]:
mapping_columns_names = {
    0 : "MSE",
    1 : "Avg-Acc",
    2 : "prob-skip-Acc",
    3 : "firstfix-dur-Acc",
    4 : "firstrun-dur-Acc",
    5 : "dur-Acc",
    6 : "firstrun-nfix-Acc",
    7 : "nfix-Acc",
    8 : "prob-refix-Acc",
    9 : "prob-reread-Acc",
}

def process_to_present(res, column_width="1.1cm", plot_transpose=True):
    df_res = pd.DataFrame.from_dict(res).T
    df_res.iloc[:, 1:10]  = 100 - df_res.iloc[:, 1:10]
    df_res = df_res.rename(columns=mapping_columns_names)
    if not plot_transpose:
        df_res = df_res.T
    s = df_res.style
    s.format(na_rep='MISS', precision=1)
    print(s.to_latex(column_format='l'+('p{'+f'{column_width}'+'}')*len(df_res.columns)))

In [8]:
res_ = collect_losses("../finetuning/notpretraining")

res = dict()

for k, v in res_.items():
    res[" ".join(k.split("_")[-3:-1])] = v[0]
    
res = collections.OrderedDict(sorted(res.items()))

In [9]:
process_to_present(res)

\begin{tabular}{lp{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}}
 & MSE & Avg-Acc & prob-skip-Acc & firstfix-dur-Acc & firstrun-dur-Acc & dur-Acc & firstrun-nfix-Acc & nfix-Acc & prob-refix-Acc & prob-reread-Acc \\
en all & 103.0 & 84.5 & 79.0 & 65.6 & 88.2 & 91.7 & 86.1 & 90.2 & 89.5 & 86.0 \\
ge 0 & 164.1 & 85.2 & 83.3 & 81.0 & 92.6 & 96.3 & 92.2 & 95.4 & 73.5 & 67.1 \\
ge 1 & 94.6 & 87.5 & 85.4 & 78.0 & 93.1 & 93.6 & 91.9 & 92.7 & 80.4 & 84.9 \\
ge 2 & 98.8 & 86.8 & 80.4 & 81.8 & 91.4 & 91.5 & 90.1 & 91.2 & 84.5 & 83.5 \\
ge all & 106.0 & 86.4 & 85.2 & 72.4 & 92.1 & 92.3 & 90.7 & 91.1 & 83.2 & 84.4 \\
it 0 & 228.7 & 81.9 & 78.8 & 78.3 & 90.7 & 91.8 & 87.8 & 94.1 & 71.5 & 62.5 \\
it 1 & 236.3 & 84.6 & 59.5 & 85.7 & 90.4 & 92.8 & 86.0 & 92.3 & 90.8 & 79.4 \\
it 2 & 142.3 & 86.0 & 79.0 & 72.5 & 92.9 & 95.0 & 92.0 & 92.4 & 81.3 & 82.9 \\
it all & 147.9 & 85.0 & 79.9 & 69.5 & 90.6 & 93.7 & 89.8 & 92.3 & 82.6 & 81.5 \\
sp 0 & 175.7 & 81.4 & 76.6 & 76.1 & 

### Pretraining

In [10]:
res_ = collect_losses("../finetuning/pretraining")

res = dict()

for k,v in res_.items():
    res[" ".join(k.split("_")[-3:-1])] = v[0]
    
res = collections.OrderedDict(sorted(res.items()))

In [11]:
process_to_present(res)

\begin{tabular}{lp{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}}
 & MSE & Avg-Acc & prob-skip-Acc & firstfix-dur-Acc & firstrun-dur-Acc & dur-Acc & firstrun-nfix-Acc & nfix-Acc & prob-refix-Acc & prob-reread-Acc \\
en all & 120.0 & 83.0 & 77.2 & 62.1 & 86.7 & 90.7 & 84.3 & 89.0 & 89.2 & 84.9 \\
ge 0 & 176.1 & 84.8 & 84.3 & 78.8 & 91.9 & 96.4 & 91.2 & 95.3 & 74.2 & 66.5 \\
ge 1 & 107.4 & 86.5 & 85.2 & 74.6 & 92.2 & 92.8 & 90.5 & 91.8 & 80.1 & 84.4 \\
ge 2 & 110.8 & 85.8 & 80.1 & 79.2 & 90.3 & 90.6 & 88.7 & 90.2 & 84.5 & 83.0 \\
ge all & 122.6 & 85.1 & 84.7 & 68.7 & 90.9 & 91.2 & 89.0 & 89.9 & 82.9 & 83.5 \\
it 0 & 249.8 & 81.3 & 79.6 & 76.2 & 89.9 & 91.3 & 86.2 & 93.8 & 71.8 & 61.6 \\
it 1 & 245.2 & 84.8 & 60.3 & 85.7 & 90.5 & 93.0 & 86.0 & 92.5 & 91.0 & 79.8 \\
it 2 & 158.6 & 85.2 & 78.4 & 69.7 & 92.2 & 94.6 & 91.1 & 91.8 & 81.1 & 82.4 \\
it all & 168.3 & 83.8 & 79.0 & 66.3 & 89.6 & 93.1 & 88.5 & 91.6 & 82.1 & 80.4 \\
sp 0 & 197.7 & 80.1 & 75.6 & 73.8 

### Not Pretraining Not Full

In [12]:
res_ = collect_losses("../finetuning/notpretraining_notfull")

res = dict()

for k,v in res_.items():
    res[" ".join(k.split("_")[-4:-2])] = v[0]
    
res = collections.OrderedDict(sorted(res.items()))

In [13]:
process_to_present(res)

\begin{tabular}{lp{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}}
 & MSE & Avg-Acc & prob-skip-Acc & firstfix-dur-Acc & firstrun-dur-Acc & dur-Acc & firstrun-nfix-Acc & nfix-Acc & prob-refix-Acc & prob-reread-Acc \\
en all & 164.6 & 79.5 & 72.1 & 54.7 & 83.5 & 88.6 & 80.8 & 86.6 & 87.8 & 82.2 \\
ge 0 & 213.6 & 83.1 & 83.7 & 73.7 & 89.9 & 95.8 & 89.1 & 94.6 & 72.8 & 65.2 \\
ge 1 & 149.4 & 83.4 & 82.5 & 67.0 & 89.4 & 90.6 & 87.4 & 89.2 & 77.7 & 83.1 \\
ge 2 & 151.6 & 82.9 & 76.8 & 73.4 & 87.3 & 88.1 & 85.4 & 87.6 & 83.3 & 81.6 \\
ge all & 171.3 & 81.6 & 81.7 & 60.7 & 87.9 & 88.7 & 85.8 & 87.0 & 80.6 & 80.5 \\
it 0 & 307.0 & 79.3 & 79.0 & 71.7 & 87.5 & 89.8 & 83.3 & 92.7 & 71.4 & 59.1 \\
it 1 & 282.1 & 83.8 & 58.6 & 83.7 & 89.0 & 91.9 & 84.0 & 91.4 & 91.2 & 80.5 \\
it 2 & 212.4 & 82.5 & 75.2 & 63.0 & 90.0 & 93.1 & 88.9 & 89.8 & 79.0 & 80.7 \\
it all & 227.2 & 80.7 & 75.5 & 59.4 & 86.9 & 91.4 & 85.8 & 89.6 & 79.8 & 77.3 \\
sp 0 & 256.7 & 77.0 & 71.5 & 69.2 

### Pretraining Not Full

In [14]:
res_ = collect_losses("../finetuning/pretraining_notfull")

res = dict()

for k,v in res_.items():
    res[" ".join(k.split("_")[-4:-2])] = v[0]
    
res = collections.OrderedDict(sorted(res.items()))

In [15]:
process_to_present(res)

\begin{tabular}{lp{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}p{1.1cm}}
 & MSE & Avg-Acc & prob-skip-Acc & firstfix-dur-Acc & firstrun-dur-Acc & dur-Acc & firstrun-nfix-Acc & nfix-Acc & prob-refix-Acc & prob-reread-Acc \\
en all & 194.1 & 76.8 & 69.5 & 51.3 & 80.7 & 85.9 & 77.6 & 83.5 & 86.2 & 79.9 \\
ge 0 & 244.4 & 81.7 & 85.4 & 70.2 & 87.1 & 94.0 & 86.0 & 92.0 & 73.7 & 64.9 \\
ge 1 & 181.8 & 81.0 & 82.1 & 63.2 & 86.3 & 87.6 & 83.9 & 85.7 & 77.0 & 82.3 \\
ge 2 & 184.1 & 80.8 & 76.3 & 69.8 & 84.4 & 85.2 & 82.0 & 84.2 & 83.7 & 80.9 \\
ge all & 208.1 & 78.8 & 80.5 & 56.8 & 84.7 & 85.6 & 82.2 & 83.4 & 79.2 & 78.2 \\
it 0 & 344.0 & 77.9 & 79.9 & 68.9 & 85.3 & 87.6 & 80.7 & 90.3 & 71.8 & 58.4 \\
it 1 & 306.8 & 84.0 & 59.1 & 83.1 & 88.5 & 91.4 & 83.3 & 90.8 & 93.4 & 82.1 \\
it 2 & 244.7 & 80.8 & 75.1 & 60.0 & 87.8 & 91.1 & 86.2 & 87.2 & 78.7 & 80.0 \\
it all & 263.8 & 78.6 & 74.5 & 56.2 & 84.4 & 89.1 & 82.9 & 86.8 & 78.9 & 75.7 \\
sp 0 & 296.3 & 74.9 & 70.6 & 66.0 

### Not Pretraining Dur

In [25]:
res_ = collect_losses("../finetuning/notpretraining_dur")

res = dict()

for k,v in res_.items():
    res[" ".join(k.split("_")[-4:-2])] = v[0][:-1]
    
res = collections.OrderedDict(sorted(res.items()))

In [26]:
process_to_present(res, "1cm", False)

\begin{tabular}{lp{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}}
 & en all & ge 0 & ge 1 & ge 2 & ge all & it 0 & it 1 & it 2 & it all & sp 0 & sp 1 & sp 2 & sp all \\
MSE & 15.4 & 12.1 & 16.5 & 26.4 & 19.3 & 33.2 & 34.1 & 12.6 & 14.9 & 51.2 & 43.7 & 29.7 & 41.5 \\
Avg-Acc & 94.7 & 96.2 & 95.0 & 93.3 & 94.5 & 92.9 & 92.8 & 96.0 & 95.5 & 91.0 & 91.4 & 92.7 & 91.7 \\
\end{tabular}



### Pretraining Dur

In [18]:
res_ = collect_losses("../finetuning/pretraining_dur")

res = dict()

for k,v in res_.items():
    res[" ".join(k.split("_")[-4:-2])] = v[0][:-1]
    
res = collections.OrderedDict(sorted(res.items()))

In [19]:
process_to_present(res, "1cm", False)

\begin{tabular}{lp{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}}
 & en all & ge 0 & ge 1 & ge 2 & ge all & it 0 & it 1 & it 2 & it all & sp 0 & sp 1 & sp 2 & sp all \\
MSE & 15.8 & 9.7 & 14.6 & 26.0 & 18.9 & 31.2 & 27.9 & 9.7 & 12.2 & 62.3 & 44.7 & 30.9 & 51.7 \\
Avg-Acc & 95.0 & 96.7 & 95.7 & 93.7 & 95.0 & 93.5 & 93.8 & 96.9 & 96.2 & 90.2 & 91.7 & 92.9 & 90.9 \\
\end{tabular}



### Not Pretraining Prob Skip

In [20]:
res_ = collect_losses("../finetuning/notpretraining_prob_skip")

res = dict()

for k,v in res_.items():
    res[" ".join(k.split("_")[-5:-3])] = v[0][:-1]
    
res = collections.OrderedDict(sorted(res.items()))

In [21]:
process_to_present(res, "1cm", False)

\begin{tabular}{lp{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}}
 & en all & ge 0 & ge 1 & ge 2 & ge all & it 0 & it 1 & it 2 & it all & sp 0 & sp 1 & sp 2 & sp all \\
MSE & 147.2 & 200.6 & 107.0 & 199.4 & 102.9 & 270.9 & 861.4 & 302.4 & 248.6 & 267.8 & 579.5 & 316.0 & 249.7 \\
Avg-Acc & 79.8 & 84.7 & 86.5 & 77.6 & 86.2 & 80.6 & 55.4 & 74.8 & 77.2 & 74.5 & 61.2 & 70.9 & 75.3 \\
\end{tabular}



### Pretraining Prob Skip

In [22]:
res_ = collect_losses("../finetuning/pretraining_prob_skip")

res = dict()

for k,v in res_.items():
    res[" ".join(k.split("_")[-5:-3])] = v[0][:-1]
    
res = collections.OrderedDict(sorted(res.items()))

In [23]:
process_to_present(res, "1cm", False)

\begin{tabular}{lp{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}p{1cm}}
 & en all & ge 0 & ge 1 & ge 2 & ge all & it 0 & it 1 & it 2 & it all & sp 0 & sp 1 & sp 2 & sp all \\
MSE & 175.2 & 199.0 & 110.4 & 200.0 & 107.6 & 285.9 & 976.6 & 306.2 & 259.9 & 288.9 & 641.5 & 331.0 & 273.0 \\
Avg-Acc & 79.4 & 85.4 & 87.3 & 82.3 & 87.0 & 81.6 & 60.8 & 79.9 & 80.4 & 77.5 & 65.0 & 77.4 & 77.5 \\
\end{tabular}

