In [25]:
import json
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [26]:

def get_results(model_name, context=""):
    if context:
        iter_files =glob.glob("results/{}/{}/results_checkpoint_{}_{}_iter_*.json".format(model_name, context, model_name, context))
    else:
        iter_files = glob.glob("results/{}/results_checkpoint_{}__iter_*.json".format(model_name, model_name))
    assert len(iter_files) == 10
    iters_results_df = []
    for iter_file in iter_files:
        with open(iter_file) as iter_f:
            iter_file_json = json.load(iter_f)
            context_keys = sorted(list(dict(iter_file_json).keys()))
            results_dict = dict()
            for key in context_keys:
                results_dict[key] = iter_file_json[key]["AUC"] 
            results_df = pd.DataFrame.from_dict(results_dict, orient='index', columns=[model_name])
            iters_results_df.append(results_df)
    return pd.concat(iters_results_df).groupby(level=0).mean()

In [27]:
model_names = [
               'AE_all', 
               'AE_FiLM_one_hot',
               'AE_FiLM_embed_32',
               'AE_FiLM_embed_64',
               'AE_FiLM_embed_128',
               'AE_FiLM_embed_256'
            ]
results_list = []
for model_name in model_names:
    results_list.append(get_results(model_name))
single_model_list = []
for i in [0, 1, 2]:
    single_model_list.append(get_results("AE", context=str(i)))

single_model_results = pd.concat(single_model_list)

results_df = pd.concat([single_model_results]+results_list, axis=1)
results_df


Unnamed: 0,AE,AE_all,AE_FiLM_one_hot,AE_FiLM_embed_32,AE_FiLM_embed_64,AE_FiLM_embed_128,AE_FiLM_embed_256
0,0.944576,0.608726,0.921225,0.859631,0.864815,0.869195,0.926043
1,0.944406,0.491309,0.916138,0.788101,0.802105,0.824765,0.829804
2,0.934978,0.518067,0.893411,0.804082,0.787087,0.815886,0.87414


In [28]:
rank_results_df = results_df[
          ["AE_all", 
           "AE_FiLM_one_hot", 
           "AE_FiLM_embed_32",
           "AE_FiLM_embed_64",
           "AE_FiLM_embed_128",
           "AE_FiLM_embed_256"
          ]
        ].rank(1, ascending=False, method='first')
mean_rank_row = rank_results_df.mean().to_frame().T
mean_rank_row = mean_rank_row.rename(index={0: 'average rank'})
results_and_ranks_df = pd.concat([results_df, mean_rank_row])
results_and_ranks_df

Unnamed: 0,AE,AE_all,AE_FiLM_one_hot,AE_FiLM_embed_32,AE_FiLM_embed_64,AE_FiLM_embed_128,AE_FiLM_embed_256
0,0.944576,0.608726,0.921225,0.859631,0.864815,0.869195,0.926043
1,0.944406,0.491309,0.916138,0.788101,0.802105,0.824765,0.829804
2,0.934978,0.518067,0.893411,0.804082,0.787087,0.815886,0.87414
average rank,,6.0,1.333333,4.666667,4.333333,3.0,1.666667


In [33]:
table_cols = [
           "AE",
           "AE_all", 
           "AE_FiLM_one_hot", 
           "AE_FiLM_embed_32",
           "AE_FiLM_embed_64",
           "AE_FiLM_embed_128",
           "AE_FiLM_embed_256"
          ]

results_table = results_and_ranks_df[table_cols]
formatted_col_map= {
           "AE": "AE separate (reference)",
           "AE_all": "AE no cond.", 
           "AE_FiLM_one_hot"  :"AE FiLM one hot"  , 
           "AE_FiLM_embed_32" :"AE FiLM 32 embed" ,
           "AE_FiLM_embed_64" :"AE FiLM 64 embed" ,
           "AE_FiLM_embed_128":"AE FiLM 128 embed",
           "AE_FiLM_embed_256":"AE FiLM 256 embed"
}
formatted_results_table = results_table.rename(columns=formatted_col_map, errors="raise")
formatted_results_table

Unnamed: 0,AE separate (reference),AE no cond.,AE FiLM one hot,AE FiLM 32 embed,AE FiLM 64 embed,AE FiLM 128 embed,AE FiLM 256 embed
0,0.944576,0.608726,0.921225,0.859631,0.864815,0.869195,0.926043
1,0.944406,0.491309,0.916138,0.788101,0.802105,0.824765,0.829804
2,0.934978,0.518067,0.893411,0.804082,0.787087,0.815886,0.87414
average rank,,6.0,1.333333,4.666667,4.333333,3.0,1.666667


In [35]:
print(formatted_results_table.to_latex(index=False, columns=formatted_results_table.columns, float_format='%.3f'))

\begin{tabular}{rrrrrrr}
\toprule
 AE separate (reference) &  AE no cond. &  AE FiLM one hot &  AE FiLM 32 embed &  AE FiLM 64 embed &  AE FiLM 128 embed &  AE FiLM 256 embed \\
\midrule
                   0.945 &        0.609 &            0.921 &             0.860 &             0.865 &              0.869 &              0.926 \\
                   0.944 &        0.491 &            0.916 &             0.788 &             0.802 &              0.825 &              0.830 \\
                   0.935 &        0.518 &            0.893 &             0.804 &             0.787 &              0.816 &              0.874 \\
                     NaN &        6.000 &            1.333 &             4.667 &             4.333 &              3.000 &              1.667 \\
\bottomrule
\end{tabular}



In [32]:
print(results_df.to_latex())

\begin{tabular}{lrrrrrrr}
\toprule
{} &        AE &    AE\_all &  AE\_FiLM\_one\_hot &  AE\_FiLM\_embed\_32 &  AE\_FiLM\_embed\_64 &  AE\_FiLM\_embed\_128 &  AE\_FiLM\_embed\_256 \\
\midrule
0 &  0.944576 &  0.608726 &         0.921225 &          0.859631 &          0.864815 &           0.869195 &           0.926043 \\
1 &  0.944406 &  0.491309 &         0.916138 &          0.788101 &          0.802105 &           0.824765 &           0.829804 \\
2 &  0.934978 &  0.518067 &         0.893411 &          0.804082 &          0.787087 &           0.815886 &           0.874140 \\
\bottomrule
\end{tabular}

