In [26]:
import os
import sys
import json
import copy
import socket
import getpass
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
sys.path.append('../')
import pickle
import warnings
from sklearn.metrics import recall_score, matthews_corrcoef, roc_auc_score, f1_score
from collections import defaultdict
import json

warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

def flatten(l):
    return [item for sublist in l for item in sublist]

def compute_opt_thres(target, pred):
    opt_thres = 0
    opt_f1 = 0
    for i in np.arange(0.001, 0.999, 0.001):
        f1 = f1_score(target, pred >= i)
        if f1 >= opt_f1:
            opt_thres = i
            opt_f1 = f1
    return opt_thres

plt.rcParams.update({'font.size': 20})

In [2]:
root_dir = Path(f'/path/to/your/root')

In [3]:
if Path('opt_thres.pkl').is_file():
    already_computed = set(pickle.load(Path('opt_thres.pkl').open('rb')).keys())
else:
    already_computed = set()

In [5]:
results = []
for i in tqdm(root_dir.glob('**/done')):
    args = json.load((i.parent/'args.json').open('r'))
    if (args['dataset'][0], args['task'], args['attr'], args['algorithm']) in already_computed:
        continue
    
    final_res = pickle.load((i.parent/'final_results.pkl').open('rb'))
    
    ssets = ['va', 'te', 'MIMIC-sex-te', 'CheXpert-sex-te', 'NIH-sex-te', 
             'PadChest-sex-te', 'VinDr-sex-te']
    if args['task'] == 'Pneumothorax':
        ssets.append('SIIM-sex-te')
    
    for sset in ssets:
        if sset in final_res:
            args[f'{sset}_auroc'] = final_res[sset]['overall']['AUROC']
            if sset == 'va':
                args[f'{sset}_min_attr_auroc'] = final_res[sset]['min_attr']['AUROC']
    args['va_y'] = final_res['va']['y']
    args['va_preds'] = final_res['va']['preds']
    
    results.append(args)
df = pd.DataFrame(results)

6155it [00:14, 430.52it/s]


In [6]:
df['dataset'] = df['dataset'].apply(lambda x: x[0])

In [7]:
df.shape

(663, 32)

## Optimal Threshold

In [8]:
best_models = df.groupby(['dataset', 'task', 'attr', 'algorithm']).apply(lambda x: x.loc[x['va_min_attr_auroc'].idxmax()])

In [30]:
opt_thres = {}
for idx, row in tqdm(best_models.iterrows(), total = len(best_models)):
    dataset, task, attr, algorithm = idx
#     if dataset not in opt_thres:
#         opt_thres[dataset] = {}
    opt_thres[(dataset, task, attr, algorithm)] = np.round(compute_opt_thres(row['va_y'], row['va_preds']), 3)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:29<00:00,  1.23s/it]


In [31]:
if Path('opt_thres.pkl').is_file():
    old_file = pickle.load(Path('opt_thres.pkl').open('rb'))
else:
    old_file = {}

In [32]:
opt_thres = {**old_file, **opt_thres}

In [33]:
opt_thres

{('CheXpert', 'Cardiomegaly', 'age', 'CDANN'): 0.305,
 ('CheXpert', 'Cardiomegaly', 'age', 'DANN'): 0.217,
 ('CheXpert', 'Cardiomegaly', 'age', 'ERM'): 0.342,
 ('CheXpert', 'Cardiomegaly', 'age', 'GroupDRO'): 0.759,
 ('CheXpert', 'Cardiomegaly', 'age', 'MA'): 0.342,
 ('CheXpert', 'Cardiomegaly', 'age', 'ReSample'): 0.811,
 ('CheXpert', 'Cardiomegaly', 'ethnicity', 'CDANN'): 0.255,
 ('CheXpert', 'Cardiomegaly', 'ethnicity', 'DANN'): 0.332,
 ('CheXpert', 'Cardiomegaly', 'ethnicity', 'ERM'): 0.333,
 ('CheXpert', 'Cardiomegaly', 'ethnicity', 'GroupDRO'): 0.83,
 ('CheXpert', 'Cardiomegaly', 'ethnicity', 'MA'): 0.342,
 ('CheXpert', 'Cardiomegaly', 'ethnicity', 'ReSample'): 0.753,
 ('CheXpert', 'Cardiomegaly', 'sex', 'CDANN'): 0.207,
 ('CheXpert', 'Cardiomegaly', 'sex', 'DANN'): 0.197,
 ('CheXpert', 'Cardiomegaly', 'sex', 'ERM'): 0.255,
 ('CheXpert', 'Cardiomegaly', 'sex', 'GroupDRO'): 0.822,
 ('CheXpert', 'Cardiomegaly', 'sex', 'MA'): 0.345,
 ('CheXpert', 'Cardiomegaly', 'sex', 'ReSample'): 

In [34]:
pickle.dump(opt_thres, open('opt_thres.pkl', 'wb'))