In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
df = pd.concat([pd.read_csv('experiment_output_mnist-multiseed.csv'),
                pd.read_csv('experiment_output_cifar10-multiseed.csv'),
                ]).reset_index(drop=True)
df = df.drop(columns=df.columns[0])
df = df.loc[:, (df != df.iloc[0]).any()]  #remove ct cols
df['Train xe loss'] = df['Train xe loss'].apply(lambda x: float(x[7:13]))
df

In [None]:
def binKL(p, q):
    return p * np.log(p/q) + (1-p) * np.log((1-p)/(1-q))

def KLinv(x, k, d=1e-10):
    # sup{b in [x, 1] | kl(x||b)<k}
    
    #assert x<1
    b0 = x
    b1 = 1
    while b1 - b0 > d:
        tmp = (b0 + b1)/2
        if binKL(x, tmp) < k:
            b0 = tmp
        else:
            b1 = tmp
            
    return b0

In [None]:
target_01_bounds=[] # Expected bounds with inf mc samples
_150k_01_bounds=[] # Expected bounds with 150k mc samples
target_xe_bounds=[] # Expected bounds with inf mc samples
_150k_xe_bounds=[] # Expected bounds with 150k mc samples
for i, row in df.iterrows():
    target_01_bounds.append(KLinv(row['Train 01 loss'], row['KL']))
    emp_01_bound = KLinv(row['Train 01 loss'], np.log(2/0.01)/150000)
    _150k_01_bounds.append(KLinv(emp_01_bound, row['KL']))
    
    target_xe_bounds.append(KLinv(row['Train xe loss'], row['KL']))
    emp_xe_bound = KLinv(row['Train xe loss'], np.log(2/0.01)/150000)
    _150k_xe_bounds.append(KLinv(emp_xe_bound, row['KL']))
    
df['target_01_bounds'] = target_01_bounds
df['150k_01_bounds'] = _150k_01_bounds
df['target_xe_bounds'] = target_xe_bounds
df['150k_xe_bounds'] = _150k_xe_bounds
display(df.sort_values(['target_01_bounds']))

In [None]:
cols_of_interest = ['Train 01 loss', 'Train xe loss', '150k_01_bounds', '150k_xe_bounds', 'KL']
df2 = df.fillna('-').groupby(['Dataset', 'objective', 'sigma_prior', 'layers']).agg(list)

In [None]:
df2 = df2[cols_of_interest]
for col in cols_of_interest:
    df2[f'{col} - mean'] = df2[col].apply(np.mean)
    df2[f'{col} - 2sigma'] = 2*df2[col].apply(np.std)
df2 = df2.drop(columns=cols_of_interest)
df2 = df2.reset_index()

## MNIST RESULTS

In [None]:
df_mnist = df2[df2['Dataset']=='mnist']
df_mnist_xe = df_mnist.sort_values('150k_xe_bounds - mean').drop_duplicates(['objective']) # keeps 1st by default
df_mnist_xe

In [None]:
df_mnist_01 = df_mnist.sort_values('150k_01_bounds - mean').drop_duplicates(['objective']) # keeps 1st by default
df_mnist_01

## CIFAR10 RESULTS

In [None]:
df_cifar = df2[df2['Dataset']=='cifar10']
df_cifar.sort_values('150k_01_bounds - mean').drop_duplicates(['objective', 'layers'])