In [10]:
import pandas as pd
import json
from IPython.display import display

In [11]:
def load_data_to_df(filepath):
    with open(filepath, 'rb') as f:
        res = json.load(f)
    res = res['results']

    df = pd.concat([pd.DataFrame(r) for r in res], axis=0, keys=range(len(res)))
    df = pd.concat([df, df['hyperparam'].apply(pd.Series)], axis=1).drop(columns='hyperparam')
    df = df.rename_axis(['rollout', 'config_id'], axis='index')
    return df


def select_hparams(df, hparam_config):
    # Select results for specific hparams
    if hparam_config:
        for key, val in hparam_config.items():
            if not isinstance(val, list):
                val = [val]
            df = df[df[key].isin(val)]
    return df


def get_mean_and_sem(df, test_metric='test_risk', val_metric='val_loss', hparam_config=None):    
    if hparam_config:
        df = select_hparams(df, hparam_config)
        
    # Select best hparams
    df = df.loc[df.groupby('rollout')[val_metric].idxmin()]
    return df[test_metric].mean(), df[test_metric].sem()


# Best hparam configs
def get_best_hparam_results(df, metric='val_loss', num_best=20, hparam_config=None):
    if hparam_config:
        df = select_hparams(df, hparam_config)

    best = df.groupby('config_id').mean().sort_values(by=metric)[:num_best]
    return best

In [12]:
# Merge datasets
def load_and_merge_datasets(filepaths, property_dict=None, merge='hparam_configs'):
    if isinstance(filepaths, str):
        filepaths = [filepaths]
    
    if merge == 'rollouts':
        merge_property = 'rollout'
    elif merge == 'hparam_configs':
        merge_property = 'config_id'
    else:
        raise NotImplementedError
    
    if not property_dict:
        prop_name = 'version'
        vals = range(len(filepaths))
    else:
        prop_name = list(property_dict.keys())[0]
        vals = property_dict[prop_name]
        
    start_merge_id = 0
    dfs = []
        
    successful = 0
    for filepath, prop in zip(filepaths, vals):
        try:
            data_frame = load_data_to_df(filepath)
            data_frame = data_frame.drop(columns=['test_risk_optim', 'parameter_mse_optim', 'best_index'])
            data_frame[prop_name] = prop

            # Add new config id
            data_frame = data_frame.reset_index()
            data_frame[merge_property] += start_merge_id
            start_merge_id = data_frame[merge_property].max() + 1
            dfs.append(data_frame)
            successful += 1
        except FileNotFoundError:
            pass
            # print('File not found:', filepath)
    print(f'Found {successful}/{len(filepaths)} files.')

    df = pd.concat(dfs, ignore_index=True)
    df = df.set_index(['rollout', 'config_id'])
    return df

filepath1 = '../results/bennet_hetero/bennet_hetero_method=KMM-RF_n=2000.json'
filepath2 = '../results/bennet_hetero/bennet_hetero_method=KMM-RF_n=2000.json'
fps = [filepath1, filepath2]
property_dict = {'divergence': ['kl', 'log']}

df1 = load_data_to_df(filepath1)
df2 = load_data_to_df(filepath2)
df3 = load_and_merge_datasets(fps, property_dict)

print(get_mean_and_sem(df1))
print(get_mean_and_sem(df2))
print(get_mean_and_sem(df3))

best = get_best_hparam_results(df3, metric='test_risk')
best

df = load_and_merge_datasets('../results/bennet_hetero/bennet_hetero_method=KMM-RF_n=2000.json')
best = get_best_hparam_results(df, metric='test_risk', num_best=20)
best

# EXP1: Bennett Hetero

In [22]:
exp_path_1 = "/Users/hkremer/code/kmm/wasserstein-method-of-moments/results/bennet_hetero/bennet_hetero_method="

## VMM

In [23]:
vmm = load_and_merge_datasets(exp_path_1 + "VMM-neural_n=2000_seed0=12345.json")
best_vmm = get_best_hparam_results(vmm, metric='val_loss',)
best_vmm

Found 1/1 files.


Unnamed: 0_level_0,test_risk,mse,val_loss,reg_param,version
config_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
3,0.270553,0.253146,0.819964,1.0,0.0
2,0.239885,0.167199,0.82677,0.01,0.0
0,0.226658,0.150483,0.828801,0.0,0.0
1,0.235654,0.188516,0.828997,0.0001,0.0


## FGEL

In [24]:
fgel = load_and_merge_datasets(exp_path_1 + "FGEL-neural_n=2000_seed0=12345.json")
best_fgel = get_best_hparam_results(fgel, metric='val_loss', hparam_config={'divergence': 'log'})
best_fgel

Found 1/1 files.


Unnamed: 0_level_0,test_risk,mse,val_loss,reg_param,version
config_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
11,0.18729,0.167554,0.832788,1.0,0.0
8,0.448142,0.796722,0.908296,0.01,0.0
2,0.47352,0.859253,0.909634,0.0,0.0
5,0.49831,0.890296,0.918164,0.0001,0.0


## KMM

In [15]:
filepaths = []

from cmr.default_config import kmm_methods

for method in kmm_methods:
    path = (exp_path_1 + f'{method}_n=2000_seed0=12345.json')
    filepaths.append(path)
    
df = load_and_merge_datasets(filepaths)

In [17]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,test_risk,mse,val_loss,train_stats,n_reference_samples,entropy_reg_param,reg_param,kde_bw,n_random_features,val_loss_func,version
rollout,config_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
0,0,2.144071,4.868012,0.000125,"{'epochs': 240, 'val_loss': [4.353307303972542...",0,1.0,10.0,1.0,5000,mmr,28
1,0,0.761713,1.186237,0.000113,"{'epochs': 240, 'val_loss': [3.074266351177357...",0,1.0,10.0,1.0,5000,mmr,28
2,0,0.971003,2.106306,8.4e-05,"{'epochs': 280, 'val_loss': [6.448858039220795...",0,1.0,10.0,1.0,5000,mmr,28
3,0,1.147568,2.219241,0.000396,"{'epochs': 240, 'val_loss': [0.000220777743379...",0,1.0,10.0,1.0,5000,mmr,28
4,0,1.266955,2.195853,0.000178,"{'epochs': 240, 'val_loss': [9.748094453243539...",0,1.0,10.0,1.0,5000,mmr,28
5,0,0.840369,1.557705,0.000318,"{'epochs': 240, 'val_loss': [0.000116533803520...",0,1.0,10.0,1.0,5000,mmr,28
6,0,1.033562,1.317618,0.000103,"{'epochs': 240, 'val_loss': [7.977530913194641...",0,1.0,10.0,1.0,5000,mmr,28
7,0,0.39817,0.85515,0.000411,"{'epochs': 240, 'val_loss': [0.000278413703199...",0,1.0,10.0,1.0,5000,mmr,28
8,0,0.961727,2.065455,0.000287,"{'epochs': 240, 'val_loss': [0.000134243484353...",0,1.0,10.0,1.0,5000,mmr,28
9,0,1.391199,2.747568,0.000125,"{'epochs': 240, 'val_loss': [7.540157821495086...",0,1.0,10.0,1.0,5000,mmr,28


In [26]:
get_best_hparam_results(df, metric='val_loss', num_best=20, 
                        hparam_config={'val_loss_func': 'moment_violation',
                                       #"n_reference_samples": [200], 
                                       "entropy_reg_param": [1000],
                                       #"reg_param": [0.01],
                                       'kde_bw': 0.1,
                                      })

Unnamed: 0_level_0,test_risk,mse,val_loss,n_reference_samples,entropy_reg_param,reg_param,kde_bw,n_random_features,version
config_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1


In [28]:
get_mean_and_sem(df, test_metric='test_risk', val_metric='val_loss', hparam_config={
                                       'val_loss_func': 'moment_violation',
                                       "n_reference_samples": [200], 
                                       "entropy_reg_param": [1],
                                       "reg_param": [1],
                                       'kde_bw': 0.1}
                )

(nan, nan)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from cmr.utils.plot import NEURIPS_RCPARAMS, LINE_WIDTH

plt.rcParams.update(NEURIPS_RCPARAMS)
sns.set_theme()
figsize = (LINE_WIDTH/1.3, LINE_WIDTH / 1.8)

fig, ax = plt.subplots(1, 1, figsize=figsize)

n_refs = [0, 100, 200]
for divergence in ['kl', 'log']:
    mean = []
    std = []
    for n_ref in n_refs:
        m, s = get_mean_and_sem(df, test_metric='test_risk', val_metric='val_loss', hparam_config={'rkhs_func_z_dependent': 1.0,
                                       'val_loss_func': 'moment_violation',
                                       "n_reference_samples": n_ref, 
                                       "entropy_reg_param": [1],
                                       "reg_param": [1],
                                       'kde_bw': 0.1,
                                       'divergence': divergence})
        mean.append(m)
        std.append(s)
    ax.errorbar(n_refs, mean, std)
    

# EXP2: Network IV

In [101]:
exp_path_2 = "/Users/hkremer/code/kmm/wasserstein-method-of-moments/results/network_iv/network_iv_method="
func = 'sin'
metric = 'val_loss'
val_loss_func = 'moment_violation'

## Load Data

In [102]:
from cmr.default_config import vmm_methods

filepaths = []
for method in vmm_methods:
    path = (exp_path_2 + f'{method}_n=2000_seed0=12345_{func}.json')
    filepaths.append(path)
    
vmm = load_and_merge_datasets(filepaths)



from cmr.default_config import fgel_methods

filepaths = []
for method in fgel_methods:
    path = (exp_path_2 + f'{method}_n=2000_seed0=12345_{func}.json')
    filepaths.append(path)
    
fgel = load_and_merge_datasets(filepaths)



from cmr.default_config import kmm_methods

filepaths = []
for method in kmm_methods:
    path = (exp_path_2 + f'{method}_n=2000_seed0=12345_{func}.json')
    filepaths.append(path)
    
kmm = load_and_merge_datasets(filepaths)

Found 9/10 files.
Found 28/30 files.
Found 456/512 files.


## VMM

In [103]:
best_vmm = get_best_hparam_results(vmm, metric=metric, hparam_config={'val_loss_func': val_loss_func})
best_vmm

Unnamed: 0_level_0,test_risk,mse,val_loss,reg_param,version
config_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
8,0.08648,0.0,0.977255,10.0,9.0
6,0.083299,0.0,1.011242,1.0,7.0
5,0.084308,0.0,1.03499,0.01,5.0
1,0.081761,0.0,1.036589,0.0,1.0
3,0.087129,0.0,1.048433,0.0001,3.0


## FGEL

In [104]:
best_fgel = get_best_hparam_results(fgel, metric=metric, hparam_config={'val_loss_func': val_loss_func})
best_fgel

Unnamed: 0_level_0,test_risk,mse,val_loss,reg_param,version
config_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
24,0.086449,0.0,0.969225,10.0,25.0
27,0.086634,0.0,0.972099,10.0,29.0
23,0.082209,0.0,1.011087,1.0,23.0
19,0.083907,0.0,1.013788,1.0,19.0
21,0.085495,0.0,1.016115,1.0,21.0
3,0.086074,0.0,1.062634,0.0,3.0
11,0.090044,0.0,1.064716,0.0001,11.0
5,0.089608,0.0,1.065058,0.0,5.0
15,0.090452,0.0,1.06514,0.01,15.0
7,0.09002,0.0,1.065865,0.0001,7.0


## KMM

In [106]:
best_kmm = get_best_hparam_results(kmm, metric=metric, hparam_config={'val_loss_func': val_loss_func,
                                                                      'n_reference_samples': 400,
                                                                      "reg_param": 1,
                                                                      'entropy_reg_param': 1,
                                                                     })
best_kmm

Unnamed: 0_level_0,test_risk,mse,val_loss,n_reference_samples,entropy_reg_param,reg_param,kde_bw,n_random_features,version
config_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
355,0.083886,0.0,1.008171,400.0,1.0,1.0,0.1,10000.0,403.0
359,0.08652,0.0,1.010481,400.0,1.0,1.0,1.0,10000.0,407.0
353,0.082441,0.0,1.011908,400.0,1.0,1.0,0.1,5000.0,401.0
357,0.082441,0.0,1.011908,400.0,1.0,1.0,1.0,5000.0,405.0
