In [1]:
import pandas as pd
import json
from IPython.display import display

In [111]:
def load_data_to_df(filepath):
    with open(filepath, 'rb') as f:
        res = json.load(f)
    res = res['results']

    df = pd.concat([pd.DataFrame(r) for r in res], axis=0, keys=range(len(res)))
    df = pd.concat([df, df['hyperparam'].apply(pd.Series)], axis=1).drop(columns='hyperparam')
    df = df.rename_axis(['rollout', 'config_id'], axis='index')
    return df


def get_mean_and_sem(df, test_metric='test_risk', val_metric='val_loss', hparam_config=None):    
    # Select results for specific hparams
    if hparam_config:
        for key, val in hparam_config.items():
            df = df[df[key] == val]
        
    # Select best hparams
    df = df.loc[df.groupby('rollout')[val_metric].idxmin()]
    return df[test_metric].mean(), df[test_metric].sem()


# Best hparam configs
def get_best_hparam_results(df, metric='val_loss', num_best=5):
    best = df.groupby('config_id').mean().sort_values(by=metric)[:num_best]
    return best

In [112]:
# Merge datasets
def load_and_merge_datasets(filepaths, property_dict=None, merge='hparam_configs'):
    if merge == 'rollouts':
        merge_property = 'rollout'
    elif merge == 'hparam_configs':
        merge_property = 'config_id'
    else:
        raise NotImplementedError
    
    if not property_dict:
        prop_name = 'version'
        vals = range(len(filepaths))
    else:
        prop_name = list(property_dict.keys())[0]
        vals = property_dict[prop_name]
        
    start_merge_id = 0
    dfs = []
        
    for filepath, prop in zip(filepaths, vals):
        data_frame = load_data_to_df(filepath)
        data_frame = data_frame.drop(columns=['test_risk_optim', 'parameter_mse_optim', 'best_index'])
        data_frame[prop_name] = prop
        
        # Add new config id
        data_frame = data_frame.reset_index()
        data_frame[merge_property] += start_merge_id
        start_merge_id = data_frame[merge_property].max() + 1
        dfs.append(data_frame)

    df = pd.concat(dfs, ignore_index=True)
    df = df.set_index(['rollout', 'config_id'])
    return df

In [113]:
filepath1 = '../results/bennet_hetero/bennet_hetero_method=KMM-RF-1x-ref-kl_n=2000.json'
filepath2 = '../results/bennet_hetero/bennet_hetero_method=KMM-RF-1x-ref-log_n=2000.json'
fps = [filepath1, filepath2]
property_dict = {'divergence': ['kl', 'log']}

df1 = load_data_to_df(filepath1)
df2 = load_data_to_df(filepath2)
df3 = load_and_merge_datasets(fps, property_dict)

print(get_mean_and_sem(df1))
print(get_mean_and_sem(df2))
print(get_mean_and_sem(df3))

best = get_best_hparam_results(df3, metric='test_risk')
best

(1.2280433942901667, 0.05006893609756681)
(1.227631812576569, 0.05283398339489353)
(1.233904024195733, 0.05280786517835165)


Unnamed: 0_level_0,test_risk,mse,val_loss,entropy_reg_param,reg_param,kde_bw
config_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
45,1.226791,2.23915,0.000919,100.0,0.01,0.5
21,1.231208,2.247808,0.000869,100.0,0.01,0.5
9,1.265049,2.2626,0.000884,10.0,0.0,0.5
20,1.265474,2.265193,0.000951,100.0,0.01,0.1
46,1.272552,2.272989,0.000863,100.0,1.0,0.1
