# Federated Random Forest using FLEX library. 


In this notebook we show how to use the *Federated Random Forest* model, from the [paper](https://academic.oup.com/bioinformatics/article/38/8/2278/6525214).

First we do all the imports needed.

In [None]:
from flex.data import FedDataDistribution, FedDatasetConfig
from flex.pool import FlexPool

from flextrees.datasets.tabular_datasets import ildp

from flextrees.pool import (
    init_server_model_rf,
    deploy_server_config_rf,
    deploy_server_model_rf,
    aggregate_trees_from_rf,
    evaluate_global_rf_model,
    evaluate_global_rf_model_at_clients,
    evaluate_local_rf_model_at_clients,
    train_rf,
    collect_clients_trees_rf,
    set_aggregated_trees_rf,
)


## Loading the data using FLEX.

In this tutorial we are going to use the **ildp** database. We can use it by importing the dataset using the flextrees library.

In [None]:
train_data, test_data = ildp(ret_feature_names=False, categorical=False)
n_clients = 2

## Federating the data using FLEX

Once the data is loaded, we have to federate it. To do so we use the FLEX library. We show to ways of federating the data, using a iid distribution or a non-idd distribution. For the IID distribution we can just use the the `ìid_distribution` function from FedDataDistribution. If we are using a non-iid distribution, we have to use a custom configuration and, in this case, we just set the seed, the number of clients, and we can set manually the weights by creating them randomly or whatever the user wants. For more information, go to the FLEX library notebooks, and take a look at the notebook *Federating data with FLEXible*.

In [None]:
dist = 'iid'

if dist == 'iid':
    federated_data = FedDataDistribution.iid_distribution(centralized_data=train_data,
                                                        n_nodes=n_clients)
else:
    weights = np.random.dirichlet(np.repeat(1, n_clients), 1)[0] # To generate random weights (Full Non-IID)
    config_nidd = FedDatasetConfig(seed=0, n_nodes=n_clients, 
                                replacement=False, weights=weights)

    federated_data = FedDataDistribution.from_config(centralized_data=train_data,
                                                        config=config_nidd)

## Creating the federated architecture

When creating the federated architecture, we use `FlexPool`. As we're running a client-server architecture, we use the function `client_server_architecture`. We need to give to this function the dimension of the dataset for creating the LSH functions in order of creating the planes to hash all the data from the clients.

In [None]:
# Set server config
pool = FlexPool.client_server_pool(federated_data, init_server_model_rf)

clients = pool.clients
aggregator = pool.aggregators
server = pool.servers

We set the number of estimators to be built within the federated model, and the number of estimators that each client must build.

In [None]:
# Total number of estimators
total_estimators = 100
# Number of estimators per client
nr_estimators = total_estimators // n_clients

Lastly, we set the configuration for all the clients for training the model.

In [None]:
# Deploy clients config
server.map(func=deploy_server_config_rf, dst_pool=pool.clients)

## Training the model

Now we use the primitives to build the different estimators. In this case, each client will train `nr_estimators`, then they will be collected into the server and aggregated. The server will send the aggregated forest of trees to the clients.

In [None]:
clients.map(func=train_rf)
clients.map(func=evaluate_local_rf_model_at_clients)
aggregator.map(func=collect_clients_trees_rf, dst_pool=pool.clients, nr_estimators=nr_estimators)
aggregator.map(func=aggregate_trees_from_rf)
aggregator.map(func=set_aggregated_trees_rf, dst_pool=pool.servers)
server.map(func=deploy_server_model_rf, dst_pool=pool.clients)

## Evaluating the models

We use the primitives for evaluating the model within a global test at the server's side, and at client's side, so we test if the federated model is really improving or not the local model built by each client.

In [None]:
server.map(func=evaluate_global_rf_model, test_data=test_data)
clients.map(func=evaluate_global_rf_model_at_clients)

# End of notebook