In [None]:
import os
import pickle

import numpy as np
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt

from utils import set_matplotlib_configuration, get_accuracy_with_clustering, filter_df

In [None]:
DATA_PARTITION = "val"
WEIGHTING = "weighted_average" # options -- average / weighted_average (weight by dataset size)
FOLDER = "cifar10_ablation"
BASE_FOLDER = f"../outputs/{FOLDER}"

PLOTTING_KWARGS, SAVEFIG_KWARGS = set_matplotlib_configuration(8.1)

In [None]:
PLOTTING_KWARGS["error_kw"]["capthick"]=0.4
PLOTTING_KWARGS["error_kw"]["capsize"]=1.2

## Accuracy without clustering

In [None]:
accuracies = []
for folder in glob(f"{BASE_FOLDER}/without_clustering_seed*/"):
    df = pd.read_csv(folder + f"{DATA_PARTITION}_accuracy.csv")
    with open(folder + "metric_evolution.pkl", "rb") as file:
        data = pickle.load(file)
    if WEIGHTING == "average":
        acc = df["accuracy"].mean()
    else:
        acc = (df["accuracy"] * df["dataset_size"]).sum() / df["dataset_size"].sum()
    accuracies.append(acc)
    _, a = zip(*data["test_accuracy"])
WO_CLUSTERING_MEAN, std = np.mean(accuracies)*100, np.std(accuracies)*100 / np.sqrt(3)
f"Average accuracy without clustering: {WO_CLUSTERING_MEAN:.2f} +/- {std:.2f}"

## Analysis w.r.t. clustering algorithm

Plot data

In [None]:
df = get_accuracy_with_clustering(BASE_FOLDER, "val", WEIGHTING, analysis_keys=["num_clusters", "seed", "clustering_algorithm", "mu", "ho"])
print(filter_df(df, {"num_ho_clients": 0, "algorithm": "FedProx"}).reset_index(drop=True).groupby(["method", "num_clusters", "clustering_algorithm"]).size().mean())
df = filter_df(df, {"num_ho_clients": 0, "algorithm": "FedProx"}) \
    .reset_index(drop=True) \
    .drop("meta", axis=1) \
    .groupby(["method", "num_clusters", "clustering_algorithm"])["acc"] \
    .agg(["mean", "std"]) \
    .reset_index()
df[["mean", "std"]] *= 100
pass

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(5.5, 3.5), sharex=True, sharey=True)
for idx, method in enumerate(["CoLEDS", "WDP", "REPA", "LbP"]):
    ax_ = ax[idx % 2][idx // 2]
    ax_.set_title(method)
    method_df = filter_df(df, {"method": method}).pivot(index="num_clusters", columns="clustering_algorithm", values=["mean", "std"])
    method_df = method_df.rename(
        columns={
            "complete": "Hierarchical\ncomplete linkage",
            "kmeans": "KMeans",
            "ward": "Hierarchical\nward linkage",
        },
        level=1,
    )
    method_df = method_df.sort_index(axis=1, level=0, sort_remaining=True)
    method_df["mean"].plot(kind="bar", yerr=method_df["std"], ax=ax_, rot=0, legend=None, **PLOTTING_KWARGS)
    ax_.axhline(WO_CLUSTERING_MEAN, linestyle="--", color=PLOTTING_KWARGS["color"][4], linewidth=1.0)
    ax_.set_xlabel("")
fig.supylabel("Validation Accuracy", x=0.06)
fig.supxlabel("Number of clusters")
plt.subplots_adjust(wspace=0.1)
ax[0][0].set_ylim(50, 80)
ax[1][1].legend(loc=(1.03, 0.85), title="Clustering algorithm")
plt.savefig(f"images/clustering_algos_cifar10.pdf", **SAVEFIG_KWARGS)

## Accuracy with respect to training algorithm

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(5.5, 1.6), sharey=True)
for idx, data_partition in enumerate(["val", "ho"]):
    # get accuracy without clustering
    ax[idx].set_title({
        "val": "Validation sets\nof training clients",
        "ho": "Validation sets\nof holdout clients",
    }[data_partition])
    accs = []
    for file in glob(BASE_FOLDER + "/" + f"without_clustering_seed_*_mu_-1.0_ho_50/{data_partition}_accuracy.csv"):
        df = pd.read_csv(file)
        if WEIGHTING == "average":
            acc = df["accuracy"].mean()
        else:
            acc = (df["accuracy"] * df["dataset_size"]).sum() / df["dataset_size"].sum()
        accs.append(acc)
    wo_clustering = np.mean(accs) * 100

    # get accuracy with clustering
    df = get_accuracy_with_clustering(BASE_FOLDER, data_partition, WEIGHTING, analysis_keys=["num_clusters", "seed", "clustering_algorithm", "mu", "ho"])
    print(filter_df(df, {"num_ho_clients": 50, "clustering_algorithm": "kmeans", "algorithm": "FedAvg"}).reset_index(drop=True).drop("meta", axis=1).groupby(["method", "num_clusters"]).size().mean())
    fedavg_df = filter_df(df, {"num_ho_clients": 50, "clustering_algorithm": "kmeans", "algorithm": "FedAvg"}) \
        .reset_index(drop=True) \
        .drop("meta", axis=1) \
        .groupby(["method", "num_clusters"])["acc"] \
        .agg(["mean", "std"]) \
        .reset_index().pivot(index="num_clusters", columns="method", values=["mean", "std"])
    fedavg_df = fedavg_df.reindex(
        columns=pd.MultiIndex.from_product(
            [["mean", "std"], ["LbP", "WDP", "CoLEDS", "REPA", "AESP"]]
        )
    )
    fedavg_df *= 100
    # plot data
    fedavg_df["mean"].plot(kind="bar", yerr=fedavg_df["std"], rot=0, **PLOTTING_KWARGS, ax=ax[idx], legend=None)
    ax[idx].axhline(wo_clustering, color=PLOTTING_KWARGS["color"][5], linestyle="--", linewidth=1.2)
    ax[idx].set_xlabel("")
ax[1].legend(loc=(1.02, 0.1), title="Profiling\nmethod")
plt.subplots_adjust(wspace=0.1)
ax[0].set_ylabel("Accuracy")
ax[0].set_ylim(40, 80)
fig.supxlabel("Number of clusters", y=-0.06)
plt.savefig("images/fedavg_cifar10.pdf", **SAVEFIG_KWARGS)

## Accuracy gains table

In [None]:
DATA_PARTITION = "val" # repeat twice, once with `ho` and once with `val`

In [None]:
accuracies = []
for folder in glob(f"{BASE_FOLDER}/without_clustering_seed*_mu_-1.0_ho_50/"):
    df = pd.read_csv(folder + f"{DATA_PARTITION}_accuracy.csv")
    with open(folder + "metric_evolution.pkl", "rb") as file:
        data = pickle.load(file)
    if WEIGHTING == "average":
        acc = df["accuracy"].mean()
    else:
        acc = (df["accuracy"] * df["dataset_size"]).sum() / df["dataset_size"].sum()
    accuracies.append(acc)
    _, a = zip(*data["test_accuracy"])
WO_CLUSTERING_MEAN, std = np.mean(accuracies)*100, np.std(accuracies)*100 / np.sqrt(3)
f"Average accuracy without clustering: {WO_CLUSTERING_MEAN:.2f} +/- {std:.2f}"

In [None]:
df = get_accuracy_with_clustering(BASE_FOLDER, DATA_PARTITION, WEIGHTING, analysis_keys=["num_clusters", "seed", "clustering_algorithm", "mu", "ho"])
df = filter_df(df, {"clustering_algorithm": "kmeans", "num_ho_clients": 50, "algorithm": "FedProx"})
# assert ((df.groupby(["method", "num_clusters", "meta"]).size()) == 3).all()
df = df.groupby(["method", "num_clusters", "meta"])["acc"].agg(["mean", "std"]).reset_index()
df = df.pivot(index="num_clusters", columns="method", values=["mean", "std"]) * 100
df = df.reindex(columns=pd.MultiIndex.from_product([["mean", "std"], ["WDP", "LbP", "CoLEDS", "REPA", "AESP"]]))
df

In [None]:
# Format `df` for LaTeX with per-row highlighting:
# best -> \textbf, second -> \underline, third -> \emph.
# Only the mean value is styled, not the "Â± std".
def _format_cell(mean_val, std_val, style=None):
    mean_str = f"{mean_val:.2f}"
    if style == "best":
        mean_str = r"\bm1{" + mean_str + "}"
    elif style == "second":
        mean_str = r"\bm2{" + mean_str + "}"
    elif style == "third":
        mean_str = r"\bm3{" + mean_str + "}"

    if pd.isna(std_val):
        return mean_str
    return mean_str + r" \tiny{$\pm " + f"{std_val:.2f}" + r"$}"

df["mean"] -= WO_CLUSTERING_MEAN
order = ["LbP", "WDP", "CoLEDS", "AESP", "REPA"]
formatted_df = pd.DataFrame(index=df.index, columns=order, dtype=object)

for idx in df.index:
    # collect means for this row
    row_means = {
        m: df.loc[idx, ("mean", m)]
        for m in order
        if ("mean", m) in df.columns
    }

    # rank methods by mean (descending)
    ranked = sorted(row_means.items(), key=lambda x: x[1], reverse=True)
    styles = {}
    if len(ranked) > 0:
        styles[ranked[0][0]] = "best"
    if len(ranked) > 1:
        styles[ranked[1][0]] = "second"
    if len(ranked) > 2:
        styles[ranked[2][0]] = "third"

    # fill formatted table row
    for m in order:
        if ("mean", m) in df.columns and ("std", m) in df.columns:
            mean_val = df.loc[idx, ("mean", m)]
            std_val = df.loc[idx, ("std", m)]
            style = styles.get(m)
            formatted_df.loc[idx, m] = _format_cell(mean_val, std_val, style)
        else:
            formatted_df.loc[idx, m] = ""

formatted_df.index.name = "Number of clusters"
latex_table = formatted_df.reset_index().to_latex(
    escape=False,
    index=False,
    column_format="r||ccccc",
)
print(latex_table)

In [None]:
# df["mean"].plot(kind="bar", figsize=(8, 3), **PLOTTING_KWARGS)
# plt.gca().axhline(WO_CLUSTERING_MEAN)
# # plt.ylim(75, 85)
# plt.legend(loc=(1.04, 0.1))