In [1]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd

def load_result_csv_files(directory):
    result_dfs = []
    file_paths = []
    for root, _, files in os.walk(directory):
        for filename in files:
            if filename.endswith("results.csv"):
                file_path = os.path.join(root, filename)
                relative_path = os.path.relpath(file_path, directory)
                df = pd.read_csv(file_path)
                result_dfs.append(df)
                file_paths.append(relative_path)
    return result_dfs, file_paths

directory_path = "/workdir/optimal-summaries-public/_models_ablation"
result_dfs, file_paths = load_result_csv_files(directory_path)
len(result_dfs)

15

In [2]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

In [3]:
combined_df = pd.concat(result_dfs, ignore_index=True)
combined_df = combined_df[combined_df["Split"] == "test"].drop(columns=["Split"])
combined_df["Dataset"] = combined_df["Dataset"].replace({"mimic": "MIMIC", "spoken_arabic_digits": "SpokenArabicDigits", "tiselac": "Tiselac"})
combined_df["Model"] = combined_df["Model"].replace({"atomics_sum2atomics_False": "Atomics A", "atomics_sum2atomics_True": "Atomics B", "original": "Original", "shared_encode_time_dim_False": "Shared A", "shared_encode_time_dim_True": "Shared B"})

### Ablation


In [4]:
# Model,Indicators,Summaries,Dataset,Seed,Split,AUC,ACC,F1,Cutoff,Lower threshold,Upper threshold
per_set = combined_df.drop(columns=['Seed', "Cutoff", "Lower threshold", "Upper threshold"]).groupby(["Dataset", "Model", "Indicators", "Summaries"]).agg(["mean", "std"]).round(3)

with open('ablation_per_dataset.tex', 'w') as f:
    tex = per_set.to_latex(escape=True, float_format="{:.2%}".format)
    tex = tex.replace('%', r'\%')
    f.write(tex)

per_set

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,AUC,AUC,ACC,ACC,F1,F1
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,mean,std,mean,std,mean,std
Dataset,Model,Indicators,Summaries,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
MIMIC,Atomics A,False,False,0.91,0.005,0.832,0.008,0.84,0.006
MIMIC,Atomics A,False,True,0.868,0.016,0.852,0.01,0.855,0.006
MIMIC,Atomics A,True,False,0.91,0.009,0.843,0.002,0.849,0.004
MIMIC,Atomics A,True,True,0.877,0.004,0.856,0.011,0.858,0.007
MIMIC,Atomics B,False,False,0.908,0.002,0.828,0.004,0.837,0.004
MIMIC,Atomics B,False,True,0.909,0.006,0.831,0.01,0.84,0.008
MIMIC,Atomics B,True,False,0.907,0.014,0.848,0.014,0.853,0.009
MIMIC,Atomics B,True,True,0.91,0.006,0.834,0.013,0.843,0.011
MIMIC,Original,False,False,0.908,0.002,0.832,0.001,0.84,0.001
MIMIC,Original,False,True,0.875,0.007,0.853,0.011,0.857,0.008


In [5]:
grouped = combined_df.drop(columns=["Seed", "Dataset", "Cutoff", "Lower threshold", "Upper threshold"]).groupby(["Model", "Indicators", "Summaries"]).agg(["mean", "std"]).round(3)#.reset_index()
# grouped = grouped.style.highlight_max(color = 'green', axis = 0).highlight_min(color = 'red', axis = 0)

with open('ablation_grouped.tex', 'w') as f:
    tex = grouped.to_latex(escape=True, float_format="{:.2%}".format)
    tex = tex.replace('%', r'\%')
    f.write(tex)

grouped

# best auc 0.966000 acc 0.871000 f1 0.857000

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,AUC,AUC,ACC,ACC,F1,F1
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,std,mean,std,mean,std
Model,Indicators,Summaries,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
Atomics A,False,False,0.954,0.038,0.831,0.096,0.82,0.115
Atomics A,False,True,0.725,0.213,0.489,0.367,0.444,0.404
Atomics A,True,False,0.954,0.038,0.833,0.102,0.82,0.124
Atomics A,True,True,0.729,0.217,0.494,0.37,0.45,0.405
Atomics B,False,False,0.954,0.039,0.83,0.096,0.819,0.115
Atomics B,False,True,0.955,0.037,0.828,0.083,0.815,0.106
Atomics B,True,False,0.953,0.04,0.835,0.103,0.821,0.124
Atomics B,True,True,0.951,0.034,0.808,0.072,0.792,0.094
Original,False,False,0.951,0.036,0.813,0.08,0.803,0.1
Original,False,True,0.691,0.201,0.436,0.373,0.383,0.415
