In [88]:
%run constants.ipynb

In [89]:
def get_results_dir(config):
    results_dir = RESULTS_DIR
    os.makedirs(results_dir, exist_ok=True)
    return results_dir

def plot_filepath(task_type, suffix=None):
    return f"{get_results_dir(config)}/{config.task_type}_{suffix}_plot.png"

def data_filepath(task_type, suffix=None):
    return f"{get_results_dir(config)}/{config.task_type}_{suffix}_data.h5"

In [91]:
import h5py

def export_results(data, task_type, datatype='mean'):
    """ Exports the dictionary `data` to an HDF5 file. """
    with h5py.File(data_filepath(task_type, suffix=datatype), 'w') as f:
        # Iterate over the dictionary and save each array as a dataset
        for key, array in data.items():
            f.create_dataset(key, data=array)

def import_results(task_type, datatype='mean'):
    """ Imports the data dictionary from an HDF5 file. """
    imported_data = {}
    with h5py.File(data_filepath(task_type, suffix=datatype), 'r') as f:
        # Iterate through the keys in the file and load the datasets
        for key in f.keys():
            imported_data[key] = f[key][:]
    return imported_data

In [76]:
# plot styling based on model name
def get_markerstyle(m):
    if m.startswith('Gaussian'):
        return 'o'
    elif m.startswith('ExpGaussian'):
        return 'v'
    elif m.startswith('Exp'):
        return 's'
    else:
        return 'p'

def get_linestyle(m):
    if m == VANILLA_MODEL:
        return 'dashed'
    return 'solid'

In [85]:
def plot_results(config, results, results_std=None, export=True, title='Mean'):
    """Plot comparison of results with optional error bars."""
    MARKER_SIZE = 8
    CAPSIZE = 4
    ELINEWIDTH = 1
    fig, ax = plt.subplots(figsize=(18, 6))
    x_ticks = range(1, len(config.tasks)+1)
    model_names = sorted(results.keys(), reverse=True)
    
    for model in model_names:
        x_vals = np.arange(len(results[model])) + 1
        # If std info is provided, add error bars
        if results_std is not None:
            ax.errorbar(x_vals, results[model], yerr=results_std[model], 
                         label=model, marker=get_markerstyle(model), linestyle=get_linestyle(model),
                         markersize=MARKER_SIZE, capsize=CAPSIZE, elinewidth=ELINEWIDTH)
        else:
            ax.plot(x_vals, results[model], 
                     label=model, marker='o', markersize=MARKER_SIZE)
    
    ax.set_xticks(x_ticks)
    ax.set_title(f"{title} {config.eval_metric}", fontsize=18)
    ax.set_xlabel("# Tasks", fontsize=15)
    ax.legend(fontsize=15, loc='upper left', bbox_to_anchor=(1, 1))

    plt.tight_layout()
    if export:
        plt.savefig(plot_filepath(config, suffix=title.lower()))
    
    plt.show()

In [81]:
def plot_mean_results(config, results, results_std=None, export=True, mname_filter=None):
    get_mean_val = (lambda res: np.mean(res, axis=0, where=(res > 1e-6)))
    
    mean_results = {m: get_mean_val(res) for m, res in results.items()}
    mean_results_std = {m: get_mean_val(s) for m, s in results_std.items()}
    if mname_filter is not None:
        mean_results = {m: res for m, res in mean_results.items() if mname_filter(m)}
        mean_results_std = {m: res for m, res in mean_results_std.items() if mname_filter(m)}
    return plot_results(config, mean_results, results_std=mean_results_std, export=export, title='Mean')

def plot_final_results(config, results, results_std=None, export=True, mname_filter=None):
    final_results = {m: res[-1] for m, res in results.items()}
    final_results_std = {m: s[-1] for m, s in results_std.items()}
    if mname_filter is not None:
        final_results = {m: res for m, res in final_results.items() if mname_filter(m)}
        final_results_std = {m: res for m, res in final_results_std.items() if mname_filter(m)}
    return plot_results(config, final_results, results_std=final_results_std, export=export, title='Final')

In [64]:
def print_if(msg, print_progress):
    if print_progress:
        print(msg)

In [57]:
from itertools import product

def get_all_configs(task_type, config_filter=None, 
                    init_prior_scale=0.1, coreset_sizes = [100, 200]):

    configs = []
    combos = list(product(['gaussian', 'exponential'], [None, 'random', 'kcenter'], 
                          coreset_sizes, [True, False]))
    # no coreset
    combos = [(p,ca,0,u) if ca is None else (p,ca,cs,u) for p,ca,cs,u in combos]
    # gaussian not affected by the update_prior_type attribute
    combos = [(p,ca,cs,False) if p == 'gaussian' else (p,ca,cs,u) for p,ca,cs,u in combos]
    combos = list(set(combos))
    for combo in combos:
        prior_type, coreset_alg_name, coreset_size, update_prior_type = combo
        configs.append(ExperimentConfig(task_type=task_type,
                                          prior_type=prior_type,
                                          init_prior_scale=init_prior_scale,
                                          coreset_alg_name=coreset_alg_name,
                                          coreset_size=coreset_size,
                                          update_prior_type=update_prior_type))
    return list(filter(config_filter, configs))