In [None]:
# federated_analysis_simulation_3x3.py

import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.models import load_model, clone_model
from sklearn.metrics import mean_squared_error

# --- Set up working directory (adjust for your env if needed) ---
os.chdir("/Users/adeslatt/Scitechcon Dropbox/Anne DeslattesMays/projects/oadr-autoantibody")

# --- Load 3x3 MSE summaries ---
with open("data/mse_summaries_3x3.pkl", "rb") as f:
    mse_summaries_3x3 = pickle.load(f)

study_list = list(mse_summaries_3x3.keys())
study_list

In [None]:
# --- Extract local model weights ---
def extract_model_weights_per_study(mse_summaries_3x3):
    weights_dict = {}
    for study in mse_summaries_3x3:
        model_path = f"models/{study}_3x3_model.keras"
        try:
            model = load_model(model_path)
            weights_dict[study] = model.get_weights()
        except Exception as e:
            print(f"Could not load model for {study}: {e}")
            weights_dict[study] = None
    return weights_dict

# --- Median weight aggregation ---
def aggregate_weights_median(weights_dict):
    studies = list(weights_dict.keys())
    n_layers = len([w for w in weights_dict[studies[0]] if w is not None])
    aggregated_weights = []
    for i in range(n_layers):
        layer_weights = [weights_dict[s][i] for s in studies if weights_dict[s] is not None]
        stacked = np.stack(layer_weights, axis=0)
        aggregated_weights.append(np.median(stacked, axis=0))
    return aggregated_weights

# --- Build new federated model from reference ---
from tensorflow.keras.models import model_from_json

def build_model_from_file(model_path):
    model = load_model(model_path)
    model_json = model.to_json()
    return model_from_json(model_json)

def build_federated_model(aggregated_weights, reference_model_path):
    model = build_model_from_file(reference_model_path)
    model.set_weights(aggregated_weights)
    return model

In [None]:
from sklearn.metrics import mean_squared_error
import numpy as np
import pandas as pd

def evaluate_federated_model_3x3(
    federated_model,
    mse_summaries_3x3,
    study
):
    summary = mse_summaries_3x3[study]

    X_test = summary["X_test"]                 # shape: (n, 9) OR (n, >9)
    antibodies = summary["antibody_labels"]    # 5 antibodies
    age_groups = summary["Age_Group"]
    sex = summary["Sex"]

    n_ab = len(antibodies)  # should be 5

    # ---- Ensure exactly 9 features (5 Ab + 4 demo) ----
    if X_test.shape[1] == 9:
        X_full9 = X_test
    elif X_test.shape[1] > 9:
        X_full9 = X_test[:, :9]
    else:
        raise ValueError(f"{study}: Expected ≥9 features, got {X_test.shape[1]}")

    # ---- Reshape to 3×3 CNN input ----
    X_test_img = X_full9.reshape(-1, 3, 3, 1)

    # ---- Predict ----
    reconstructed = federated_model.predict(X_test_img, verbose=0)
    reconstructed_flat = reconstructed.reshape(X_full9.shape)

    # ---- Overall MSE (all 9 dims) ----
    mse = mean_squared_error(X_full9, reconstructed_flat)

    # ---- Per‑antibody MSE (first 5 dims only) ----
    per_antibody_mse = {
        ab: mean_squared_error(
            X_full9[:, i],
            reconstructed_flat[:, i]
        )
        for i, ab in enumerate(antibodies)
    }

    # ---- Median + IQR (per‑antibody) ----
    mse_median = np.median(list(per_antibody_mse.values()))
    mse_iqr = np.percentile(list(per_antibody_mse.values()), 75) - \
              np.percentile(list(per_antibody_mse.values()), 25)

    # ---- DataFrames for stratified MSE ----
    df_true = pd.DataFrame(X_full9[:, :n_ab], columns=antibodies)
    df_true["Age_Group"] = age_groups
    df_true["Sex"] = sex

    df_pred = pd.DataFrame(reconstructed_flat[:, :n_ab], columns=antibodies)
    df_pred["Age_Group"] = age_groups
    df_pred["Sex"] = sex

    # ---- MSE by age group ----
    mse_by_age = (
        df_true.groupby("Age_Group")
        .apply(lambda g: mean_squared_error(
            g[antibodies],
            df_pred.loc[g.index, antibodies]
        ))
        .to_dict()
    )

    # ---- MSE by sex ----
    mse_by_sex = (
        df_true.groupby("Sex")
        .apply(lambda g: mean_squared_error(
            g[antibodies],
            df_pred.loc[g.index, antibodies]
        ))
        .to_dict()
    )

    return {
        "mse": mse,
        "mse_median": mse_median,
        "mse_iqr": mse_iqr,
        "per_antibody_mse": per_antibody_mse,
        "mse_by_age_group": mse_by_age,
        "mse_by_sex": mse_by_sex,
        "N_test": X_full9.shape[0],
        "y_true": X_full9,
        "y_pred": reconstructed_flat,
    }


In [None]:
 # === Build and evaluate federated model ===
weights_dict = extract_model_weights_per_study(mse_summaries_3x3)
agg_weights = aggregate_weights_median(weights_dict)
ref_path = "models/SDY569_3x3_model.keras"
federated_model = build_federated_model(agg_weights, ref_path)
federated_model.save("models/federated_3x3_model.keras")
federated_model.save_weights("models/federated_3x3.weights.h5")

federated_results = {
    study: evaluate_federated_model_3x3(federated_model, mse_summaries_3x3, study)
    for study in study_list
}

In [None]:
# === Retraining from federated ===
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import clone_model

retrain_detailed_results = {}

for study in study_list:
    s = mse_summaries_3x3[study]
    X_train = s["X_train"][:, :9]  # Use only 9 features
    X_train_img = X_train.reshape(-1, 3, 3, 1)
    model = clone_model(federated_model)
    model.set_weights(federated_model.get_weights())
    model.compile(optimizer=Adam(), loss="mse")
    model.fit(X_train_img, X_train_img, epochs=10, batch_size=8, verbose=1)
    retrain_detailed_results[study] = evaluate_federated_model_3x3(model, mse_summaries_3x3, study)

# === Comparison summary ===
rows = []
for study in study_list:
    local = mse_summaries_3x3[study]["mse_median"]
    fed = federated_results[study]["mse_median"]
    retr = retrain_detailed_results[study]["mse_median"]
    rows.append({
        "Study": study,
        "Local MSE": local,
        "Federated MSE": fed,
        "Retrained MSE": retr,
        "Δ Fed - Local": fed - local,
        "Δ Retrain - Local": retr - local,
        "% Improvement (Fed vs Local)": 100 * (local - fed) / local,
        "% Improvement (Retrain vs Local)": 100 * (local - retr) / local
    })
df_mse_comparison = pd.DataFrame(rows)
df_mse_comparison.to_csv("data/mse_comparison_summary_3x3.csv", index=False)

# === Save per-antibody MSE ===
rows = []
for study in study_list:
    for ab in mse_summaries_3x3[study]["per_antibody_mse"]:
        local = mse_summaries_3x3[study]["per_antibody_mse"][ab]
        fed = federated_results[study]["per_antibody_mse"].get(ab, np.nan)
        retr = retrain_detailed_results[study]["per_antibody_mse"].get(ab, np.nan)
        rows.append({
            "Study": study,
            "Antibody": ab,
            "Local MSE (median)": local,
            "Federated MSE (median)": fed,
            "Retrained MSE (median)": retr,
            "Δ Fed − Local": fed - local,
            "Δ Retrain − Local": retr - local,
            "% Improvement Fed vs Local": ((local - fed) / local * 100) if local > 0 else np.nan,
            "% Improvement Retrain vs Local": ((local - retr) / local * 100) if local > 0 else np.nan
        })
df_ab = pd.DataFrame(rows)
df_ab.to_csv("data/per_antibody_mse_median_3x3.csv", index=False)