In [129]:
%run constants.ipynb
%run config.ipynb

In [130]:
def get_results_dir(config):
    results_dir = RESULTS_DIR
    os.makedirs(results_dir, exist_ok=True)
    return results_dir

def plot_filepath(task_type, suffix=None):
    return f"{get_results_dir(config)}/{config.task_type}_{suffix}_plot.png"

def data_filepath(task_type, suffix=None):
    return f"{get_results_dir(config)}/{config.task_type}_{suffix}_data.h5"

In [131]:
import h5py

def export_results(data, task_type, datatype='mean'):
    """ Exports the dictionary `data` to an HDF5 file. """
    with h5py.File(data_filepath(task_type, suffix=datatype), 'w') as f:
        # Iterate over the dictionary and save each array as a dataset
        for key, array in data.items():
            f.create_dataset(key, data=array)

def import_results(task_type, datatype='mean'):
    """ Imports the data dictionary from an HDF5 file. """
    imported_data = {}
    with h5py.File(data_filepath(task_type, suffix=datatype), 'r') as f:
        # Iterate through the keys in the file and load the datasets
        for key in f.keys():
            imported_data[key] = f[key][:]
    return imported_data

In [132]:
# plot styling based on model name
def get_markerstyle(m):
    if m.startswith('Gaussian'):
        return 'o'
    elif m.startswith('ExpGaussian'):
        return 'v'
    elif m.startswith('Exp'):
        return 's'
    else:
        return 'p'

def get_linestyle(m):
    if m == VANILLA_MODEL:
        return (0,(3,5,1,5))
    elif 'None' in m:
        return 'dashed'
    elif '50' in m:
        return 'dashdot'
    elif '100' in m:
        return (0,(1,1))
    return 'solid'

In [133]:
def plot_results(config, results, results_std=None, export=True, title='Mean'):
    """Plot comparison of results with optional error bars."""
    MARKER_SIZE = 8
    CAPSIZE = 4
    ELINEWIDTH = 1
    fig, ax = plt.subplots(figsize=(18, 6))
    x_ticks = range(1, len(config.tasks)+1)
    model_names = sorted(results.keys(), reverse=True)
    
    for model in model_names:
        x_vals = np.arange(len(results[model])) + 1
        # If std info is provided, add error bars
        if results_std is not None:
            ax.errorbar(x_vals, results[model], yerr=results_std[model], 
                         label=model, marker=get_markerstyle(model), linestyle=get_linestyle(model),
                         markersize=MARKER_SIZE, capsize=CAPSIZE, elinewidth=ELINEWIDTH)
        else:
            ax.plot(x_vals, results[model], 
                    label=model, marker=get_markerstyle(model), linestyle=get_linestyle(model),
                    markersize=MARKER_SIZE)
    
    ax.set_xticks(x_ticks)
    ax.set_title(f"{title} {config.eval_metric}", fontsize=18)
    ax.set_xlabel("# Tasks", fontsize=15)
    ax.legend(fontsize=15, loc='upper left', bbox_to_anchor=(1, 1))

    plt.tight_layout()
    if export:
        plt.savefig(plot_filepath(config, suffix=title.lower()))
    
    plt.show()

In [157]:
def aggr_results(model_to_results, mname_filter=None, aggr='final'): 
    """
    Mean results for the given list of models, using the provided aggregation.
    `aggr` is one of: 
    - 'all' (average across the lifetime of all tasks), 
    - 'final' (average across the final result of all tasks)
    """
    if model_to_results is None:
        return None
    if mname_filter is None:
        mname_filter = (lambda _ : True)
    aggregate = (lambda res: res[-1]) if aggr == 'final' else \
                (lambda res: np.mean(res, axis=0, where=(res > 1e-6)))
    return {m:aggregate(res) for m,res in model_to_results.items() if mname_filter(m)}

def mean_aggr_results(model_to_results, mname_filter=None, aggr='final'): 
    aggregated_results = aggr_results(model_to_results, mname_filter=mname_filter, aggr=aggr)
    return { m: np.mean(res) for m,res in aggregated_results.items() }

In [155]:
def plot_mean_results(config, results, results_std=None, export=True, mname_filter=None):    
    mean_results = aggr_results(results, mname_filter=mname_filter, aggr='all')
    mean_results_std = aggr_results(results_std, mname_filter=mname_filter, aggr='all')
    if mname_filter is not None:
        mean_results = {m: res for m, res in mean_results.items() if mname_filter(m)}
        if mean_results_std is not None:
            mean_results_std = {m: res for m, res in mean_results_std.items() if mname_filter(m)}
    return plot_results(config, mean_results, results_std=mean_results_std, export=export, title='Mean')

def plot_final_results(config, results, results_std=None, export=True, mname_filter=None):
    final_results = aggr_results(results, mname_filter=mname_filter, aggr='final')
    final_results_std = aggr_results(results_std, mname_filter=mname_filter, aggr='final')
    if mname_filter is not None:
        final_results = {m: res for m, res in final_results.items() if mname_filter(m)}
        if final_results_std is not None:
            final_results_std = {m: res for m, res in final_results_std.items() if mname_filter(m)}
    return plot_results(config, final_results, results_std=final_results_std, export=export, title='Final')

In [135]:
def print_if(msg, print_progress):
    if print_progress:
        print(msg)

In [140]:
from itertools import product

def get_all_configs(task_type, config_filter=None, 
                    init_prior_scale=0.1, coreset_sizes = [50, 100, 200]):

    configs = []
    combos = list(product(['gaussian', 'exponential'], [None, 'random', 'kcenter'], coreset_sizes))
    # no coreset
    combos = [(p,ca,0) if ca is None else (p,ca,cs) for p,ca,cs in combos]
    combos = list(set(combos))
    for combo in combos:
        prior_type, coreset_alg_name, coreset_size = combo
        configs.append(ExperimentConfig(task_type=task_type,
                                          prior_type=prior_type,
                                          init_prior_scale=init_prior_scale,
                                          coreset_alg_name=coreset_alg_name,
                                          coreset_size=coreset_size))
    return list(filter(config_filter, configs))

In [141]:
[c.name for c in get_all_configs('classification')]

['GaussianVCL (None)',
 'ExpVCL (Kcenter, 50)',
 'GaussianVCL (Kcenter, 50)',
 'ExpVCL (Random, 50)',
 'GaussianVCL (Random, 50)',
 'ExpVCL (Kcenter, 200)',
 'ExpVCL (None)',
 'ExpVCL (Kcenter, 100)',
 'GaussianVCL (Kcenter, 200)',
 'ExpVCL (Random, 200)',
 'GaussianVCL (Kcenter, 100)',
 'ExpVCL (Random, 100)',
 'GaussianVCL (Random, 200)',
 'GaussianVCL (Random, 100)']