# Flower Quickstart (Simulation with TensorFlow/Keras)

Welcome to Flower, a friendly federated learning framework!

In this notebook, we'll simulate a federated learning system with 100 clients. The clients will use TensorFlow/Keras to define model training and evaluation. Let's start by installing Flower Nightly, published as `flwr-nightly` on PyPI:

In [None]:
# !pip install git+https://github.com/adap/flower.git@release/0.17#egg=flwr["simulation"]  # For a specific branch (release/0.17) w/ extra ("simulation")
# # !pip install -U flwr["simulation"]  # Once 0.17.1 is released

Next, we import the required dependencies. The most important imports are Flower (`flwr`) and TensorFlow:

In [None]:
import os
import math

# Make TensorFlow logs less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import flwr as fl
from shared.shared import*

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow.python.util.deprecation as deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False

tf.logging.set_verbosity(tf.logging.ERROR)

import tensorflow.keras as keras
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras import regularizers
from tensorflow.keras.datasets import mnist
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import RMSprop
from shared.shared import*

import ax
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.metrics.branin import branin
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import render, init_notebook_plotting

init_notebook_plotting()

With that out of the way, let's move on to the interesting bits. Federated learning systems consist of a server and multiple clients. In Flower, we create clients by implementing subclasses of `flwr.client.Client` or `flwr.client.NumPyClient`. We use `NumPyClient` in this tutorial because it is easier to implement and requires us to write less boilerplate.

To implement the Flower client, we create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`:

- `get_parameters`: Return the current local model parameters
- `fit`: Receive model parameters from the server, train the model parameters on the local data, and return the (updated) model parameters to the server 
- `evaluate`: Received model parameters from the server, evaluate the model parameters on the local data, and return the evaluation result to the server

We mentioned that our clients will use TensorFlow/Keras for the model training and evaluation. Keras models provide methods that make the implementation staightforward: we can update the local model with server-provides parameters through `model.set_weights`, we can train/evaluate the model through `fit/evaluate`, and we can get the updated model parameters through `model.get_weights`.

Let's see a simple implementation:

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, x_train, y_train, x_val, y_val) -> None:
        self.model = model
        self.x_train, self.y_train = x_train, y_train
        self.x_val, self.y_val = x_val,y_val

    def get_parameters(self):
        return self.model.get_weights()

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        epochs = int(config["epochs"])
        self.model.fit(self.x_train, self.y_train, epochs=epochs, verbose=0, shuffle=True)
        return self.model.get_weights(), len(self.x_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, acc = self.model.evaluate(self.x_val, self.y_val, verbose=0)
        return loss, len(self.x_val), {"accuracy": acc}
    
    def get_properties(self, ins):
        return {}

Our class `FlowerClient` defines how local training/evaluation will be performed and allows Flower to call the local training/evaluation through `fit` and `evaluate`. Each instance of `FlowerClient` represents a *single client* in our federated learning system. Federated learning systems have multiple clients (otherwise there's not much to federate, is there?), so each client will be represented by its own instance of `FlowerClient`. If we have, for example, three clients in our workload, we'd have three instances of `FlowerClient`. Flower calls `FlowerClient.fit` on the respective instance when the server selects a particular client for training (and `FlowerClient.evaluate` for evaluation).

In this notebook, we want to simulate a federated learning system with 100 clients on a single machine. This means that the server and all 100 clients will live on a single machine and share resources such as CPU, GPU, and memory. Having 100 clients would mean having 100 instances of `FlowerClient` im memory. Doing this on a single machine can quickly exhaust the available memory resources, even if only a subset of these clients participates in a single round of federated learning.

In addition to the regular capabilities where server and clients run on multiple machines, Flower therefore provides special simulation capabilities that create `FlowerClient` instances only when they are actually necessary for training or evaluation. To enable the Flower framework to create clients when necessary, we need to implement a function called `client_fn` that creates a `FlowerClient` instance on demand. Flower calls `client_fn` whenever it needs an instance of one particular client to call `fit` or `evaluate` (those instances are usually discarded after use). Clients are identified by a client ID, or short `cid`. The `cid` can be used, for example, to load different local data partitions for each client:

In [None]:
def get_client_fn_and_eval_fn(data_generator = load_partitioned_FEMINST_data):
    X_trains, X_tests, Y_trains, Y_tests, X_eval_test, Y_eval_test = data_generator()
    
    def client_fn(cid: str) -> fl.client.Client:
        nonlocal X_trains
        nonlocal Y_trains
        nonlocal X_tests
        nonlocal Y_tests
        s = int(cid) % 10 
        model = get_epoch_opt_compiled_original_CBO_CNN(X_trains[s].shape[1:])
        return FlowerClient(model, X_trains[s], Y_trains[s], X_tests[s], Y_tests[s])
    
    def eval_fn(weights: fl.common.Weights):
        nonlocal X_eval_test
        nonlocal Y_eval_test
        model.set_weights(weights=weights)
        loss, acc = model.evaluate(X_eval_test, Y_eval_test, verbose=0)
        return loss, {"accuracy": acc}

    return client_fn, eval_fn

In [None]:
X_trains, X_tests, Y_trains, Y_tests, X_eval_test, Y_eval_test = load_partitioned_FEMINST_data()
client_fn, eval_fn  = get_client_fn_and_eval_fn()
model = get_epoch_opt_compiled_original_CBO_CNN(X_trains[1].shape[1:])
client_fn("1").evaluate(model.get_weights(),{})

In [None]:
def generate_config(first_epoch, discount_factor):
    def fit_config(round: int):
        print(f"Configuring round {round}...")
        return {
            "epochs": str(int(first_epoch* (discount_factor**(round-1)))),
        }
    return fit_config 


We now have `FlowerClient` which defines client-side training and evaluation and `client_fn` which allows Flower to create `FlowerClient` instances whenever it needs to call `fit` or `evaluate` on one particular client. The last step is to start the actual simulation using `flwr.simulation.start_simulation`. 

The function `start_simulation` accepts a number of arguments, amongst them the `client_fn` used to create `FlowerClient` instances, the number of clients to simulate `num_clients`, the number of rounds `num_rounds`, and the strategy. The strategy encapsulates the federated learning approach/algorithm, for example, *Federated Averaging* (FedAvg).

Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - actually starts the simulation.

In [None]:
cnt = 0 

def optimise_epochs(inital_epoch_count, epoch_decay, rounds = 5, name = "hist_total_epochs_opt", NUM_CLIENTS = 10, data_generator=load_partitioned_FEMINST_data):
    global cnt 
    cnt += 1 
    config_fn = generate_config(inital_epoch_count, epoch_decay)
    client_fn, eval_fn  = get_client_fn_and_eval_fn(data_generator)
    print(client_fn)
    # Create FedAvg strategy
    strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,  # Sample 10% of available clients for training
            fraction_eval=1.0,  # Sample 5% of available clients for evaluation
            min_fit_clients=NUM_CLIENTS,  # Never sample less than 10 clients for training
            min_eval_clients=NUM_CLIENTS,  # Never sample less than 5 clients for evaluation
            min_available_clients=NUM_CLIENTS,  # Wait until at least 75 clients are available
            eval_fn = eval_fn,
            on_fit_config_fn = config_fn,
            on_evaluate_config_fn = config_fn
            
    )

    # Start simulation
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=NUM_CLIENTS,
        num_rounds=rounds,
        strategy=strategy
    )
    # filename = f"/home/ubuntu/r244_alex/R244_Project/results/{name}_{cnt}"
    # with open(filename,"w+b") as f:
    #     pickle.dump(history, f)
    return history.metrics_centralized["accuracy"][-1][1]
    


In [None]:
#Run Uniform Epoch count BO experiment with iid data

full_iid = get_iid_controlled_data_generator(iid_frac=0.5)


total_epochs = 120
best_parameters, values, experiment, model = optimize(
    parameters=[
        {
            "name":"num_rounds",
            "type": "range",
            "bounds": [1, 40],
            "value_type": "int",  # Optional, defaults to inference from type of "bounds".
            "log_scale": False,  # Optional, defaults to False.
        },
    ],
    experiment_name="acc",
    objective_name="acc",
    evaluation_function=lambda p:optimise_epochs(int(float(total_epochs)/p["num_rounds"]), 1.0, p["num_rounds"], data_generator=full_iid),
    minimize=False,  # Optional, defaults to False.
    # parameter_constraints=["x1 + x2 <= 20"],  # Optional.
    # outcome_constraints=["l2norm <= 1.25"],  # Optional.
    total_trials=15, # Optional.
)

In [None]:
#Run Uniform Epoch count BO experiment

total_epochs = 120
best_parameters, values, experiment, model = optimize(
    parameters=[
        {
            "name":"num_rounds",
            "type": "range",
            "bounds": [1, 40],
            "value_type": "int",  # Optional, defaults to inference from type of "bounds".
            "log_scale": False,  # Optional, defaults to False.
        },
    ],
    experiment_name="acc",
    objective_name="acc",
    evaluation_function=lambda p:optimise_epochs(int(float(total_epochs)/p["num_rounds"]), 1.0, p["num_rounds"]),
    minimize=False,  # Optional, defaults to False.
    # parameter_constraints=["x1 + x2 <= 20"],  # Optional.
    # outcome_constraints=["l2norm <= 1.25"],  # Optional.
    total_trials=15, # Optional.
)

In [None]:
[trial.objective_mean for trial in experiment.trials.values()]

In [None]:
best_objectives = np.array([[trial.objective_mean for trial in experiment.trials.values() ][1:]   ])

plot = optimization_trace_single_method(
    y =best_objectives,
    title="",
    ylabel="Accuracy",
    plot_trial_points = True
)
render(plot)

In [None]:
from ax.storage.json_store.load import load_experiment

experiment2 = load_experiment("BO_epochs_uniform.json")
best_objectives = np.array([[trial.objective_mean for trial in experiment2.trials.values() ][1:]   ])

plot = optimization_trace_single_method(
    y =best_objectives,
    title="",
    ylabel="Accuracy",
    plot_trial_points = True
)
render(plot)

In [None]:
from ax.plot.slice import plot_slice
from copy import deepcopy
from collections import defaultdict
exp2 = deepcopy(experiment)
seen = defaultdict(int)
arm_names = []
data = []
ele = experiment.trials[0]

for i in range(len(experiment.trials.values())):
    if seen[experiment.trials[i].arm.parameters["num_rounds"]] == 0:
        seen[experiment.trials[i].arm.parameters["num_rounds"]] = 1
    
    else:
        exp2.trials.pop(i)

print(exp2.trials)


plot = plot_slice(model, param_name="num_rounds", metric_name="acc", generator_runs_dict = exp2.trials)
render(plot)

In [None]:
data = plot[0]['data']
lay = plot[0]['layout']

import plotly.graph_objects as go
fig = {
    "data": data,
    "layout": lay,
}
go.Figure(fig).write_image("Uniform_Epoch_Count_Opt.pdf")

In [None]:
from ax.storage.json_store.save import save_experiment


whole_experiment = best_parameters, values, experiment, model
save_experiment(experiment, "iid_0.5_BO_epochs_uniform.json")
