In [1]:
import os

import pandas as pd
from sklearn.inspection import PartialDependenceDisplay

DATA_PATH = "vm/fl_architectural_dataset.csv"
PLOT_DIR = "plots/rq4"

os.makedirs(PLOT_DIR, exist_ok=True)

print("Setup complete.")


Setup complete.


In [2]:
df = pd.read_csv(DATA_PATH)

# Remove first round (delta undefined)
df = df[df["FL Round"] > 1].copy()

# Remove extreme ΔF1 outliers
lower = df["delta_val_f1"].quantile(0.01)
upper = df["delta_val_f1"].quantile(0.99)

# Standardize Dataset names
df["Dataset"] = df["Dataset"].replace({
    "CIFAR-10": "CIFAR10",
    "CIFAR 10": "CIFAR10"
})

df_clean = df[
    (df["delta_val_f1"] >= lower) &
    (df["delta_val_f1"] <= upper)
    ].copy()

required_cols = [
    "delta_val_f1",
    "Val F1",
    "Total Time of FL Round",
    "client_selector",
    "heterogeneous_data_handler",
    "message_compressor",
    "Nhigh",
    "Nlow",
    "iid",
    "JSD",
    "FL Round"
]

df_clean = df_clean.dropna(subset=required_cols).copy()

print("Dataset cleaned.")
print(df_clean.shape)


Dataset cleaned.
(161882, 33)


In [3]:
features = [
    "client_selector",
    "heterogeneous_data_handler",
    "message_compressor",
    "Nhigh",
    "Nlow",
    "iid",
    "JSD",
    # "FL Round"
]

group_cols = ["Model Type", "Dataset"]

print("Features defined.")


Features defined.


In [4]:
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline

linear_models_delta = {}
linear_coeffs_delta = {}

linear_models_delta_strat = {}
linear_coeffs_delta_strat = {}

MIN_SAMPLES = 200  # avoid unstable subgroup fits

for (model_type, dataset), subdf in df_clean.groupby(["Model Type", "Dataset"]):

    if len(subdf) < 500:
        continue

    features = [
        "client_selector",
        "heterogeneous_data_handler",
        "message_compressor",
        "Nhigh",
        "Nlow",
        "iid",
        "JSD",
    ]

    key = f"{model_type}_{dataset}"

    # --------------------------
    # 1️⃣ Full model (as before)
    # --------------------------
    X_full = subdf[features]
    y_full = subdf["delta_val_f1"]

    model_full = Pipeline([
        ("scaler", StandardScaler()),
        ("ridge", Ridge(alpha=1.0))
    ])

    model_full.fit(X_full, y_full)

    coef_full = pd.Series(
        model_full.named_steps["ridge"].coef_,
        index=features
    )

    linear_models_delta[key] = model_full
    linear_coeffs_delta[key] = coef_full

    print(f"Fitted FULL ΔValF1 model for {key}")

    # --------------------------------
    # 2️⃣ Stratified models by IID
    # --------------------------------
    linear_models_delta_strat[key] = {}
    linear_coeffs_delta_strat[key] = {}

    for iid_value in [0, 1]:

        subdf_strat = subdf[subdf["iid"] == iid_value]

        if len(subdf_strat) < MIN_SAMPLES:
            print(f"Skipped stratified fit for {key}, iid={iid_value} (too few samples)")
            continue

        # Remove 'iid' from features since it is constant in subgroup
        strat_features = [f for f in features if f != "iid"]

        X_strat = subdf_strat[strat_features]
        y_strat = subdf_strat["delta_val_f1"]

        model_strat = Pipeline([
            ("scaler", StandardScaler()),
            ("ridge", Ridge(alpha=1.0))
        ])

        model_strat.fit(X_strat, y_strat)

        coef_strat = pd.Series(
            model_strat.named_steps["ridge"].coef_,
            index=strat_features
        )

        linear_models_delta_strat[key][f"iid_{iid_value}"] = model_strat
        linear_coeffs_delta_strat[key][f"iid_{iid_value}"] = coef_strat

        print(f"Fitted STRATIFIED ΔValF1 model for {key}, iid={iid_value}")


Fitted FULL ΔValF1 model for CNN 16k_CIFAR10
Fitted STRATIFIED ΔValF1 model for CNN 16k_CIFAR10, iid=0
Fitted STRATIFIED ΔValF1 model for CNN 16k_CIFAR10, iid=1
Fitted FULL ΔValF1 model for CNN 64k_CIFAR10
Fitted STRATIFIED ΔValF1 model for CNN 64k_CIFAR10, iid=0
Fitted STRATIFIED ΔValF1 model for CNN 64k_CIFAR10, iid=1
Fitted FULL ΔValF1 model for TextMLP_AGNEWS
Fitted STRATIFIED ΔValF1 model for TextMLP_AGNEWS, iid=0
Fitted STRATIFIED ΔValF1 model for TextMLP_AGNEWS, iid=1


In [5]:
key = "CNN 64k_CIFAR10"

print("FULL model")
print(linear_coeffs_delta[key])

print("\nIID = 0")
print(linear_coeffs_delta_strat[key].get("iid_0"))

print("\nIID = 1")
print(linear_coeffs_delta_strat[key].get("iid_1"))

FULL model
client_selector              -0.000819
heterogeneous_data_handler    0.010637
message_compressor            0.000122
Nhigh                         0.000001
Nlow                         -0.000001
iid                           0.052220
JSD                           0.001293
dtype: float64

IID = 0
client_selector              -0.000914
heterogeneous_data_handler    0.011084
message_compressor            0.000645
Nhigh                         0.000178
Nlow                         -0.000178
JSD                           0.001508
dtype: float64

IID = 1
client_selector               0.000000
heterogeneous_data_handler    0.000000
message_compressor           -0.002339
Nhigh                        -0.001186
Nlow                          0.001186
JSD                           0.000000
dtype: float64


In [6]:
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline

linear_models_time = {}
linear_coeffs_time = {}

linear_models_time_strat = {}
linear_coeffs_time_strat = {}

MIN_SAMPLES = 200  # safety threshold

for (model_type, dataset), subdf in df_clean.groupby(["Model Type", "Dataset"]):

    if len(subdf) < 500:
        continue

    features = [
        "client_selector",
        "heterogeneous_data_handler",
        "message_compressor",
        "Nhigh",
        "Nlow",
        "iid",
        "JSD",
    ]

    key = f"{model_type}_{dataset}"

    # --------------------------
    # 1️⃣ Full model
    # --------------------------
    X_full = subdf[features]
    y_full = subdf["Total Time of FL Round"]

    model_full = Pipeline([
        ("scaler", StandardScaler()),
        ("ridge", Ridge(alpha=1.0, solver="svd"))  # deterministic
    ])

    model_full.fit(X_full, y_full)

    coef_full = pd.Series(
        model_full.named_steps["ridge"].coef_,
        index=features
    )

    linear_models_time[key] = model_full
    linear_coeffs_time[key] = coef_full

    print(f"Fitted FULL Round Time model for {key}")

    # --------------------------------
    # 2️⃣ Stratified models by IID
    # --------------------------------
    linear_models_time_strat[key] = {}
    linear_coeffs_time_strat[key] = {}

    for iid_value in [0, 1]:

        subdf_strat = subdf[subdf["iid"] == iid_value]

        if len(subdf_strat) < MIN_SAMPLES:
            print(f"Skipped stratified time fit for {key}, iid={iid_value}")
            continue

        strat_features = [f for f in features if f != "iid"]

        X_strat = subdf_strat[strat_features]
        y_strat = subdf_strat["Total Time of FL Round"]

        model_strat = Pipeline([
            ("scaler", StandardScaler()),
            ("ridge", Ridge(alpha=1.0, solver="svd"))
        ])

        model_strat.fit(X_strat, y_strat)

        coef_strat = pd.Series(
            model_strat.named_steps["ridge"].coef_,
            index=strat_features
        )

        linear_models_time_strat[key][f"iid_{iid_value}"] = model_strat
        linear_coeffs_time_strat[key][f"iid_{iid_value}"] = coef_strat

        print(f"Fitted STRATIFIED Round Time model for {key}, iid={iid_value}")


Fitted FULL Round Time model for CNN 16k_CIFAR10
Fitted STRATIFIED Round Time model for CNN 16k_CIFAR10, iid=0
Fitted STRATIFIED Round Time model for CNN 16k_CIFAR10, iid=1
Fitted FULL Round Time model for CNN 64k_CIFAR10
Fitted STRATIFIED Round Time model for CNN 64k_CIFAR10, iid=0
Fitted STRATIFIED Round Time model for CNN 64k_CIFAR10, iid=1
Fitted FULL Round Time model for TextMLP_AGNEWS
Fitted STRATIFIED Round Time model for TextMLP_AGNEWS, iid=0
Fitted STRATIFIED Round Time model for TextMLP_AGNEWS, iid=1


In [7]:
import matplotlib.pyplot as plt
import os

PLOT_DIR = "plots/rq4"
os.makedirs(PLOT_DIR, exist_ok=True)

all_values = []
for coef_df in linear_coeffs_delta.values():
    all_values.extend(coef_df.values)

max_abs = max(abs(v) for v in all_values)

for key, coef_df in linear_coeffs_delta.items():
    coef_sorted = coef_df.sort_values()

    plt.figure(figsize=(8, 6))

    colors = ["green" if v > 0 else "red" for v in coef_sorted.values]

    coef_sorted.plot(kind="barh", color=colors)

    plt.title(f"Linear Coefficients (ΔValF1)\n{key}")
    plt.xlabel("Standardized Coefficient")
    plt.xlim(-max_abs * 1.1, max_abs * 1.1)
    plt.tight_layout()

    save_path = os.path.join(PLOT_DIR, f"linear_delta_{key}.pdf")
    plt.savefig(save_path, dpi=300)
    plt.close()

    print("Saved:", save_path)

Saved: plots/rq4/linear_delta_CNN 16k_CIFAR10.pdf
Saved: plots/rq4/linear_delta_CNN 64k_CIFAR10.pdf
Saved: plots/rq4/linear_delta_TextMLP_AGNEWS.pdf


In [8]:
import matplotlib.pyplot as plt
import os

PLOT_DIR = "plots/rq4"
os.makedirs(PLOT_DIR, exist_ok=True)

# ------------------------------
# 1️⃣ Global axis scaling
# ------------------------------
all_values = []
for coef_df in linear_coeffs_time.values():
    all_values.extend(coef_df.values)

max_abs = max(abs(v) for v in all_values)

# ------------------------------
# 2️⃣ Plot per task
# ------------------------------
for key, coef_df in linear_coeffs_time.items():

    coef_sorted = coef_df.sort_values()

    plt.figure(figsize=(8, 6))

    # Red = increases time (bad), Green = reduces time (good)
    colors = ["red" if v > 0 else "green" for v in coef_sorted.values]

    coef_sorted.plot(kind="barh", color=colors)

    # Symmetric axis for fair visual comparison
    plt.xlim(-max_abs * 1.1, max_abs * 1.1)

    plt.axvline(0, color="black", linewidth=1)
    plt.title(f"Linear Coefficients (Round Time)\n{key}")
    plt.xlabel("Standardized Coefficient")
    plt.tight_layout()

    save_path = os.path.join(PLOT_DIR, f"linear_time_{key}.pdf")
    plt.savefig(save_path, dpi=300)
    plt.close()

    print("Saved:", save_path)


Saved: plots/rq4/linear_time_CNN 16k_CIFAR10.pdf
Saved: plots/rq4/linear_time_CNN 64k_CIFAR10.pdf
Saved: plots/rq4/linear_time_TextMLP_AGNEWS.pdf


In [9]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

# Analysing the interactions only makes sense on non-IID data,
# that's where HDH is applied
df_clean = df_clean[df_clean["Data Distr. Type"] == 'non-IID'].copy()

models_delta = {}
models_time = {}

features = [
    "client_selector",
    "heterogeneous_data_handler",
    "message_compressor",
    # "Nhigh",
    # "Nlow",
    # "iid",
    # "JSD",
    # "FL Round"
]

for (model_type, dataset), subdf in df_clean.groupby(["Model Type", "Dataset"]):

    if len(subdf) < 500:
        continue

    key = f"{model_type}_{dataset}"

    if "64k" in model_type:
        features = [
            "client_selector",
            "heterogeneous_data_handler",
            "message_compressor",
        ]
    else:
        features = [
            "client_selector",
            "heterogeneous_data_handler",
        ]

    print("\n======================================")
    print(f"Training RF models for {model_type} | {dataset}")
    print("======================================")

    X = subdf[features]

    # -------- ΔValF1 model --------
    y_delta = subdf["Val F1"]

    X_train_d, X_test_d, y_train_d, y_test_d = train_test_split(
        X, y_delta, test_size=0.2, random_state=42
    )

    rf_delta = RandomForestRegressor(
        n_estimators=200,
        max_depth=12,
        random_state=42,
        n_jobs=-1
    )

    rf_delta.fit(X_train_d, y_train_d)

    key = f"{model_type}_{dataset}"
    models_delta[key] = (rf_delta, X_train_d)

    print("ΔValF1 R²:", rf_delta.score(X_test_d, y_test_d))

    # -------- Round Time model --------
    y_time = subdf["Total Time of FL Round"]

    X_train_t, X_test_t, y_train_t, y_test_t = train_test_split(
        X, y_time, test_size=0.2, random_state=42
    )

    rf_time = RandomForestRegressor(
        n_estimators=200,
        max_depth=12,
        random_state=42,
        n_jobs=-1
    )

    rf_time.fit(X_train_t, y_train_t)

    models_time[key] = (rf_time, X_train_t)

    print("Round Time R²:", rf_time.score(X_test_t, y_test_t))



Training RF models for CNN 16k | CIFAR10
ΔValF1 R²: 0.013469143361005131
Round Time R²: 0.07535715419360411

Training RF models for CNN 64k | CIFAR10
ΔValF1 R²: 0.003922409872294308
Round Time R²: 0.030662524121086387

Training RF models for TextMLP | AGNEWS
ΔValF1 R²: 0.005445266230077128
Round Time R²: 0.18502388369869882


In [10]:
for (model_type, dataset), subdf in df_clean.groupby(["Model Type", "Dataset"]):
    print(f"\n{model_type} | {dataset}")
    print(subdf.groupby([
        "client_selector",
        "heterogeneous_data_handler",
        "message_compressor"
    ]).size())


CNN 16k | CIFAR10
client_selector  heterogeneous_data_handler  message_compressor
0                0                           0                     20712
                                             1                      3084
                 1                           0                      8531
1                0                           0                      6007
                 1                           0                      2054
dtype: int64

CNN 64k | CIFAR10
client_selector  heterogeneous_data_handler  message_compressor
0                0                           0                     6369
                                             1                     2173
                 1                           0                     1153
                                             1                      917
1                0                           0                     1654
                                             1                      356
                 1      

In [11]:
interactions = [
    ("client_selector", "heterogeneous_data_handler"),
    ("client_selector", "message_compressor"),
    ("heterogeneous_data_handler", "message_compressor")
]

for key, (rf, X_train) in models_delta.items():

    for pair in interactions:

        if key != "CNN 64k_CIFAR10" and "message_compressor" == pair[1]:
            continue

        fig, ax = plt.subplots(figsize=(6, 4))

        PartialDependenceDisplay.from_estimator(
            rf,
            X_train,
            [pair],
            ax=ax,
            contour_kw={"cmap": "RdYlGn"}  # Red (low) → Green (high)
        )

        plt.title(f"ΔValF1 — {key}")
        plt.tight_layout()

        save_path = os.path.join(
            PLOT_DIR,
            f"interaction_delta_{pair[0]}x{pair[1]}_{key}.png"
        )

        plt.savefig(save_path, dpi=300)
        plt.close()

        print("Saved:", save_path)


Saved: plots/rq4/interaction_delta_client_selectorxheterogeneous_data_handler_CNN 16k_CIFAR10.png
Saved: plots/rq4/interaction_delta_client_selectorxheterogeneous_data_handler_CNN 64k_CIFAR10.png
Saved: plots/rq4/interaction_delta_client_selectorxmessage_compressor_CNN 64k_CIFAR10.png
Saved: plots/rq4/interaction_delta_heterogeneous_data_handlerxmessage_compressor_CNN 64k_CIFAR10.png
Saved: plots/rq4/interaction_delta_client_selectorxheterogeneous_data_handler_TextMLP_AGNEWS.png


In [12]:
for key, (rf, X_train) in models_time.items():

    for pair in interactions:

        if key != "CNN 64k_CIFAR10" and "message_compressor" == pair[1]:
            continue


        fig, ax = plt.subplots(figsize=(6, 4))

        PartialDependenceDisplay.from_estimator(
            rf,
            X_train,
            [pair],
            ax=ax,
            contour_kw={"cmap": "RdYlGn"}  # Red (low) → Green (high)
        )

        plt.title(f"Round Time — {key}")
        plt.tight_layout()

        save_path = os.path.join(
            PLOT_DIR,
            f"interaction_time_{pair[0]}_{pair[1]}_{key}.png"
        )

        plt.savefig(save_path, dpi=300)
        plt.close()

        print("Saved:", save_path)


Saved: plots/rq4/interaction_time_client_selector_heterogeneous_data_handler_CNN 16k_CIFAR10.png
Saved: plots/rq4/interaction_time_client_selector_heterogeneous_data_handler_CNN 64k_CIFAR10.png
Saved: plots/rq4/interaction_time_client_selector_message_compressor_CNN 64k_CIFAR10.png
Saved: plots/rq4/interaction_time_heterogeneous_data_handler_message_compressor_CNN 64k_CIFAR10.png
Saved: plots/rq4/interaction_time_client_selector_heterogeneous_data_handler_TextMLP_AGNEWS.png
