# Practical Federated Gradient Boosting Decision Trees using FLEX library. 


In this notebook we show how to use the *Practical Federated Gradient Boosting Decision Trees* model, from the [paper](https://ojs.aaai.org/index.php/AAAI/article/view/5895).

Note: the preprocessing stage takes a lot of time to finish, so 

First we do all the imports needed.

In [None]:
import time
import argparse

import numpy as np

from flex.data import FedDataDistribution, FedDatasetConfig, one_hot_encoding
from flex.pool import FlexPool

from flextrees.datasets.tabular_datasets import adult

from flextrees.pool import (
    init_server_model_gbdt,
    init_hash_tables,
    compute_hash_values,
    deploy_server_config_gbdt,
    evaluate_global_model,
    evaluate_global_model_clients_gbdt,
    train_n_estimators,
    preprocessing_stage,
)

## Loading the data using FLEX.

In this tutorial we are going to use the adult database. We can use it by importing the dataset using the flextrees library.

In [None]:
train_data, test_data = adult(ret_feature_names=False, categorical=False)
n_labels = len(np.unique(train_data.y_data.to_numpy())) # We need the number of total labels for the softmax.
dataset_dim = train_data.to_numpy()[0].shape[1] # We need the dimension to create the LSH hyper planes.

## Federating the data using FLEX

Once the data is loaded, we have to federate it. To do so we use the FLEX library. We show to ways of federating the data, using a iid distribution or a non-idd distribution. For the IID distribution we can just use the the `ìid_distribution` function from FedDataDistribution. If we are using a non-iid distribution, we have to use a custom configuration and, in this case, we just set the seed, the number of clients, and we can set manually the weights by creating them randomly or whatever the user wants. For more information, go to the FLEX library notebooks, and take a look at the notebook *Federating data with FLEXible*.

In [None]:
dist = 'iid'

if dist == 'iid':
    federated_data = FedDataDistribution.iid_distribution(centralized_data=train_data,
                                                        n_nodes=n_clients)
else:
    config_nidd = FedDatasetConfig(seed=0, n_nodes=n_clients, replacement=False)

    federated_data = FedDataDistribution.from_config(centralized_data=train_data,
                                                        config=config_nidd)

For using the model, we will need to do a little preprocess to the data, and this is to ``one hot encode`` the labels. After federating the data, we can use the `apply` function from the `FedDataset` to apply the selected function to all the data that is federated.

In [None]:
# One hot encode the labels for using softmax
federated_data.apply(one_hot_encoding, n_labels=n_labels)

## Creating the federated architecture

When creating the federated architecture, we use `FlexPool`. As we're running a client-server architecture, we use the function `client_server_architecture`. We need to give to this function the dimension of the dataset for creating the LSH functions in order of creating the planes to hash all the data from the clients.

In [None]:
pool = FlexPool.client_server_architecture(federated_data, init_server_model_gbdt, dataset_dim=dataset_dim)
clients = pool.clients
aggregator = pool.aggregators
server = pool.servers

In [None]:
# Total number of estimators
total_estimators = 10
print(f"Number of trees to build: {total_estimators}")
estimators_built = 0

Lastly, we set the configuration for all the clients for training the model.

In [None]:
server.map(func=deploy_server_config_gbdt, dst_pool=pool.clients)

Now that's everything is set, we can begin with the code for the boosting model. First the client's have to create the hash table for its data.

In [None]:
clients.map(func=init_hash_tables) # Init hash tables
clients.map(func=compute_hash_values) # Calculate the hash tables on the clients

## Preprocessing stage

This phase use the hash tables from the clients to search the similar instances for each instance from other clients without sharing the data, just sharing the planes. This phase is complex, the pseudo-code is available in the paper, and we have put all together into one primitive function.

The `preprocessing_stage` function recieves the clients, the server and the aggregator and create the global hash table with the similar instances for each client.

In [None]:
preprocessing_stage(clients=clients,
                        server=server,
                        aggregator=aggregator
                        )

## Training stage

The second phase is the training phase. As it's done with the preprocessing stage, we have a primitive function to train all the estimators selected.

In [None]:
train_n_estimators(clients=clients, server=server,
                    aggregator=aggregator, total_estimators=total_estimators,
                )

## Evaluating the model

After the model is trained, we have to evaluate it at the server level and at client level.

In [None]:
# On server side
server.map(evaluate_global_model, test_data=test_data)
# On clients side
clients.map(evaluate_global_model_clients_gbdt)