In [None]:
# Author: Sebastian Collins (Newcastle University)
# Date: May 2025
# Description: Federated learning model for medical diagnosis with/without differential privacy Complex Model

In [None]:
# Remove potential problematic packages
!ls -la /usr/local/lib/python3.11/dist-packages | grep cipy

# Remove the corrupted packages
!rm -rf /usr/local/lib/python3.11/dist-packages/~cipy*

# Uninstall problematic versions
!pip uninstall -y jax jaxlib flax optax orbax-checkpoint ml-dtypes scipy tensorflow tensorflow-federated chex numpy pandas matplotlib scikit-learn imbalanced-learn

# Install packages with correct versions
!pip install tensorflow==2.14.0 tensorflow-federated==0.84.0 tensorflow-addons==0.22.0 tensorflow-privacy==0.9.0 scipy==1.9.3 ml-dtypes==0.2.0
!pip install jax==0.4.27 jaxlib==0.4.27 flax==0.7.2 optax==0.1.7 orbax-checkpoint==0.11.10
!pip install numpy==1.25.2 pandas==2.2.3 matplotlib==3.10.1 scikit-learn==1.6.1 imbalanced-learn==0.13.0

!pip install -U imbalanced-learn

!pip install faker

# Imports
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_addons as tfa
import tensorflow_privacy as tfp

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter

from tensorflow_federated import aggregators
from tensorflow_federated.python.aggregators import clipping_factory
from tensorflow_federated.python.aggregators import zeroing_factory
from tensorflow_privacy.privacy.dp_query import gaussian_query

from imblearn.over_sampling import SMOTE
from imblearn.over_sampling import BorderlineSMOTE

from faker import Faker

from google.colab import drive

import os
from PIL import Image

drwxr-xr-x 21 root root     4096 Apr 24 23:02 scipy
drwxr-xr-x  2 root root     4096 Apr 24 23:02 scipy-1.9.3.dist-info
drwxr-xr-x  2 root root     4096 Apr 24 23:02 scipy.libs
Found existing installation: jax 0.4.14
Uninstalling jax-0.4.14:
  Successfully uninstalled jax-0.4.14
Found existing installation: jaxlib 0.4.14
Uninstalling jaxlib-0.4.14:
  Successfully uninstalled jaxlib-0.4.14
[0mFound existing installation: ml-dtypes 0.2.0
Uninstalling ml-dtypes-0.2.0:
  Successfully uninstalled ml-dtypes-0.2.0
Found existing installation: scipy 1.9.3
Uninstalling scipy-1.9.3:
  Successfully uninstalled scipy-1.9.3
Found existing installation: tensorflow 2.14.0
Uninstalling tensorflow-2.14.0:
  Successfully uninstalled tensorflow-2.14.0
Found existing installation: tensorflow_federated 0.84.0
Uninstalling tensorflow_federated-0.84.0:
  Successfully uninstalled tensorflow_federated-0.84.0
[0mFound existing installation: numpy 1.25.2
Uninstalling numpy-1.25.2:
  Successfully uninstalled nu

Collecting jax==0.4.27
  Using cached jax-0.4.27-py3-none-any.whl.metadata (23 kB)
Collecting jaxlib==0.4.27
  Using cached jaxlib-0.4.27-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting flax==0.7.2
  Using cached flax-0.7.2-py3-none-any.whl.metadata (10.0 kB)
Collecting optax==0.1.7
  Using cached optax-0.1.7-py3-none-any.whl.metadata (13 kB)
Collecting orbax-checkpoint==0.11.10
  Using cached orbax_checkpoint-0.11.10-py3-none-any.whl.metadata (2.0 kB)
Collecting chex>=0.1.5 (from optax==0.1.7)
  Using cached chex-0.1.89-py3-none-any.whl.metadata (17 kB)
INFO: pip is looking at multiple versions of orbax-checkpoint to determine which version is compatible with other requirements. This could take a while.
[31mERROR: Cannot install flax==0.7.2, jax==0.4.27, optax==0.1.7 and orbax-checkpoint==0.11.10 because these package versions have conflicting dependencies.[0m[31m
[0m
The conflict is caused by:
    The user requested jax==0.4.27
    flax 0.7.2 depends on jax>=0.4

ERROR:jax._src.xla_bridge:Jax plugin configuration error: Plugin module %s could not be loaded
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/xla_bridge.py", line 428, in discover_pjrt_plugins
    plugin_module = importlib.import_module(plugin_module_name)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1147, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 690, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_r

In [None]:
# Mount Google Drive (used for dataset storage during model)
drive.mount('/content/drive', force_remount=True)

# Store csv in dataframe
x_ray_info = pd.read_csv('/content/drive/My Drive/your_dataset_folder/your_file.csv', header=0)  # <-- Replace with actual path

Mounted at /content/drive


In [None]:
IMAGE_SIZE       = (224, 224)
BATCH_SIZE       = 16           # Smaller batch size for CPU
AUTOTUNE         = tf.data.AUTOTUNE
CLIENT_AMOUNT    = 4
PROX             = 0.05         # Proximal strength for FedProx if used
NUM_ROUNDS       = 50           # Federated rounds
PATIENCE         = 8            # Patience FOR early stopping used
CLIENT_LR        = 0.001        # Lower LR for CPU stability
SERVER_LR        = 0.3          # Moderate server LR with momentum
EPOCHS           = 1            # More epochs for CPU stability
NOISE_MULTIPLIER = 0.02         # Gaussian noise multiplyer
L2_NORM_CLIP     = 2.0          # Ensures adding noise will not spike
MIN_DELTA        = 0.001        # Minimum amount of change

INITIAL_LR       = 0.1          # unused
WARM_UP          = 500
TOTAL_STEPS      = 800
DECAY_STEPS      = 800
DECAY_RATE = 0.96


LABEL_MAP        = {'Normal': 0, 'Pnemonia': 1}

In [None]:
# Path to folder with images
image_folder_train = '/content/drive/My Drive/CS work/uni/year 3/3094 (dis)/coding/Complex_Dataset/train'
image_folder_test = '/content/drive/My Drive/CS work/uni/year 3/3094 (dis)/coding/Complex_Dataset/test'

In [None]:
train_image_count = 0
test_image_count = 0
null_count = 0
both_count = 0

# Iterate through each row in the DataFrame
for index, row in x_ray_info.iterrows():
    image_name = row['X_ray_image_name']
    image_path_train = os.path.join(image_folder_train, image_name)
    image_path_test = os.path.join(image_folder_test, image_name)

    # Load the image
    if os.path.exists(image_path_train) and os.path.exists(image_path_test): # Validation
        both_count += 1
    elif os.path.exists(image_path_train): # Differentiate between train and Test data
        train_image_count += 1
        x_ray_info.at[index, 'in_data'] = "Train"
    elif os.path.exists(image_path_test):
        test_image_count += 1
        x_ray_info.at[index, 'in_data'] = "Test"
    else:
        null_count += 1

print("there are " + str(train_image_count) + " images in train folder")
print("there are " + str(test_image_count) + " images in test folder")
print("there are " + str(null_count) + " images not being used")
print("there are " + str(both_count) + " images in both folders")

print("-"*40)

# Drop all the data that was not stored in the drive
x_ray_info = x_ray_info[x_ray_info['in_data'] != ""].reset_index(drop=True)

print(x_ray_info['Label'].value_counts())

print("-"*40)

there are 3049 images in train folder
there are 338 images in test folder
there are 2523 images not being used
there are 0 images in both folders
----------------------------------------
Label
Pnemonia    4334
Normal      1576
Name: count, dtype: int64
----------------------------------------


In [None]:
# Get rid of unused data
x_ray_info = x_ray_info.drop(columns=['Label_1_Virus_category', 'Label_2_Virus_category'])

# Filter only the training data (since test data should not be modified)
train_data = x_ray_info[x_ray_info['in_data'] == "Train"]

# Separate into Normal and Pneumonia classes
normal_samples = train_data[train_data['Label'] == 'Normal']
pneumonia_samples = train_data[train_data['Label'] == 'Pnemonia']  # Match your label spelling

# Randomly select 1000 from each (if available)
selected_normal = normal_samples.sample(n=min(800, len(normal_samples)), random_state=42)
selected_pneumonia = pneumonia_samples.sample(n=min(800, len(pneumonia_samples)), random_state=42)

# Combine into a new balanced DataFrame
balanced_df = pd.concat([selected_normal, selected_pneumonia]).reset_index(drop=True)

# Verify the counts
print(balanced_df['Label'].value_counts())
print("-"*40)
print(f"\nNew balanced DataFrame shape: {balanced_df.shape}")

Label
Normal      800
Pnemonia    800
Name: count, dtype: int64
----------------------------------------

New balanced DataFrame shape: (1600, 5)


In [None]:
# Create pseudo meta data for patients
fake = Faker()
balanced_df["age"] = [fake.random_int(min=18, max=80) for _ in range(len(balanced_df))] #10 digit code
balanced_df["patient_name"] = [fake.name() for _ in range(len(balanced_df))]

balanced_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1600 entries, 0 to 1599
Data columns (total 7 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   Unnamed: 0        1600 non-null   int64 
 1   X_ray_image_name  1600 non-null   object
 2   Label             1600 non-null   object
 3   Dataset_type      1600 non-null   object
 4   in_data           1600 non-null   object
 5   age               1600 non-null   int64 
 6   patient_name      1600 non-null   object
dtypes: int64(2), object(5)
memory usage: 87.6+ KB


In [None]:
 # Pseudonymisation
 # Pseudonymise sensitive data by replacing personally identifiable information

patient_data = {} # empty dictionary

# Iterate through the dataFrame and populate the dictionary
for index, row in balanced_df.iterrows():

    # Generate a unique patient ID
    patient_id = fake.random_int(min=1000000000, max=9999999999) # 10 digit code

    while patient_id in patient_data:  # Ensure uniqueness
        patient_id = fake.random_int(min=1000000000, max=9999999999)

    # Add the patient name and ID to the dictionary
    patient_data[row['patient_name']] = patient_id

# Check
for i in range(min(3, len(patient_data))):
    key = list(patient_data.keys())[i]  # Get the key at index i
    value = patient_data[key]  # Get the value associated with the key
    print(f"Patient: {key}, ID: {value}")

balanced_df["patient_name"] = balanced_df["patient_name"].map(patient_data)
print("-"*40)
balanced_df["patient_name"].values

Patient: Charles Wallace, ID: 6655802004
Patient: Marco Norton, ID: 9647342437
Patient: Ashley Blanchard, ID: 7706561040
----------------------------------------


array([6655802004, 9647342437, 7706561040, ..., 5838400377, 9648292850,
       9601499362])

In [None]:
# Generalisation
# To generalise data apply transformations to discrete values and put them into catagories

patient_age_data = {}

# Define age groups
age_groups = {
    '18-30': [],
    '31-40': [],
    '41-50': [],
    '51-60': [],
    '61+': []
}

# Function to categorize ages
def generalize_age(age):
    if 18 <= age <= 30:
        return '18-30'
    elif 31 <= age <= 40:
        return '31-40'
    elif 41 <= age <= 50:
        return '41-50'
    elif 51 <= age <= 60:
        return '51-60'
    else:
        return '61+'

# Apply generalisation to DataFrame
balanced_df['age_group'] = balanced_df['age'].apply(generalize_age)

# Store original ages
for index, row in balanced_df.iterrows():
    patient_age_data[row['patient_name']] = row['age']

print(balanced_df[['patient_name', 'age']].head(3))
balanced_df = balanced_df.drop(columns=['age'])
print("-"*40)
print(len(balanced_df['age_group']))
print("-"*40)
print(balanced_df[['patient_name', 'age_group']].head(3))

   patient_name  age
0    6655802004   44
1    9647342437   21
2    7706561040   38
----------------------------------------
1600
----------------------------------------
   patient_name age_group
0    6655802004     41-50
1    9647342437     18-30
2    7706561040     31-40


In [None]:
def load_and_preprocess(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=1) # Decode JPEG-encoded image to a single grayscale channel
    img = tf.image.resize(img, IMAGE_SIZE) # Resize image
    return tf.cast(img, tf.float32) / 255.0 # Normalise

def parse_image(path, label):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=1)
    img = tf.image.resize(img, IMAGE_SIZE)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img, label

# randomly alter the pictures so the model will focus on the essential features
def augment_image(img, label):
    img = tf.image.random_flip_left_right(img) # Random Horizontal flip
    img = tf.image.random_brightness(img, max_delta=0.1) # Random brightness adjustment
    img = tf.image.random_contrast(img, lower=0.9, upper=1.1) # Random contrast adjustment
    # small rotation ±10° for variance
    angle = tf.random.uniform([], -0.174, 0.174)
    img = tfa.image.rotate(img, angle)
    return img, label

def extract_path_and_label(row):
    image_name = row['X_ray_image_name'] # Find image name
    label = LABEL_MAP[row['Label']] # Extract if it's label
    if row['in_data'] == "Train":
        path = os.path.join(image_folder_train, image_name) # Create correct path by adding image_name to train data path
    else:
        path = os.path.join(image_folder_test, image_name) # Incase of error
    return path, label

In [None]:
def create_federated_datasets(df, client_count, epochs, augment):

    # Shuffle and split indices across clients
    indices = df.index.tolist()
    np.random.shuffle(indices)
    client_indices = np.array_split(indices, client_count)
    client_data = []

    for indices in client_indices:
        # Gather file paths and numeric labels for this client
        client_rows = df.loc[indices].reset_index(drop=True)
        records = client_rows.to_dict(orient='records')
        image_paths = []
        labels = []

        for row in records:
            path, label = extract_path_and_label(row)
            if os.path.exists(path):
                image_paths.append(path)
                labels.append(label)

        path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
        label_ds = tf.data.Dataset.from_tensor_slices(labels)
        ds = tf.data.Dataset.zip((path_ds, label_ds))

        # Decode and process the image
        ds = ds.map(parse_image, num_parallel_calls=AUTOTUNE)

        # Potentially augment data
        if augment:
            ds = ds.map(augment_image, num_parallel_calls=AUTOTUNE)

        ds = ds.cache() # Speeds up subsequent epochs by avoiding repeated disk
        ds = ds.repeat(epochs)
        ds = ds.shuffle(buffer_size=len(image_paths)) # Improve generalisation
        ds = ds.batch(BATCH_SIZE) # Group examples into batches
        ds = ds.prefetch(AUTOTUNE) # Maximises efficiency by preloading next batch while training a current batch

        client_data.append(ds)

        # Report per-client statistics
        print(f"Client {len(client_data)} has {len(image_paths)} images")
        print(f"Labels distribution: {Counter(labels)}")

    return client_data

In [None]:
def create_keras_model_comlex(): #(complex isn't as effective)

    # Use He normal initialiser for improved training stability with ReLU-like activations
    initializer = tf.keras.initializers.HeNormal()

    # Mitigate overfitting by adding weight decay
    l2_regularizer = tf.keras.regularizers.L2(1e-4)

    model = tf.keras.Sequential([

        # Define input
        tf.keras.layers.Input(shape=(224, 224, 1)),

        # First convolutional block 32 filters
        # Swish activation provides smoother gradients
        tf.keras.layers.Conv2D(32, 3, padding='same',
                               kernel_initializer=initializer,
                               kernel_regularizer=l2_regularizer,
                               use_bias=False),
        tfa.layers.GroupNormalization(groups=8, axis=-1, epsilon=1e-5),
        tf.keras.layers.Activation('swish'),
        tf.keras.layers.MaxPooling2D(),

        # Second convolutional block 64 filters
        tf.keras.layers.SeparableConv2D(64, 3, padding='same',
                                        kernel_regularizer=l2_regularizer,
                                        use_bias=False),
        tfa.layers.GroupNormalization(groups=8, axis=-1, epsilon=1e-5),
        tf.keras.layers.Activation('swish'),
         tf.keras.layers.SpatialDropout2D(0.15),
        tf.keras.layers.MaxPooling2D(),

        # Third convolutional block 128 filters
        tf.keras.layers.SeparableConv2D(128, 3, padding='same',
                                        kernel_regularizer=l2_regularizer,
                                        use_bias=False),
        tfa.layers.GroupNormalization(groups=8, axis=-1, epsilon=1e-5),
        tf.keras.layers.Activation('swish'),
         tf.keras.layers.SpatialDropout2D(0.15),
        tf.keras.layers.MaxPooling2D(),

        # Block 4
        tf.keras.layers.SeparableConv2D(256, 3, padding='same',
                                        kernel_regularizer=l2_regularizer,
                                        use_bias=False),
        tfa.layers.GroupNormalization(groups=8, axis=-1, epsilon=1e-5),
        tf.keras.layers.Activation('swish'),
         tf.keras.layers.SpatialDropout2D(0.15),
        tf.keras.layers.MaxPooling2D(),

        # Pool + Dense head
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(128,
                              kernel_regularizer=l2_regularizer,
                              use_bias=False),
        tfa.layers.GroupNormalization(groups=4, axis=-1, epsilon=1e-5),
        tf.keras.layers.Activation('swish'),
        tf.keras.layers.Dropout(0.7),

        # Use float32 for final output to prevent dtype mismatch with loss functions
        tf.keras.layers.Dense(1, activation='sigmoid', dtype='float32')
    ])
    return model

In [None]:
def create_keras_model(): #(simple)
    initializer = tf.keras.initializers.HeNormal()
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(224, 224, 1)),

        # First convolutional block 32 filters
        # Swish activation provides smoother gradients
        tf.keras.layers.Conv2D(32, (3, 3), activation='swish', padding='same', kernel_initializer=initializer),
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.MaxPooling2D(),

        # Second convolutional block 64 filters
        tf.keras.layers.SeparableConv2D(64, (3, 3), activation='swish', padding='same'),
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.MaxPooling2D(),

        # Third convolutional block 128 filters
        tf.keras.layers.SeparableConv2D(128, (3, 3), activation='swish', padding='same'),
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.MaxPooling2D(),

        # Fourth convolutional block 256 filters
        tf.keras.layers.SeparableConv2D(256, (3, 3), activation='swish', padding='same'),
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.MaxPooling2D(),

        # Pool + Dense head
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(128, activation='swish'),
        tf.keras.layers.Dropout(0.5),

        # Use float32 for final output to prevent dtype mismatch with loss functions
        tf.keras.layers.Dense(1, activation='sigmoid', dtype='float32')
    ])
    return model

In [None]:
def model_fn():
    keras_model = create_keras_model()

    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=(
            tf.TensorSpec(shape=(None, 224, 224, 1), dtype=tf.float32),  # Batch of images
            tf.TensorSpec(shape=(None), dtype=tf.int32)               # Batch of labels
        ),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=[
            tf.keras.metrics.BinaryAccuracy(),
            tf.keras.metrics.Precision(name='precision'), # Precision of positive class
            tf.keras.metrics.Recall(name='recall'), # Sensitivity/recall of positives
            tf.keras.metrics.AUC(name='auc') # Area under ROC curve
        ]
    )

In [None]:
# Differential Privacy

# This process provides formal differential privacy guarantees
# The presence or absence of any individual client barely changes the outcome of training, making it provably private
dp_agg = tff.aggregators.DifferentiallyPrivateFactory.gaussian_fixed(
    noise_multiplier=NOISE_MULTIPLIER,
    clients_per_round=CLIENT_AMOUNT, # Number of clients participating in each training round
    clip=L2_NORM_CLIP
)

weighted_dp_agg = tff.aggregators.as_weighted_aggregator(dp_agg)

In [None]:
iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    # Model function
    model_fn=model_fn,

    # Optimiser used by clients for local training
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(
        learning_rate=CLIENT_LR, momentum=0.9,
    ),

    # Optimiser used by the server to update the global model
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(
        learning_rate=SERVER_LR, momentum=0.9,
    ),
    # Aggregation strategy that enforces differential privacy
    model_aggregator=None
)

train_data = create_federated_datasets(balanced_df, CLIENT_AMOUNT, EPOCHS, True)
state = iterative_process.initialize()


Client 1 has 400 images
Labels distribution: Counter({1: 206, 0: 194})
Client 2 has 400 images
Labels distribution: Counter({0: 204, 1: 196})
Client 3 has 400 images
Labels distribution: Counter({0: 208, 1: 192})
Client 4 has 400 images
Labels distribution: Counter({1: 206, 0: 194})


In [None]:
# Store metrics
metrics_history = []
best_auc = 0.0
wait = 0

for round_num in range(1, NUM_ROUNDS + 1):
    state, metrics = iterative_process.next(state, train_data)

    # Extract metrics
    train_metrics = metrics['client_work']['train']
    loss = train_metrics['loss']
    acc = train_metrics['binary_accuracy']
    prec = train_metrics['precision']
    rec = train_metrics['recall']
    auc = train_metrics['auc']

        # Save metrics to list
    metrics_history.append({
        'round': round_num,
        'loss': loss,
        'accuracy': acc,
        'precision': prec,
        'recall': rec,
        'auc': auc
    })

    # Formatted output
    print(f"\nRound {round_num}")
    print("-"*40)
    print(f"Train Metrics:")
    print(f"Loss:     {loss:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision:{prec:.4f}")
    print(f"Recall:   {rec:.4f}")
    print(f"AUC:      {auc:.4f}")

    # Early stopping condition based on AUC
    if auc > best_auc + MIN_DELTA:
        best_auc = auc
        wait = 0
    else:
        wait += 1
        print(f"Early stopping counter: {wait}/{PATIENCE}")
        if wait >= PATIENCE:
            print("Early stopping triggered.")
            break


Round 1
----------------------------------------
Train Metrics:
Loss:     0.7370
Accuracy: 0.5225
Precision:0.5221
Recall:   0.5312
AUC:      0.5179

Round 2
----------------------------------------
Train Metrics:
Loss:     0.7263
Accuracy: 0.5150
Precision:0.5146
Recall:   0.5275
AUC:      0.5211

Round 3
----------------------------------------
Train Metrics:
Loss:     0.7083
Accuracy: 0.5200
Precision:0.5196
Recall:   0.5300
AUC:      0.5299

Round 4
----------------------------------------
Train Metrics:
Loss:     0.7023
Accuracy: 0.5156
Precision:0.5175
Recall:   0.4625
AUC:      0.5392

Round 5
----------------------------------------
Train Metrics:
Loss:     0.6949
Accuracy: 0.5369
Precision:0.5369
Recall:   0.5362
AUC:      0.5598

Round 6
----------------------------------------
Train Metrics:
Loss:     0.6941
Accuracy: 0.5500
Precision:0.5532
Recall:   0.5200
AUC:      0.5617

Round 7
----------------------------------------
Train Metrics:
Loss:     0.6883
Accuracy: 0.5569
P

In [None]:

# Directory to save data
directory = '/content/drive/My Drive/your_metric_folder/' # <-- Replace with actual path
if not os.path.exists(directory):
    os.makedirs(directory)

# Save metrics to CSV
metrics_df = pd.DataFrame(metrics_history)
metrics_df.to_csv(os.path.join(directory, 'Complex_Model_with_dp(2).csv'), index=False) # Change depending on use of differential privacy



In [None]:
# Extract final model weights
final_model_weights = iterative_process.get_model_weights(state)

# Create and compile a fresh Keras model
keras_eval_model = create_keras_model()

# Compile it for evaluation
keras_eval_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=tf.keras.optimizers.SGD(),
        metrics=[
            tf.keras.metrics.BinaryAccuracy(),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall'),
            tf.keras.metrics.AUC(name='auc')
        ]
)

# Load in the federated-trained weights
final_model_weights.assign_weights_to(keras_eval_model)


# Isolate test data 'not balanced'
test_df = x_ray_info[x_ray_info['in_data'] == 'Test'].reset_index(drop=True)

# Apply preprocessing and create single dataset
test_dataset = create_federated_datasets(test_df, 1, EPOCHS, True)[0]

# Evaluate on the test data
print("-"*40)
results = keras_eval_model.evaluate(test_dataset)
print(f"Loss:          {results[0]:.4f}")
print(f"Accuracy:      {results[1]:.4f}")
print(f"Precision:     {results[2]:.4f}")
print(f"Recall:        {results[3]:.4f}")
print(f"AUC:           {results[4]:.4f}")

# Create data frame to store test results
results_df = pd.DataFrame({
    'metric': ['Loss', 'Accuracy', 'Precision', 'Recall', 'AUC'],
    'value': results
})

# Store test Results in csv
results_df.to_csv(os.path.join(directory, 'test_results_with_dp(2).csv'), index=False) # Change depending on use of differential privacy


Client 1 has 338 images
Labels distribution: Counter({0: 188, 1: 150})
----------------------------------------
Loss:          0.6010
Accuracy:      0.6834
Precision:     0.6215
Recall:        0.7333
AUC:           0.7809
