In [328]:
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import glob
import pandas as pd
import os
import matplotlib as mpl
from numpy import concatenate as cat
from collections import defaultdict
from sklearn import metrics
mpl.rcParams['figure.dpi'] = 250
pd.set_option('precision', 3)
DATASET = 'tbi'

if DATASET == 'tbi':
    seeds = range(7)
elif DATASET == 'csi' or DATASET == 'csi_gbp':
    seeds = range(7)
elif DATASET == 'iai':
    seeds = list(range(7))
    seeds.remove(3)

if DATASET == 'sim':
    seeds = range(10)
    sens_levels = []
    metrics_ = ['auc', 'aps', 'acc', 'f1']
else:
    sens_levels = [0.92, 0.94, 0.96, 0.98]
    metrics_ = ['high_spec_avg', 'spec_0.92', 'spec_0.94', 'spec_0.96', 'spec_0.98', 'auc', 'aps', 'acc']

In [329]:
def plott(model_name, sens, spec, ppv, ax):
    if 'pecarn' in model_name.lower():
        ax.plot(sens[0], spec[0], '.-', label=model_name)
    else:
        ax.plot(sens, spec, '.-', label=model_name)

In [330]:
def multiplot(paths, ax, suffix=""):
    for model_file in paths:
        basename = os.path.basename(model_file).split('.')[0]
        dct = pkl.load(open(model_file, 'rb'))
        plott(basename, dct['sens_tune'], dct['spec_tune'], dct['ppv_tune'], ax)
    ax.legend(frameon=False, loc='best')
    ax.set_xlim(0.5, 1.05)
    # ax.set_ylim(0, 0.2)

In [331]:
def multiseedplot(group='all'):
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    for k, i in enumerate(seeds[:6]):
        paths = sorted(glob.glob(f'results/{DATASET}/seed_{i}/{group}/*.pkl'))
        multiplot(paths, axes[k // 3, k % 3], group)
    plt.xlabel('sens')
    plt.ylabel('spec')
    plt.tight_layout()

In [332]:
# multiseedplot()

In [333]:
# multiseedplot(group='young')

In [334]:
# multiseedplot(group='old')

In [337]:
def pkl_to_table(group):
    table = defaultdict(lambda:[])
    for i in seeds:
        seed_paths = sorted(glob.glob(f'results/{DATASET}/seed_{i}/{group}/*.pkl'))
        table_index = [os.path.basename(f).split('.')[0] for f in seed_paths]
        for model_file in seed_paths:
            dct = pkl.load(open(model_file, 'rb'))
            specs = np.array(dct['spec_tune'])
            senses = np.array(dct['sens_tune'])
            precisions = np.array(dct['ppv_tune'])
            for sens in sens_levels:
                table[f'spec_{sens}_seed_{i}'].append(
                    np.max(specs[senses > sens]) if specs[senses > sens].shape[0] > 0 else 0.0)
            table[f'auc_seed_{i}'].append(metrics.auc(1 - specs, senses))
            # table[f'auprc_seed_{i}'].append(metrics.auc(senses, np.nan_to_num(precisions, nan=1)))
            table[f'aps_seed_{i}'].append(-np.sum(np.diff(senses) * np.array(precisions)[:-1]))
            table[f'acc_seed_{i}'].append(dct['acc'])
            # table[f'f1_seed_{i}'].append(dct['f1'])
        
    res_table = pd.DataFrame(table, index=table_index)
    if sens_levels:
        for i in seeds:
            res_table[f'high_spec_avg_seed_{i}'] = res_table.loc[
                :, [f'spec_0.92_seed_{i}', f'spec_0.94_seed_{i}', f'spec_0.96_seed_{i}', f'spec_0.98_seed_{i}']].mean(axis=1)

    for metric in metrics_:
        res_table_metric = res_table.loc[:, res_table.columns.str.contains(metric)]
        res_table[metric] = res_table_metric.mean(axis=1)

    # res_table['high_spec_avg'] = res_table[
    #     [f'spec_0.94', f'spec_0.96', f'spec_0.98']].mean(axis=1)

    for metric in metrics_:
        res_table_metric = res_table.loc[:, res_table.columns.str.contains(metric)]
        res_table[f'{metric}_std_err'] = res_table_metric.std(axis=1) / np.sqrt(10)
        # res_table[f'{metric}_std'] = res_table_metric.std(axis=1)
        
    return res_table

In [338]:
pkl_to_table('all').to_csv(f'results/{DATASET}/all_average.csv')
pkl_to_table('young').to_csv(f'results/{DATASET}/young_average.csv')
pkl_to_table('old').to_csv(f'results/{DATASET}/old_average.csv')

In [346]:
pkl_to_table('all').iloc[[0, 3, 7, 6], -16:].style.highlight_max(color='blue')

Unnamed: 0,high_spec_avg,spec_0.92,spec_0.94,spec_0.96,spec_0.98,auc,aps,acc,high_spec_avg_std_err,spec_0.92_std_err,spec_0.94_std_err,spec_0.96_std_err,spec_0.98_std_err,auc_std_err,aps_std_err,acc_std_err
cart_all,0.111,0.198,0.118,0.105,0.024,0.728,0.061,0.776,0.052,0.085,0.069,0.067,0.006,0.058,0.003,0.006
figs_all,0.145,0.197,0.185,0.175,0.022,0.628,0.043,0.676,0.063,0.079,0.082,0.083,0.011,0.066,0.003,0.011
tao_all,0.088,0.115,0.115,0.102,0.021,0.55,0.049,0.774,0.055,0.072,0.072,0.07,0.009,0.086,0.003,0.009
pfigs_all,0.215,0.387,0.27,0.187,0.014,0.684,0.046,0.718,0.047,0.069,0.061,0.068,0.009,0.03,0.004,0.008


In [347]:
pkl_to_table('young').iloc[[0, 1, 2, 3, 7, 8, 6], -16:].style.highlight_max(color='blue')

Unnamed: 0,high_spec_avg,spec_0.92,spec_0.94,spec_0.96,spec_0.98,auc,aps,acc,high_spec_avg_std_err,spec_0.92_std_err,spec_0.94_std_err,spec_0.96_std_err,spec_0.98_std_err,auc_std_err,aps_std_err,acc_std_err
cart_all,0.115,0.181,0.181,0.092,0.008,0.648,0.08,0.812,0.057,0.09,0.09,0.064,0.005,0.083,0.009,0.01
cart_young,0.064,0.121,0.117,0.009,0.009,0.434,0.054,0.803,0.041,0.082,0.082,0.005,0.005,0.1,0.01,0.018
figs_all,0.213,0.356,0.348,0.115,0.031,0.507,0.045,0.659,0.061,0.097,0.095,0.065,0.023,0.067,0.005,0.008
figs_young,0.191,0.241,0.241,0.186,0.096,0.425,0.056,0.786,0.074,0.086,0.086,0.087,0.067,0.087,0.013,0.011
tao_all,0.052,0.096,0.096,0.007,0.007,0.452,0.055,0.798,0.039,0.074,0.074,0.005,0.005,0.092,0.006,0.016
tao_young,0.071,0.139,0.135,0.004,0.004,0.482,0.051,0.805,0.04,0.081,0.082,0.003,0.003,0.095,0.008,0.017
pfigs_young,0.266,0.343,0.343,0.24,0.137,0.522,0.052,0.758,0.087,0.105,0.105,0.1,0.08,0.08,0.012,0.015


In [348]:
pkl_to_table('old').iloc[[0, 1, 2, 3, 7, 8, 6], -16:].style.highlight_max(color='blue')

Unnamed: 0,high_spec_avg,spec_0.92,spec_0.94,spec_0.96,spec_0.98,auc,aps,acc,high_spec_avg_std_err,spec_0.92_std_err,spec_0.94_std_err,spec_0.96_std_err,spec_0.98_std_err,auc_std_err,aps_std_err,acc_std_err
cart_all,0.12,0.215,0.121,0.121,0.024,0.723,0.056,0.764,0.049,0.085,0.065,0.065,0.008,0.058,0.002,0.007
cart_old,0.093,0.183,0.093,0.093,0.004,0.475,0.046,0.755,0.048,0.084,0.064,0.064,0.003,0.085,0.003,0.011
figs_all,0.172,0.282,0.201,0.189,0.014,0.639,0.045,0.682,0.06,0.091,0.079,0.081,0.008,0.068,0.003,0.012
figs_old,0.224,0.348,0.277,0.267,0.004,0.488,0.046,0.686,0.068,0.093,0.095,0.092,0.003,0.061,0.003,0.005
tao_all,0.117,0.211,0.117,0.117,0.021,0.554,0.05,0.766,0.054,0.089,0.07,0.07,0.011,0.083,0.003,0.007
tao_old,0.092,0.183,0.092,0.092,0.0,0.468,0.046,0.77,0.048,0.083,0.063,0.063,0.0,0.082,0.003,0.011
pfigs_old,0.3,0.467,0.463,0.172,0.097,0.673,0.045,0.705,0.058,0.067,0.067,0.079,0.061,0.054,0.001,0.01
