In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from pathlib import Path
import numpy as np
import pandas as pd
import pickle
import json
import torch
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import seaborn as sns

methods = ['OracleID', 'OracleMerged', 'ERM', 'GroupDRO', 'IRM', 'VREx', 'RVP','IGA', 'CORAL', 'MLDG']
es_mapping = {
    'train': "Model selection method: training domains",
    'val': "Model selection method: validation domain",
}

In [2]:
def load_final_results(folder):
    data = torch.load((folder.parent/'stats.pkl').open('rb')) 
    return data

In [3]:
def extract_stats(lst, keep):
    df_lst = []
    for i in lst:
        temp = {}
        for k in keep:
            if '/' in k:
                key1, key2 = k.split('/')
                if key2 in ['penalty_anneal_iters', 'lambda']:
                    temp[key2] = i[key1][i['args']['algorithm'].lower() +'_' + key2]
                else:                
                    if key2 == 'algorithm' and i[key1][key2] =='ERMID':
                        temp[key2] = 'OracleID'  
                    elif key2 == 'algorithm' and i[key1][key2] =='ERMMerged':
                        temp[key2] = 'OracleMerged' 
                    else:
                        temp[key2] = i[key1][key2]
                
            else:
                temp[k] = i[k]
        df_lst.append(temp)
    df = pd.DataFrame(df_lst).query('~(algorithm == "OracleMerged" & es_method == "val")')
    return df

def load_saved_data(path):
    lst = []
    for i in tqdm(path.glob('*/done')):
        lst.append(load_final_results(i))
    return lst

In [4]:
def agg_and_group(df, distinct_hparams, eval_metric, other_metrics):
    # best model for the particular trial
    df = df.groupby(distinct_hparams + ['trial_seed']).apply(lambda x: x.loc[x[eval_metric].idxmax(), [i for i in [eval_metric] + other_metrics]]).reset_index()
    # averaged over all 5 trials
    df = df.groupby(distinct_hparams).agg({i: ['mean', 'std'] for i in [eval_metric] + other_metrics}).reset_index()                                     
    df.columns = ["_".join(pair) if pair[1] != '' else pair[0] for pair in df.columns ]
    for i in [eval_metric] + other_metrics:
        df[i] = df[f'{i}_mean'].apply('{:.3f}'.format) + u"\u00B1" + df[f'{i}_std'].apply('{:.3f}'.format)
    return df

def pivot(df, values, index, methods = methods):
    temp = df.pivot_table(values = values, index = index, columns = ['algorithm'], aggfunc = lambda x: x)
    temp.columns = [pair[1] if pair[1] != '' else pair[0] for pair in temp.columns]
    temp = temp[list(temp.columns[:len(temp.columns) - len(methods)]) + methods]
    return temp

## eICU CorrLabel

In [5]:
path = Path('/scratch/ssd001/home/haoran/clinicaldg_results/eICUCorrLabel/')
lst = load_saved_data(path)   

2997it [00:02, 1432.41it/s]


In [6]:
keep = ['es_roc', 'test_results/South_roc', 
        'model_hparams/eicu_architecture', 'args/algorithm', 'args/es_method', 'args/hparams_seed',
       'args/trial_seed', 'model_hparams/corr_label_train_corrupt_mean']

df = extract_stats(lst, keep)

In [7]:
distinct_hparams = ['algorithm', 'es_method', 'corr_label_train_corrupt_mean']
eval_metric = 'es_roc'
other_metrics = ['South_roc']

df = agg_and_group(df, distinct_hparams, eval_metric, other_metrics)

In [8]:
pivot(df, ['South_roc'], ['es_method', 'corr_label_train_corrupt_mean'])

Unnamed: 0_level_0,Unnamed: 1_level_0,OracleID,OracleMerged,ERM,GroupDRO,IRM,VREx,RVP,IGA,CORAL,MLDG
es_method,corr_label_train_corrupt_mean,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
train,0.1,0.963±0.006,0.767±0.021,0.305±0.048,0.317±0.022,0.347±0.041,0.349±0.067,0.399±0.026,0.400±0.057,0.359±0.030,0.350±0.051
train,0.3,0.961±0.009,0.862±0.011,0.694±0.034,0.660±0.046,0.704±0.031,0.702±0.025,0.688±0.028,0.726±0.028,0.709±0.037,0.687±0.025
train,0.5,0.963±0.006,0.911±0.005,0.865±0.012,0.845±0.015,0.862±0.013,0.871±0.018,0.862±0.008,0.869±0.008,0.872±0.010,0.857±0.014
val,0.1,0.963±0.006,,0.678±0.087,0.677±0.065,0.733±0.016,0.612±0.142,0.715±0.045,0.683±0.069,0.689±0.082,0.690±0.056
val,0.3,0.963±0.006,,0.697±0.045,0.724±0.024,0.748±0.032,0.716±0.025,0.684±0.052,0.757±0.035,0.699±0.021,0.717±0.015
val,0.5,0.963±0.006,,0.865±0.013,0.862±0.014,0.868±0.016,0.845±0.023,0.855±0.017,0.865±0.010,0.862±0.007,0.860±0.009


## eICU CorrNoise

In [9]:
path = Path('/scratch/ssd001/home/haoran/clinicaldg_results/eICUGaussianNoise/')
lst = load_saved_data(path) 

3994it [00:02, 1500.64it/s]


In [10]:
keep = ['es_roc', 'test_results/South_roc',  
        'model_hparams/eicu_architecture', 'args/algorithm', 'args/es_method', 'args/hparams_seed',
       'args/trial_seed', 'model_hparams/corr_noise_train_corrupt_mean', 'model_hparams/corr_noise_train_corrupt_dist']

df = extract_stats(lst, keep)

In [11]:
distinct_hparams = ['algorithm', 'es_method', 'corr_noise_train_corrupt_mean', 'corr_noise_train_corrupt_dist']
eval_metric = 'es_roc'
other_metrics = ['South_roc']

In [12]:
df = agg_and_group(df, distinct_hparams, eval_metric, other_metrics)

In [13]:
df = df[df.corr_noise_train_corrupt_mean.isin([1.0, 2.0])]

In [14]:
print('South ROC')
pivot(df, ['South_roc'], ['es_method', 'corr_noise_train_corrupt_mean', 'corr_noise_train_corrupt_dist'])

South ROC


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,OracleID,OracleMerged,ERM,GroupDRO,IRM,VREx,RVP,IGA,CORAL,MLDG
es_method,corr_noise_train_corrupt_mean,corr_noise_train_corrupt_dist,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
train,1.0,0.5,0.959±0.008,0.794±0.023,0.388±0.047,0.422±0.046,0.386±0.021,0.374±0.031,0.418±0.037,0.404±0.042,0.440±0.050,0.410±0.020
train,1.0,1.0,0.959±0.008,0.826±0.026,0.556±0.032,0.614±0.080,0.557±0.041,0.571±0.013,0.655±0.030,0.565±0.064,0.600±0.033,0.548±0.014
train,2.0,0.5,0.954±0.008,0.717±0.022,0.209±0.027,0.214±0.048,0.207±0.020,0.193±0.023,0.191±0.018,0.200±0.026,0.234±0.027,0.199±0.028
train,2.0,1.0,0.954±0.008,0.730±0.024,0.244±0.025,0.253±0.031,0.254±0.028,0.245±0.025,0.251±0.019,0.263±0.027,0.279±0.033,0.281±0.027
val,1.0,0.5,0.959±0.008,,0.446±0.139,0.415±0.057,0.463±0.097,0.494±0.085,0.466±0.049,0.596±0.132,0.497±0.101,0.414±0.071
val,1.0,1.0,0.954±0.008,,0.561±0.090,0.585±0.057,0.522±0.020,0.566±0.099,0.669±0.062,0.641±0.084,0.591±0.058,0.541±0.053
val,2.0,0.5,0.955±0.008,,0.492±0.181,0.489±0.110,0.423±0.181,0.395±0.075,0.385±0.060,0.513±0.087,0.503±0.053,0.467±0.112
val,2.0,1.0,0.959±0.008,,0.506±0.149,0.417±0.169,0.436±0.171,0.405±0.110,0.414±0.128,0.546±0.152,0.366±0.111,0.347±0.087
