In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import random
import scipy.stats
from itertools import islice
from IPython.display import display

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from keras.utils import to_categorical 

import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

In [2]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml(name='mnist_784')

In [3]:
nr_images = len(mnist.data.values)
all_images = mnist.data.values
all_labels = np.array(list(map(int, mnist.target.values)))
all_images = all_images / 255.0

### Create clients with uneven class distribution

In [4]:
# Split dataset without shuffling
X_train, X_test, y_train, y_test = train_test_split(
    all_images, all_labels, test_size=0.3, shuffle=True, random_state=1337
)

## Create clients with even distribution

In [12]:
distribution = [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]
clients = create_clients(X_train, y_train, distribution)
#plot_client_data(clients)

Metal device set to: Apple M1 Max


## Train Clients Individually

In [14]:
nr_cycles = 0

while nr_cycles < 40:
    for client in clients:
        fit_model_to_data(client.model, client.data, client.labels, n_classes=10, epochs=2)
    nr_cycles += 1

## Federated Learning Simulation

In [15]:
# Create global model
global_model = create_simple_model()

global_amount_data = len(X_train)

nr_cycles = 0

while nr_cycles < 40:
    
    # Extract weights from global model
    global_weights = global_model.get_weights()

    # Simulating sending the global model to the clients
    for c in clients: 
        client_model = c.model
        client_model.set_weights(global_weights)

    # Train clients on their own data
    for client in clients:
        fit_model_to_data(
            client.model, 
            client.data, 
            client.labels, 
            n_classes=10, 
            epochs=2
        )

    # Simulate the clients sending their weights to central
    client_weights = []
    for c in clients:
        weights = np.array(c.model.get_weights())
        client_weights.append(weights)
    
    # Calculate client weight scaling factor
    scaling_factor_clients = []
    for c in clients:
        scaling_factor = len(c.data) / global_amount_data
        scaling_factor_clients.append(scaling_factor)

    # Construct new global model from client weights
    new_global_weights = np.array(global_weights) * 0 # Create empty weights array 
    for scaling_factor, weights in zip(scaling_factor_clients, client_weights):
        new_global_weights += weights * scaling_factor  

    # Set the new weights on the global model
    global_model.set_weights(new_global_weights)
    acc = evaluate_model(global_model, X_test, y_test)
    
    nr_cycles += 1
    print("Cycle ", nr_cycles, " complete. Global model accuracy:", acc)

Cycle  1  complete. Global model accuracy: 0.9390023594858363
Cycle  2  complete. Global model accuracy: 0.9630327515530116
Cycle  3  complete. Global model accuracy: 0.9687151481600054
Cycle  4  complete. Global model accuracy: 0.972371244901408
Cycle  5  complete. Global model accuracy: 0.9743610367917538
Cycle  6  complete. Global model accuracy: 0.9760370577577574
Cycle  7  complete. Global model accuracy: 0.9778896749638151
Cycle  8  complete. Global model accuracy: 0.9786127346195561
Cycle  9  complete. Global model accuracy: 0.9790854221334874
Cycle  10  complete. Global model accuracy: 0.9796523549796179
Cycle  11  complete. Global model accuracy: 0.9801385526045032
Cycle  12  complete. Global model accuracy: 0.9807075511931114
Cycle  13  complete. Global model accuracy: 0.9809919903482788
Cycle  14  complete. Global model accuracy: 0.9806596685579072
Cycle  15  complete. Global model accuracy: 0.9803727123951217
Cycle  16  complete. Global model accuracy: 0.9813259780501161
Cy

# Helper functions

In [6]:
class Client:
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
        self.model = None

In [7]:
def create_simple_model():
    model = Sequential(name="MNIST_Classifier")
    model.add(Dense(784, input_shape = (784,), activation="relu"))
    model.add(Dense(10, activation="softmax"))
    model.compile(
        loss='categorical_crossentropy', 
        optimizer='adam'
    )
    return model

In [8]:
def fit_model_to_data(model, X, y, n_classes, epochs):
    y_one_hot = to_categorical(y, n_classes)
    model.fit(
        X, # Samples
        y_one_hot, # Labels
        batch_size=32,
        epochs=epochs,
        verbose=0
    )

def evaluate_model(model, X_test, y_test):
    one_hot_predictions = model.predict(X_test)
    label_predictions = np.argmax(one_hot_predictions, axis=1)
    return f1_score(y_test, label_predictions, average='weighted')

In [9]:
def chunk_array(array, distribution):
    distribution = np.array(distribution) * len(array)
    distribution = [int(d) for d in distribution]
    
    it = iter(array)
    return [np.array(list(islice(it, 0, i))) for i in distribution]

def chunk_data(data, labels, distribution):
    data_chunks = chunk_array(data, distribution)
    label_chunks = chunk_array(labels, distribution)
    return list(zip(data_chunks, label_chunks))

In [10]:
def create_clients(data, labels, distribution):
    # Split the training data into chunks with the provided distribution
    client_data = chunk_data(X_train, y_train, distribution)
    clients = []

    # Create client instances, with data and temporary ML model
    for c_d in client_data:

        # Give a client its portion of data
        new_client = Client(c_d[0], c_d[1])

        # Create a simple model for each client 
        # This will be updated with the weights from global in the future
        new_client.model = create_simple_model()

        clients.append(new_client)
    return clients

In [11]:
def plot_client_data(clients):
    f, axs = plt.subplots(1,2,figsize=(12, 7))
    for i in range(len(clients)):
        labels = clients[i].labels

        plt.subplot(2, 4, i+1)
        plt.title('Owner {}'.format(i+1))
        plt.ylim([0, 4500])
        plt.xlim([-1,10])
        plt.ylabel(" ")
        plt.yticks([])
        plt.xticks(list(np.arange(0,10)))
        sns.histplot(labels, bins=10, discrete=True)
    plt.subplots_adjust(left=0.1,
                    bottom=0.1,
                    right=0.9,
                    top=0.9,
                    wspace=0.1,
                    hspace=0.3)