In [16]:
import tensorflow as tf
import tensorflow_federated as tff
import numpy as np

In [17]:
# Define the model
def create_keras_model():
    return tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(2,)),
        tf.keras.layers.Dense(10, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])

In [18]:
# Wrap the Keras model for use in TFF
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=(tf.TensorSpec(shape=[None, 2], dtype=tf.float32),
                    tf.TensorSpec(shape=[None, 1], dtype=tf.float32)),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=[tf.keras.metrics.BinaryAccuracy()]
    )

In [19]:
# Create synthetic data for two clients
client1_data = {
    'x': np.array([[0.1, 0.2], [0.4, 0.6], [0.7, 0.8]], dtype=np.float32),
    'y': np.array([[0], [1], [1]], dtype=np.float32)
}

client2_data = {
    'x': np.array([[0.2, 0.3], [0.5, 0.7], [0.8, 0.9]], dtype=np.float32),
    'y': np.array([[0], [1], [1]], dtype=np.float32)
}

client3_data = {
    'x': np.array([[0.3, 0.4], [0.6, 0.8], [1.0, 1.2]], dtype=np.float32),
    'y': np.array([[0], [1], [1]], dtype=np.float32)
}

In [20]:
# Convert client data to TFF datasets
def create_tf_dataset_for_client(client_data):
    return tf.data.Dataset.from_tensor_slices((client_data['x'], client_data['y'])).batch(2)

federated_train_data = [
    create_tf_dataset_for_client(client1_data),
    create_tf_dataset_for_client(client2_data)
]

In [21]:
# Define a federated computation for training
iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn=model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01)
)

In [22]:
# Initialize the process
state = iterative_process.initialize()

# Perform federated training
NUM_ROUNDS = 20
for round_num in range(NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    print(f'Round {round_num+1}, Metrics: {metrics}')
    

2024-05-16 12:15:48.657247: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
2024-05-16 12:15:48.657434: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-05-16 12:15:48.665059: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1118 MB memory:  -> device: 0, name: NVIDIA L4, pci bus id: 0000:00:03.0, compute capability: 8.9
2024-05-16 12:15:48.685301: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
2024-05-16 12:15:48.685381: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-05-16 12:15:48.693082: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1118 MB memory:  -> device: 0, name: NVIDIA L4, pci bus id: 0000:00:03.0, compute capability: 8.9
2024

Round 1, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('binary_accuracy', 0.6666667), ('loss', 0.6651173), ('num_examples', 6), ('num_batches', 4)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
Round 2, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('binary_accuracy', 0.6666667), ('loss', 0.66318184), ('num_examples', 6), ('num_batches', 4)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
Round 3, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('binary_accuracy', 0.6666667), ('loss', 0.6612796), ('num_examples', 6), ('num_batches', 4)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])


In [23]:
# Extract the trained model weights
keras_model = create_keras_model()
trained_weights = iterative_process.get_model_weights(state)
keras_model.set_weights(trained_weights.trainable)
keras_model.save('federated_model.h5')
print("Model saved as federated_model.h5")





Model saved as federated_model.h5
