In [1]:
import pandas as pd
import json
from IPython.display import display

In [2]:
def load_data_to_df(filepath):
    with open(filepath, 'rb') as f:
        res = json.load(f)
    res = res['results']

    df = pd.concat([pd.DataFrame(r) for r in res], axis=0, keys=range(len(res)))
    df = pd.concat([df, df['hyperparam'].apply(pd.Series)], axis=1).drop(columns='hyperparam')
    df = df.rename_axis(['rollout', 'config_id'], axis='index')
    return df


def get_mean_and_sem(df, test_metric='test_risk', val_metric='val_loss', hparam_config=None):    
    # Select results for specific hparams
    if hparam_config:
        for key, val in hparam_config.items():
            df = df[df[key] == val]
        
    # Select best hparams
    df = df.loc[df.groupby('rollout')[val_metric].idxmin()]
    
    return df[test_metric].mean(), df[test_metric].sem()

In [3]:
# Merge datasets
def load_and_merge_datasets(filepaths, property_dict=None):
    dfs = []
    if not property_dict:
        prop_name = 'version'
        vals = range(len(filepaths))
    else:
        prop_name = list(property_dict.keys())[0]
        vals = property_dict[prop_name]
        
    start_config_id = 0
        
    for filepath, prop in zip(filepaths, vals):
        data_frame = load_data_to_df(filepath)
        data_frame.drop(columns=['test_risk_optim', 'parameter_mse_optim', 'best_index'])
        data_frame[prop_name] = prop
        # data_frame.index = data_frame.index.droplevel('config_id')
        data_frame = data_frame.reset_index()
        
        # Add new config id
        data_frame = data_frame.reset_index()
        end_config_id = start_config_id + data_frame['config_id'].max() + 1
        config_ids = list(range(start_config_id, end_config_id))
        start_config_id = end_config_id           
        num_rollouts = data_frame['rollout'].max() + 1
        data_frame['config_id'] = config_ids * num_rollouts
        
        data_frame.drop(columns=['test_risk_optim', 'parameter_mse_optim', 'best_index'])
        dfs.append(data_frame)

    df = pd.concat(dfs, ignore_index=True)
    df = df.set_index(['rollout', 'config_id'])
    df = df.drop(columns='index')
    return df

In [4]:
filepath1 = 'bennet_hetero/bennet_hetero_method=KMM-RF-1x-ref-kl_n=2000.json'
filepath2 = 'bennet_hetero/bennet_hetero_method=KMM-RF-1x-ref-log_n=2000.json'
fps = [filepath1, filepath2]
property_dict = {'divergence': ['kl', 'log']}

df1 = load_data_to_df(filepath1)
df2 = load_data_to_df(filepath2)
df3 = load_and_merge_datasets(fps, property_dict)

print(get_mean_and_sem(df1))
print(get_mean_and_sem(df2))
print(get_mean_and_sem(df3))

(1.2280433942901667, 0.05006893609756681)
(1.227631812576569, 0.05283398339489353)
(1.233904024195733, 0.05280786517835165)
