In [72]:
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import glob
import pandas as pd
import os
import copy
import matplotlib as mpl
from numpy import concatenate as cat
from collections import defaultdict
from sklearn import metrics
mpl.rcParams['figure.dpi'] = 250
pd.set_option('precision', 3)
DATASET = 'iai'

# if 'tbi' in DATASET or 'csi' in DATASET:
seeds = sorted([
    int(path.split('_')[-1]) for path in glob.glob(f'results/{DATASET}/seed_*')])

if DATASET == 'sim':
    sens_levels = []
    METRICS = ['auc', 'aps', 'acc', 'f1']
else:
    sens_levels = [0.92, 0.94, 0.96, 0.98]
    METRICS = ['spec_0.92', 'spec_0.94', 'spec_0.96', 'spec_0.98', 'auc', 'aps', 'acc', 'f1']

In [73]:
# def plott(model_name, sens, spec, ppv, ax):
#     if 'pecarn' in model_name.lower():
#         # ax.plot(sens[0], spec[0], '.-', label=model_name)
#         pass
#     elif 'mix' not in model_name.lower() and 'pcart' not in model_name.lower():
#         ax.plot(sens, spec, '.-', label=model_name)

In [74]:
# def multiplot(paths, ax, suffix=""):
#     for model_file in paths:
#         basename = os.path.basename(model_file).split('.')[0]
#         dct = pkl.load(open(model_file, 'rb'))
#         plott(basename, dct['sens_tune'], dct['spec_tune'], dct['ppv_tune'], ax)
#     ax.legend(frameon=False, loc='best')
#     ax.set_xlim(0.9, 1)
    # ax.set_ylim(0, 0.2)

In [75]:
# def multiseedplot(group='all'):
#     fig, axes = plt.subplots(3, 3, figsize=(15, 15))
#     for k, seed in enumerate(seeds[:9]):
#         paths = sorted(glob.glob(f'results/{DATASET}/seed_{seed}/{group}/*.pkl'))
#         multiplot(paths, axes[k // 3, k % 3], group)
#     plt.xlabel('sens')
#     plt.ylabel('spec')
#     plt.tight_layout()

In [76]:
# multiseedplot()

In [77]:
# multiseedplot(group='young')

In [78]:
# multiseedplot(group='old')

In [89]:
def pkl_to_table(group):
    table = defaultdict(lambda:[])
    metrics_ = copy.copy(METRICS)
    if DATASET == 'iai' and group == 'young':
        metrics_.remove('auc')

    for i in seeds:

        seed_paths = sorted(glob.glob(f'results/{DATASET}/seed_{i}/{group}/*.pkl'))
        table_index = [os.path.basename(f).split('.')[0] for f in seed_paths]
        for model_file in seed_paths:
            dct = pkl.load(open(model_file, 'rb'))
            specs = np.array(dct['spec_tune'])
            senses = np.array(dct['sens_tune'])
            precisions = np.array(dct['ppv_tune'])
            for sens in sens_levels:
                table[f'spec_{sens}_seed_{i}'].append(
                    np.max(specs[senses > sens]) if specs[senses > sens].shape[0] > 0 else 0.0)
            if DATASET != 'iai' or group != 'young':
                table[f'auc_seed_{i}'].append(metrics.auc(1 - specs, senses))
            table[f'aps_seed_{i}'].append(-np.sum(np.diff(senses) * np.array(precisions)[:-1]))
            table[f'acc_seed_{i}'].append(dct['acc'])
            table[f'f1_seed_{i}'].append(dct['f1'])
        
    res_table = pd.DataFrame(table, index=table_index)
    if sens_levels:
        for i in seeds:
            res_table[f'high_spec_avg_seed_{i}'] = res_table.loc[
                :, [f'spec_0.92_seed_{i}', f'spec_0.94_seed_{i}', f'spec_0.96_seed_{i}', f'spec_0.98_seed_{i}']].mean(axis=1)

    for metric in metrics_:
        res_table_metric = res_table.loc[:, res_table.columns.str.contains(metric)]
        res_table[metric] = res_table_metric.mean(axis=1)

    # res_table['high_spec_avg'] = res_table[
    #     [f'spec_0.94', f'spec_0.96', f'spec_0.98']].mean(axis=1)

    for metric in metrics_:
        res_table_metric = res_table.loc[:, res_table.columns.str.contains(metric)]
        res_table[f'{metric}_std_err'] = res_table_metric.std(axis=1) / np.sqrt(10)
        # res_table[f'{metric}_std'] = res_table_metric.std(axis=1)
        
    return res_table

In [95]:
# pkl_to_table('all').to_csv(f'results/{DATASET}/all_average.csv')
# pkl_to_table('young').to_csv(f'results/{DATASET}/young_average.csv')
# pkl_to_table('old').to_csv(f'results/{DATASET}/old_average.csv')

In [91]:
pkl_to_table('all').iloc[:, -17:].style.background_gradient()#highlight_max(color='blue')

Unnamed: 0,high_spec_avg_seed_16,spec_0.92,spec_0.94,spec_0.96,spec_0.98,auc,aps,acc,f1,spec_0.92_std_err,spec_0.94_std_err,spec_0.96_std_err,spec_0.98_std_err,auc_std_err,aps_std_err,acc_std_err,f1_std_err
cart_all,0.112,0.118,0.027,0.016,0.014,0.688,0.081,0.833,0.134,0.05,0.01,0.005,0.005,0.057,0.002,0.006,0.004
cart_combine,0.023,0.11,0.093,0.028,0.0,0.688,0.08,0.833,0.13,0.016,0.018,0.014,0.0,0.012,0.003,0.005,0.005
cart_mix,0.03,0.104,0.057,0.057,0.017,0.734,0.086,0.827,0.129,0.047,0.013,0.013,0.01,0.039,0.003,0.005,0.005
figs_all,0.056,0.321,0.137,0.014,0.0,0.541,0.06,0.693,0.094,0.055,0.06,0.008,0.0,0.044,0.005,0.041,0.01
figs_combine,0.028,0.188,0.092,0.025,0.009,0.653,0.05,0.691,0.08,0.044,0.022,0.017,0.008,0.018,0.005,0.011,0.003
figs_mix,0.248,0.307,0.133,0.088,0.007,0.653,0.057,0.678,0.08,0.064,0.051,0.037,0.003,0.047,0.007,0.016,0.003
pcart_combine,0.065,0.117,0.101,0.038,0.007,0.732,0.084,0.815,0.125,0.013,0.016,0.013,0.004,0.015,0.004,0.015,0.007
pcart_mix,0.054,0.141,0.119,0.057,0.006,0.767,0.086,0.832,0.135,0.049,0.045,0.015,0.003,0.017,0.002,0.006,0.005
pecarn_combine,0.314,0.425,0.425,0.299,0.084,0.28,0.028,0.435,0.055,0.003,0.003,0.062,0.053,0.002,0.001,0.003,0.002
pfigs_combine,0.311,0.297,0.188,0.117,0.03,0.671,0.064,0.713,0.091,0.069,0.066,0.051,0.013,0.035,0.005,0.022,0.005


In [92]:
s = '7'
pkl_to_table('all').iloc[[0, 1, 11, 12, 3, 4, 9], :].loc[:, [f'spec_0.92_seed_{s}', f'spec_0.94_seed_{s}', f'spec_0.96_seed_{s}', f'spec_0.98_seed_{s}']
    ].style.background_gradient()#highlight_max(color='blue')

Unnamed: 0,spec_0.92_seed_7,spec_0.94_seed_7,spec_0.96_seed_7,spec_0.98_seed_7
cart_all,0.0,0.0,0.0,0.0
cart_combine,0.096,0.096,0.0,0.0
tao_all,0.0,0.0,0.0,0.0
tao_combine,0.096,0.096,0.0,0.0
figs_all,0.097,0.0,0.0,0.0
figs_combine,0.096,0.096,0.0,0.0
pfigs_combine,0.0,0.0,0.0,0.0


In [93]:
pkl_to_table('young').iloc[:, -15:].drop('pecarn_young').style.background_gradient()#highlight_max(color='blue')

Unnamed: 0,high_spec_avg_seed_16,spec_0.92,spec_0.94,spec_0.96,spec_0.98,aps,acc,f1,spec_0.92_std_err,spec_0.94_std_err,spec_0.96_std_err,spec_0.98_std_err,aps_std_err,acc_std_err,f1_std_err
cart_all,0.183,0.296,0.296,0.296,0.296,0.044,0.821,0.059,0.097,0.097,0.097,0.097,0.007,0.011,0.01
cart_young,0.0,0.0,0.0,0.0,0.0,0.007,0.886,0.025,0.0,0.0,0.0,0.0,0.003,0.007,0.01
figs_all,0.21,0.394,0.394,0.394,0.394,0.045,0.574,0.04,0.107,0.107,0.107,0.107,0.013,0.1,0.007
figs_young,0.0,0.09,0.09,0.09,0.09,0.008,0.85,0.02,0.074,0.074,0.074,0.074,0.004,0.022,0.011
pcart_young,0.0,0.044,0.044,0.044,0.044,0.019,0.847,0.017,0.019,0.019,0.019,0.019,0.006,0.017,0.008
pfigs_young,0.603,0.243,0.243,0.243,0.243,0.027,0.868,0.044,0.086,0.086,0.086,0.086,0.011,0.03,0.017
tao_all,0.215,0.333,0.333,0.333,0.333,0.047,0.846,0.067,0.092,0.092,0.092,0.092,0.007,0.008,0.011
tao_young,0.0,0.0,0.0,0.0,0.0,0.007,0.886,0.025,0.0,0.0,0.0,0.0,0.003,0.007,0.01


In [94]:
pkl_to_table('old').iloc[:, -15:].drop('pecarn_old').style.background_gradient()#highlight_max(color='blue')

Unnamed: 0,spec_0.94,spec_0.96,spec_0.98,auc,aps,acc,f1,spec_0.92_std_err,spec_0.94_std_err,spec_0.96_std_err,spec_0.98_std_err,auc_std_err,aps_std_err,acc_std_err,f1_std_err
cart_all,0.028,0.016,0.013,0.691,0.086,0.834,0.142,0.039,0.011,0.005,0.005,0.057,0.002,0.006,0.005
cart_old,0.042,0.042,0.013,0.633,0.089,0.828,0.137,0.052,0.012,0.012,0.009,0.067,0.004,0.006,0.006
figs_all,0.146,0.014,0.0,0.533,0.063,0.706,0.1,0.068,0.064,0.008,0.0,0.047,0.005,0.036,0.01
figs_old,0.19,0.092,0.005,0.617,0.055,0.674,0.083,0.063,0.062,0.039,0.003,0.055,0.007,0.012,0.003
pcart_old,0.105,0.039,0.007,0.629,0.092,0.811,0.133,0.061,0.051,0.014,0.004,0.068,0.004,0.017,0.007
pfigs_old,0.221,0.138,0.024,0.696,0.065,0.696,0.094,0.069,0.07,0.057,0.01,0.04,0.006,0.025,0.006
tao_all,0.003,0.0,0.0,0.353,0.083,0.842,0.146,0.002,0.002,0.0,0.0,0.044,0.002,0.007,0.006
tao_old,0.055,0.042,0.013,0.672,0.088,0.828,0.135,0.051,0.014,0.012,0.009,0.056,0.002,0.006,0.005
