In [None]:
# Author: Sebastian Collins (Newcastle University)
# Date: May 2025
# Description: Federated learning model for medical diagnosis with/without differential privacy Simple Model

In [None]:
# Find the problematic package
!ls -la /usr/local/lib/python3.11/dist-packages | grep cipy

# Remove the corrupted package
!rm -rf /usr/local/lib/python3.11/dist-packages/~cipy*

# Uninstall problematic versions
!pip uninstall -y jax jaxlib flax optax orbax-checkpoint ml-dtypes scipy tensorflow tensorflow-federated chex numpy pandas matplotlib scikit-learn imbalanced-learn

# Install packages with correct versions
!pip install tensorflow==2.14.0 tensorflow-federated==0.84.0 tensorflow-addons==0.22.0 tensorflow-privacy==0.9.0 scipy==1.9.3 ml-dtypes==0.2.0
!pip install jax==0.4.27 jaxlib==0.4.27 flax==0.7.2 optax==0.1.7 orbax-checkpoint==0.11.10
!pip install numpy==1.25.2 pandas==2.2.3 matplotlib==3.10.1 scikit-learn==1.6.1 imbalanced-learn==0.13.0

# Install imbalanced-learn
!pip install -U imbalanced-learn

!pip install faker

# Imports
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_addons as tfa
import tensorflow_privacy as tfp

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter

from tensorflow_federated import aggregators
from tensorflow_federated.python.aggregators import clipping_factory
from tensorflow_federated.python.aggregators import zeroing_factory
from tensorflow_privacy.privacy.dp_query import gaussian_query

from imblearn.over_sampling import SMOTE
from imblearn.over_sampling import BorderlineSMOTE

from sklearn.feature_selection import mutual_info_classif

from faker import Faker

import copy

from google.colab import drive

import os
from PIL import Image

drwxr-xr-x 23 root root     4096 May  6 21:41 scipy
drwxr-xr-x  2 root root     4096 May  6 21:41 scipy-1.15.2.dist-info
drwxr-xr-x  2 root root     4096 May  6 21:41 scipy.libs
Found existing installation: jax 0.4.14
Uninstalling jax-0.4.14:
  Successfully uninstalled jax-0.4.14
Found existing installation: jaxlib 0.4.14
Uninstalling jaxlib-0.4.14:
  Successfully uninstalled jaxlib-0.4.14
[0mFound existing installation: ml-dtypes 0.2.0
Uninstalling ml-dtypes-0.2.0:
  Successfully uninstalled ml-dtypes-0.2.0
Found existing installation: scipy 1.15.2
Uninstalling scipy-1.15.2:
  Successfully uninstalled scipy-1.15.2
Found existing installation: tensorflow 2.14.0
Uninstalling tensorflow-2.14.0:
  Successfully uninstalled tensorflow-2.14.0
Found existing installation: tensorflow_federated 0.84.0
Uninstalling tensorflow_federated-0.84.0:
  Successfully uninstalled tensorflow_federated-0.84.0
[0mFound existing installation: numpy 1.25.2
Uninstalling numpy-1.25.2:
  Successfully uninstalle

Collecting jax==0.4.27
  Using cached jax-0.4.27-py3-none-any.whl.metadata (23 kB)
Collecting jaxlib==0.4.27
  Using cached jaxlib-0.4.27-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting flax==0.7.2
  Using cached flax-0.7.2-py3-none-any.whl.metadata (10.0 kB)
Collecting optax==0.1.7
  Using cached optax-0.1.7-py3-none-any.whl.metadata (13 kB)
Collecting orbax-checkpoint==0.11.10
  Using cached orbax_checkpoint-0.11.10-py3-none-any.whl.metadata (2.0 kB)
Collecting chex>=0.1.5 (from optax==0.1.7)
  Using cached chex-0.1.89-py3-none-any.whl.metadata (17 kB)
INFO: pip is looking at multiple versions of orbax-checkpoint to determine which version is compatible with other requirements. This could take a while.
[31mERROR: Cannot install flax==0.7.2, jax==0.4.27, optax==0.1.7 and orbax-checkpoint==0.11.10 because these package versions have conflicting dependencies.[0m[31m
[0m
The conflict is caused by:
    The user requested jax==0.4.27
    flax 0.7.2 depends on jax>=0.4

ERROR:jax._src.xla_bridge:Jax plugin configuration error: Plugin module %s could not be loaded
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/xla_bridge.py", line 428, in discover_pjrt_plugins
    plugin_module = importlib.import_module(plugin_module_name)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_r

In [None]:
# Mount Google Drive (used for dataset storage during model)
drive.mount('/content/drive', force_remount=True)

stroke_data = pd.read_csv('/content/drive/My Drive/your_dataset_folder/your_file.csv', header=0)  # <-- Replace with actual path


Mounted at /content/drive


In [None]:
# Hyper parameters
BATCH_SIZE = 64
NUM_ROUNDS = 200
NUM_CLIENTS = 15
NUM_SELECTED = 6 # The amount of clients to be selected
LEARNING_RATE_CLIENT = 0.001
LEARNING_RATE_SERVER = 0.05
MOMENTUM = 0.95 # Optimizer momentum
MIN_DELTA = 0.0005  # Minimum improvement in AUC to consider as progress
PATIENCE = 8 # Maximum amounts of rounds without improvement
NOISE_MULTIPLIER = 0.1 # DP noise level
L2_NORM_CLIP = 2.0 # DP clipping value

In [None]:
stroke_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4981 entries, 0 to 4980
Data columns (total 11 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   gender             4981 non-null   object 
 1   age                4981 non-null   float64
 2   hypertension       4981 non-null   int64  
 3   heart_disease      4981 non-null   int64  
 4   ever_married       4981 non-null   object 
 5   work_type          4981 non-null   object 
 6   Residence_type     4981 non-null   object 
 7   avg_glucose_level  4981 non-null   float64
 8   bmi                4981 non-null   float64
 9   smoking_status     4981 non-null   object 
 10  stroke             4981 non-null   int64  
dtypes: float64(3), int64(3), object(5)
memory usage: 428.2+ KB


In [None]:
# Prepare data for federated learning
# Copy original
stroke_data_cleaned = stroke_data.copy()

# Data Masking
# Define mapping dicts
gender_map = {'female': 0, 'male': 1}
ever_married_map = {'no': 0, 'yes': 1}
work_type_map = {'private': 0, 'self-employed': 1, 'govt_job': 2, 'children': 3, 'never_worked': 4}
res_type_map = {'urban': 0, 'rural': 1}
smoking_map = {'formerly smoked': 0, 'never smoked': 1, 'smokes': 2, 'unknown': 3}

# Normalise and map each categorical column
# By mapping the data not only is this better for federated learning but also obscures the original values
for col, m in [
    ('gender', gender_map),
    ('ever_married', ever_married_map),
    ('work_type', work_type_map),
    ('Residence_type', res_type_map),
    ('smoking_status', smoking_map),
]:
    stroke_data_cleaned[col] = (
        stroke_data_cleaned[col]
        .astype(str)
        .str.lower()
        .map(m)
        .fillna(5) # Make null data equal 5
        .astype('Int64')
    )

# Assign correct data type
int_cols = ['stroke', 'hypertension', 'heart_disease','gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']
float_cols = ['avg_glucose_level', 'bmi', 'age']
stroke_data_cleaned[int_cols] = stroke_data_cleaned[int_cols].astype('int64')
stroke_data_cleaned[float_cols] = stroke_data_cleaned[float_cols].astype('float32')

# Create Interaction terms
stroke_data_cleaned['bmi_age_interaction'] = stroke_data_cleaned['bmi'] * stroke_data_cleaned['age']
stroke_data_cleaned['avg_glucose_age_interaction'] = stroke_data_cleaned['avg_glucose_level'] * stroke_data_cleaned['age']

In [None]:
X = stroke_data_cleaned.drop(columns=['stroke'])
y = stroke_data_cleaned['stroke']
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Creating DataFrame with correct column names
stroke_data_final = pd.DataFrame(X_scaled, columns=X.columns)

# Assign 'stroke' values to the DataFrame
stroke_data_final['stroke'] = y.values # Reset index to ensure alignment with 'y'

#Confirm data has been assigned correctly
stroke_data_final.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4981 entries, 0 to 4980
Data columns (total 13 columns):
 #   Column                       Non-Null Count  Dtype  
---  ------                       --------------  -----  
 0   gender                       4981 non-null   float64
 1   age                          4981 non-null   float64
 2   hypertension                 4981 non-null   float64
 3   heart_disease                4981 non-null   float64
 4   ever_married                 4981 non-null   float64
 5   work_type                    4981 non-null   float64
 6   Residence_type               4981 non-null   float64
 7   avg_glucose_level            4981 non-null   float64
 8   bmi                          4981 non-null   float64
 9   smoking_status               4981 non-null   float64
 10  bmi_age_interaction          4981 non-null   float64
 11  avg_glucose_age_interaction  4981 non-null   float64
 12  stroke                       4981 non-null   int64  
dtypes: float64(12), in

In [None]:
# Synthetic data
# Classes are imbalanced
stroke_counts = stroke_data_final['stroke'].value_counts()
print(stroke_counts)

# Use Smote to Balance them
smote = SMOTE(random_state=42)
x = stroke_data_final.drop(columns=['stroke'])
y = stroke_data_final['stroke']
x_res, y_res = smote.fit_resample(x, y)

# This dilutes the amount of real data being used in the model

stroke_data_balanced = pd.DataFrame(x_res, columns=x.columns)
stroke_data_balanced['stroke'] = y_res

# Check the new class distribution
print(stroke_data_balanced['stroke'].value_counts())

stroke
0    4733
1     248
Name: count, dtype: int64
stroke
1    4733
0    4733
Name: count, dtype: int64


In [None]:
# Split data
train_data, test_data = train_test_split(stroke_data_balanced, test_size=0.2, random_state=42)

print(f"Train data size: {len(train_data)}")
print(f"Test data size: {len(test_data)}")

Train data size: 7572
Test data size: 1894


In [None]:
def create_federated_data(df, num_clients):

    # Shuffle the data randomly
    df = df.sample(frac=1).reset_index(drop=True)

    index = np.arange(len(df))

    # Calculate the size of each client's dataset
    client_size = len(df) // num_clients

    # Create a list to store the client data
    client_dfs = []

    # Distribute data into each client
    for i in range(num_clients):
        start = i * client_size
        end = (i + 1) * client_size if i < num_clients - 1 else len(df)
        shard_idx = index[start:end]
        client_data = df.iloc[shard_idx].reset_index(drop=True)

        # Convert to tf.data.Dataset
        client_dataset = tf.data.Dataset.from_tensor_slices(
            (
                {'x': client_data.drop(columns=['stroke']).values.astype(np.float32)}, # Features
                client_data['stroke'].values.astype(np.float32) # Labels
            )
        ).batch(BATCH_SIZE)  # Apply batching

        client_dfs.append(client_dataset)

    return client_dfs

federated_data = create_federated_data(train_data, NUM_CLIENTS)

for i, client_df in enumerate(federated_data):
    print(f"Client {i + 1} amount of batches: {len(client_df)}")

Client 1 amount of batches: 8
Client 2 amount of batches: 8
Client 3 amount of batches: 8
Client 4 amount of batches: 8
Client 5 amount of batches: 8
Client 6 amount of batches: 8
Client 7 amount of batches: 8
Client 8 amount of batches: 8
Client 9 amount of batches: 8
Client 10 amount of batches: 8
Client 11 amount of batches: 8
Client 12 amount of batches: 8
Client 13 amount of batches: 8
Client 14 amount of batches: 8
Client 15 amount of batches: 9


In [None]:
def create_keras_model():
    initializer = tf.keras.initializers.HeNormal(seed = 42)
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(12,), name='x'),  # You have 10 features
        tf.keras.layers.Dense(128, activation='swish', kernel_initializer=initializer),
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.Dense(64, activation='swish'),
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.Dense(32, activation='swish'),
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.Dropout(0.4), # 40% of random training units will be dropped after each step
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    return model

In [None]:
def model_fn():
    keras_model = create_keras_model()

    return tff.learning.models.from_keras_model(
        keras_model, # Uses earlier created keras model
        input_spec=(
            {'x': tf.TensorSpec(shape=(None, 12), dtype=tf.float32)},  # Features
            tf.TensorSpec(shape=(None,), dtype=tf.float32)             # Labels
        ),
        loss=tf.keras.losses.BinaryCrossentropy(), # Cross-entropy loss function for binary classification
        metrics=[
            tf.keras.metrics.BinaryAccuracy(),
            tf.keras.metrics.Precision(name='precision'), # Evaluate the proportion of true positives out of predicted positives
            tf.keras.metrics.Recall(name='recall'), # Proportion of true positives out of actual positives
            tf.keras.metrics.AUC(name='auc') # AUC metric to evaluate the area under the ROC curve for classification performance
        ]
    )

In [None]:
def select_clients(federated_data, num_selected):
    client_scores = []

    for i, dataset in enumerate(federated_data):
        labels = []

        for element in dataset:
            # Check if element is of the expected format
            if isinstance(element, tuple) and len(element) == 2:
                features, label_batch = element
                labels.extend(label_batch.numpy())
            else:
                # Handle cases where the element is not a tuple
                print(f"Warning: Unexpected data format in client {i}: {element}")

        labels = np.array(labels)
        # Count the number of occurrences of each class (class balance)
        class_counts = np.bincount(labels.astype(int))
        if len(class_counts) < 2:
            balance_score = 0
        else:
          # Calculate the balance score as the ratio of the smaller class count to the larger one
            balance_score = min(class_counts) / max(class_counts)

        client_scores.append((i, balance_score))

    # Sort clients based on their balance scores in descending order (higher balance first)
    client_scores.sort(key=lambda x: x[1], reverse=True)
    selected_indices = [x[0] for x in client_scores[:num_selected]] # select 'num_selected' amount of classes
    return [federated_data[i] for i in selected_indices]

In [None]:
# Differential Privacy

# This process provides formal differential privacy guarantees
# The presence or absence of any individual client barely changes the outcome of training, making it provably private
dp_agg = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
    noise_multiplier=NOISE_MULTIPLIER,
    clients_per_round=NUM_SELECTED, # Number of clients participating in each training round
    clip=L2_NORM_CLIP
)

weighted_dp_agg = tff.aggregators.as_weighted_aggregator(dp_agg)

In [None]:
iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_CLIENT),
    # Adam is used for the client optimiser as it adapts learning rates based on gradient moments, ideal for heterogeneous client data.
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=LEARNING_RATE_SERVER, momentum=MOMENTUM),
    # SGD with momentum is used for the server optimiser to stabilize and accelerate parameter updates during model aggregation.
    model_aggregator=None, # If using the differential privacy aggregator
)

state = iterative_process.initialize()

In [None]:
metrics_history = []
best_auc = 0.0
wait = 0

for round_num in range(1, NUM_ROUNDS + 1):
    selected_clients = select_clients(federated_data, NUM_SELECTED)

    state, metrics = iterative_process.next(state, selected_clients)

    # Extract metrics
    train_metrics = metrics['client_work']['train']
    loss = train_metrics['loss']
    acc = train_metrics['binary_accuracy']
    prec = train_metrics['precision']
    rec = train_metrics['recall']
    auc = train_metrics['auc']

        # Save metrics to list
    metrics_history.append({
        'round': round_num,
        'loss': loss,
        'accuracy': acc,
        'precision': prec,
        'recall': rec,
        'auc': auc
    })

    # Formatted output
    print(f"\nRound {round_num}")
    print("-------------------------------")
    print(f"Train Metrics:")
    print(f"Loss:     {loss:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision:{prec:.4f}")
    print(f"Recall:   {rec:.4f}")
    print(f"AUC:      {auc:.4f}")

    # Early stopping condition based on AUC
    if auc > best_auc + MIN_DELTA:
        best_auc = auc
        wait = 0
    else:
        wait += 1
        print(f"Early stopping counter: {wait}/{PATIENCE}")
        if wait >= PATIENCE:
            print("Early stopping triggered.")
            break


Round 1
-------------------------------
Train Metrics:
Loss:     0.7275
Accuracy: 0.6453
Precision:0.6344
Recall:   0.6955
AUC:      0.6954

Round 2
-------------------------------
Train Metrics:
Loss:     0.7140
Accuracy: 0.6413
Precision:0.6319
Recall:   0.6870
AUC:      0.7018

Round 3
-------------------------------
Train Metrics:
Loss:     0.6970
Accuracy: 0.6571
Precision:0.6443
Recall:   0.7105
AUC:      0.7167

Round 4
-------------------------------
Train Metrics:
Loss:     0.6598
Accuracy: 0.6785
Precision:0.6612
Recall:   0.7400
AUC:      0.7385

Round 5
-------------------------------
Train Metrics:
Loss:     0.6291
Accuracy: 0.7029
Precision:0.6833
Recall:   0.7629
AUC:      0.7625

Round 6
-------------------------------
Train Metrics:
Loss:     0.5906
Accuracy: 0.7148
Precision:0.6879
Recall:   0.7924
AUC:      0.7852

Round 7
-------------------------------
Train Metrics:
Loss:     0.5779
Accuracy: 0.7260
Precision:0.7038
Recall:   0.7859
AUC:      0.7942

Round 8
----

In [None]:
# Directory to save data
directory = '/content/drive/My Drive/your_metric_folder/' # <-- Replace with actual path
if not os.path.exists(directory):
    os.makedirs(directory)

# Save metrics to CSV
metrics_df = pd.DataFrame(metrics_history)
metrics_df.to_csv(os.path.join(directory, 'metrics_no_dp(2).csv'), index=False) # Change depending on use of differential privacy


In [None]:
# Extract final model weights
final_model_weights = iterative_process.get_model_weights(state)

# Create and compile a fresh Keras model
keras_eval_model = create_keras_model()

# Compile it for evaluation
keras_eval_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.SGD(),
        metrics=[
            tf.keras.metrics.BinaryAccuracy(),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall'),
            tf.keras.metrics.AUC(name='auc')
        ]
)

# Load in the federated-trained weights
final_model_weights.assign_weights_to(keras_eval_model)

# Apply preprocessing and create single dataset
test_dataset = create_federated_data(test_data, 1)[0]

# Evaluate on the test data
print("-"*40)
results = keras_eval_model.evaluate(test_dataset)
print(f"Loss:          {results[0]:.4f}")
print(f"Accuracy:      {results[1]:.4f}")
print(f"Precision:     {results[2]:.4f}")
print(f"Recall:        {results[3]:.4f}")
print(f"AUC:           {results[4]:.4f}")

# Create data frame to store test results
results_df = pd.DataFrame({
    'metric': ['Loss', 'Accuracy', 'Precision', 'Recall', 'AUC'],
    'value': results
})

# Store test Results in csv
results_df.to_csv(os.path.join(directory, 'test_results_no_dp(2).csv'), index=False)  # Change depending on use of differential privacy

----------------------------------------
Loss:          0.3389
Accuracy:      0.8596
Precision:     0.8128
Recall:        0.9346
AUC:           0.9233
