## Federated Analysis

Here we take the weights from each of the models, aggregate them and then use these with each of the studies


In [1]:
import numpy as np

def aggregate_weights_median(weights_dict):
    """
    Aggregate weights layer-wise across all studies using the median.
    """
    studies = list(weights_dict.keys())
    n_layers = len([w for w in weights_dict[studies[0]] if w is not None])

    aggregated_weights = []
    for layer_idx in range(n_layers):
        # Collect all weights for this layer across studies
        layer_weights = [weights_dict[study][layer_idx] for study in studies if weights_dict[study] is not None]

        # Stack and take median across the first axis (studies)
        stacked = np.stack(layer_weights, axis=0)
        median_weight = np.median(stacked, axis=0)
        aggregated_weights.append(median_weight)

    return aggregated_weights


In [2]:
def extract_model_weights_per_study(mse_summaries):
    weights_dict = {}

    for study in mse_summaries.keys():
        model_path = f"models/{study}_model.keras"

        try:
            model = load_model(model_path)
            weights = model.get_weights()
            weights_dict[study] = weights
        except Exception as e:
            print(f"Could not load model for {study}: {e}")
            weights_dict[study] = None

    return weights_dict


In [3]:
from tensorflow.keras.models import load_model
from tensorflow.keras.models import model_from_json
import os

def build_model_from_file(model_path):
    """
    Load a Keras model structure from an existing model file (without weights).
    """
    model = load_model(model_path)
    model_json = model.to_json()
    new_model = model_from_json(model_json)
    return new_model


In [4]:
def build_federated_model(aggregated_weights, reference_model_path):
    """
    Build a model with aggregated weights.
    """
    model = build_model_from_file(reference_model_path)
    model.set_weights(aggregated_weights)
    return model


In [5]:
from sklearn.metrics import mean_squared_error
import pandas as pd

def evaluate_federated_model(model, X_test, y_true, antibody_labels):
    """
    Evaluate reconstructed output from federated model.
    """
    reconstructed = model.predict(X_test)
    mse_per_ab = {}

    for i, ab in enumerate(antibody_labels):
        mse = mean_squared_error(y_true[:, i], reconstructed[:, i])
        mse_per_ab[ab] = mse

    mse_median = np.median(list(mse_per_ab.values()))
    mse_iqr = np.subtract(*np.percentile(list(mse_per_ab.values()), [75, 25]))

    print(f"Federated Model MSE (median): {mse_median:.4f}")
    print(f"IQR: {mse_iqr:.4f}")
    print("Per-Autoantibody MSE:")
    print(pd.Series(mse_per_ab).sort_values())

    return {
        "mse_median": mse_median,
        "mse_iqr": mse_iqr,
        "per_antibody_mse": mse_per_ab,
        "reconstructed": reconstructed
    }


In [6]:
def evaluate_federated_model(federated_model, mse_summaries, study):
    summary = mse_summaries[study]
    
    X_test = summary["X_test"]
    antibodies = summary["antibody_labels"]
    age_groups = summary["Age_Group"]
    sex = summary["Sex"]

    # Reshape for CNN
    X_test_reshaped = X_test.reshape((-1, X_test.shape[1], 1, 1))
    
    # Predict
    reconstructed = federated_model.predict(X_test_reshaped)

    # Flatten CNN output
    reconstructed_flat = reconstructed.reshape(X_test.shape)

    # Overall MSE
    mse = mean_squared_error(X_test, reconstructed_flat)
    
    # Per-antibody MSE
    per_ab_mse = {
        ab: mean_squared_error(X_test[:, i], reconstructed_flat[:, i])
        for i, ab in enumerate(antibodies)
    }

    # Build DataFrames (MUST use flattened arrays)
    df = pd.DataFrame(X_test, columns=antibodies)
    df["Age_Group"] = age_groups
    df["Sex"] = sex

    df_recon = pd.DataFrame(reconstructed_flat, columns=antibodies)
    df_recon["Age_Group"] = age_groups
    df_recon["Sex"] = sex

    # MSE by age group
    mse_by_age = df.groupby("Age_Group").apply(
        lambda g: mean_squared_error(
            g[antibodies], df_recon.loc[g.index, antibodies]
        )
    ).to_dict()

    # MSE by sex
    mse_by_sex = df.groupby("Sex").apply(
        lambda g: mean_squared_error(
            g[antibodies], df_recon.loc[g.index, antibodies]
        )
    ).to_dict()

    return {
        "overall_mse": mse,
        "per_antibody_mse": per_ab_mse,
        "mse_by_age_group": mse_by_age,
        "mse_by_sex": mse_by_sex,
        "N_test": len(X_test),
    }


In [7]:
import os
# ========== STEP 0: Set working directory (for running locally on laptop) =========
os.getcwd()
os.chdir("/Users/adeslatt/Scitechcon Dropbox/Anne DeslattesMays/projects/oadr-autoantibody")
os.getcwd()

'/Users/adeslatt/Scitechcon Dropbox/Anne DeslattesMays/projects/oadr-autoantibody'

In [8]:
import pickle

# Load mse_summaries from file
with open("data/mse_summaries.pkl", "rb") as f:
    mse_summaries = pickle.load(f)

# Confirm structure
print(mse_summaries.keys())


dict_keys(['SDY569', 'SDY1625', 'SDY524', 'SDY797', 'SDY1737'])


In [9]:
# === Run this
study_list = ["SDY569", "SDY1625", "SDY524", "SDY797", "SDY1737"]
weights_dict = extract_model_weights_per_study(mse_summaries)

In [10]:
# 1. Aggregate the weights
aggregated_weights = aggregate_weights_median(weights_dict)

# 2. Build federated model using any of the saved models as architecture reference
reference_model_path = "models/SDY569_model.keras"  # or any other
federated_model = build_federated_model(aggregated_weights, reference_model_path)



In [11]:
federated_results = {}

for study in mse_summaries:
    print(f"Evaluating federated model on {study}...")
    federated_results[study] = evaluate_federated_model(federated_model, mse_summaries, study)


Evaluating federated model on SDY569...
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
Evaluating federated model on SDY1625...
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
Evaluating federated model on SDY524...
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step
Evaluating federated model on SDY797...
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step
Evaluating federated model on SDY1737...
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step


  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(
  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(
  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(
  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(
  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(


In [12]:
comparison_data = []

for study in mse_summaries:
    local_mse = mse_summaries[study]["mse_median"]
    fed_mse = federated_results[study]["overall_mse"]
    delta = local_mse - fed_mse
    pct_improvement = (delta / local_mse) * 100

    comparison_data.append({
        "Study": study,
        "Local MSE": round(local_mse, 4),
        "Federated MSE": round(fed_mse, 4),
        "Δ MSE": round(delta, 4),
        "% Improvement": f"{pct_improvement:.2f}%"
    })

comparison_df = pd.DataFrame(comparison_data)
display(comparison_df)


Unnamed: 0,Study,Local MSE,Federated MSE,Δ MSE,% Improvement
0,SDY569,0.0009,0.0476,-0.0467,-5139.68%
1,SDY1625,0.001,0.0327,-0.0317,-3188.99%
2,SDY524,0.0427,0.1015,-0.0588,-137.68%
3,SDY797,0.126,0.3485,-0.2225,-176.61%
4,SDY1737,0.0404,0.036,0.0044,10.94%


## Retrain Federated Model on Each Study and Evaluate Stratified MSE

In [13]:
from tensorflow.keras.models import clone_model
from tensorflow.keras.optimizers import Adam
import numpy as np

# Store results after retraining
retrain_detailed_results = {}

for study in study_list:
    print(f"\n--- Retraining Federated Model on Study: {study} ---")

    summary = mse_summaries[study]
    
    # Extract X_train and reshape only relevant features
    X_train = summary["X_train"]
    antibodies = summary["antibody_labels"]

    # Number of features used in training (including one-hot encoded Age_Group and Sex)
    n_features_used = federated_model.input_shape[1]

    # Defensive check
    if X_train.shape[1] != n_features_used:
        print(f"  ⚠ Reshaping X_train: expected {n_features_used}, got {X_train.shape[1]}")
        # Try to truncate or slice appropriately
        X_train = X_train[:, :n_features_used]

    # Reshape for CNN input
    X_train_reshaped = X_train.reshape((-1, n_features_used, 1, 1))

    # Clone the federated model and copy weights
    local_model = clone_model(federated_model)
    local_model.set_weights(federated_model.get_weights())
    local_model.compile(optimizer=Adam(), loss='mse')

    # Retrain
    local_model.fit(X_train_reshaped, X_train_reshaped, epochs=10, batch_size=8, verbose=1)

    # Evaluate after retraining
    results_post = evaluate_federated_model(local_model, mse_summaries, study)
    retrain_detailed_results[study] = results_post



--- Retraining Federated Model on Study: SDY569 ---
  ⚠ Reshaping X_train: expected 7, got 12
Epoch 1/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 330ms/step - loss: 0.0752
Epoch 2/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - loss: 0.0746
Epoch 3/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.0740
Epoch 4/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.0734
Epoch 5/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.0728
Epoch 6/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.0722
Epoch 7/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.0717
Epoch 8/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.0711
Epoch 9/10
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 0.0705
Epoch 10/10
[1m1/1[0m [32m━

  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0262  
Epoch 2/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0258
Epoch 3/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.0255
Epoch 4/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.0253
Epoch 5/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.0251
Epoch 6/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.0249
Epoch 7/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0247
Epoch 8/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.0246
Epoch 9/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.0245
Epoch 10/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0244
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━

  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.0758  
Epoch 2/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.0694
Epoch 3/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.0633
Epoch 4/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.0571
Epoch 5/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.0520
Epoch 6/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.0474
Epoch 7/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.0437
Epoch 8/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.0400
Epoch 9/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.0376
Epoch 10/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0356
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━

  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(


[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.3907  
Epoch 2/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.3755
Epoch 3/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.3597
Epoch 4/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.3431
Epoch 5/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.3269
Epoch 6/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.3093
Epoch 7/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.2918
Epoch 8/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.2750
Epoch 9/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.2567
Epoch 10/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.2387
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━

  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0864  
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0851
Epoch 3/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0836
Epoch 4/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0826
Epoch 5/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0812
Epoch 6/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0802
Epoch 7/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.0792
Epoch 8/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0779
Epoch 9/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0769
Epoch 10/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0759
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━

  mse_by_age = df.groupby("Age_Group").apply(
  mse_by_sex = df.groupby("Sex").apply(
