In [None]:
import os
import pickle

import numpy as np
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt

from utils import set_matplotlib_configuration, get_accuracy_with_clustering

## Femmnist accuracy gains

Notebook for generating the table with accuracy gains. Change `DATA_PARTITION` to either `ho` or `val` to get the accuracy gain on the houldout clients or validation dataset of training clients, respectively.

In [None]:
PLOTTING_KWARGS, SAVEFIG_KWARGS = set_matplotlib_configuration()

In [None]:
DATA_PARTITION = "ho"
WEIGHTING = "weighted_average" # options -- average / weighted_average (weight by dataset size)
FOLDER = "accuracy_gains"
BASE_FOLDER = f"../outputs/{FOLDER}"

In [None]:
accuracies = []
for folder in glob(f"{BASE_FOLDER}/without_clustering_seed*/"):
    df = pd.read_csv(folder + f"{DATA_PARTITION}_accuracy.csv")
    with open(folder + "metric_evolution.pkl", "rb") as file:
        data = pickle.load(file)
    if WEIGHTING == "average":
        acc = df["accuracy"].mean()
    else:
        acc = (df["accuracy"] * df["dataset_size"]).sum() / df["dataset_size"].sum()
    accuracies.append(acc)
    _, a = zip(*data["test_accuracy"])
WO_CLUSTERING_MEAN, std = np.mean(accuracies)*100, np.std(accuracies)*100 / np.sqrt(3)
f"Average accuracy without clustering: {WO_CLUSTERING_MEAN:.2f} +/- {std:.2f}"

In [None]:
df = get_accuracy_with_clustering(BASE_FOLDER, DATA_PARTITION, WEIGHTING)
assert (df.groupby(["method", "num_clusters", "meta"]).size() == 3).all()

df = df.groupby(["method", "num_clusters", "meta"])["acc"].agg(["mean", "std"]).reset_index()
coleds_df = df.loc[df[df["method"] == "CoLEDS"].groupby("num_clusters")["mean"].idxmax()].drop("meta", axis=1)
df = pd.concat([df[df["method"] != "CoLEDS"], coleds_df]).drop("meta", axis=1)
df = df.pivot(index="num_clusters", columns="method", values=["mean", "std"]) * 100
df["mean"] -= WO_CLUSTERING_MEAN
df["std"] /= np.sqrt(3)
df = df.reindex(columns=pd.MultiIndex.from_product([["mean", "std"], ["WDP", "LbP", "CoLEDS", "AESP", "REPA"]]))

In [None]:
# plotting
df["mean"].plot(kind="bar", rot=0, **PLOTTING_KWARGS, figsize=(5, 3))
plt.legend(loc=(1.05, 0.1), )

In [None]:
# Format `df` for LaTeX with per-row highlighting:
# best -> \textbf, second -> \underline, third -> \emph.
# Only the mean value is styled, not the "Â± std".
def _format_cell(mean_val, std_val, style=None):
    mean_str = f"{mean_val:.2f}"
    if style == "best":
        mean_str = r"\bm1{" + mean_str + "}"
    elif style == "second":
        mean_str = r"\bm2{" + mean_str + "}"
    elif style == "third":
        mean_str = r"\bm3{" + mean_str + "}"

    if pd.isna(std_val):
        return mean_str
    return mean_str + r" \tiny{$\pm " + f"{std_val:.2f}" + r"$}"

order = ["LbP", "WDP", "CoLEDS", "AESP", "REPA"]
formatted_df = pd.DataFrame(index=df.index, columns=order, dtype=object)

for idx in df.index:
    # collect means for this row
    row_means = {
        m: df.loc[idx, ("mean", m)]
        for m in order
        if ("mean", m) in df.columns
    }

    # rank methods by mean (descending)
    ranked = sorted(row_means.items(), key=lambda x: x[1], reverse=True)
    styles = {}
    if len(ranked) > 0:
        styles[ranked[0][0]] = "best"
    if len(ranked) > 1:
        styles[ranked[1][0]] = "second"
    if len(ranked) > 2:
        styles[ranked[2][0]] = "third"

    # fill formatted table row
    for m in order:
        if ("mean", m) in df.columns and ("std", m) in df.columns:
            mean_val = df.loc[idx, ("mean", m)]
            std_val = df.loc[idx, ("std", m)]
            style = styles.get(m)
            formatted_df.loc[idx, m] = _format_cell(mean_val, std_val, style)
        else:
            formatted_df.loc[idx, m] = ""

formatted_df.index.name = "Number of clusters"
latex_table = formatted_df.reset_index().to_latex(
    escape=False,
    index=False,
    column_format="r||ccccc",
)
print(latex_table)

Note: the above latex tabular highlights the first, second and third value. To make it work with the provided style, include the following to your latex document:

```latex
\usepackage{tikz}
\definecolor{gold}{HTML}{FBF2D2}
\definecolor{silver}{HTML}{DDDDDD}
\definecolor{bronze}{HTML}{EED2B8}
\definecolor{goldD}{HTML}{D9AE13}
\definecolor{silverD}{HTML}{909090}
\definecolor{bronzeD}{HTML}{9A5F26}
\newcommand{\medal}[3]{\tikz[baseline=(char.base)]{\node[rounded corners=2pt,fill=#1,draw=#2,inner sep=1.5pt] (char) {#3};}}

\newcommand{\bm}[2]{
    \ifcase#1\or% case 1
      {\medal{gold}{goldD}{\textbf{#2}}}
    \or % case 2
      {\medal{silver}{silverD}{#2}}
    \or % case 3
      {\medal{bronze}{bronzeD}{#2}}
    \else % default case
      #2
    \fi\ignorespaces
}
```

In [None]:
seed = "4"
holdout_clients = None
for folder in os.listdir():
    if not folder.endswith(seed):
        continue
    for file in glob(f"{folder}/*/ho_accuracy.csv"):
        tmp = set(pd.read_csv(file).client_idx.tolist())
        if holdout_clients is None:
            holdout_clients = tmp
        assert tmp == holdout_clients