In [None]:
import pickle, yaml
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import os

PROJECT_DIRECTORY = Path(os.path.abspath("")).parent
SAVE_FIG_PATH = PROJECT_DIRECTORY / "_static"
RESULTS_PATH = PROJECT_DIRECTORY / "results"

In [None]:
def saveFig(name, fig):
    fig.savefig(
        name,
        dpi=None,
        facecolor=fig.get_facecolor(),
        edgecolor="none",
        orientation="portrait",
        format="png",
        transparent=False,
        bbox_inches="tight",
        pad_inches=0.2,
        metadata=None,
    )

In [None]:
def read_pickle(path_to_pickle):
    with open(path_to_pickle, "rb") as handle:
        data = pickle.load(handle)
    return data

def read_config(config_directory):
    with open(config_directory / "config.json", "r") as file:
        config = yaml.safe_load(file)
    return config

In [None]:
def fuse_by_dataset(losses):
    """Transform per-round history (list of dicts) into

    a single dict, with values as lists."""
    fussed_losses = {}

    for _, loss_dict in losses:
        for k, v in loss_dict.items():
            if k in fussed_losses:
                fussed_losses[k].append(v)
            else:
                fussed_losses[k] = [v]
    return fussed_losses

In [None]:
def process_multirun_data(path_multirun):
    """Given a path to a multirun directory, this loads the history of all runs"""
    res_list = []
    for results in list(Path(path_multirun).glob("**/*.pkl")):
        config = read_config(Path(results).parent)
        data = read_pickle(results)
        pre_train_loss = data["history"].metrics_distributed_fit["pre_train_losses"]
        fussed_losses = fuse_by_dataset(pre_train_loss)
        res_list.append(
            {
                "strategy": config["algorithm-name"],
                "train_losses": fussed_losses,
            }
        )
    return res_list

In [None]:
all_losses = process_multirun_data(RESULTS_PATH)

In [None]:
def average_by_client_type(all_fused_lossed):
    """If there are multliple runs for the same strategy add them up,

    average them later. This is useful if you run the `--multirun` running
    more than one time the same configuration."""

    # identify how many unique clients were used
    to_plot = {}
    for res in all_fused_lossed:
        strategy = res["strategy"]
        if strategy not in to_plot:
            to_plot[strategy] = {}

        for dataset, train_loss in res["train_losses"].items():
            if dataset in to_plot[strategy]:
                to_plot[strategy][dataset]["train_loss"] += np.array(train_loss)
                to_plot[strategy][dataset]["run_count"] += 1
            else:
                to_plot[strategy][dataset] = {"train_loss": np.array(train_loss)}
                to_plot[strategy][dataset]["run_count"] = 1

    # print(to_plot)
    return to_plot

In [None]:
to_plot = average_by_client_type(all_losses)

In [None]:
datasets = to_plot[list(to_plot.keys())[0]].keys()
print(datasets)

num_datasets = len(datasets)
fig, axs = plt.subplots(figsize=(14, 4), nrows=1, ncols=num_datasets)


for s_id, (strategy, results) in enumerate(to_plot.items()):
    for i, dataset in enumerate(datasets):
        loss = results[dataset]["train_loss"] / results[dataset]["run_count"]
        axs[i].plot(range(len(loss)), loss, label=strategy)
        axs[i].set_xlabel("Round")
        if i == 0:
            axs[i].set_ylabel("Train Loss")

        axs[i].legend()

        if s_id == 0:
            axs[i].grid()
            axs[i].set_title(dataset)
            axs[i].set_xticks(np.arange(0, 100 + 1, 25))


saveFig(SAVE_FIG_PATH / "train_loss.png", fig)