In [1]:
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import glob
import pandas as pd
import os
import matplotlib as mpl
from numpy import concatenate as cat
from collections import defaultdict
from sklearn import metrics
from common import PECARNModel, TransferTree
mpl.rcParams['figure.dpi'] = 250
# pd.set_option('precision', 2)

METRICS = ['spec_0.92', 'spec_0.94', 'spec_0.96', 'spec_0.98', 'auc', 'aps', 'f1', 'acc',]

In [2]:
def make_res_table(name, errs=False):
    result_df = pd.read_csv(f'results/{name}_average.csv').set_index('Unnamed: 0')

    for metric in METRICS:
        if metric not in result_df.columns:
            continue

        if metric in {'auc', 'aps'}:
            result_df[metric] = result_df[metric].round(3).astype(str).map(lambda x: x[1:])
        else:
            result_df[metric] = (result_df[metric]*100).round(1)
        
        if errs:
            if metric in {'auc', 'f1'}:
                std_err_string = result_df[f'{metric}_std_err'].round(2).astype(str).map(lambda x: f' ({x[1:]})')
            else:
                std_err_string = (result_df[f'{metric}_std_err'] * 100).round(1).map(lambda x: f' ({x})')
            
            result_df[f'{metric}_final'] = result_df[metric].astype(str) + std_err_string
        else:
            result_df[f'{metric}_final'] = result_df[metric].astype(str)

    result_df = result_df.rename({
        'spec_0.92_final': '0.92',
        'spec_0.94_final': '0.94',
        'spec_0.96_final': '0.96',
        'spec_0.98_final': '0.98',
        'auc_final': 'AUC of ROC',
        'aps_final': 'Avg precision score',
        'acc_final': 'Accuracy',
        'f1_final': 'F1',
    }, axis=1)
    result_df = result_df.rename(lambda x: x.replace('pfigs_combine', 'pfigs').replace('pcart_combine', 'pcart'), axis=0)
    result_df.index = result_df.index.rename('Model')

    return result_df

In [3]:
# results = {}
# for dataset in ['csi', 'tbi', 'iai']:
#     for group in ['all']:
#         results[f'{dataset}_{group}'] = pd.read_csv(
#             f'results/{dataset}/{group}_average.csv').set_index('Unnamed: 0')
# results['sim'] = pd.read_csv(f'results/sim/all_average.csv').set_index('Unnamed: 0')
# for key in results:


# csi

In [4]:
order = [11, 12, 0, 1, 6, 3, 4, 9]

#### all

In [5]:
make_res_table('csi/all', True).iloc[order, -8:-2]#.to_latex()

Unnamed: 0_level_0,0.92,0.94,0.96,0.98,AUC of ROC,Avg precision score
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
tao_all,41.5 (0.9),21.2 (6.6),0.2 (0.2),0.2 (0.2),.422 (.04),.351 (1.4)
tao_combine,32.5 (4.9),7.0 (1.6),5.4 (0.7),2.5 (1.0),.702 (.01),.359 (1.6)
cart_all,38.6 (3.6),13.7 (5.7),1.5 (0.6),1.1 (0.4),.617 (.06),.36 (1.5)
cart_combine,32.1 (5.1),7.8 (1.5),5.4 (0.7),2.5 (1.0),.707 (.0),.358 (1.5)
pcart,38.5 (3.4),15.2 (4.8),4.9 (1.0),3.9 (1.1),.751 (.01),.369 (1.5)
figs_all,39.1 (3.0),33.8 (2.4),24.1 (3.2),16.7 (3.9),.664 (.03),.372 (1.5)
figs_combine,38.7 (1.6),33.1 (2.0),20.1 (2.6),3.9 (2.2),.643 (.02),.351 (1.8)
pfigs,42.2 (1.3),36.2 (2.3),28.4 (3.8),15.7 (3.9),.7 (.01),.375 (1.3)


#### young

In [147]:
order_group = [0, 1, 7, 8, 2, 3, 6]

In [148]:
make_res_table('csi/young', True).iloc[order_group, -8:-2]#.to_latex()

Unnamed: 0_level_0,0.92,0.94,0.96,0.98,AUC of ROC,F1
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
cart_all,45.8 (12.0),45.8 (12.0),45.8 (12.0),45.8 (12.0),.231 (.04),53.4 (.07)
cart_young,0.0 (0.0),0.0 (0.0),0.0 (0.0),0.0 (0.0),.116 (.02),31.7 (.04)
tao_all,45.8 (12.0),45.8 (12.0),45.8 (12.0),45.8 (12.0),.201 (.02),53.8 (.06)
tao_young,0.0 (0.0),0.0 (0.0),0.0 (0.0),0.0 (0.0),.116 (.02),31.7 (.04)
figs_all,56.1 (10.6),56.1 (10.6),56.1 (10.6),56.1 (10.6),.544 (.05),47.5 (.06)
figs_young,6.9 (6.5),6.9 (6.5),6.9 (6.5),6.9 (6.5),.236 (.06),32.5 (.03)
pfigs_young,65.9 (7.9),65.9 (7.9),65.9 (7.9),65.9 (7.9),.608 (.05),49.5 (.07)


#### old

In [149]:
make_res_table('csi/old', True).iloc[order_group, -8:-2]#.to_latex()

Unnamed: 0_level_0,0.92,0.94,0.96,0.98,AUC of ROC,F1
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
cart_all,37.1 (3.4),16.8 (5.6),1.6 (0.6),1.2 (0.5),.615 (.05),45.5 (.01)
cart_old,33.0 (4.5),17.3 (5.7),2.8 (1.5),1.4 (0.5),.721 (.03),44.8 (.01)
tao_all,39.5 (0.9),20.1 (6.3),0.2 (0.2),0.2 (0.2),.433 (.04),44.2 (.01)
tao_old,33.6 (4.2),17.9 (5.6),3.5 (1.6),1.4 (0.5),.716 (.03),45.0 (.01)
figs_all,37.8 (2.8),33.6 (2.4),25.5 (3.0),14.1 (4.3),.657 (.03),42.7 (.01)
figs_old,39.5 (2.1),33.8 (1.2),22.0 (3.2),11.6 (3.4),.653 (.02),41.8 (.01)
pfigs_old,40.7 (1.3),33.5 (2.5),23.8 (4.2),13.5 (4.0),.675 (.02),42.2 (.01)


# tbi

### all

In [6]:
make_res_table('tbi/all', True).iloc[order, -8:-2]#.to_latex()

Unnamed: 0_level_0,0.92,0.94,0.96,0.98,AUC of ROC,Avg precision score
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
tao_all,6.2 (5.9),6.2 (5.9),0.4 (0.4),0.4 (0.4),.294 (.05),.039 (0.2)
tao_combine,26.7 (6.4),13.9 (5.4),10.4 (5.5),2.4 (1.5),.748 (.02),.049 (0.4)
cart_all,20.9 (8.8),14.8 (7.6),7.8 (5.8),2.1 (0.6),.702 (.06),.057 (0.4)
cart_combine,26.6 (6.4),13.8 (5.4),10.3 (5.5),2.4 (1.5),.753 (.02),.049 (0.4)
pcart,15.5 (5.5),13.5 (5.7),6.4 (2.2),3.0 (1.5),.758 (.01),.044 (0.2)
figs_all,23.8 (9.0),18.2 (8.5),12.1 (7.3),0.4 (0.3),.38 (.07),.04 (0.3)
figs_combine,39.9 (7.9),19.7 (6.8),17.5 (7.0),2.6 (1.6),.619 (.05),.045 (0.3)
pfigs,41.9 (6.6),23.0 (7.8),14.7 (6.5),6.4 (2.8),.696 (.04),.041 (0.2)


### young

In [151]:
make_res_table('tbi/young', True).iloc[order_group, -8:-2]#.to_latex()

Unnamed: 0_level_0,0.92,0.94,0.96,0.98,AUC of ROC,F1
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
cart_all,19.0 (8.7),19.0 (8.7),7.1 (5.5),1.2 (0.6),.64 (.08),6.9 (.01)
cart_young,20.6 (9.2),14.3 (8.4),8.3 (6.9),8.3 (6.9),.496 (.08),7.0 (.01)
tao_all,7.7 (6.5),7.7 (6.5),0.0 (0.0),0.0 (0.0),.267 (.05),5.4 (.0)
tao_young,20.6 (9.2),14.3 (8.4),8.3 (6.9),8.3 (6.9),.494 (.09),7.0 (.01)
figs_all,36.3 (9.4),30.3 (9.6),5.9 (5.5),0.1 (0.1),.351 (.06),5.0 (.0)
figs_young,31.1 (8.5),25.1 (8.4),13.0 (7.8),6.7 (5.8),.48 (.08),6.4 (.01)
pfigs_young,17.2 (8.3),17.2 (8.3),13.7 (8.6),7.5 (7.0),.579 (.08),5.6 (.0)


### old

In [152]:
make_res_table('tbi/old', True).iloc[order_group, -8:-2]#.to_latex()

Unnamed: 0_level_0,0.92,0.94,0.96,0.98,AUC of ROC,F1
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
cart_all,22.0 (8.8),15.4 (7.7),14.8 (7.3),2.1 (0.7),.701 (.06),5.4 (.0)
cart_old,19.7 (8.6),13.4 (7.6),7.3 (5.4),0.8 (0.4),.627 (.08),5.4 (.0)
tao_all,12.2 (7.7),6.1 (5.8),6.1 (5.8),0.3 (0.3),.304 (.06),5.2 (.0)
tao_old,19.8 (8.6),13.4 (7.5),7.4 (5.4),0.5 (0.3),.622 (.08),5.5 (.0)
figs_all,24.5 (9.2),18.1 (8.4),18.1 (8.4),0.5 (0.3),.389 (.07),4.8 (.0)
figs_old,25.3 (8.9),20.3 (8.6),18.7 (8.4),6.3 (5.1),.557 (.06),4.7 (.0)
pfigs_old,44.1 (9.0),31.5 (9.8),19.5 (8.8),5.3 (3.4),.585 (.08),4.5 (.0)


# simulation

In [153]:
make_res_table('sim/all', True).iloc[[6, 7, 0, 1, 2, 3, 4], -4:]#.to_latex()

Unnamed: 0_level_0,AUC of ROC,F1,Avg precision score,Accuracy
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
tao_all,.376 (.07),58.0 (.04),.498 (4.4),59.0 (2.4)
tao_combine,.475 (.04),60.4 (.03),.573 (3.4),58.2 (2.0)
cart_all,.37 (.07),54.7 (.03),.495 (4.2),56.5 (1.9)
cart_combine,.475 (.04),60.4 (.03),.573 (3.4),58.2 (2.0)
figs_all,.47 (.04),55.5 (.03),.539 (3.7),58.5 (1.7)
figs_combine,.475 (.04),60.4 (.03),.573 (3.4),58.2 (2.0)
pcart,.55 (.03),63.9 (.04),.644 (3.4),65.8 (2.8)


# iai

In [7]:
make_res_table('iai/all', True).iloc[order, -8:-2]#.to_latex()

Unnamed: 0_level_0,0.92,0.94,0.96,0.98,AUC of ROC,Avg precision score
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
tao_all,0.2 (0.2),0.2 (0.2),0.0 (0.0),0.0 (0.0),.372 (.04),.079 (0.2)
tao_combine,12.1 (1.7),8.4 (2.0),2.0 (1.3),0.0 (0.0),.675 (.01),.078 (0.1)
cart_all,11.8 (5.0),2.7 (1.0),1.6 (0.5),1.4 (0.5),.688 (.06),.081 (0.2)
cart_combine,11.0 (1.6),9.3 (1.8),2.8 (1.4),0.0 (0.0),.688 (.01),.08 (0.3)
pcart,11.7 (1.3),10.1 (1.6),3.8 (1.3),0.7 (0.4),.732 (.02),.084 (0.4)
figs_all,32.1 (5.5),13.7 (6.0),1.4 (0.8),0.0 (0.0),.541 (.04),.06 (0.5)
figs_combine,18.8 (4.4),9.2 (2.2),2.5 (1.7),0.9 (0.8),.653 (.02),.05 (0.5)
pfigs,29.7 (6.9),18.8 (6.6),11.7 (5.1),3.0 (1.3),.671 (.03),.064 (0.5)


In [155]:
make_res_table('iai/young', True).iloc[order_group, -7:]#.to_latex()

Unnamed: 0_level_0,0.92,0.94,0.96,0.98,F1,Avg precision score,Accuracy
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
cart_all,29.6 (9.7),29.6 (9.7),29.6 (9.7),29.6 (9.7),5.9 (.01),.044 (0.7),82.1 (1.1)
cart_young,0.0 (0.0),0.0 (0.0),0.0 (0.0),0.0 (0.0),2.5 (.01),.007 (0.3),88.6 (0.7)
tao_all,33.3 (9.2),33.3 (9.2),33.3 (9.2),33.3 (9.2),6.7 (.01),.047 (0.7),84.6 (0.8)
tao_young,0.0 (0.0),0.0 (0.0),0.0 (0.0),0.0 (0.0),2.5 (.01),.007 (0.3),88.6 (0.7)
figs_all,39.4 (10.7),39.4 (10.7),39.4 (10.7),39.4 (10.7),4.0 (.01),.045 (1.3),57.4 (10.0)
figs_young,9.0 (7.4),9.0 (7.4),9.0 (7.4),9.0 (7.4),2.0 (.01),.008 (0.4),85.0 (2.2)
pfigs_young,24.3 (8.6),24.3 (8.6),24.3 (8.6),24.3 (8.6),4.4 (.02),.027 (1.1),86.8 (3.0)


In [156]:
make_res_table('iai/old', True).iloc[order_group, -8:]

Unnamed: 0_level_0,0.92,0.94,0.96,0.98,AUC of ROC,F1,Avg precision score,Accuracy
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
cart_all,7.6 (3.9),2.8 (1.1),1.6 (0.5),1.3 (0.5),.691 (.06),14.2 (.0),.086 (0.2),83.4 (0.6)
cart_old,9.0 (5.2),4.2 (1.2),4.2 (1.2),1.3 (0.9),.633 (.07),13.7 (.01),.089 (0.4),82.8 (0.6)
tao_all,0.3 (0.2),0.3 (0.2),0.0 (0.0),0.0 (0.0),.353 (.04),14.6 (.01),.083 (0.2),84.2 (0.7)
tao_old,10.2 (5.1),5.5 (1.4),4.2 (1.2),1.3 (0.9),.672 (.06),13.5 (.0),.088 (0.2),82.8 (0.6)
figs_all,24.9 (6.8),14.6 (6.4),1.4 (0.8),0.0 (0.0),.533 (.05),10.0 (.01),.063 (0.5),70.6 (3.6)
figs_old,28.0 (6.3),19.0 (6.2),9.2 (3.9),0.5 (0.3),.617 (.05),8.3 (.0),.055 (0.7),67.4 (1.2)
pfigs_old,27.9 (6.9),22.1 (7.0),13.8 (5.7),2.4 (1.0),.696 (.04),9.4 (.01),.065 (0.6),69.6 (2.5)


# fairness?

# CV table

In [232]:
def get_best_args(val_df_group, model_name, val_metrics):
    return val_df_group.filter(regex=model_name, axis=0).round(2).sort_values(
        by=val_metrics, kind='mergesort', ascending=False)['args'].iloc[0]

In [233]:
def make_average_val_table(dataset, name):
    seeds = sorted([int(path.split('_')[-1]) for path in glob.glob(f'results/{dataset}/seed_*')])
    tbi_val_df_avg = pd.read_csv(f'results/{dataset}/seed_0/{name}.csv').set_index('Unnamed: 0').iloc[:, :-1]
    tbi_val_df_std = []
    for seed in seeds:
        tbi_val_df_curr = pd.read_csv(f'results/{dataset}/seed_{seed}/{name}.csv').set_index('Unnamed: 0').iloc[:, :-1]
        tbi_val_df_avg += tbi_val_df_curr
        tbi_val_df_std.append(tbi_val_df_curr.values)
    tbi_val_df_std = pd.DataFrame(np.std(tbi_val_df_std, axis=0), index=tbi_val_df_avg.index, columns=tbi_val_df_avg.columns)
    return tbi_val_df_avg / 10, tbi_val_df_std / np.sqrt(10)

In [238]:
def get_hp_table(dataset, match, metric='spec94'):
    table, std = make_average_val_table(dataset, 'val')
    tbi_val_df_avg_old, tbi_val_df_std_old = table[table.index.str.contains(match)], std[std.index.str.contains(match)]
    hp_table = pd.DataFrame(
        {'8': tbi_val_df_avg_old.iloc[::3][metric].tolist(),
        '12': tbi_val_df_avg_old.iloc[1::3][metric].tolist(),
        '16': tbi_val_df_avg_old.iloc[2::3][metric].tolist()}, 
        index = tbi_val_df_avg_old.iloc[::3].index)
    
    hp_table_std = pd.DataFrame(
        {'8': tbi_val_df_std_old.iloc[::3][metric].tolist(),
        '12': tbi_val_df_std_old.iloc[1::3][metric].tolist(),
        '16': tbi_val_df_std_old.iloc[2::3][metric].tolist()}, 
        index = tbi_val_df_std_old.iloc[::3].index)
    
    for column in ['8', '12', '16']:
        hp_table[column] = ((hp_table[column] * 100).round(1).astype(str) +
            (hp_table_std[column] * 100).round(1).map(lambda x: f' ({x})'))
    
    return hp_table.rename(
        {'LLPCART_<2_8': 'G-CART w/ LR (C = 2.8)',
        'LSPCART_<2_8': 'G-CART w/ LR (C = 0.1)',
        'GBLPCART_<2_8': 'G-CART w/ GB (N = 100)',
        'GBSPCART_<2_8': 'G-CART w/ GB (N = 50)',
        'LLPFIGS_<2_8': '\methodabbrv~w/ LR (C = 2.8)',
        'LSPFIGS_<2_8': '\methodabbrv~w/ LR (C = 0.1)',
        'GBLPFIGS_<2_8': '\methodabbrv~w/ GB (N = 100)',
        'GBSPFIGS_<2_8': '\methodabbrv~w/ GB (N = 50)',
        'TAO_<2_8_1': 'TAO (1 iter)',
        'TAO_<2_12_5': 'TAO (5 iter)'}
    )

In [239]:
# get_hp_table('tbi', 'all', 'aps').style.highlight_max(axis=1)
# get_hp_table('tbi', 'all', 'spec90').round(2)

In [240]:
table = pd.concat((
    get_hp_table('tbi', '<2'),
    get_hp_table('tbi', '>2').set_index(get_hp_table('tbi', '<2').index)), axis=1)
table
print(table.to_latex())

\begin{tabular}{lllllll}
\toprule
{} &           8 &          12 &          16 &           8 &          12 &          16 \\
Unnamed: 0                   &             &             &             &             &             &             \\
\midrule
CART\_<2\_8                    &  15.1 (6.7) &  14.4 (6.1) &   0.0 (0.0) &  14.0 (7.8) &   8.9 (5.9) &   3.1 (0.9) \\
G-CART w/ LR (C = 2.8)       &   7.9 (6.7) &   3.1 (2.1) &   3.5 (1.7) &  19.0 (8.8) &  21.8 (8.4) &   2.1 (0.6) \\
G-CART w/ LR (C = 0.1)       &  20.4 (8.6) &   8.3 (6.6) &  10.1 (6.7) &  12.7 (7.6) &  14.9 (7.1) &   3.6 (0.9) \\
G-CART w/ GB (N = 100)       &  19.8 (8.3) &   7.2 (6.3) &   7.6 (6.1) &  13.3 (8.0) &  21.4 (8.5) &   9.0 (5.6) \\
G-CART w/ GB (N = 50)        &  26.8 (9.7) &   8.1 (6.3) &   8.4 (6.1) &  13.3 (8.0) &  21.4 (8.5) &   9.7 (5.6) \\
FIGS\_<2\_8                    &  13.7 (5.9) &   0.0 (0.0) &   0.0 (0.0) &  23.1 (8.8) &  13.0 (7.4) &   7.8 (5.6) \\
\textbackslash methodabbrv\textasciitilde w/ LR (C 

In [241]:
table, table_std = make_average_val_table('tbi', 'pmodel_val')
tbi_pm_df_avg, tbi_pm_df_std = table[table.index.str.startswith('P')], table_std[table_std.index.str.startswith('P')]

hp_table = pd.DataFrame(
        {'LR (C = 2.8)': tbi_pm_df_avg.iloc[::4]['spec94'].tolist(),
        'LR (C = 0.1)': tbi_pm_df_avg.iloc[1::4]['spec94'].tolist(),
        'GB (100 trees)': tbi_pm_df_avg.iloc[2::4]['spec94'].tolist(),
        'GB (50 trees)': tbi_pm_df_avg.iloc[3::4]['spec94'].tolist()}, 
        index = tbi_pm_df_avg.iloc[::4].index)

hp_table_std = pd.DataFrame(
        {'LR (C = 2.8)': tbi_pm_df_std.iloc[::4]['spec94'].tolist(),
        'LR (C = 0.1)': tbi_pm_df_std.iloc[1::4]['spec94'].tolist(),
        'GB (100 trees)': tbi_pm_df_std.iloc[2::4]['spec94'].tolist(),
        'GB (50 trees)': tbi_pm_df_std.iloc[3::4]['spec94'].tolist()}, 
        index = tbi_pm_df_std.iloc[::4].index)

for column in ['LR (C = 2.8)', 'LR (C = 0.1)', 'GB (100 trees)', 'GB (50 trees)']:
        hp_table[column] = ((hp_table[column] * 100).round(1).astype(str) +
                (hp_table_std[column] * 100).round(1).map(lambda x: f' ({x})'))

In [242]:
# hp_table.style.highlight_max(axis=1)
print(hp_table.to_latex())

\begin{tabular}{lllll}
\toprule
{} & LR (C = 2.8) & LR (C = 0.1) & GB (100 trees) & GB (50 trees) \\
Unnamed: 0   &              &              &                &               \\
\midrule
PFIGS\_LL\_all &   51.3 (5.8) &   54.5 (6.2) &     57.4 (5.6) &    44.6 (7.4) \\
PCART\_LL\_all &   27.8 (6.0) &   21.5 (5.9) &     19.0 (5.7) &    27.1 (6.5) \\
\bottomrule
\end{tabular}

