### Imports and data paths

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pickle5 as pickle

from auxiliary import load_pickle
%matplotlib inline

sns.set_theme()

In [None]:
PATH_random_iid='./results/TestSuite-random_iid_06_15_2021.pkl'
PATH_uniform_iid='./results/TestSuite-uniform_iid_06_16_2021.pkl'
PATH_uniform_non_iid="./results/TestSuite-uniform_non_iid_06_17_2021.pkl"
PATH_random_non_iid = "./results/TestSuite-random_non_iid_06_17_2021.pkl"

### Load the data

In [None]:
def load_pickle(path):
    with open(path, "rb") as fh:
        data = pickle.load(fh)
    return data

data_uniform_iid = load_pickle(PATH_uniform_iid)
data_uniform_non_iid = load_pickle(PATH_uniform_non_iid)
data_random_iid = load_pickle(PATH_random_iid)
data_random_non_iid = load_pickle(PATH_random_non_iid)

### Helper functions for plotting and wrangling the data

In [None]:
# Our pickles are organised as dictionaries, 
# So we need to hierarchically iterate through them to get the desired data
def get_mean_test_acc_per_epoch(results):
    nr_nodes = results['nr_nodes']
    top_acc_list = []
    for i in range(nr_nodes):
        # For each node save the test accuracies
        node_dict = results[f'node_{i}']
        top_acc_list.append(node_dict['test_accuracies'])
    # Transpose the lists because we want to take the average across each node per epoch
    transpose = list(zip(*top_acc_list))
    
    # For each epoch, compute the mean test accuracy
    return [np.mean(node) for node in transpose]

In [None]:
def graph_mean_epochs(results, setup="not_private", n_subplots = 3):
    
    sns.set_style("whitegrid")
    sns.set_palette("colorblind")
    
    f, ax = plt.subplots(1, n_subplots, sharex=True, sharey=True, figsize=(13,5))
        
    # Iterate through each run and separate the plots in private/ non-private
    for idx, run in enumerate(results):
        if setup == "not_private":
            if not run['add_privacy_list']:
                ax[idx%3].plot(run['epoch_list'], get_mean_test_acc_per_epoch(run), label=f"{run['graph']}")
                ax[idx%3].tick_params(axis='x', labelsize=15)
                ax[idx%3].tick_params(axis='y', labelsize=13)

        if setup == "private":
            if run['add_privacy_list']:
                ax[idx%3].plot(run['epoch_list'], get_mean_test_acc_per_epoch(run), label=f"{run['graph']}")
                ax[idx%3].tick_params(axis='x', labelsize=15)
                ax[idx%3].tick_params(axis='y', labelsize=13)

        
    for i in range(n_subplots):
        ax[i].set_xlabel("Epochs", fontsize=18)
        ax[i].set_ylabel("Mean Test Accuracy", fontsize=18)
        ax[i].set_xlim(0, 51)
        ax[i].set_ylim(0,1)
        
    ax[2].legend(loc = "lower right")

### Create and save the plots for each setup

In [None]:
graph_mean_epochs(data_uniform_iid)
plt.savefig("./plots/topology_testrun_iid_uniform_data.pdf")

In [None]:
graph_mean_epochs(data_uniform_iid, "private")
plt.savefig("./plots/topology_testrun_iid_uniform_private.pdf")

In [None]:
graph_mean_epochs(data_random_iid)
plt.savefig("./plots/topology_testrun_iid_random.pdf")

In [None]:
graph_mean_epochs(data_random_iid, "private")
plt.savefig("./plots/topology_testrun_iid_random_private.pdf")

In [None]:
graph_mean_epochs(data_uniform_non_iid)
plt.savefig("./plots/topology_testrun_non_iid_uniform.pdf")

In [None]:
graph_mean_epochs(data_uniform_non_iid, "private")
plt.savefig("./plots/topology_testrun_non_iid_uniform_private.pdf")

In [None]:
graph_mean_epochs(data_random_non_iid)
plt.savefig("./plots/topology_testrun_non_iid_random.pdf")

In [None]:
graph_mean_epochs(data_random_non_iid, "private")
plt.savefig("./plots/topology_testrun_non_iid_random_private.pdf")