In [120]:
import json
import matplotlib.pyplot as plt
import numpy as np
import operator

from functools import reduce

In [121]:
def plot_distribution(data, model, output_directory, generator_name,
                      plot_type, plot_id, show_plot=False):
    """
    Generates a plot a two-class dataset and saves it to a file.
    
    Args:
        data (pandas.DataFrame): 
            Records along with their labels.
        model (MLModelCatalog):
            Classifier implementing a `predict_proba()` method.
        output_directory (str): 
            Name of the directory where images are saved.
        generator_name (str): 
            Name of the applied recourse generator.
        plot_name (str): 
            Type of the created plot.
        plot_id (str): 
            ID for the generated plot (e.g. consecutive numbers for different distributions).
        show_plot (Boolean): 
            If True the plot will also be outputted directly to the notebook.
    """
    data = data.to_numpy()
    x_min = np.min(data[:, :], axis=0) - 1
    x_max = np.max(data[:, :], axis=0) + 1
    
    x0, x1 = np.meshgrid(np.arange(x_min[0], x_max[0], 0.01),
                         np.arange(x_min[1], x_max[1], 0.01))
    
    x_new = np.c_[x0.flatten().reshape((-1, 1)),
                  x1.flatten().reshape((-1, 1))]
    
    y_new = model.predict_proba(x_new)[:, 1]
    
    z = y_new.reshape(x0.shape)
    
    y = data[:, 2]
    y = y.reshape((len(y), ))
    
    plt.figure(dpi=150)
    plt.axis('equal')
    plt.grid(True)
    plt.xlim([-0.25, 1.25])
    plt.ylim([-0.25, 1.25])
    plt.xlabel('$feature1$')
    plt.ylabel('$feature2$')
    
    plt.contourf(x0, x1, z, cmap='plasma', alpha=0.8)
    
    plt.scatter(data[y == 0, 0], data[y == 0, 1], s=60,
                cmap='Set1', linewidth=1, edgecolor='black')
    plt.scatter(data[y == 1, 0], data[y == 1, 1], s=60,
                cmap='Set1', linewidth=1, edgecolor='black')
    
#     plt.suptitle(f'Recourse generated by {generator_name.upper()} at t = {plot_id}')
    plt.savefig(f"{output_directory}/{generator_name}_{plot_type}_{f'{plot_id:06}'}.png", bbox_inches='tight')
    
    if show_plot:
        plt.show()
        
    plt.close()

In [134]:
def get_by_path(root, items):
    """
    Access a dictionary based on a set of keys in the provided order.
    
    Args:
        root (dict):
            Top-level of the nested dictionary.
        items (List[str]):
            List of strings specifying the consecutive keys.
            
    Returns:
        object: Value corresponding to the last key in the `items` list.
    
    """  
    try: 
        return reduce(operator.getitem, items, root)
    except:
        raise ValueError("Specified path does not exist in the dictionary.")


def plot_experiment_data(output_directory, generators, plot_type, dict_path,
                         file_name='measurements.json', show_plot=False):
    """
    Plots a specified component of the experiment data gathered over all epochs of recourse.
    
    Args:
        output_directory (str): 
            Name of the directory where images are saved.
        generator_name (List[str]): 
            List of the names of all generators which should be plotted.
        plot_type (str): 
            Type of the created plot.
        dict_path (List[str]):
            Location of the measurements of interest within the dictionary of experiment data.
        file_name (str):
            Name of the file containing the experiment data dictionary.
        show_plot (Boolean): 
            If True the plot will also be outputted directly to the notebook.
    
    """
    with open(f'{output_directory}/{file_name}') as data_file:
        data = json.load(data_file)
        
        plt.figure(dpi=150)
        plt.grid(True)
        
        # Apply consistent theme over all plots generated for the project
        colormap = plt.cm.plasma
        colors = [colormap(int(g * colormap.N / len(generators))) for g in range(len(generators))]
        
        for index, g in enumerate(generators):
            # Check if the generators have been correctly specified
            if g not in data:
                raise ValueError(f'No measurements available for {g}')
              
            # Sort the keys in a dictionary of the generator in a chronological order
            data[g] = {int(k): v for k, v in data[g].items()}
            epochs = sorted(data[g].items())
            
            result = []
            for e in epochs:
                result.append(get_by_path(e[1], dict_path))
                
            plt.plot(range(len(result)), result, linewidth=2,
                     label=f'{g.capitalize()}', color=colors[index])
        
        # Format the plot
        plt.xlim([0, len(result) - 1])
        plt.ylim([0 - 0.2 * max(result), 1.2 * max(result)])
        plt.legend()
        plt.savefig(f"{output_directory}/{plot_type}.png", bbox_inches='tight')
        
        # Only show if asked
        if show_plot:
            plt.show()
            
        plt.close

In [141]:
experiment_path = '../experiment_data/20220512173229_poster'


# config = {'type': 'MMD', 'dict_path': ['MMD', 'positive']}
# config = {'type': 'disagreement', 'dict_path': ['disagreement']}
# config = {'type': 'num_clusters', 'dict_path': ['distribution', 'num_clusters']}

# plot_experiment_data(experiment_path, ['DICE', 'wachter'], config['type'], config['dict_path'])