In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Union, Tuple, Dict

In [None]:
%config InlineBackend.figure_format = 'retina'

# Plotting Routine

This code snippet was used to generate all plots in the paper. The input for plotting is the JSON file output by `run_expr.py` and then optionally combined using `combine_results.py`. 

In [None]:
filename = None  # Enter filename to plot

In [None]:
with open('k500_all.json') as f:  # read it back
    statistics = json.load(f)

In [None]:
def plot_metric_per_time_or_iter(result_stats: Dict, 
                                 order: List[str], 
                                 metric: str, 
                                 x_axis: str, 
                                 y_log: bool = False, 
                                 x_log: bool = False, 
                                 ylabel: Union[None, str] = None, 
                                 xlabel: Union[None, str] = None, 
                                 title: Union[None, str] = None, 
                                 filename: Union[None, str] = None, 
                                 xrange: Union[None, Tuple[float, float]] = None, 
                                 yrange: Union[None, Tuple[float, float]] = None, 
                                 linewidth: float = 1):
    """
    Plots specified metric(s) over time or iterations given result statistics. 
    
    Parameters
    result_stats: A dictionary mapping from config name to the result statistics for the config. 
    metric: The metric to use for the y-axis. One of 'STK', 'KLS', 'Precision@K', 'Recall@K', 'AvgRank' or 'WorstRank'. 
    x_axis: The metric to use for the x-axis. Either 'sec', 'hour', or 'iteration'. 
    y_log: Whether the y-axis should be logarithmically scaled. 
    x_log: Whether the x-axis should be logarithmically scaled. 
    ylabel: The label for the y-axis. 
    xlabel: The label for the x-axis. 
    title: The title of the plot. 
    filename: The file that this plot should be saved to. 
    xrange: The range to limit the x-axis to. 
    yrange: The range to limit the y-axis to. 
    linewidth: The width of the lines in the plot. 
    """
    plt.figure(figsize=(8, 5))

    # Plot each config specified in order
    for config_name in order:
        # Obtain the stats for this config
        if config_name in result_stats:
            stats: Dict = result_stats[config_name]
            
            # Obtain the x values
            if x_axis == 'sec' or x_axis == 'hour':
                x_values: List[float] = stats['time']
                if x_axis == 'hour':  # The 'time' field logs time incurred in seconds
                    x_values = np.array(x_values) / 3600.0
            elif x_axis == 'iteration':
                x_values = list(range(1, len(stats[f'{metric}'])+1))
            else:
                raise ValueError

            # Obtain the y values
            y_values = [stats[f'{metric}'][i] for i in range(0, len(stats[f'{metric}']))]

            # Create plot
            plt.plot(x_values, y_values, 
                     label=methods[config_name]['name'], 
                     color=methods[config_name]['color'], 
                     linestyle=methods[config_name]['style'], 
                     linewidth=linewidth)

    # Label the x-axis
    if xlabel is not None:
        if x_axis == 'sec':
            xlabel = 'Time (s)'
        elif x_axis == 'hour':
            xlabel = 'Time (hr)'
        elif x_axis == 'iteration':
            xlabel = 'Iteration'
        else:
            raise ValueError
    plt.xlabel('Time (s)' if x_axis == 'sec' else 'Iteration', fontsize=14)

    # Label the y-axis
    if not ylabel:
        plt.ylabel(f'{metric} Value', fontsize=14)
    else:
        plt.ylabel(ylabel, fontsize=14)

    # Limit axis values optionally
    if xrange is not None:
        plt.xlim(xrange)
    if yrange is not None:
        plt.ylim(yrange)

    # Title, legend, grid
    if title is not None:
        plt.title(title)
    plt.legend(loc='best', fontsize=14)
    plt.grid(visible=True, color="#f0f0f0")

    # Log scale the axis optionally
    if y_log:
        plt.yscale('log')
    if x_log:
        plt.xscale('log')

    # Save plot to file optionally
    if filename is not None:
        plt.savefig(filename, format="pdf", bbox_inches="tight")
    
    plt.show()

In [None]:
# A mapping from method name to its label, color, and line-style. 
# These mappings are kept consistent across experiments by setting this as a global variable. 
methods = {
    "ScanBestOrder": {
        "name": "Scan (Best Case)",
        "color": "#BBBBBB",
        "style": (0, (5, 0))
    },
    "ScanWorstOrder": {
        "name": "Scan (Worst Case)",
        "color": "#000000",
        "style": (0, (5, 0))
    },
    "UniformSample": {
        "name": "UniformSample",
        "color": "#EE3377",
        "style": "dashdot"
    },
    "UCB": {
        "name": "UCB",
        "color": "#EE7733",
        "style": "dotted"
    },
    "UniformExploration": {
        "name": "ExplorationOnly",
        "color": "#CC3311",
        "style": "dashed"
    },
    "EpsGreedy": {
        "name": "Ours",
        "color": "#0077BB",
        "style": "solid"
    },
    "EpsGreedy Without Rebinning": {
        "name": "Ours (No Re-binning)",
        "color": "#33BBEE",
        "style": (5, (10, 3))
    },
    "EpsGreedy Without Subtraction": {
        "name": "Ours (No Subtraction)",
        "color": "#009988",
        "style": (0, (5, 1))
    },
    "EpsGreedy (B=4)": {
        "name": "Ours (B=4)",
        "color": "#BAE6FF",
        "style": "solid"
    },
    "EpsGreedy (B=8)": {
        "name": "Ours (B=8)",
        "color": "#0077BB",
        "style": "solid"
    },
    "EpsGreedy (B=16)": {
        "name": "Ours (B=16)",
        "color": "#00507D",
        "style": "solid"
    },
    "EpsGreedy (B=32)": {
        "name": "Ours (B=32)",
        "color": "#002d47",
        "style": "solid"
    }
}

## Plots

Six plots (STK vs iter, STK/KLS/Precision@K/Recall@K/AvgRank/WorstRank vs time) have been pre-defined. 

For different datasets, you may need to modify:
1. The methods included in "order".
2. Name of the saved file.
3. Line width for readability.
4. Whether x-axis (time) is "sec" or "hour". 

Note that the `ScanBestOrder` and `ScanWorstOrder` configs have the "time" field, but this is not comparable to other methods since the "time" field for the GT run does not include the priority queue maintenance overhead. Hence, in the paper, `ScanBestOrder` and `ScanWorstOrder` are included only for Figure 4, Synthetic data. 

## STK vs Iteration

In [None]:
plot_metric_per_time_or_iter(statistics, 
                             order=['EpsGreedy', 
                                    'EpsGreedy (B=4)',
                                    'EpsGreedy (B=8)',
                                    'EpsGreedy (B=16)',
                                    'EpsGreedy (B=32)',
                                    'EpsGreedy Without Rebinning', 
                                    'EpsGreedy Without Subtraction', 
                                    'UCB', 
                                    'UniformExploration', 
                                    'UniformSample', 
                                    #'ScanBestOrder', 
                                    #'ScanWorstOrder'
                                   ],
                             metric='STK',
                             x_axis='iteration',
                             title=None, 
                             filename=None, 
                             ylabel="Sum of Top-k (STK)", 
                             xlabel=None,
                             xrange=None,
                             yrange=None,
                             x_log=False,
                             linewidth=1)

## STK vs Time

In [None]:
plot_metric_per_time_or_iter(statistics, 
                             order=['EpsGreedy', 
                                    'EpsGreedy Without Rebinning', 
                                    'EpsGreedy Without Subtraction', 
                                    'UCB', 
                                    'UniformExploration', 
                                    'UniformSample', 
                                   ],
                             metric='STK',
                             x_axis='sec',
                             title=None, 
                             filename=None, 
                             ylabel="Sum of Top-k (STK)", 
                             xlabel=None,
                             xrange=None,
                             x_log=False,
                             y_log=False,
                             yrange=None,
                             linewidth=1.0)

## KLS vs Time

In [None]:
plot_metric_per_time_or_iter(statistics, 
                             order=['EpsGreedy', 
                                    'EpsGreedy Without Rebinning', 
                                    'EpsGreedy Without Subtraction', 
                                    'UCB', 
                                    'UniformExploration', 
                                    'UniformSample', 
                                   ],
                             metric='KLS', 
                             x_axis='sec',
                             title=None, 
                             filename=None, 
                             ylabel="kth Largest Score (KLS)", 
                             xlabel=None,
                             xrange=None,
                             yrange=None,
                             linewidth=1.0)

## Precision vs Time

In [None]:
plot_metric_per_time_or_iter(statistics, 
                             order=['EpsGreedy', 
                                    'EpsGreedy Without Rebinning', 
                                    'EpsGreedy Without Subtraction', 
                                    'UCB', 
                                    'UniformExploration', 
                                    'UniformSample', 
                                   ],
                             metric='Precision@K',
                             x_axis='sec', 
                             title=None,
                             filename=None,
                             ylabel="Precision@K",
                             linewidth=1.0)

## Recall vs time

In [None]:
plot_metric_per_time_or_iter(statistics,
                             order=['EpsGreedy', 
                                    'EpsGreedy Without Rebinning', 
                                    'EpsGreedy Without Subtraction', 
                                    'UCB', 
                                    'UniformExploration', 
                                    'UniformSample', 
                                   ],
                             metric='Recall@K',
                             x_axis='sec', 
                             title=None, 
                             filename=None, 
                             ylabel="Recall@K",
                             linewidth=1.0)

## Average Rank vs time

In [None]:
plot_metric_per_time_or_iter(statistics, 
                             order=['EpsGreedy', 
                                    'EpsGreedy Without Rebinning', 
                                    'EpsGreedy Without Subtraction', 
                                    'UCB', 
                                    'UniformExploration', 
                                    'UniformSample', 
                                   ],
                             metric='AvgRank', 
                             x_axis='sec', 
                             y_log=True, 
                             ylabel='Log Average Rank',
                             title=None,
                             filename=None,
                             linewidth=1.0)

## WorstRank vs Time

In [None]:
plot_metric_per_time_or_iter(statistics,
                             order=['EpsGreedy', 
                                    'EpsGreedy Without Rebinning', 
                                    'EpsGreedy Without Subtraction', 
                                    'UCB', 
                                    'UniformExploration', 
                                    'UniformSample', 
                                   ],
                             metric='WorstRank',
                             x_axis='sec',
                             y_log=True, 
                             ylabel='Log kth Rank', 
                             title=None, 
                             filename=None,
                             linewidth=1.0)