## Federated Analysis

Here we take the weights from each of the models, aggregate them and then use these with each of the studies


In [1]:
import numpy as np

def aggregate_weights_median(weights_dict):
    """
    Aggregate weights layer-wise across all studies using the median.
    """
    studies = list(weights_dict.keys())
    n_layers = len([w for w in weights_dict[studies[0]] if w is not None])

    aggregated_weights = []
    for layer_idx in range(n_layers):
        # Collect all weights for this layer across studies
        layer_weights = [weights_dict[study][layer_idx] for study in studies if weights_dict[study] is not None]

        # Stack and take median across the first axis (studies)
        stacked = np.stack(layer_weights, axis=0)
        median_weight = np.median(stacked, axis=0)
        aggregated_weights.append(median_weight)

    return aggregated_weights


In [2]:
def extract_model_weights_per_study(mse_summaries):
    weights_dict = {}

    for study in mse_summaries.keys():
        model_path = f"models/{study}_model.keras"

        try:
            model = load_model(model_path)
            weights = model.get_weights()
            weights_dict[study] = weights
        except Exception as e:
            print(f"Could not load model for {study}: {e}")
            weights_dict[study] = None

    return weights_dict


In [3]:
from tensorflow.keras.models import load_model
from tensorflow.keras.models import model_from_json
import os

def build_model_from_file(model_path):
    """
    Load a Keras model structure from an existing model file (without weights).
    """
    model = load_model(model_path)
    model_json = model.to_json()
    new_model = model_from_json(model_json)
    return new_model


In [4]:
def build_federated_model(aggregated_weights, reference_model_path):
    """
    Build a model with aggregated weights.
    """
    model = build_model_from_file(reference_model_path)
    model.set_weights(aggregated_weights)
    return model


In [5]:
from sklearn.metrics import mean_squared_error
import pandas as pd

def evaluate_federated_model(model, X_test, y_true, antibody_labels):
    """
    Evaluate reconstructed output from federated model.
    """
    reconstructed = model.predict(X_test)
    mse_per_ab = {}

    for i, ab in enumerate(antibody_labels):
        mse = mean_squared_error(y_true[:, i], reconstructed[:, i])
        mse_per_ab[ab] = mse

    mse_median = np.median(list(mse_per_ab.values()))
    mse_iqr = np.subtract(*np.percentile(list(mse_per_ab.values()), [75, 25]))

    print(f"Federated Model MSE (median): {mse_median:.4f}")
    print(f"IQR: {mse_iqr:.4f}")
    print("Per-Autoantibody MSE:")
    print(pd.Series(mse_per_ab).sort_values())

    return {
        "mse_median": mse_median,
        "mse_iqr": mse_iqr,
        "per_antibody_mse": mse_per_ab,
        "reconstructed": reconstructed
    }


In [6]:
def evaluate_federated_model(federated_model, mse_summaries, study):
    summary = mse_summaries[study]
    
    X_test = summary["X_test"]
    antibodies = summary["antibody_labels"]
    age_groups = summary["Age_Group"]
    sex = summary["Sex"]

    # Reshape for CNN
    X_test_reshaped = X_test.reshape((-1, X_test.shape[1], 1, 1))
    
    # Predict
    reconstructed = federated_model.predict(X_test_reshaped)

    # Flatten CNN output
    reconstructed_flat = reconstructed.reshape(X_test.shape)

    # Overall MSE
    mse = mean_squared_error(X_test, reconstructed_flat)
    
    # Per-antibody MSE
    per_ab_mse = {
        ab: mean_squared_error(X_test[:, i], reconstructed_flat[:, i])
        for i, ab in enumerate(antibodies)
    }

    # Build DataFrames (MUST use flattened arrays)
    df = pd.DataFrame(X_test, columns=antibodies)
    df["Age_Group"] = age_groups
    df["Sex"] = sex

    df_recon = pd.DataFrame(reconstructed_flat, columns=antibodies)
    df_recon["Age_Group"] = age_groups
    df_recon["Sex"] = sex

    # MSE by age group
    mse_by_age = df.groupby("Age_Group").apply(
        lambda g: mean_squared_error(
            g[antibodies], df_recon.loc[g.index, antibodies]
        )
    ).to_dict()

    # MSE by sex
    mse_by_sex = df.groupby("Sex").apply(
        lambda g: mean_squared_error(
            g[antibodies], df_recon.loc[g.index, antibodies]
        )
    ).to_dict()

    return {
        "overall_mse": mse,
        "per_antibody_mse": per_ab_mse,
        "mse_by_age_group": mse_by_age,
        "mse_by_sex": mse_by_sex,
        "N_test": len(X_test),
    }


In [7]:
import os
# ========== STEP 0: Set working directory (for running locally on laptop) =========
os.getcwd()
os.chdir("/Users/adeslatt/Scitechcon Dropbox/Anne DeslattesMays/projects/oadr-autoantibody")
os.getcwd()

'/Users/adeslatt/Scitechcon Dropbox/Anne DeslattesMays/projects/oadr-autoantibody'

In [8]:
import pickle

# Load mse_summaries from file
with open("data/mse_summaries.pkl", "rb") as f:
    mse_summaries = pickle.load(f)

# Confirm structure
print(mse_summaries.keys())


dict_keys(['SDY569', 'SDY1625', 'SDY524', 'SDY797', 'SDY1737'])


In [9]:
# === Run this
study_list = ["SDY569", "SDY1625", "SDY524", "SDY797", "SDY1737"]
weights_dict = extract_model_weights_per_study(mse_summaries)

In [10]:
# 1. Aggregate the weights
aggregated_weights = aggregate_weights_median(weights_dict)

# 2. Build federated model using any of the saved models as architecture reference
reference_model_path = "models/SDY569_model.keras"  # or any other
federated_model = build_federated_model(aggregated_weights, reference_model_path)



In [11]:
comparison_data = []

for study in mse_summaries:
    local_mse = mse_summaries[study]["mse_median"]
    fed_mse = federated_results[study]["overall_mse"]
    delta = local_mse - fed_mse
    pct_improvement = (delta / local_mse) * 100

    comparison_data.append({
        "Study": study,
        "Local MSE": round(local_mse, 4),
        "Federated MSE": round(fed_mse, 4),
        "Δ MSE": round(delta, 4),
        "% Improvement": f"{pct_improvement:.2f}%"
    })

comparison_df = pd.DataFrame(comparison_data)
display(comparison_df)


NameError: name 'federated_results' is not defined

## Retrain Federated Model on Each Study and Evaluate Stratified MSE

In [None]:
from tensorflow.keras.models import clone_model
from tensorflow.keras.optimizers import Adam
import numpy as np

# Store results after retraining
retrain_detailed_results = {}

for study in study_list:
    print(f"\n--- Retraining Federated Model on Study: {study} ---")

    summary = mse_summaries[study]
    
    # Extract X_train and reshape only relevant features
    X_train = summary["X_train"]
    antibodies = summary["antibody_labels"]

    # Number of features used in training (including one-hot encoded Age_Group and Sex)
    n_features_used = federated_model.input_shape[1]

    # Defensive check
    if X_train.shape[1] != n_features_used:
        print(f"  ⚠ Reshaping X_train: expected {n_features_used}, got {X_train.shape[1]}")
        # Try to truncate or slice appropriately
        X_train = X_train[:, :n_features_used]

    # Reshape for CNN input
    X_train_reshaped = X_train.reshape((-1, n_features_used, 1, 1))

    # Clone the federated model and copy weights
    local_model = clone_model(federated_model)
    local_model.set_weights(federated_model.get_weights())
    local_model.compile(optimizer=Adam(), loss='mse')

    # Retrain
    local_model.fit(X_train_reshaped, X_train_reshaped, epochs=10, batch_size=8, verbose=1)

    # Evaluate after retraining
    results_post = evaluate_federated_model(local_model, mse_summaries, study)
    retrain_detailed_results[study] = results_post


In [None]:
def evaluate_federated_model(model, mse_summaries, study):
    summary = mse_summaries[study]
    X_test = summary["X_test"]
    n_features_used = model.input_shape[1]

    if X_test.shape[1] != n_features_used:
        X_test = X_test[:, :n_features_used]

    X_test_reshaped = X_test.reshape((-1, n_features_used, 1, 1))
    preds = model.predict(X_test_reshaped, verbose=0)

    mse = np.mean(np.square(X_test_reshaped - preds))

    return {
        "mse": mse,
        "y_true": X_test_reshaped,
        "y_pred": preds
    }


In [None]:
retrain_detailed_results = {}

for study in study_list:
    summary = mse_summaries[study]
    X_train = summary["X_train"]
    n_features_used = federated_model.input_shape[1]
    
    if X_train.shape[1] != n_features_used:
        X_train = X_train[:, :n_features_used]
    
    X_train_reshaped = X_train.reshape((-1, n_features_used, 1, 1))
    
    # Clone & compile model
    local_model = clone_model(federated_model)
    local_model.set_weights(federated_model.get_weights())
    local_model.compile(optimizer=Adam(), loss='mse')
    
    # Retrain
    local_model.fit(X_train_reshaped, X_train_reshaped, epochs=10, batch_size=8, verbose=0)
    
    # Evaluate and save
    results_post = evaluate_federated_model(local_model, mse_summaries, study)
    retrain_detailed_results[study] = results_post


In [None]:
from pprint import pprint
pprint(retrain_detailed_results[study_list[0]])


In [None]:
pprint(mse_summaries[study_list[0]])


In [None]:
import pandas as pd

rows = []

for study in study_list:
    try:
        fed_mse = mse_summaries[study]["mse_median"]
        retrain_mse = retrain_detailed_results[study]["mse"]

        rows.append({
            "Study": study,
            "Federated MSE (Median)": round(fed_mse, 6),
            "Retrained MSE": round(retrain_mse, 6),
            "Δ MSE": round(retrain_mse - fed_mse, 6),
            "% Change": round(100 * (retrain_mse - fed_mse) / fed_mse, 2)
        })
    except KeyError as e:
        print(f"Missing key {e} for study {study}")

df_summary = pd.DataFrame(rows)
df_summary


In [None]:
import pandas as pd

rows = []

for study in study_list:
    local_mse = mse_summaries[study]["mse_median"]
    federated_mse = federated_results[study]["overall_mse"]
    retrained_mse = retrain_detailed_results[study]["mse"]

    rows.append({
        "Study": study,
        "Local MSE": round(local_mse, 6),
        "Federated MSE": round(federated_mse, 6),
        "Retrained MSE": round(retrained_mse, 6),
        "Δ Fed - Local": round(federated_mse - local_mse, 6),
        "Δ Retrain - Fed": round(retrained_mse - federated_mse, 6),
    })

df_mse_comparison = pd.DataFrame(rows)
df_mse_comparison


In [None]:
# Save final MSE summary DataFrame to CSV
df_mse_comparison.to_csv("mse_comparison_summary.csv", index=False)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Assuming df is already defined and includes the MSE values

# Plot grouped bar chart
df_plot = df_mse_comparison[["Study", "Local MSE", "Federated MSE", "Retrained MSE"]].set_index("Study")
ax = df_plot.plot(kind="bar", figsize=(10, 6), width=0.8)
plt.title("MSE Comparison per Study (Lower is Better)")
plt.ylabel("Mean Squared Error")
plt.xticks(rotation=0)
plt.tight_layout()
plt.grid(axis="y")
plt.legend(title="Model")
plt.show()


In [None]:
mse_summaries[study]["per_antibody_mse"]


In [None]:
pprint (mse_summaries)