In [7]:
%run constants.ipynb

In [44]:
def get_results_dir(config):
    results_dir = RESULTS_DIR
    os.makedirs(results_dir, exist_ok=True)
    return results_dir

def plot_filepath(config, suffix=None):
    return f"{get_results_dir(config)}/{config.task_type}_{config.coreset_alg_name}{config.coreset_size}_{suffix}_plot.png"

def data_filepath(config, suffix=None):
    return f"{get_results_dir(config)}/{config.task_type}_{config.coreset_alg_name}_{suffix}.h5"

In [45]:
import h5py

def export_results(config, data, datatype='mean'):
    """ Exports the dictionary `data` to an HDF5 file. """
    with h5py.File(data_filepath(config, suffix=datatype), 'w') as f:
        # Iterate over the dictionary and save each array as a dataset
        for key, array in data.items():
            f.create_dataset(key, data=array)

def import_results(config, datatype='mean'):
    """ Imports the data dictionary from an HDF5 file. """
    imported_data = {}
    with h5py.File(data_filepath(config, suffix=datatype), 'r') as f:
        # Iterate through the keys in the file and load the datasets
        for key in f.keys():
            imported_data[key] = f[key][:]
    return imported_data

In [46]:
def plot_results(config, results, results_std=None, export=True, title='Mean'):
    """Plot comparison of results with optional error bars."""
    MARKER_SIZE = 8
    CAPSIZE = 4
    ELINEWIDTH = 1
    fig, ax = plt.subplots(figsize=(18, 6))
    x_ticks = range(1, len(config.tasks)+1)
    
    for model in results:
        x_vals = np.arange(len(results[model])) + 1
        # If std info is provided, add error bars
        if results_std is not None:
            ax.errorbar(x_vals, results[model], yerr=results_std[model], 
                         label=model, marker='o', 
                         markersize=MARKER_SIZE, capsize=CAPSIZE, elinewidth=ELINEWIDTH)
        else:
            ax.plot(x_vals, results[model], 
                     label=model, marker='o', markersize=MARKER_SIZE)
    
    ax.set_xticks(x_ticks)
    ax.set_title(f"{title} {config.eval_metric}", fontsize=18)
    ax.set_xlabel("# Tasks", fontsize=15)
    ax.legend(fontsize=15)

    plt.tight_layout()
    if export:
        plt.savefig(plot_filepath(config, suffix=title.lower()))
    
    plt.show()

In [47]:
def plot_mean_results(config, results, results_std=None, export=True):
    get_mean_val = (lambda res: np.mean(res, axis=0, where=(res > 1e-6)))
    
    mean_results = {m: get_mean_val(res) for m, res in results.items()}
    mean_results_std = {m: get_mean_val(s) for m, s in results_std.items()}
    return plot_results(config, mean_results, results_std=mean_results_std, export=export, title='Mean')

def plot_final_results(config, results, results_std=None, export=True):
    final_results = {m: res[-1] for m, res in results.items()}
    final_results_std = {m: s[-1] for m, s in results_std.items()}
    return plot_results(config, final_results, results_std=final_results_std, export=export, title='Final')