# Federated Machine Learning

This demo illustrates the federated learning algorithm with the fashion mnist dataset

In [1]:
import keras
from util import *
from federated import CentralServer

In [2]:
# Load the Fashion mnist dataset

dataset = keras.datasets.fashion_mnist

(X_train, y_train), (X_test, y_test) = dataset.load_data()    

In [3]:
# split the training data between the server and the clients
num_clients = 5
shards = split_data(X_train, y_train, num_clients + 1)
server_X, server_y = shards[0]

# create the clients
clients = create_clients(shards=shards[1:], create_model_fn=build_and_compile_simple_model)

In [4]:
# create test and validation sets
half = len(y_test) // 2
X_valid, y_valid = X_test[:half], y_test[:half]
X_test, y_test = X_test[half:], y_test[half:]


In [5]:
    # pre-train server model to obtain initial weights
server_model = build_and_compile_simple_model()
server_model.fit(server_X, server_y, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x24847832be0>

In [6]:
# Create the central server and run federated learning for 20 iterations
server = CentralServer(server_model, clients, client_epochs=3)
server.train(20, evaluate_fn=lambda model: model.evaluate(X_test, y_test) )


Server iteration 0
Fitting local model for client client_1
Fitting local model for client client_2
Fitting local model for client client_3
Fitting local model for client client_4
Fitting local model for client client_5
Server iteration 1
Fitting local model for client client_1
Fitting local model for client client_2
Fitting local model for client client_3
Fitting local model for client client_4
Fitting local model for client client_5
Server iteration 2
Fitting local model for client client_1
Fitting local model for client client_2
Fitting local model for client client_3
Fitting local model for client client_4
Fitting local model for client client_5
Server iteration 3
Fitting local model for client client_1
Fitting local model for client client_2
Fitting local model for client client_3
Fitting local model for client client_4
Fitting local model for client client_5
Server iteration 4
Fitting local model for client client_1
Fitting local model for client client_2
Fitting local model for c

In [7]:
# Evaluate the final model on the validation set

loss, accuracy = server_model.evaluate(X_valid, y_valid)
print(f" accuracy: {accuracy} | loss: {loss}")

 accuracy: 0.870199978351593 | loss: 0.36102667450904846
