# Flower Quickstart (Simulation with TensorFlow/Keras)

Welcome to Flower, a friendly federated learning framework!

In this notebook, we'll simulate a federated learning system with 100 clients. The clients will use TensorFlow/Keras to define model training and evaluation. Let's start by installing Flower Nightly, published as `flwr-nightly` on PyPI:

In [None]:
!pip install -U flwr-nightly[simulation]  # For the latest flwr-nightly release
# !pip install -U flwr[simulation]  # Once 0.17 is released
# !pip install git+https://github.com/adap/flower.git@branchname#egg=flwr[simulation]  # For a specific branch

Next, we import the required dependencies:

In [2]:
import os
import math

# Make TensorFlow logs less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import flwr as fl
import tensorflow as tf

With the boring parts out of the way, let's move on to the interesting bits. Flower federates existing machine learning projects by implementing subclasses of `flwr.client.NumPyClient`.

In [3]:
class FlwrClient(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train) -> None:
        super().__init__()
        self.model = model
        split_idx = math.floor(len(x_train) * 0.9)  # Use 10% of x_train for validation
        self.x_train, self.y_train = x_train[:split_idx], y_train[:split_idx]
        self.x_val, self.y_val = x_train[split_idx:], y_train[split_idx:]

    def get_parameters(self):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.x_train, self.y_train, epochs=2, verbose=2)
        return self.model.get_weights(), len(self.x_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, acc = self.model.evaluate(self.x_val, self.y_val, verbose=2)
        return loss, len(self.x_val), {"accuracy": acc}

The class `FlwrClient` provides methods such as `fit` and `evaluate`. An instance of `FlwrClient` represents a **single client** in your federated learning system. Federated learning systems have multiple clients (otherwise there's not much to federate, is there?). Each client will be represented by its own instance of `FlwrClient`. If you have, for example, three clients in your workload, you'd have three instances of `FlwrClient`. Flower calls `FlwrClient.fit` on the respective instance when the server selects a particular client for training (and `FlwrClient.evaluate` for evaluation).

In [4]:
def client_fn(cid: str) -> fl.client.Client:
    # Load model
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation="softmax"),
        ]
    )
    model.compile("adam", "sparse_categorical_crossentropy", metrics=["accuracy"])

    # Load data partition (divide MNIST into NUM_CLIENTS distinct partitions)
    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
    partition_size = math.floor(len(x_train) / NUM_CLIENTS)
    idx_from, idx_to = int(cid) * partition_size, (int(cid) + 1) * partition_size
    x_train_cid = x_train[idx_from:idx_to] / 255.0
    y_train_cid = y_train[idx_from:idx_to]

    # Create and return client
    return FlwrClient(model, x_train_cid, y_train_cid)

The last step is to start the actual simulation.

In [None]:
NUM_CLIENTS = 100

# Create FedAvg strategy
strategy=fl.server.strategy.FedAvg(
        fraction_fit=0.1,  # Sample 10% of available clients for training
        fraction_eval=0.05,  # Sample 5% of available clients for evaluation
        min_fit_clients=10,  # Never sample less than 10 clients for training
        min_eval_clients=10,  # Never sample less than 5 clients for evaluation
        min_available_clients=int(NUM_CLIENTS * 0.75),  # Wait until at least 75 clients are 
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    num_rounds=5,
    strategy=strategy,
)