<a href="https://colab.research.google.com/github/UmarBalak/Federated-Learning-with-TensorFlow/blob/main/FL_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorflow_federated

# Data Preprocessing

In [None]:
import tensorflow as tf
import tensorflow_federated as tff
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models

# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize and reshape
x_train = x_train.reshape((60000, 28, 28, 1)).astype('float32') / 255
x_test = x_test.reshape((10000, 28, 28, 1)).astype('float32') / 255

# Split training data into two parts
x_train1, x_train2, x_train3, x_train4 = np.split(x_train, 4)
y_train1, y_train2, y_train3, y_train4 = np.split(y_train, 4)

NUM_CLIENTS = 10

# Function to split data among clients
def split_data_among_clients(x_data, y_data):
    client_data = []
    data_per_client = len(x_data) // NUM_CLIENTS
    for i in range(NUM_CLIENTS):
        start = i * data_per_client
        end = start + data_per_client
        client_data.append((x_data[start:end], y_data[start:end]))
    return client_data

# Create a TFF dataset for each client
def create_tf_dataset_for_client(data):
    x, y = data
    return tf.data.Dataset.from_tensor_slices((x, y)).batch(32)

# Split training data among clients
client_data1 = split_data_among_clients(x_train1, y_train1)
federated_train_data1 = [create_tf_dataset_for_client(client_data1[i]) for i in range(NUM_CLIENTS)]

client_data2 = split_data_among_clients(x_train2, y_train2)
federated_train_data2 = [create_tf_dataset_for_client(client_data2[i]) for i in range(NUM_CLIENTS)]

client_data3 = split_data_among_clients(x_train3, y_train3)
federated_train_data3 = [create_tf_dataset_for_client(client_data3[i]) for i in range(NUM_CLIENTS)]

client_data4 = split_data_among_clients(x_train4, y_train4)
federated_train_data4 = [create_tf_dataset_for_client(client_data4[i]) for i in range(NUM_CLIENTS)]

# Functions and models

In [None]:
def create_keras_model():
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(32, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    return model

class MulticlassMetrics(tf.keras.metrics.Metric):
    def __init__(self, name='multiclass_metrics', **kwargs):
        super(MulticlassMetrics, self).__init__(name=name, **kwargs)
        self.tp = self.add_weight(name='tp', initializer='zeros')
        self.fp = self.add_weight(name='fp', initializer='zeros')
        self.fn = self.add_weight(name='fn', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1)
        y_true = tf.cast(y_true, 'int64')

        tp = tf.cast(tf.equal(y_true, y_pred), 'float32')
        fp = tf.cast(tf.logical_and(tf.not_equal(y_true, y_pred), tf.equal(y_pred, 1)), 'float32')
        fn = tf.cast(tf.logical_and(tf.not_equal(y_true, y_pred), tf.equal(y_true, 1)), 'float32')

        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, 'float32')
            tp = tf.multiply(tp, sample_weight)
            fp = tf.multiply(fp, sample_weight)
            fn = tf.multiply(fn, sample_weight)

        self.tp.assign_add(tf.reduce_sum(tp))
        self.fp.assign_add(tf.reduce_sum(fp))
        self.fn.assign_add(tf.reduce_sum(fn))

    def result(self):
        precision = tf.divide(self.tp, self.tp + self.fp + tf.keras.backend.epsilon())
        recall = tf.divide(self.tp, self.tp + self.fn + tf.keras.backend.epsilon())
        f1_score = 2 * precision * recall / (precision + recall + tf.keras.backend.epsilon())
        return [precision, recall, f1_score]

    def reset_states(self):
        self.tp.assign(0.)
        self.fp.assign(0.)
        self.fn.assign(0.)

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=federated_train_data1[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[
            tf.keras.metrics.SparseCategoricalAccuracy(),
            MulticlassMetrics()
        ]
    )

# FL Training

In [None]:
# Build federated training and evaluation processes
iterative_process  = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.Adam(learning_rate = 0.001),
    server_optimizer_fn=lambda: tf.keras.optimizers.Adam(learning_rate = 0.001)
    )

# Initialize federated learning state
state = iterative_process.initialize()

In [None]:
NUM_ROUNDS = 3

def evaluate_model(iterative_process, state, model_fn, x_test, y_test):
    evaluation_process = tff.learning.algorithms.build_fed_eval(model_fn)
    evaluation_state = evaluation_process.initialize()
    model_weights = iterative_process.get_model_weights(state)
    evaluation_state = evaluation_process.set_model_weights(evaluation_state, model_weights)

    federated_test_data = [create_tf_dataset_for_client((x_test, y_test))]
    evaluation_output = evaluation_process.next(evaluation_state, federated_test_data)

    client_metrics = evaluation_output.metrics['client_work']['eval']['current_round_metrics']
    round_loss = client_metrics['loss']
    round_accuracy = client_metrics['sparse_categorical_accuracy']
    round_precision = client_metrics['multiclass_metrics'][0]
    round_recall = client_metrics['multiclass_metrics'][1]
    round_f1_score = client_metrics['multiclass_metrics'][2]

    print(f'Evaluation on entire test set: Loss={round_loss:.4f}, Accuracy={round_accuracy:.4f}, Precision={round_precision:.4f}, Recall={round_recall:.4f}, F1 Score={round_f1_score:.4f}')

    return client_metrics


########################### 1st Training and evaluation ####################################
for round_num in range(NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data1)
    print(f'Round {round_num+1} on part 1: Loss={metrics["client_work"]["train"]["loss"]:.4f}, Accuracy={metrics["client_work"]["train"]["sparse_categorical_accuracy"]:.4f}')

eval_metrics = evaluate_model(iterative_process, state, model_fn, x_test, y_test)

########################### 2nd Training and evaluation ####################################
for round_num in range(NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data2)
    print(f'Round {round_num+1} on part 2: Loss={metrics["client_work"]["train"]["loss"]:.4f}, Accuracy={metrics["client_work"]["train"]["sparse_categorical_accuracy"]:.4f}')

eval_metrics = evaluate_model(iterative_process, state, model_fn, x_test, y_test)

########################### 3rd Training and evaluation ####################################
for round_num in range(NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data3)
    print(f'Round {round_num+1} on part 3: Loss={metrics["client_work"]["train"]["loss"]:.4f}, Accuracy={metrics["client_work"]["train"]["sparse_categorical_accuracy"]:.4f}')

eval_metrics = evaluate_model(iterative_process, state, model_fn, x_test, y_test)

########################### 4th Training and evaluation ####################################
for round_num in range(NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data4)
    print(f'Round {round_num+1} on part 4: Loss={metrics["client_work"]["train"]["loss"]:.4f}, Accuracy={metrics["client_work"]["train"]["sparse_categorical_accuracy"]:.4f}')

eval_metrics = evaluate_model(iterative_process, state, model_fn, x_test, y_test)
