In [11]:
# import necessary libs
import tensorflow as tf
import tensorflow_federated as tff
import numpy as np

In [12]:
# Load the EMNIST dataset provided by TFF
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [13]:
# Preprocess the data
NUM_CLIENTS = 10  # You can increase the number of clients


In [14]:
def preprocess(dataset):
    def batch_format_fn(element):
        return (tf.reshape(element['pixels'], [28, 28, 1]), element['label'])
    return dataset.map(batch_format_fn).batch(20).prefetch(10)


In [15]:
# Select a subset of clients
client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False)
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x)) for x in client_ids]


In [16]:
# Define the model
def create_keras_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])


In [17]:
# Wrap the Keras model for use in TFF
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=(tf.TensorSpec(shape=[None, 28, 28, 1], dtype=tf.float32),
                    tf.TensorSpec(shape=[None], dtype=tf.int32)),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )


In [18]:
# Define a federated computation for training
iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01)
)

In [19]:
# Initialize the process
state = iterative_process.initialize()

# Perform federated training
NUM_ROUNDS = 10
for round_num in range(NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    print(f'Round {round_num+1}, Metrics: {metrics}')


2024-05-16 11:47:55.643465: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
2024-05-16 11:47:55.643586: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-05-16 11:47:55.650740: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 32 MB memory:  -> device: 0, name: NVIDIA L4, pci bus id: 0000:00:03.0, compute capability: 8.9
2024-05-16 11:47:55.675774: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
2024-05-16 11:47:55.675853: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-05-16 11:47:55.682663: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 32 MB memory:  -> device: 0, name: NVIDIA L4, pci bus id: 0000:00:03.0, compute capability: 8.9
2024-05-

Round 1, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.06432161), ('loss', 2.4024725), ('num_examples', 995), ('num_batches', 54)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
Round 2, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.061306532), ('loss', 2.3603761), ('num_examples', 995), ('num_batches', 54)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
Round 3, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.067336686), ('loss', 2.3463943), ('num_examples', 995), ('num_batches', 54)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finali

In [20]:
# Extract the trained model weights
trained_weights = iterative_process.get_model_weights(state)

# Save the trained model weights on the coordinator
keras_model = create_keras_model()
keras_model.set_weights(trained_weights.trainable)
keras_model.save('federated_model_emnist.h5')
print("Model saved as federated_model_emnist.h5")





Model saved as federated_model_emnist.h5
