In [None]:
import os

import pandas as pd
import matplotlib.pyplot as plt

from utils import (
    download_all_data,
    get_coleds_dataframe,
    get_es_dataframe,
    get_wd_dataframe,
    filter_df,
    plot_with_std,
    get_wd_correlations_late_in_training,
    set_matplotlib_configuration,
)

In [None]:
PLOTTING_KWARGS = set_matplotlib_configuration()

In [None]:
SAVE_IMAGES_FOLDER = "images"
NUMBER_OF_SEEDS = 5 # how many times every experiment is run
if not os.path.exists(SAVE_IMAGES_FOLDER):
    os.mkdir(SAVE_IMAGES_FOLDER)

In [None]:
download_all_data()

# Comparing the best correlation of different methods

In [None]:
# get the appropriate data frames
wddf = get_wd_dataframe(refresh=True, number_of_seeds=NUMBER_OF_SEEDS)
esdf = get_es_dataframe(refresh=True, number_of_seeds=NUMBER_OF_SEEDS)
cldf = get_coleds_dataframe(refresh=True, number_of_seeds=NUMBER_OF_SEEDS) # cl for contrastive learning

In [None]:
# coleds -- group by hyperparameters, and get the highest correlation
correlation_per_conf = \
    cldf.groupby(["dataset", "model", "temperature", "batch_size", "fraction_fit", "num_client_updates"])["max_correlation"].agg(["mean", "std"]).reset_index()

# get the maximum correlation per every dataset
best_per_dataset = correlation_per_conf.loc[
    correlation_per_conf.groupby("dataset")["mean"].idxmax()
].reset_index(drop=True)

# restructure the dataframe -- we don't need the exact configuration right now!
cldf_formatted = best_per_dataset[["dataset", "mean", "std"]].assign(profiler="CoLEDS")

# format the weight difference data frame
wddf_formatted = wddf.groupby("dataset")["max_correlation"].agg(["mean", "std"]).reset_index().assign(profiler="WDP")

# # format the embedding space data frame
esdf_formatted = esdf.groupby(["model", "dataset"])["max_correlation"].agg(["mean", "std"]).reset_index()
esdf_formatted["profiler"] = esdf_formatted["model"].apply(lambda mdl: f"ES - {mdl}")
esdf_formatted = esdf_formatted[["dataset", "mean", "std", "profiler"]]

In [None]:
combined_df = pd.concat([cldf_formatted, wddf_formatted, esdf_formatted], ignore_index=True)
fig, ax = plt.subplots(figsize=(12, 4))
index_order = ["MNIST", "Fashion-MNIST", "CIFAR10", "CINIC10", "CIFAR100"]
column_order = ["WDP", "CoLEDS", "ES - Classifier", "ES - AutoEncoder"]
plot_with_std(
    combined_df,
    index="dataset",
    columns="profiler",
    ax=ax,
    index_order=index_order,
    column_order=column_order,
    **PLOTTING_KWARGS
)

wddf_late_training = get_wd_correlations_late_in_training()

n_datasets = combined_df["dataset"].nunique()
start = 0 * n_datasets
end = start + n_datasets
wdp_bars = ax.patches[start:end]

for bar, dataset in zip(wdp_bars, index_order):
    x = bar.get_x()
    width = bar.get_width()
    height = wddf_late_training.loc[dataset]["mean"]

    ax.bar(
        x,
        height,
        width=width,
        align="edge",
        color=bar.get_facecolor(),
        hatch="//",
        edgecolor="black",
        linewidth=1.0,
        zorder=bar.get_zorder() + 1,  # ensure it is drawn on top
    )

ax.set_xlabel("Dataset")
ax.set_ylabel("Max. Correlation")
plt.legend(loc=(1.02, 0.3), title="Profiler")

# Analysis w.r.t. batch size and temperature

In [None]:
# constant parameters
fraction_fit = 0.5
num_client_updates = 4
model = "Set2Set"

In [None]:
df = filter_df(get_coleds_dataframe(), {
    "fraction_fit": fraction_fit,
    "num_client_updates": num_client_updates,
    "model": model,
    "dataset": {"CIFAR10", "CINIC10"}
})
df = df.groupby(["dataset", "batch_size", "temperature"])["max_correlation"].agg(["mean", "std"]).reset_index()

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=True)
df_ = df[df["batch_size"].isin({1, 2, 4, 8, 16, 32, 48, 96})]
for idx, dataset in enumerate(["CIFAR10", "CINIC10"]):
    tmp = df_[df_["dataset"] == dataset]
    plot_with_std(tmp, "batch_size", "temperature", ax=ax[idx], **PLOTTING_KWARGS)
    ax[idx].set_xlabel("")
ax[0].legend().remove()
ax[1].legend(loc=(1.01, 0.7), title="Temperature")

fig.supylabel("Max. Correlation", x=0.07)
ax[1].set_xlabel("Batch Size")
plt.show()

# Analysis w.r.t. fraction fit and model

In [None]:
# constant parameters
batch_size = 16
num_client_updates = 4
temperature = 0.2

df = filter_df(get_coleds_dataframe(), {
    "batch_size": batch_size,
    "temperature": temperature,
    "num_client_updates": num_client_updates,
    "dataset": {"CIFAR10", "CINIC10"}
})
df = df.groupby(["dataset", "model", "fraction_fit"])["max_correlation"].agg(["mean", "std"]).reset_index()

In [None]:
fig, ax = plt.subplots(2, 1, sharex=True, figsize=(12, 6))
for idx, dataset in enumerate(["CIFAR10", "CINIC10"]):
    tmp = df[df["dataset"] == dataset]
    plot_with_std(
        tmp,
        index="model",
        columns="fraction_fit",
        ax=ax[idx],
        index_order=["Set2Set", "Cl-Mean", "GRU"],
        column_order=sorted(df["fraction_fit"].unique()),
        **PLOTTING_KWARGS
    )
    ax[idx].set_xlabel("")
ax[0].legend().remove()
ax[1].legend(loc=(1.01, 0.5), title="Temperature")

fig.supxlabel("Model")
fig.supylabel("Max. Correlation")
plt.show()

# Evolution of correlation through training w.r.t. 

In [None]:
## TODO