In [None]:
import pickle
import glob
import pprint

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


from typing import Dict, List, Union, Any, Tuple
from collections import OrderedDict

# Helper functions

In [None]:
def get_batchsize(path: str) -> int:
    """Extract max batch size from filename"""
    fname = path.split("/")[-1]
    return int(fname.split("_")[4][:-4])

def get_replicas(path: str) -> int:
    """Extract nb of replicas from filename"""
    fname = path.split("/")[-1]
    return int(fname.split("_")[2])

def unpack_data(folder: str, batch_sizes: List[int], replicas_range: range) -> List[List[Dict[str, Any]]]:
    """
    Returns the experimental data organised as follows:

        - The top level list contains a `N` sublists. For each of `N` sub-lists the batch size is the same
        - Each of the sub-lists contains `K` elements: each of them corresponds to a different number of replicas
        - Each element of the sub-lists is a dict with keys `t_elapsed` and `explanations`. The values are lists with
        len corresponding to the number of experimental runs for each batch size and replicas number combination
     """

    fname_lists = filter_filenames(folder, replicas_range, batch_sizes)

    # Load data in memory
    data = []
    for lst in fname_lists:
        matching_files = []
        for exp in lst:
            with open(exp, 'rb') as f:
                matching_files.append(pickle.load(f))
        data.append(matching_files)

    return data


def filter_filenames(folder: str, cpu_range: range, batch_sizes: List[int]) -> List[List[str]]:
    """
    Return the filenames from `folder` for specified `batch_sizes`, `cpu_range`. The response:

        - Contains ``len(batch_sizes)`` sub-lists
        - Each sublist contains ``len(list(cpu_range))`` filenames

    Each sublist is sorted by the number of replicas (ascending order).
    """

    fname_lists = [[] for _ in range(len(batch_sizes))]
    batch_size_mapping = {batch_size: i for i, batch_size in enumerate(batch_sizes)}
    for file in glob.glob(f'{folder}/ray_*.pkl'):
        batch_size = get_batchsize(file)
        if batch_size not in batch_size_mapping:
            continue
        ncpu = get_replicas(file)
        if ncpu in cpu_range:
            fname_lists[batch_size_mapping[batch_size]].append(file)

    for lst in fname_lists:
        lst.sort(key=get_replicas)

    # pprint.pprint(fname_lists)

    return fname_lists

def get_stats(run, field):
    """
    Calculate average runtime of a given (batch,ncpu_ setting.
    """
    data = run[field]
    avg = sum(data) / len(data)
    std = np.std(data)
    return avg, std


def extract_timeseries(experiments:  List[List[Dict[str, Any]]], field: str) -> \
        Tuple[List[List[Union[float, Any]]], List[List[np.ndarray]]]:
    """
    Extract time series of average runtime timeseries. See `unpack_data` for `experiments` structure. The means and
    standard deviations output are organised as follows:

        - Each entry in the top level list corresponds to a fixed batch size
        - Each entry in the sublists corresponds to a different number of replicas
    """

    means, stds = [], []
    for batch_size in experiments:
        tmp_mean, tmp_std = [], []
        for replicas in batch_size:
            mu, sigma = get_stats(replicas, field)
            tmp_mean.append(mu)
            tmp_std.append(sigma)
        means.append(tmp_mean)
        stds.append(tmp_std)

    return means, stds


def prepare_data(data, field='t_elapsed', decimals=3):
    """
    Extract data from saved files and formats it so that a grouped bars plot can be created.
    """
    means, stds = extract_timeseries(data, field)
    mean_arrays, std_arrays = [], []
    for mean_arr, std_arr in zip(means, stds):
            mean_arrays.append(np.around(mean_arr, decimals=decimals))
            std_arrays.append(np.around(std_arr, decimals=decimals))
    return mean_arrays, std_arrays

def compare_timing(means, stds, labels, cpu_range, bar_width=0.1, y_min=0, y_max=34.0, y_step=0.5,
                   legend_pos=(1.24, -0.09), yticks_fontsize=20):
    fig, ax = plt.subplots(figsize=(40, 20))

    bar_positions = np.array(np.arange(1, cpu_range.stop - cpu_range.start + 1))
    xticks = [str(x) for x in range(cpu_range.start, cpu_range.stop)]
    bps = []
    with sns.axes_style("white"):
        sns.set_style("ticks")
        sns.set_context("talk")
        for m, s, label in zip(means, stds, labels):
            if len(m.shape) == 2:
                m = m.squeeze()
                s = s.squeeze()
            bar_plot = ax.bar(bar_positions,
                              m,
                              bar_width,
                              yerr=s,
                              label=label,
                              capsize=10
                              )
            bps.append(bar_plot)
            bar_positions = bar_positions + bar_width

        ax.set_xlabel(r'ncpu', fontsize=30)
        ax.set_xticks(bar_positions - 2 * bar_width)
        ax.set_xticklabels(xticks, rotation=45, fontsize=25)

        ax.set_ylabel('Time (s)', fontsize=30)
        ax.set_ylim(top=y_max)
        ax.set_ylim(bottom=y_min)
        y_ticks = np.arange(y_min, y_max, y_step)
        ax.set_yticks(y_ticks)
        ax.set_yticklabels(y_ticks, fontsize=yticks_fontsize)

        ax.grid(True)
        handles, labels = plt.gca().get_legend_handles_labels()
        by_label = OrderedDict(zip(labels, handles))
        ax.legend(by_label.values(), by_label.keys(), ncol=2, fontsize=20)
        sns.despine()

    return ax, bps


def autolabel(ax, rects):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects:
        height = rect.get_height()
        ax.annotate('{}'.format(height),
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=20)

# Results

In [None]:
batch_sizes = [1, 5, 10]
replicas_range = range(1, 4)

folder = 'results'
labels = ['ray_batch_{}'.format(size) for size in batch_sizes]
data = unpack_data(folder, batch_sizes, replicas_range)
mean, std = prepare_data(data)
ax, bps = compare_timing(mean, std, labels, replicas_range, y_max=40, y_step=8)
for bp in bps:
    autolabel(ax, bp)

# Sanity check

In [None]:
with open('results/ray_replicas_3_maxbatch_1.pkl', 'rb') as f:
    times = pickle.load(f)['t_elapsed']
print(np.mean(times))