In [None]:
import os
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt

from utils import (
    download_all_data,
    get_coleds_dataframe,
    get_es_dataframe,
    get_wd_dataframe,
    filter_df,
    plot_with_std,
    get_wd_correlations_late_in_training,
    set_matplotlib_configuration,
)

In [None]:
SAVE_FOLDER = Path("images/")
FONTSIZE = 8.1

In [None]:
PLOTTING_KWARGS, SAVEFIG_KWARGS = set_matplotlib_configuration(FONTSIZE)
HALF_PLOTTING_KWARGS = PLOTTING_KWARGS.copy() # for when the figure takes 1/2 of the column width
HALF_PLOTTING_KWARGS["error_kw"] = {"capthick": 0.5, "elinewidth": 0.5, "capsize": 1.0}

In [None]:
SAVE_IMAGES_FOLDER = "images"
NUMBER_OF_SEEDS = 5 # how many times every experiment is run
if not os.path.exists(SAVE_IMAGES_FOLDER):
    os.mkdir(SAVE_IMAGES_FOLDER)

In [None]:
download_all_data()

# Comparing the best correlation of different methods

In [None]:
# get the appropriate data frames
wddf = get_wd_dataframe(refresh=False, number_of_seeds=NUMBER_OF_SEEDS)
esdf = get_es_dataframe(refresh=False, number_of_seeds=NUMBER_OF_SEEDS)
cldf = get_coleds_dataframe(refresh=False, number_of_seeds=NUMBER_OF_SEEDS) # cl for contrastive learning

In [None]:
# coleds -- group by hyperparameters, and get the highest correlation
correlation_per_conf = \
    cldf.groupby(["dataset", "model", "temperature", "batch_size", "fraction_fit", "num_client_updates"])["max_correlation"].agg(["mean", "std"]).reset_index()

# get the maximum correlation per every dataset
best_per_dataset = correlation_per_conf.loc[
    correlation_per_conf.groupby("dataset")["mean"].idxmax()
].reset_index(drop=True)

# restructure the dataframe -- we don't need the exact configuration right now!
cldf_formatted = best_per_dataset[["dataset", "mean", "std"]].assign(profiler="CoLEDS")

# format the weight difference data frame
wddf_formatted = wddf.groupby("dataset")["max_correlation"].agg(["mean", "std"]).reset_index().assign(profiler="WDP")

# # format the embedding space data frame
esdf_formatted = esdf.groupby(["model", "dataset"])["max_correlation"].agg(["mean", "std"]).reset_index()
esdf_formatted["profiler"] = esdf_formatted["model"].apply(lambda mdl: f"ES - {mdl}")
esdf_formatted = esdf_formatted[["dataset", "mean", "std", "profiler"]]

In [None]:
combined_df = pd.concat([cldf_formatted, wddf_formatted, esdf_formatted], ignore_index=True)
fig, ax = plt.subplots(figsize=(5.1, 2))
index_order = ["MNIST", "Fashion-MNIST", "CIFAR10", "CINIC10", "CIFAR100"]
column_order = ["WDP", "CoLEDS", "ES - Classifier", "ES - AutoEncoder"]
plot_with_std(
    combined_df,
    index="dataset",
    columns="profiler",
    ax=ax,
    index_order=index_order,
    column_order=column_order,
    **PLOTTING_KWARGS,
)

wddf_late_training = get_wd_correlations_late_in_training()

n_datasets = combined_df["dataset"].nunique()
start = 0 * n_datasets
end = start + n_datasets
wdp_bars = ax.patches[start:end]

for bar, dataset in zip(wdp_bars, index_order):
    x = bar.get_x()
    width = bar.get_width()
    height = wddf_late_training.loc[dataset]["mean"]

    ax.bar(
        x,
        height,
        width=width,
        align="edge",
        color=bar.get_facecolor(),
        hatch="//",
        edgecolor="black",
        linewidth=1.0,
        zorder=bar.get_zorder() + 1,  # ensure it is drawn on top
    )

ax.set_xlabel("Dataset")
ax.set_ylabel("Max. Correlation")
plt.legend(loc=(1.02, 0.3), title="         Profiler", alignment="left")

plt.savefig(SAVE_FOLDER / "best_correlations.pdf", **SAVEFIG_KWARGS)

# Analysis w.r.t. batch size and temperature

In [None]:
# constant parameters
fraction_fit = 0.5
num_client_updates = 4
model = "Set2Set"

In [None]:
df = filter_df(get_coleds_dataframe(), {
    "fraction_fit": fraction_fit,
    "num_client_updates": num_client_updates,
    "model": model,
    "dataset": {"CIFAR10", "CINIC10"}
})
df = df.groupby(["dataset", "batch_size", "temperature"])["max_correlation"].agg(["mean", "std"]).reset_index()

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(2.6, 2.5), sharex=True)
df_ = df[df["batch_size"].isin({1, 2, 4, 8, 16, 32, 48, 96})]

for idx, dataset in enumerate(["CIFAR10", "CINIC10"]):
    tmp = df_[(df_["dataset"] == dataset) & (~df_["batch_size"].isin({4, 48}))]
    plot_with_std(tmp, "batch_size", "temperature", ax=ax[idx], **HALF_PLOTTING_KWARGS)
    ax[idx].set_ylabel(dataset)
    ax[idx].set_xlabel("")

ax[0].legend().remove()
ax[1].legend(loc=(1.01, 0.62), title=r"Temperature $\tau$")

fig.supylabel("Max. Correlation", x=-0.07)
fig.subplots_adjust(hspace=0.1)
ax[1].set_xlabel("Batch Size")
plt.savefig(SAVE_FOLDER / "corr_bs_temp.pdf", **SAVEFIG_KWARGS)

# Analysis w.r.t. fraction fit and model

In [None]:
# constant parameters
batch_size = 16
num_client_updates = 4
temperature = 0.2

df = filter_df(get_coleds_dataframe(), {
    "batch_size": batch_size,
    "temperature": temperature,
    "num_client_updates": num_client_updates,
    "dataset": {"CIFAR10", "CINIC10"}
})
df = df.groupby(["dataset", "model", "fraction_fit"])["max_correlation"].agg(["mean", "std"]).reset_index()

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(2.6, 2.5), sharex=True)
for idx, dataset in enumerate(["CIFAR10", "CINIC10"]):
    tmp = df[df["dataset"] == dataset]
    ax[idx].set_ylabel(dataset)
    plot_with_std(
        tmp,
        index="model",
        columns="fraction_fit",
        ax=ax[idx],
        index_order=["Set2Set", "Cl-Mean", "GRU"],
        column_order=sorted(df["fraction_fit"].unique()),
        **HALF_PLOTTING_KWARGS
    )
    ax[idx].set_xlabel("")
ax[0].legend().remove()
ax[1].legend(loc=(1.01, 0.62), title="Fraction fit $q$")

ax[1].set_xlabel("Model")
fig.supylabel("Max. Correlation", x=-0.08)
plt.savefig(SAVE_FOLDER / "corr_model_ff.pdf", **SAVEFIG_KWARGS)

# Evolution of correlation through training w.r.t. 

In [None]:
temperature = 0.2
batch_size = 16
model = "set2set"

In [None]:
from collections import defaultdict
from utils import get_project_runs, get_run_dataframe

In [None]:
runs= get_project_runs()
runs = [r for r in runs if r.config["train_config.batch_size"] == batch_size and r.config["model._target_"]=="src.models.set2set_model.Set2SetModel" and r.config["train_config.temperature"] == 0.2 and r.config["dataset.dataset_name"] == "cifar10"]

In [None]:
assert 4 * 4 * 5 == len(runs)

In [None]:
data = defaultdict(list)
for r in runs:
    config = r.config
    df = get_run_dataframe(r)["correlation"].dropna().reset_index(drop=True).to_numpy()
    ff = r.config["train_config.fraction_fit"]
    ncu = r.config["train_config.num_client_updates"]
    data[(ff, ncu)].append(df)

In [None]:
import numpy as np

linestyles = ["-", "--", "-."]

fig, ax = plt.subplots(1, 4, figsize=(5, 1.6), sharex=True, sharey=True)
ncu = 4
for idx, q in enumerate([0.05, 0.25, 0.5, 1.0]):
    ax[idx].set_title(f"$q={q}$")
    for color_idx, ncu in enumerate([1, 2, 4,]):

        list_of_arrays = data[(q, ncu)]
        # length of the shortest array
        # determine maximum length across arrays
        max_len = max(len(a) for a in list_of_arrays)

        # pad each array with NaNs up to max_len and stack
        padded = np.full((len(list_of_arrays), max_len), np.nan)
        for i, arr in enumerate(list_of_arrays):
            padded[i, : len(arr)] = arr

        # average and std at each step, ignoring missing values
        avg = np.nanmean(padded, axis=0)
        std = np.nanstd(padded, axis=0)

        # keep using `min_len` name for compatibility with the rest of the cell
        min_len = max_len

        steps = np.arange(min_len)

        ax[idx].plot(steps, avg, label=str(ncu), linewidth=1.2, color=PLOTTING_KWARGS["color"][color_idx], linestyle=linestyles[color_idx])
        # ax[idx].fill_between(steps, avg - std, avg + std, alpha=0.1, color=PLOTTING_KWARGS["color"][color_idx], linewidth=0.0)
plt.xticks([0, 5, 10, 15])
plt.xlim(0, 15)
plt.ylim(0.45, 0.67)
fig.subplots_adjust(wspace=0.2)
ax[-1].legend(loc=(1.01, 0.2), title="Number of\nclient updates")
ax[0].set_ylabel("Correlation")
fig.supxlabel("Evaluation iteration", y=-0.06, fontsize=FONTSIZE)
plt.savefig(SAVE_FOLDER / "corr_evolution.pdf", **SAVEFIG_KWARGS)