In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from pathlib import Path
import numpy as np
import pandas as pd
import pickle
import json
import torch
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import seaborn as sns

methods = ['Oracle', 'ERM', 'GroupDRO', 'IRM', 'VREx', 'RVP','IGA', 'CORAL', 'MLDG']
es_mapping = {
    'train': "Model selection method: training domains",
    'val': "Model selection method: validation domain",
}

In [2]:
def load_final_results(folder):
    data = torch.load((folder.parent/'stats.pkl').open('rb')) 
    if 'es_roc' not in data:
        results = []
        lines = (folder.parent/'results.jsonl').open('r').readlines()
        for line in lines:
            results.append(json.loads(line))
        result_row = [i for i in results if i['step'] == data['es_step']][0]
        data['es_roc'] = result_row['es_roc']
    return data

In [3]:
def extract_stats(lst, keep):
    df_lst = []
    for i in lst:
        temp = {}
        for k in keep:
            if '/' in k:
                key1, key2 = k.split('/')
                if key2 in ['penalty_anneal_iters', 'lambda']:
                    temp[key2] = i[key1][i['args']['algorithm'].lower() +'_' + key2]
                else:                
                    temp[key2] = 'Oracle' if key2 == 'algorithm' and i[key1][key2] =='ERMID' else i[key1][key2]
                
            else:
                temp[k] = i[k]
        df_lst.append(temp)
    return pd.DataFrame(df_lst)   

def load_saved_data(path):
    for i in tqdm(path.glob('*/done')):
        lst.append(load_final_results(i))
    return lst

In [4]:
def agg_and_group(df, distinct_hparams, eval_metric, other_metrics):
    # best model for the particular trial
    df = df.groupby(distinct_hparams + ['trial_seed']).apply(lambda x: x.loc[x[eval_metric].idxmax(), [i for i in [eval_metric] + other_metrics]]).reset_index()
    # averaged over all 5 trials
    df = df.groupby(distinct_hparams).agg({i: ['mean', 'std'] for i in [eval_metric] + other_metrics}).reset_index()                                     
    df.columns = ["_".join(pair) if pair[1] != '' else pair[0] for pair in df.columns ]
    for i in [eval_metric] + other_metrics:
        df[i] = df[f'{i}_mean'].apply('{:.3f}'.format) + u"\u00B1" + df[f'{i}_std'].apply('{:.3f}'.format)
    return df

def pivot(df, values, index, methods = methods):
    temp = df.pivot_table(values = values, index = index, columns = ['algorithm'], aggfunc = lambda x: x)
    temp.columns = [pair[1] if pair[1] != '' else pair[0] for pair in temp.columns]
    temp = temp[list(temp.columns[:len(temp.columns) - len(methods)]) + methods]
    return temp

## eICU CorrLabel

In [49]:
path = Path('/scratch/hdd001/home/haoran/domainbed/eICUCorrLabel')
lst = load_saved_data(path)   

In [50]:
keep = ['es_roc', 'test_results/South_roc', 'test_results/South_tpr_gap_thres', 
        'model_hparams/eicu_architecture', 'args/algorithm', 'args/es_method', 'args/hparams_seed',
       'args/trial_seed', 'model_hparams/corr_label_train_corrupt_mean']

df = extract_stats(lst, keep)

In [51]:
distinct_hparams = ['algorithm', 'es_method', 'corr_label_train_corrupt_mean']
eval_metric = 'es_roc'
other_metrics = ['South_roc']

df = agg_and_group(df, distinct_hparams, eval_metric, other_metrics)

In [52]:
pivot(df, ['South_roc'], ['es_method', 'corr_label_train_corrupt_mean'])

Unnamed: 0_level_0,Unnamed: 1_level_0,Oracle,ERM,GroupDRO,IRM,VREx,RVP,IGA,CORAL,MLDG
es_method,corr_label_train_corrupt_mean,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
train,0.1,0.958±0.008,0.327±0.035,0.333±0.045,0.340±0.023,0.343±0.041,0.458±0.014,0.488±0.157,0.331±0.032,0.317±0.043
train,0.3,0.958±0.006,0.692±0.009,0.681±0.022,0.689±0.020,0.703±0.023,0.705±0.017,0.707±0.012,0.688±0.035,0.669±0.032
train,0.5,0.958±0.008,0.863±0.009,0.864±0.009,0.869±0.002,0.861±0.005,0.856±0.009,0.794±0.072,0.863±0.011,0.867±0.007
val,0.1,0.958±0.008,0.716±0.047,0.723±0.024,0.672±0.034,0.707±0.043,0.653±0.070,0.714±0.010,0.707±0.030,0.716±0.017
val,0.3,0.958±0.008,0.713±0.021,0.699±0.016,0.708±0.023,0.695±0.018,0.708±0.025,0.740±0.026,0.723±0.017,0.733±0.012
val,0.5,0.958±0.008,0.865±0.005,0.862±0.003,0.867±0.004,0.861±0.005,0.854±0.007,0.826±0.054,0.860±0.009,0.864±0.008


## eICU CorrNoise

In [53]:
path = Path('/scratch/ssd001/home/haoran/domainbed/eICUCorrNoise')
lst = load_saved_data(path) 

In [54]:
keep = ['es_roc', 'test_results/South_roc',  
        'test_results/South_tpr_gap_thres_wb',
        'model_hparams/eicu_architecture', 'args/algorithm', 'args/es_method', 'args/hparams_seed',
       'args/trial_seed', 'model_hparams/corr_noise_train_corrupt_mean', 'model_hparams/corr_noise_train_corrupt_dist']

df = extract_stats(lst, keep)

In [55]:
distinct_hparams = ['algorithm', 'es_method', 'corr_noise_train_corrupt_mean', 'corr_noise_train_corrupt_dist']
eval_metric = 'es_roc'
other_metrics = ['South_roc', 'South_tpr_gap_thres_wb']

In [56]:
df = agg_and_group(df, distinct_hparams, eval_metric, other_metrics)

In [57]:
df = df[df.corr_noise_train_corrupt_mean.isin([1.0, 2.0])]

In [58]:
print('South ROC')
pivot(df, ['South_roc'], ['es_method', 'corr_noise_train_corrupt_mean', 'corr_noise_train_corrupt_dist'])

South ROC


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Oracle,ERM,GroupDRO,IRM,VREx,RVP,IGA,CORAL,MLDG
es_method,corr_noise_train_corrupt_mean,corr_noise_train_corrupt_dist,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
train,1.0,0.5,0.956±0.004,0.371±0.036,0.389±0.040,0.373±0.038,0.402±0.028,0.440±0.059,0.487±0.089,0.397±0.014,0.371±0.036
train,1.0,1.0,0.956±0.004,0.523±0.047,0.562±0.020,0.525±0.033,0.593±0.023,0.669±0.018,0.595±0.048,0.589±0.048,0.566±0.061
train,2.0,0.5,0.956±0.004,0.208±0.030,0.205±0.017,0.195±0.021,0.197±0.022,0.198±0.042,0.281±0.094,0.223±0.031,0.190±0.013
train,2.0,1.0,0.956±0.004,0.248±0.024,0.247±0.029,0.242±0.017,0.238±0.041,0.247±0.034,0.349±0.109,0.274±0.032,0.252±0.015
val,1.0,0.5,0.956±0.004,0.396±0.055,0.438±0.042,0.400±0.046,0.416±0.022,0.501±0.086,0.550±0.096,0.397±0.059,0.438±0.032
val,1.0,1.0,0.956±0.004,0.574±0.058,0.625±0.055,0.547±0.056,0.648±0.101,0.690±0.030,0.607±0.051,0.623±0.031,0.534±0.075
val,2.0,0.5,0.956±0.004,0.439±0.115,0.390±0.109,0.292±0.049,0.397±0.116,0.265±0.037,0.490±0.054,0.380±0.121,0.430±0.051
val,2.0,1.0,0.956±0.004,0.262±0.020,0.263±0.019,0.305±0.041,0.482±0.088,0.347±0.064,0.497±0.107,0.389±0.101,0.409±0.082
