In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import glob
import os
import sys
sys.path.append("../")
import matplotlib.pyplot as plt
import shutil
from pathlib import Path
from tqdm import tqdm
import json
import hashlib
from cxr_fairness.data import Constants
import torch
from cxr_fairness.metrics import StandardEvaluator
import hashlib

pd.set_option('display.max_columns', 50)
pd.set_option('display.max_rows', 100)

select_metrics = {
  #  'roc': 'max',
    'worst_roc': 'max',
    # 'roc_gap': 'min'
}
separating_factors = ['dataset', 'model', 'task'] # select best set for each combination
default_group_vars = ['protected_attr', 'subset_group']

In [2]:
project_dirs = [Path("/scratch/hdd001/home/haoran/cxr_debias/"),
             Path('/scratch/ssd001/home/haoran/cxr_debias/')] # list of directories with trained models

In [3]:
hparams = ['model', 'checkpoint_freq', 'es_patience', 'lr', 'batch_size', 'clf_head_ratio', 'groupdro_eta', 'distmatch_penalty_weight',
              'match_type', 'adv_alpha', 'es_metric', 'algorithm', 'data_type', 'fairalm_threshold', 'fairalm_surrogate', 'fairalm_eta',
          'JTT_weight', 'JTT_threshold']
res = []
for i in tqdm([di for project_dir in project_dirs for di in project_dir.glob('**/results.pkl')]):   
    args_i = json.load((i.parent/'args.json').open('r'))
    args_i['config_filename'] = i.parent.name    
    metrics = torch.load(i)['val_metrics']
    for metric in select_metrics:
        args_i[metric] = metrics[metric]
    args_i['hparams_id'] = hashlib.md5(str([args_i[j] for j in hparams if j in args_i]).encode('utf-8')).hexdigest() 
    res.append(args_i)

100%|██████████| 5609/5609 [20:56<00:00,  4.46it/s]  


In [4]:
df_all = pd.DataFrame(res)

In [5]:
df_all = df_all[~((df_all.data_type == 'balanced') & (df_all.exp_name == 'arl'))]
df_all = df_all[~((df_all.data_type == 'balanced') & (df_all.exp_name == 'jtt'))]

In [6]:
unique_exps = df_all.exp_name.unique()
unique_exps

array(['balanced_concat', 'MMD', 'mean_match', 'single_group',
       'erm_baseline_concat', 'dro', 'erm_baseline', 'balanced', 'arl',
       'simple_adv', 'jtt', 'fairalm'], dtype=object)

In [7]:
# sanity check
df_all.groupby(['hparams_id', 'exp_name'] + separating_factors + default_group_vars, dropna = False).count()['val_fold'].describe()

count    1109.000000
mean        4.976555
std         0.277610
min         1.000000
25%         5.000000
50%         5.000000
75%         5.000000
max         5.000000
Name: val_fold, dtype: float64

### Get best configs

In [3]:
unique_exps = df_all.exp_name.unique()
unique_exps

array(['balanced_concat', 'MMD', 'mean_match', 'single_group',
       'erm_baseline_concat', 'dro', 'erm_baseline', 'balanced', 'arl',
       'simple_adv', 'jtt', 'fairalm'], dtype=object)

In [4]:
selected_configs_raw = [] # best hparams
for select_metric in select_metrics:
    for exp in unique_exps:
        if exp == 'single_group' and select_metric == 'roc_gap':
            continue
        df = df_all[(df_all.exp_name == exp)]

        mean_performance = (
            pd.DataFrame(
                df
                .groupby(['hparams_id'] + separating_factors + default_group_vars, dropna = False)
                .agg(performance=(select_metric, 'mean'))
                .reset_index()
            )
        )

        best_model = (
            mean_performance.groupby(separating_factors + default_group_vars, dropna = False)
            .agg(performance=('performance',select_metrics[select_metric])).reset_index()
            .merge(mean_performance)
            .drop_duplicates(subset = separating_factors + default_group_vars)
        )

        selected_config = (
            best_model[['hparams_id']+ separating_factors + default_group_vars].dropna(axis = 1, how = 'all')
            .merge(df)
        )
        
        selected_config['select_metric'] = select_metric
        selected_configs_raw.append(selected_config)

In [5]:
selected_configs = pd.concat(selected_configs_raw)

In [12]:
selected_configs_raw = [] # 
for exp in ['MMD', 'mean_match', 'simple_adv']:
    df = df_all[(df_all.exp_name == exp) & (df_all.dataset == 'MIMIC')]
    mean_performance = (
        pd.DataFrame(
            df
            .groupby(['hparams_id', 'protected_attr', 'distmatch_penalty_weight', 'adv_alpha', 'task'], dropna = False)
            .agg(performance=(select_metric, 'mean'))
            .reset_index()
        )
    )
    best_model = (
        mean_performance.groupby(['distmatch_penalty_weight', 'adv_alpha', 'protected_attr', 'task'], dropna = False)
        .agg(performance=('performance','max')).reset_index()
        .merge(mean_performance)
        .drop_duplicates(subset = ['protected_attr', 'distmatch_penalty_weight', 'adv_alpha', 'task'])
    )
    selected_config = (
        best_model[['hparams_id', 'task', 'protected_attr']].dropna(axis = 1, how = 'all')
        .merge(df)
    )
    selected_config['select_metric'] = 'vary_lambda_exp'
    selected_configs_raw.append(selected_config)

In [13]:
def add_exp_name(x):
    if x.exp_name == 'simple_adv':
        return x.exp_name + '_' + str(x.adv_alpha)
    else:
        return x.exp_name + '_' + str(x.distmatch_penalty_weight)

In [14]:
temp = pd.concat(selected_configs_raw)
temp.exp_name = temp.apply(add_exp_name, axis = 1)

In [15]:
selected_configs_final = pd.concat((selected_configs, temp))

In [16]:
drop_exps = ['balanced_concat', 'erm_baseline_concat']
selected_configs_final = selected_configs_final[~selected_configs_final.exp_name.isin(drop_exps)]

In [17]:
selected_configs_final.to_pickle(
    project_dirs[0]/'selected_configs.pkl'
)

In [18]:
selected_configs_final.exp_name.unique()

array(['MMD', 'mean_match', 'single_group', 'dro', 'erm_baseline',
       'balanced', 'arl', 'simple_adv', 'jtt', 'fairalm', 'MMD_0.1',
       'MMD_0.5', 'MMD_0.75', 'MMD_1.0', 'MMD_2.0', 'MMD_3.0', 'MMD_4.0',
       'MMD_5.0', 'MMD_10.0', 'MMD_20.0', 'MMD_25.0', 'MMD_30.0',
       'MMD_50.0', 'MMD_100.0', 'mean_match_0.1', 'mean_match_0.5',
       'mean_match_0.75', 'mean_match_1.0', 'mean_match_2.0',
       'mean_match_3.0', 'mean_match_4.0', 'mean_match_5.0',
       'mean_match_10.0', 'mean_match_20.0', 'mean_match_25.0',
       'mean_match_30.0', 'mean_match_50.0', 'mean_match_100.0',
       'simple_adv_0.01', 'simple_adv_0.05', 'simple_adv_0.1',
       'simple_adv_1.0', 'simple_adv_2.0', 'simple_adv_5.0',
       'simple_adv_10.0', 'simple_adv_20.0', 'simple_adv_25.0',
       'simple_adv_30.0', 'simple_adv_50.0', 'simple_adv_100.0'],
      dtype=object)