# Federated Machine Learning

This demo illustrates the federated learning algorithm with the fashion mnist dataset

In [1]:
import keras
from util import *
from server import FederatedServer




In [2]:
# Load the Fashion mnist dataset

dataset = keras.datasets.fashion_mnist

(X_train, y_train), (X_test, y_test) = dataset.load_data()    

In [3]:
# shuffle the training data
shuffle_together(X_train, y_train)

# split the training data between the server and the clients
shard_ratios = [0.2, 0.3, 0.1, 0.2, 0.2]
shards = split_data(X_train, y_train, shard_ratios)
# assign the first shard to the server for pre-training
X_server, y_server = shards[0]

# create the clients
clients = create_clients(shards=shards[1:], create_model_fn=build_and_compile_simple_model)





In [4]:
# Create a model for the server for testing
server_model = build_and_compile_simple_model()
# Create the central server and run federated learning for 20 iterations
server = FederatedServer(server_model, clients, batch_size=10, client_epochs=1)
server.train(10, evaluate_fn=lambda model: model.evaluate(X_server, y_server))

Server iteration 0
Fitting local model for client client_1


Fitting local model for client client_2
Fitting local model for client client_3
Fitting local model for client client_4
Server iteration 1
Fitting local model for client client_1
Fitting local model for client client_2
Fitting local model for client client_3
Fitting local model for client client_4
Server iteration 2
Fitting local model for client client_1
Fitting local model for client client_2
Fitting local model for client client_3
Fitting local model for client client_4
Server iteration 3
Fitting local model for client client_1
Fitting local model for client client_2
Fitting local model for client client_3
Fitting local model for client client_4
Server iteration 4
Fitting local model for client client_1
Fitting local model for client client_2
Fitting local model for client client_3
Fitting local model for client client_4
Server iteration 5
Fitting local model for client client_1
Fitting local model for client client_2
Fitt

In [5]:
# Evaluate the final model on the validation set

loss, accuracy = server_model.evaluate(X_test, y_test)
print(f" accuracy: {accuracy} | loss: {loss}")

 accuracy: 0.859000027179718 | loss: 0.3907892107963562
