In [None]:
import os

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import seaborn as sns

from utils import (
    download_all_data,
    get_project_runs,
    get_run_dataframe,
    get_coleds_dataframe,
    filter_df,
    plot_with_std,
)

In [None]:
# matplotlib configuration
plt.style.use("seaborn-v0_8-whitegrid")
cmap = sns.color_palette("colorblind", 4)
PLOTTING_KWARGS = {
    "capsize": 5,
    "color": [cmap[i] for i in range(4)]
}

mpl.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Latin Modern Roman"],  # or "LMRoman10"
    "text.latex.preamble": r"""
        \usepackage{lmodern}   % Load Latin Modern
        \usepackage[T1]{fontenc}
    """,
})


In [None]:
SAVE_IMAGES_FOLDER = "images"
if not os.path.exists(SAVE_IMAGES_FOLDER):
    os.mkdir(SAVE_IMAGES_FOLDER)

In [None]:
download_all_data()

# Comparing the best correlation of different methods

In [None]:
## TODO

# Analysis w.r.t. batch size and temperature

In [None]:
# constant parameters
fraction_fit = 0.5
num_client_updates = 4
model = "Set2Set"

In [None]:
df = filter_df(get_coleds_dataframe(refresh=True), {
    "fraction_fit": fraction_fit,
    "num_client_updates": num_client_updates,
    "model": model,
    "dataset": {"CIFAR10", "CINIC10"}
})
df = df.groupby(["dataset", "batch_size", "temperature"])["max_correlation"].agg(["mean", "std"]).reset_index()

In [None]:
fig, ax = plt.subplots(2, 1, sharex=True)
for idx, dataset in enumerate(["CIFAR10", "CINIC10"]):
    tmp = df[df["dataset"] == dataset]
    plot_with_std(tmp, "batch_size", "temperature", ax=ax[idx], **PLOTTING_KWARGS)
    ax[idx].set_xlabel("")
ax[0].legend().remove()
ax[1].legend(loc=(1.04, 0.8), title="Temperature")
plt.show()

# Analysis w.r.t. fraction fit and model

In [None]:
# constant parameters
batch_size = 16
num_client_updates = 4
temperature = 0.2

df = filter_df(get_coleds_dataframe(refresh=True), {
    "batch_size": batch_size,
    "temperature": temperature,
    "num_client_updates": num_client_updates,
    "dataset": {"CIFAR10", "CINIC10"}
})
df = df.groupby(["dataset", "model", "fraction_fit"])["max_correlation"].agg(["mean", "std"]).reset_index()

In [None]:
fig, ax = plt.subplots(2, 1, sharex=True)
for idx, dataset in enumerate(["CIFAR10", "CINIC10"]):
    tmp = df[df["dataset"] == dataset]
    plot_with_std(tmp, "model", "fraction_fit", ax=ax[idx], **PLOTTING_KWARGS)# "GRU",
    ax[idx].set_xlabel("")
ax[0].legend().remove()
ax[1].legend(loc=(1.04, 0.8), title="Temperature")
plt.show()

# Evolution of correlation through training w.r.t. 

In [None]:
## TODO