<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Federated_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install --upgrade jax jaxlib

In [None]:
pip install tensorflow tensorflow-federated

In [None]:
pip install --upgrade tensorflow tensorflow-federated

In [None]:
import tensorflow_federated as tff
import tensorflow as tf

# Load the federated EMNIST dataset
try:
    federated_data = tff.simulation.datasets.emnist.load_data()
except Exception as e:
    print("Error loading federated data:", e)

# Print available client IDs
client_ids = federated_data[0].client_ids
print("Available client IDs:", client_ids)

# Choose a valid client ID
client_id = client_ids[0]  # For example, choose the first client ID
sample_batch = federated_data[0].create_tf_dataset_for_client(client_id).batch(20)

# Display the first batch of data
for example in sample_batch.take(1):
    print(example)

# Build a simple model for federated learning
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Reshape(target_shape=[28, 28, 1], input_shape=(784,)),
        tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

# Create a federated learning process
def model_fn():
    keras_model = create_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=sample_batch.element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

iterative_process = tff.learning.build_federated_averaging_process(model_fn)

# Initialize the process
state = iterative_process.initialize()

# Perform one round of federated learning
state, metrics = iterative_process.next(state, [sample_batch])

print(f'Metrics: {metrics}')