Federated learning (FL) integration to enable distributed training of the model while ensuring data privacy and security.

In [6]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.utils import to_categorical
import numpy as np

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = np.expand_dims(x_train, axis=-1), np.expand_dims(x_test, axis=-1)
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = to_categorical(y_train, 10), to_categorical(y_test, 10)

# Split dataset among 3 clients
def split_data(data, targets, num_clients):
    size = len(data) // num_clients
    return [(data[i*size:(i+1)*size], targets[i*size:(i+1)*size]) for i in range(num_clients)]

clients_data = split_data(x_train, y_train, 3)

# Build a CNN model
def create_cnn_model():
    model = Sequential([
        Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
        MaxPooling2D(pool_size=(2, 2)),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D(pool_size=(2, 2)),
        Flatten(),
        Dense(128, activation='relu'),
        Dropout(0.5),
        Dense(10, activation='softmax')
    ])
    # Compile the model
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.CategoricalCrossentropy(),
                  metrics=['accuracy'])
    return model

# Initialize client models
client_models = [create_cnn_model() for _ in range(3)]
optimizers = [tf.keras.optimizers.Adam() for _ in range(3)]
loss_fn = tf.keras.losses.CategoricalCrossentropy()

# Federated averaging
def federated_averaging(models):
    global_weights = [model.get_weights() for model in models]
    new_weights = [np.mean([client_weights[layer] for client_weights in global_weights], axis=0)
                   for layer in range(len(global_weights[0]))]
    for model in models:
        model.set_weights(new_weights)

# Define metrics
metrics = {
    'precision': tf.keras.metrics.Precision(),
    'recall': tf.keras.metrics.Recall()
}

def calculate_f1_score(precision, recall):
    if precision + recall == 0:
        return 0
    return 2 * (precision * recall) / (precision + recall)

# Define train_model function
def train_model(model, optimizer, x, y, batch_size=32):
    dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)
    epoch_loss_avg = tf.keras.metrics.Mean()
    for batch_x, batch_y in dataset:
        with tf.GradientTape() as tape:
            predictions = model(batch_x, training=True)
            loss = loss_fn(batch_y, predictions)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        epoch_loss_avg.update_state(loss)
    return epoch_loss_avg.result()

# Training loop
global_epochs = 100
for epoch in range(global_epochs):
    losses = []
    for client_data, model, optimizer in zip(clients_data, client_models, optimizers):
        x, y = client_data
        loss = train_model(model, optimizer, x, y)
        losses.append(loss)

    federated_averaging(client_models)

    # Evaluate the averaged model
    y_pred = []
    y_true = []
    for x, y in tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32):
        predictions = client_models[0](x, training=False)
        y_pred.extend(tf.argmax(predictions, axis=1))
        y_true.extend(tf.argmax(y, axis=1))

    for name, metric in metrics.items():
        metric.update_state(y_true, y_pred)

    precision = metrics['precision'].result().numpy()
    recall = metrics['recall'].result().numpy()
    f1 = calculate_f1_score(precision, recall)
    average_loss = np.mean(losses)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), dtype=tf.float32)).numpy()

    print(f"Global Epoch {epoch+1}: Loss = {average_loss}, Accuracy = {accuracy}, Precision = {precision}, Recall = {recall}, F1-Score = {f1}")

    # Reset metrics for the next epoch
    for metric in metrics.values():
        metric.reset_states()

# Test final model performance
final_loss, final_accuracy = client_models[0].evaluate(x_test, y_test)
print(f"Final Model: Loss = {final_loss}, Accuracy = {final_accuracy}")


Global Epoch 1: Loss = 0.3725675046443939, Accuracy = 0.6402999758720398, Precision = 0.9137791395187378, Recall = 0.9998891353607178, F1-Score = 0.9548967536499223
Global Epoch 2: Loss = 0.2871287763118744, Accuracy = 0.9745000004768372, Precision = 0.9990014433860779, Recall = 0.9982261657714844, F1-Score = 0.9986136411928881
Global Epoch 3: Loss = 0.1315055936574936, Accuracy = 0.9811000227928162, Precision = 0.9994440674781799, Recall = 0.996563196182251, F1-Score = 0.9980015793382141
Global Epoch 4: Loss = 0.09795280545949936, Accuracy = 0.9853000044822693, Precision = 0.999444305896759, Recall = 0.9970066547393799, F1-Score = 0.9982240274993605
Global Epoch 5: Loss = 0.08160554617643356, Accuracy = 0.9872999787330627, Precision = 0.9995555281639099, Recall = 0.9973392486572266, F1-Score = 0.9984461382036166
Global Epoch 6: Loss = 0.0678454339504242, Accuracy = 0.9879999756813049, Precision = 0.9994446635246277, Recall = 0.9976718425750732, F1-Score = 0.9985574548075727
Global Epo