# Flower Quickstart (Simulation with TensorFlow/Keras on tabular data)

Welcome to Flower, a friendly federated learning framework!

In this notebook, we'll simulate a federated learning system with 5 clients. The clients will use TensorFlow/Keras on tabular data to define model training and evaluation. Let's start by installing Flower (published as `flwr` on PyPI) with the `simulation` extra:

In [None]:
!pip install -q flwr["simulation"] tensorflow
!pip install -q flwr_datasets

Let's also install Matplotlib, Pandas and Scikit-learn so we can make some plots once the simulation is completed, use dataframes and get data processing functions.

In [None]:
!pip install matplotlib pandas scikit-learn

Next, we import the required dependencies. The most important imports are Flower (`flwr`) and TensorFlow:

In [None]:
from typing import Dict, List, Tuple

import tensorflow as tf

import flwr as fl
from flwr.common import Metrics
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth

from datasets import Dataset, concatenate_datasets
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner

import numpy as np
import pandas as pd

from datasets import load_dataset
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

VERBOSE = 0
NUM_CLIENTS = 5
RANDOM_STATE = 42
NUM_ROUNDS = 30

# Define Alpha to use for DirichletPartitioner
ALPHA = 0.5

# Define features to employ
RELEVANT_FEATURES = ['Pclass', 'Sex', 'Age', 'Siblings/Spouses Aboard', 'Parents/Children Aboard', 'Fare']

# Define labels (target) name to employ
LABEL_NAME = 'Survived'

Let's start by defining the model we want to federated. Since we will be working with Titanic, using a fully connected model is sufficient. You can of course customize this model.

In [None]:
def get_model(num_features):
    """Constructs a simple model architecture suitable for Titanic dataset."""
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Dense(32, input_dim=num_features, activation='relu'),
            tf.keras.layers.Dense(16, activation='relu'),
            tf.keras.layers.Dense(8, activation='relu'),
            tf.keras.layers.Dense(1, activation='sigmoid'),
        ]
    )
    model.compile("adam", "binary_crossentropy", metrics=["accuracy"])
    return model

With that out of the way, let's move on to the interesting bits. Federated learning systems consist of a server and multiple clients. In Flower, we create clients by implementing subclasses of `flwr.client.Client` or `flwr.client.NumPyClient`. We use `NumPyClient` in this tutorial because it is easier to implement and requires us to write less boilerplate.

To implement the Flower client, we create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`:

- `get_parameters`: Return the current local model parameters
- `fit`: Receive model parameters from the server, train the model parameters on the local data, and return the (updated) model parameters to the server 
- `evaluate`: Received model parameters from the server, evaluate the model parameters on the local data, and return the evaluation result to the server

We mentioned that our clients will use TensorFlow/Keras for the model training and evaluation. Keras models provide methods that make the implementation straightforward: we can update the local model with server-provides parameters through `model.set_weights`, we can train/evaluate the model through `fit/evaluate`, and we can get the updated model parameters through `model.get_weights`.

Let's see a simple implementation:

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, trainset, valset, num_features) -> None:
        # Create model
        self.model = get_model(num_features)
        self.trainset = trainset
        self.valset = valset

    def get_parameters(self, config):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(self.trainset, epochs=1, verbose=VERBOSE)
        return self.model.get_weights(), len(self.trainset), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, acc = self.model.evaluate(self.valset, verbose=VERBOSE)
        return loss, len(self.valset), {"accuracy": acc}

Our class `FlowerClient` defines how local training/evaluation will be performed and allows Flower to call the local training/evaluation through `fit` and `evaluate`. Each instance of `FlowerClient` represents a *single client* in our federated learning system. Federated learning systems have multiple clients (otherwise, there's not much to federate, is there?), so each client will be represented by its own instance of `FlowerClient`. If we have, for example, three clients in our workload, we'd have three instances of `FlowerClient`. Flower calls `FlowerClient.fit` on the respective instance when the server selects a particular client for training (and `FlowerClient.evaluate` for evaluation).

In this notebook, we want to simulate a federated learning system with 5 clients on a single machine. This means that the server and all 5 clients will live on a single machine and share resources such as CPU, GPU, and memory. Having 5 clients would mean having 5 instances of `FlowerClient` in memory. Doing this on a single machine can quickly exhaust the available memory resources, even if only a subset of these clients participates in a single round of federated learning.

In addition to the regular capabilities where server and clients run on multiple machines, Flower, therefore, provides special simulation capabilities that create `FlowerClient` instances only when they are actually necessary for training or evaluation. To enable the Flower framework to create clients when necessary, we need to implement a function called `client_fn` that creates a `FlowerClient` instance on demand. Flower calls `client_fn` whenever it needs an instance of one particular client to call `fit` or `evaluate` (those instances are usually discarded after use). Clients are identified by a client ID, or short `cid`. The `cid` can be used, for example, to load different local data partitions for each client

We now define three auxiliary functions for this example (note the last two are entirely optional):
* `get_client_fn()`: Is a function that returns another function. The returned `client_fn` will be executed by Flower's VirtualClientEngine each time a new _virtual_ client (i.e. a client that is simulated in a Python process) needs to be spawn. When are virtual clients spawned? Each time the strategy samples them to do either `fit()` (i.e. train the global model on the local data of a particular client) or `evaluate()` (i.e. evaluate the global model on the validation set of a given client).

* `weighted_average()`: This is an optional function to pass to the strategy. It will be executed after an evaluation round (i.e. when client run `evaluate()`) and will aggregate the metrics clients return. In this example, we use this function to compute the weighted average accuracy of clients doing `evaluate()`.

* `get_evaluate_fn()`: This is again a function that returns another function. The returned function will be executed by the strategy at the end of a `fit()` round and after a new global model has been obtained after aggregation. This is an optional argument for Flower strategies. In this example, we use the a Titanic test set (extracted using `sklearn.model_selection.train_test_split`) to perform this server-side evaluation.

In [None]:
def get_client_fn(dataset: FederatedDataset):
    """Return a function to construct a client.

    The VirtualClientEngine will execute this function whenever a client is sampled by
    the strategy to participate.
    """

    def client_fn(cid: str) -> fl.client.Client:
        """Construct a FlowerClient with its own dataset partition."""

        # Extract partition for client with id = cid
        client_dataset = dataset[int(cid)]

        # Now let's split it into train (90%) and validation (10%)
        client_dataset_splits = client_dataset.train_test_split(test_size=0.1, seed=RANDOM_STATE)

        trainset = client_dataset_splits["train"].to_tf_dataset(
            columns="features", label_cols="labels", batch_size=64
        )

        valset = client_dataset_splits["test"].to_tf_dataset(
            columns="features", label_cols="labels", batch_size=64
        )

        # Extract the number of features
        element_spec = trainset.element_spec
        num_features = element_spec[0].shape[1]
        
        # Create and return client
        return FlowerClient(trainset, valset, num_features).to_client()

    return client_fn


def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    """Aggregation function for (federated) evaluation metrics, i.e. those returned by
    the client's evaluate() method."""
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}


def get_evaluate_fn(testset: Dataset):
    """Return an evaluation function for server-side (i.e. centralised) evaluation."""

    # The `evaluate` function will be called after every round by the strategy
    def evaluate(
        server_round: int,
        parameters: fl.common.NDArrays,
        config: Dict[str, fl.common.Scalar],
    ):
        # Extract the number of features
        element_spec = testset.element_spec
        num_features = element_spec[0].shape[1]
        model = get_model(num_features)  # Construct the model
        model.set_weights(parameters)  # Update model with the latest parameters
        loss, accuracy = model.evaluate(testset, verbose=VERBOSE)
        return loss, {"accuracy": accuracy}

    return evaluate

We now have `FlowerClient` which defines client-side training and evaluation, and `client_fn`, which allows Flower to create `FlowerClient` instances whenever it needs to call `fit` or `evaluate` on one particular client. The last step is to start the actual simulation using `flwr.simulation.start_simulation`. 

The function `start_simulation` accepts a number of arguments, amongst them the `client_fn` used to create `FlowerClient` instances, the number of clients to simulate `num_clients`, the number of rounds `num_rounds`, and the strategy. The strategy encapsulates the federated learning approach/algorithm, for example, *Federated Averaging* (FedAvg).

Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - actually starts the simulation.

We can use [Flower Datasets](https://flower.ai/docs/datasets/) to partition effortlessly obtain an off-the-shelf partitioned dataset or partition one that isn't pre-partitioned. Since we are using the Titanic dataset (which is not part of the Flower Datasets suite), we download it from [Hugging Face](https://huggingface.co/datasets/julien-c/titanic-survival) and then partition it using the `DirichletPartitioner`.

Notice that when working with tabular datasets you may need to preprocess (select features, impute missing values, scale the data, etc.). To this purpose we define the `preprocess_data` function which is employed over the Titanic dataset. 

In [None]:
def preprocess_data(dataset):
    """Preprocess the dataset."""
    # Select relevant features and target
    dataset = dataset.remove_columns([col for col in dataset.column_names if col not in RELEVANT_FEATURES + [LABEL_NAME]])
    # print(dataset)

    # Convert 'Sex' to binary values
    dataset = dataset.map(lambda x: {'Sex': 0 if x['Sex'] == 'male' else 1})

    # Fill missing 'Age' values with the mean age
    mean_age = np.mean([x['Age'] for x in dataset if x['Age'] is not None])
    dataset = dataset.map(lambda x: {'Age': x['Age'] if x['Age'] is not None else mean_age})

    # Extract features and target
    features = np.array([tuple(x[feature] for feature in RELEVANT_FEATURES) for x in dataset])
    
    target = [x[LABEL_NAME] for x in dataset]

    # Scale the features
    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features)

    # Convert scaled features back to Dataset and add target column
    dataset = Dataset.from_dict({
        **{RELEVANT_FEATURES[i]: features_scaled[:, i] for i in range(len(RELEVANT_FEATURES))},
        LABEL_NAME: target
    })

    return dataset

In [None]:
# Set partitioner to federate data
partitioner = DirichletPartitioner(num_partitions=NUM_CLIENTS, partition_by=LABEL_NAME,
                                    alpha=ALPHA)

# Load dataset from Hugging Face 
fds = FederatedDataset(dataset="julien-c/titanic-survival", partitioners={"train": partitioner})

# Create a list to store the preprocessed federated dataset
fds_preprocessed = []

# Initialize the centralized dataset
centralized_testset = None

for cid in range(NUM_CLIENTS):
    # Extract partition for client with id = cid
    client_dataset = fds.load_partition(int(cid))
    
    # Preprocess the client dataset
    client_dataset = preprocess_data(client_dataset)

    # Split into train (90%) and validation (10%)
    client_dataset_splits = client_dataset.train_test_split(test_size=0.1, seed=RANDOM_STATE)
    
    # Extract the train data and set it as Dataset format
    train_dataset = client_dataset_splits["train"]
    train_dataset_features = train_dataset.remove_columns([col for col in train_dataset.column_names if col not in RELEVANT_FEATURES])
    train_dataset_labels = train_dataset.remove_columns([col for col in train_dataset.column_names if col not in [LABEL_NAME]])
    train_dataset = {"features": [list(row.values()) for row in train_dataset_features], "labels": [list(row.values()) for row in train_dataset_labels]}
    train_dataset = Dataset.from_dict(train_dataset)

    # Extract the test data and set it as Dataset format
    test_dataset = client_dataset_splits["test"]
    test_dataset_features = test_dataset.remove_columns([col for col in test_dataset.column_names if col not in RELEVANT_FEATURES])
    test_dataset_labels = test_dataset.remove_columns([col for col in test_dataset.column_names if col not in [LABEL_NAME]])
    test_dataset = {"features": [list(row.values()) for row in test_dataset_features], "labels": [list(row.values()) for row in test_dataset_labels]}
    test_dataset = Dataset.from_dict(test_dataset)

    # Append the preprocessed FederatedDataset
    fds_preprocessed.append(train_dataset)

    # Concatenate the test partitions of each client into the centralized_testset
    if centralized_testset is None:
        centralized_testset = test_dataset
    else:
        centralized_testset = concatenate_datasets([centralized_testset, test_dataset])

In [None]:
# Enable GPU growth in your main process
enable_tf_gpu_growth()

# Convert centralized dataset to tf tensor format
centralized_testset = centralized_testset.to_tf_dataset(
    columns="features", label_cols="labels", batch_size=64
)

# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,  # Sample 100% of available clients for training
    fraction_evaluate=0.5,  # Sample 50% of available clients for evaluation
    min_fit_clients=2,  # Never sample less than 2 clients for training
    min_evaluate_clients=2,  # Never sample less than 2 clients for evaluation
    min_available_clients=int(
        NUM_CLIENTS * 0.75
    ),  # Wait until at least 3 clients are available
    evaluate_metrics_aggregation_fn=weighted_average,  # aggregates federated metrics
    evaluate_fn=get_evaluate_fn(centralized_testset),  # global evaluation function
)

# With a dictionary, you tell Flower's VirtualClientEngine that each
# client needs exclusive access to these many resources in order to run
client_resources = {"num_cpus": 1, "num_gpus": 0.0}

# Start simulation
history = fl.simulation.start_simulation(
    client_fn=get_client_fn(fds_preprocessed),
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources=client_resources,
    actor_kwargs={
        "on_actor_init_fn": enable_tf_gpu_growth  # Enable GPU growth upon actor init.
    },
)

You can then use the resturned History object to either save the results to disk or do some visualisation (or both of course, or neither if you like chaos). Below you can see how you can plot the centralised accuracy obtainined at the end of each round (including at the very beginning of the experiment) for the global model. This is want the function `evaluate_fn()` that we passed to the strategy reports.

In [None]:
import matplotlib.pyplot as plt

print(f"{history.metrics_centralized = }")

global_accuracy_centralised = history.metrics_centralized["accuracy"]
round = [data[0] for data in global_accuracy_centralised]
acc = [100.0 * data[1] for data in global_accuracy_centralised]
plt.plot(round, acc)
plt.grid()
plt.ylabel("Accuracy (%)")
plt.xlabel("Round")
plt.title("Titanic - non-IID - "+str(NUM_CLIENTS)+" clients with 5 clients per round")

Congratulations! With that, you built a Flower client, customized it's instantiation through the `client_fn`, customized the server-side execution through a `FedAvg` strategy configured for this workload, and started a simulation with 5 clients (each holding their own individual partition of the Titanic dataset).

Next, you can continue to explore more advanced Flower topics:

- Deploy server and clients on different machines using `start_server` and `start_client`
- Customize the server-side execution through custom strategies
- Customize the client-side execution through `config` dictionaries

Get all resources you need!

* **[DOCS]** Our complete documenation: https://flower.ai/docs/
* **[Examples]** All Flower examples: https://flower.ai/docs/examples/
* **[VIDEO]** Our Youtube channel: https://www.youtube.com/@flowerlabs

Don't forget to join our Slack channel: https://flower.ai/join-slack/