In [None]:
import os
import pickle

import numpy as np
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt

from utils import set_matplotlib_configuration, get_accuracy_with_clustering

## Femmnist accuracy gains

Notebook for generating the table with accuracy gains. Change `DATA_PARTITION` to either `ho` or `val` to get the accuracy gain on the houldout clients or validation dataset of training clients, respectively.

In [None]:
PLOTTING_KWARGS, SAVEFIG_KWARGS = set_matplotlib_configuration(8.1)

In [None]:
DATA_PARTITION = "ho"
WEIGHTING = "weighted_average" # options -- average / weighted_average (weight by dataset size)
FOLDER = "accuracy_gains"
BASE_FOLDER = f"../outputs/{FOLDER}"

In [None]:
accuracies = []
for folder in glob(f"{BASE_FOLDER}/without_clustering_seed*/"):
    df = pd.read_csv(folder + f"{DATA_PARTITION}_accuracy.csv")
    with open(folder + "metric_evolution.pkl", "rb") as file:
        data = pickle.load(file)
    if WEIGHTING == "average":
        acc = df["accuracy"].mean()
    else:
        acc = (df["accuracy"] * df["dataset_size"]).sum() / df["dataset_size"].sum()
    accuracies.append(acc)
    _, a = zip(*data["test_accuracy"])
WO_CLUSTERING_MEAN, std = np.mean(accuracies)*100, np.std(accuracies)*100 / np.sqrt(3)
f"Average accuracy without clustering: {WO_CLUSTERING_MEAN:.2f} +/- {std:.2f}"

In [None]:
fixed_coleds_configuration = False
coleds_bs = 32
coleds_ff = 0.25
coleds_save_always = "false"

df = get_accuracy_with_clustering(BASE_FOLDER, DATA_PARTITION, WEIGHTING, coleds_analysis_keys=["num_clusters", "seed", "bs", "ff"])
df["bs"].fillna(0, inplace=True)
df["ff"].fillna(0., inplace=True)
df["save"].fillna("", inplace=True)
df["bs"] = pd.to_numeric(df["bs"]).astype(int)
df["ff"] = pd.to_numeric(df["ff"]).astype(float)


if fixed_coleds_configuration:
    df = df[(df["method"] != "CoLEDS") | ((df["method"] == "CoLEDS") & (df["bs"] == coleds_bs)  & (df["ff"] == coleds_ff) & (df["save"] == coleds_save_always))]
else:
    non_coleds = df[df["method"] != "CoLEDS"]

    coleds = df[df["method"] == "CoLEDS"].reset_index(drop=True)
    tmp = coleds.groupby(["num_clusters", "save", "bs", "ff"])["acc"].mean().reset_index()
    valid_combinations = tmp.loc[tmp.groupby("num_clusters")["acc"].idxmax()][["num_clusters", "save", "bs", "ff"]]
    cols = ["num_clusters", "save", "bs", "ff"]
    best_coleds = coleds.merge(valid_combinations, on=cols, how="inner")

    df = pd.concat([non_coleds, best_coleds], ignore_index=True)
# assert (df.groupby(["method", "num_clusters", "meta"]).size() == 5).all()

df = df.groupby(["method", "num_clusters"])["acc"].agg(["mean", "std"]).reset_index()
df = df.pivot(index="num_clusters", columns="method", values=["mean", "std"]) * 100
df = df.reindex(columns=pd.MultiIndex.from_product([["mean", "std"], ["WDP", "LbP", "CoLEDS", "AESP", "REPA"]]))

In [None]:
df["mean"].plot(kind="bar", figsize=(5, 2), **PLOTTING_KWARGS)
plt.gca().axhline(WO_CLUSTERING_MEAN)
plt.ylim(75, 85)
plt.legend(loc=(1.04, 0.1))

In [None]:
# Format `df` for LaTeX with per-row highlighting:
# best -> \textbf, second -> \underline, third -> \emph.
# Only the mean value is styled, not the "Â± std".
def _format_cell(mean_val, std_val, style=None):
    mean_str = f"{mean_val:.2f}"
    if style == "best":
        mean_str = r"\bm1{" + mean_str + "}"
    elif style == "second":
        mean_str = r"\bm2{" + mean_str + "}"
    elif style == "third":
        mean_str = r"\bm3{" + mean_str + "}"

    if pd.isna(std_val):
        return mean_str
    return mean_str + r" \tiny{$\pm " + f"{std_val:.2f}" + r"$}"

order = ["LbP", "WDP", "CoLEDS", "AESP", "REPA"]
formatted_df = pd.DataFrame(index=df.index, columns=order, dtype=object)

for idx in df.index:
    # collect means for this row
    row_means = {
        m: df.loc[idx, ("mean", m)]
        for m in order
        if ("mean", m) in df.columns
    }

    # rank methods by mean (descending)
    ranked = sorted(row_means.items(), key=lambda x: x[1], reverse=True)
    styles = {}
    if len(ranked) > 0:
        styles[ranked[0][0]] = "best"
    if len(ranked) > 1:
        styles[ranked[1][0]] = "second"
    if len(ranked) > 2:
        styles[ranked[2][0]] = "third"

    # fill formatted table row
    for m in order:
        if ("mean", m) in df.columns and ("std", m) in df.columns:
            mean_val = df.loc[idx, ("mean", m)]
            std_val = df.loc[idx, ("std", m)]
            style = styles.get(m)
            formatted_df.loc[idx, m] = _format_cell(mean_val, std_val, style)
        else:
            formatted_df.loc[idx, m] = ""

formatted_df.index.name = "Number of clusters"
latex_table = formatted_df.reset_index().to_latex(
    escape=False,
    index=False,
    column_format="r||ccccc",
)
print(latex_table)

Note: the above latex tabular highlights the first, second and third value. To make it work with the provided style, include the following to your latex document:

```latex
\usepackage{tikz}
\definecolor{gold}{HTML}{FBF2D2}
\definecolor{silver}{HTML}{DDDDDD}
\definecolor{bronze}{HTML}{EED2B8}
\definecolor{goldD}{HTML}{D9AE13}
\definecolor{silverD}{HTML}{909090}
\definecolor{bronzeD}{HTML}{9A5F26}
\newcommand{\medal}[3]{\tikz[baseline=(char.base)]{\node[rounded corners=2pt,fill=#1,draw=#2,inner sep=1.5pt] (char) {#3};}}

\newcommand{\bm}[2]{
    \ifcase#1\or% case 1
      {\medal{gold}{goldD}{\textbf{#2}}}
    \or % case 2
      {\medal{silver}{silverD}{#2}}
    \or % case 3
      {\medal{bronze}{bronzeD}{#2}}
    \else % default case
      #2
    \fi\ignorespaces
}
```

## Evolution of the accuracy

In [None]:
SEED = 4
NUMBER_OF_CLUSTERS = 4

In [None]:
import pickle
def get_accuracy_evolution(path):
    with open(path, "rb") as fp:
        data = pickle.load(fp)
    x, acc = zip(*data["test_accuracy"])
    return x, acc

In [None]:

def get_all_accuracy_evolution(seed, number_or_clusters):
    all_accuracies = {}

    accs = []
    for file in glob(BASE_FOLDER + f"/coleds_bs_32_ff_0.25*_save_false_seed_{seed}/{number_or_clusters}/*.pkl"):
        accs.append(get_accuracy_evolution(file)[1])
    all_accuracies["CoLEDS"] = np.mean(np.array(accs), axis=0)

    accs = []
    for file in glob(BASE_FOLDER + f"/wd_seed_{seed}/{number_or_clusters}/*.pkl"):
        accs.append(get_accuracy_evolution(file)[1])
    all_accuracies["WDP"] = np.mean(np.array(accs), axis=0)

    accs = []
    for file in glob(BASE_FOLDER + f"/label_seed_{seed}/{number_or_clusters}/*.pkl"):
        accs.append(get_accuracy_evolution(file)[1])
    all_accuracies["LbP"] = np.mean(np.array(accs), axis=0)

    accs = []
    for file in glob(BASE_FOLDER + f"/es_model_sim*_seed_{seed}/{number_or_clusters}/*.pkl"):
        accs.append(get_accuracy_evolution(file)[1])
    all_accuracies["REPA"] = np.mean(np.array(accs), axis=0)

    accs = []
    for file in glob(BASE_FOLDER + f"/es_model_bet*_seed_{seed}/{number_or_clusters}/*.pkl"):
        accs.append(get_accuracy_evolution(file)[1])
    all_accuracies["AESP"] = np.mean(np.array(accs), axis=0)


    for file in glob(BASE_FOLDER + f"/without_clustering_seed_{seed}/*.pkl"):
        tmp = get_accuracy_evolution(file)[1]
    all_accuracies["Without clustering"] = tmp
    return all_accuracies

In [None]:
SEED=8
fig, ax = plt.subplots(2, 3, figsize=(5.2, 3.2), sharey=True, sharex=True)
colors = PLOTTING_KWARGS["color"]
for ax_idx, num_clusters in enumerate([2, 4, 8, 16, 32, 64]):
    tmp_ax = ax[ax_idx // 3][ax_idx % 3]
    tmp_ax.set_title(f"{num_clusters} clusters")
    all_accuracies = get_all_accuracy_evolution(SEED, num_clusters)
    for idx, method in enumerate(["CoLEDS", "WDP", "Without clustering"]):
        linestyle = "-" if method != "Without clustering" else "--"
        tmp_ax.plot(range(5, 252, 5), np.array(all_accuracies[method]) * 100, color=colors[idx], label=method, linestyle=linestyle, linewidth=1.2)
ax[0][0].set_xticks([0, 50, 100, 150, 200, 250])
ax[-1][-1].legend(loc=(1.02, 0.9), title="Profiling method")
fig.supxlabel("Epoch", y=0.01)
fig.supylabel("Validation Accuracy", x=0.06)
fig.subplots_adjust(wspace=0.1)
plt.ylim(bottom=50)
plt.savefig("images/femnist_accuracy_evolution.pdf", **SAVEFIG_KWARGS)