In [None]:
!pip install tensorflow-federated

Collecting tensorflow-federated
  Downloading tensorflow_federated-0.78.0-py3-none-manylinux_2_31_x86_64.whl (70.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.3/70.3 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
Collecting dp-accounting==0.4.3 (from tensorflow-federated)
  Downloading dp_accounting-0.4.3-py3-none-any.whl (104 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.8/104.8 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting farmhashpy==0.4.0 (from tensorflow-federated)
  Downloading farmhashpy-0.4.0.tar.gz (98 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.7/98.7 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting google-vizier==0.1.11 (from tensorflow-federated)
  Downloading google_vizier-0.1.11-py3-none-any.whl (721 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m721.6/721.6 kB[0m [31m45.6 MB/s[0m eta 

#Load dataset and create models for federated learning

In [None]:
import tensorflow as tf
import tensorflow_federated as tff
import numpy as np
import collections

# Load MNIST data
mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
mnist_train_images, mnist_train_labels = mnist_train
mnist_test_images, mnist_test_labels = mnist_test

# Preprocess the dataset
NUM_CLIENTS = 15
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10
NUM_MIDDLE_SERVERS = 3
CLIENTS_PER_MIDDLE_SERVER = NUM_CLIENTS // NUM_MIDDLE_SERVERS

def preprocess(dataset):
    def batch_format_fn(images, labels):
        """Flatten a batch of `images` and return the features as an `OrderedDict`."""
        return collections.OrderedDict(
            x=tf.reshape(images, [-1, 784]),  # Flatten images to shape [-1, 784]
            y=tf.reshape(labels, [-1, 1])     # Reshape labels to shape [-1, 1]
        )
    return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
        BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

# Create federated data by partitioning the MNIST data into clients
def create_tf_dataset_for_client(images, labels):
    return tf.data.Dataset.from_tensor_slices((images, labels))

def make_federated_data(images, labels, num_clients):
    data_per_client = len(images) // num_clients
    federated_data = []
    for i in range(num_clients):
        client_images = images[i * data_per_client:(i + 1) * data_per_client]
        client_labels = labels[i * data_per_client:(i + 1) * data_per_client]
        client_dataset = create_tf_dataset_for_client(client_images, client_labels)
        federated_data.append(preprocess(client_dataset))
    return federated_data

federated_train_data = make_federated_data(mnist_train_images, mnist_train_labels, NUM_CLIENTS)

print(f'Number of client datasets: {len(federated_train_data)}')
print(f'First dataset: {federated_train_data[0]}')

# Define the model and training process
def create_keras_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.InputLayer(input_shape=(784,)),
        tf.keras.layers.Dense(10, activation="relu"),
        tf.keras.layers.Softmax(),
    ])

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=federated_train_data[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

training_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.0001),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)


Number of client datasets: 15
First dataset: <_PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.uint8, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.uint8, name=None))])>


#Train loop for global server and 3 local/middle servers


*   Tri layer architecture (Global server - middle server - clients)



In [None]:
# Initialize states
global_state = training_process.initialize()
middle_states = [training_process.initialize() for _ in range(NUM_MIDDLE_SERVERS)]

NUM_GLOBAL_ROUNDS = 2
NUM_LOCAL_ROUNDS = 5

for global_round in range(1, NUM_GLOBAL_ROUNDS + 1):
    print(f'Global round {global_round}/{NUM_GLOBAL_ROUNDS}')

    # Local rounds of communication between middle servers and clients
    for local_round in range(1, NUM_LOCAL_ROUNDS + 1):
        print(f'  Local round {local_round}/{NUM_LOCAL_ROUNDS}')

        for middle_server in range(NUM_MIDDLE_SERVERS):
            start_idx = middle_server * CLIENTS_PER_MIDDLE_SERVER
            end_idx = (middle_server + 1) * CLIENTS_PER_MIDDLE_SERVER
            client_data = federated_train_data[start_idx:end_idx]

            middle_result = training_process.next(middle_states[middle_server], client_data)
            middle_states[middle_server] = middle_result.state
            middle_metrics = middle_result.metrics
            print(f'    Middle server {middle_server + 1}, metrics={middle_metrics}')

    # Aggregate middle server updates at the global server
    aggregated_trainable_weights = [
        tf.reduce_mean(
            [training_process.get_model_weights(middle_states[i]).trainable[k] for i in range(NUM_MIDDLE_SERVERS)],
            axis=0
        )
        for k in range(len(training_process.get_model_weights(middle_states[0]).trainable))
    ]

    aggregated_non_trainable_weights = [
        tf.reduce_mean(
            [training_process.get_model_weights(middle_states[i]).non_trainable[k] for i in range(NUM_MIDDLE_SERVERS)],
            axis=0
        )
        for k in range(len(training_process.get_model_weights(middle_states[0]).non_trainable))
    ]

    # Update global state with aggregated weights
    new_global_model_weights = tff.learning.models.ModelWeights(
        trainable=aggregated_trainable_weights,
        non_trainable=aggregated_non_trainable_weights
    )
    global_state = training_process.set_model_weights(global_state, new_global_model_weights)

    # Distribute the global model weights back to middle servers
    for middle_server in range(NUM_MIDDLE_SERVERS):
        middle_states[middle_server] = training_process.set_model_weights(middle_states[middle_server], new_global_model_weights)


    # Evaluate the global model on the aggregated data
    global_evaluation = training_process.next(global_state, federated_train_data)
    global_metrics = global_evaluation.metrics
    print(f'  Global server round {global_round}, metrics: {global_metrics}')




Global round 1/2
  Local round 1/5
    Middle server 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.43462), ('loss', 8.935236), ('num_examples', 100000), ('num_batches', 5000)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
    Middle server 2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.44088), ('loss', 8.825633), ('num_examples', 100000), ('num_batches', 5000)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
    Middle server 3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.44002), ('loss', 8.843071), ('num_examples', 100000), ('num_batches', 5000)]))])), ('aggregato

#Check accuracy on test data




In [None]:
# Preprocess the test dataset
def preprocess_test(dataset):
    def batch_format_fn(images, labels):
        return (tf.reshape(images, [-1, 784]), tf.reshape(labels, [-1, 1]))
    return dataset.batch(BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

test_dataset = preprocess_test(create_tf_dataset_for_client(mnist_test_images.reshape([-1, 784]), mnist_test_labels))

# Convert the final federated model to a Keras model
final_keras_model = create_keras_model()
final_keras_model.compile(
    optimizer=tf.keras.optimizers.SGD(learning_rate=0.0001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)

# Update the Keras model with the final trained state
final_weights = training_process.get_model_weights(global_state)
final_keras_model.set_weights(final_weights.trainable)

# Evaluate the model on the test dataset
test_loss, test_accuracy = final_keras_model.evaluate(test_dataset)
print(f'Test loss: {test_loss:.4f}')
print(f'Test accuracy: {test_accuracy:.4f}')

Test loss: 34.5854
Test accuracy: 0.8178
