In [39]:
import argparse
import warnings
from typing import Union
from logging import INFO
from datasets import Dataset, DatasetDict
import xgboost as xgb

import flwr as fl
from flwr_datasets import FederatedDataset
from flwr.common.logger import log
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Parameters,
    Status,
)
from flwr_datasets.partitioner import IidPartitioner


warnings.filterwarnings("ignore", category=UserWarning)


# Define data partitioning related functions
def train_test_split(partition: Dataset, test_fraction: float, seed: int):
    """Split the data into train and validation set given split rate."""
    train_test = partition.train_test_split(test_size=test_fraction, seed=seed)
    partition_train = train_test["train"]
    partition_test = train_test["test"]

    num_train = len(partition_train)
    num_test = len(partition_test)

    return partition_train, partition_test, num_train, num_test


def transform_dataset_to_dmatrix(data: Union[Dataset, DatasetDict]) -> xgb.core.DMatrix:
    """Transform dataset to DMatrix format for xgboost."""
    x = data["inputs"]
    y = data["label"]
    new_data = xgb.DMatrix(x, label=y)
    return new_data


# Load (HIGGS) dataset and conduct partitioning
# We use a small subset (num_partitions=30) of the dataset for demonstration to speed up the data loading process.
partitioner = IidPartitioner(num_partitions=30)
fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner})

# Load the partition for this `partition_id`
log(INFO, "Loading partition...")
partition = fds.load_partition(node_id=0, split="train")
partition.set_format("numpy")

# Train/test splitting
train_data, valid_data, num_train, num_val = train_test_split(
    partition, test_fraction=0.2, seed=42
)
import numpy as np
# Reformat data to DMatrix for xgboost
log(INFO, "Reformatting data...")
train_dmatrix = transform_dataset_to_dmatrix(train_data)
valid_dmatrix = transform_dataset_to_dmatrix(valid_data)

# Hyper-parameters for xgboost training
num_local_round = 1
params = {
    "objective": "binary:logistic",
    "eta": 0.1,  # Learning rate
    "max_depth": 8,
    "eval_metric": "error",
    "nthread": 16,
    "num_parallel_tree": 1,
    "subsample": 1,
    "tree_method": "hist",
}
from sklearn.metrics import accuracy_score
predictions = 0

INFO flwr 2024-05-03 14:28:27,330 | 3147490696.py:55 | Loading partition...
INFO flwr 2024-05-03 14:28:48,337 | 3147490696.py:65 | Reformatting data...


In [42]:

# Define Flower client
class XgbClient(fl.client.Client):
    def __init__(self):
        self.bst = None
        self.config = None

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        _ = (self, ins)
        return GetParametersRes(
            status=Status(
                code=Code.OK,
                message="OK",
            ),
            parameters=Parameters(tensor_type="", tensors=[]),
        )

    def _local_boost(self):
        # Update trees based on local training data.
        for i in range(num_local_round):
            self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

        # Extract the last N=num_local_round trees for sever aggregation
        bst = self.bst[
            self.bst.num_boosted_rounds()
            - num_local_round : self.bst.num_boosted_rounds()
        ]

        return bst

    def fit(self, ins: FitIns) -> FitRes:
        if not self.bst:
            # First round local training
            log(INFO, "Start training at round 1")
            bst = xgb.train(
                params,
                train_dmatrix,
                num_boost_round=num_local_round,
                evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")],
            )
            self.config = bst.save_config()
            self.bst = bst
        else:
            for item in ins.parameters.tensors:
                global_model = bytearray(item)

            # Load global model into booster
            self.bst.load_model(global_model)
            self.bst.load_config(self.config)

            bst = self._local_boost()

        local_model = bst.save_raw("json")
        local_model_bytes = bytes(local_model)

        return FitRes(
            status=Status(
                code=Code.OK,
                message="OK",
            ),
            parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
            num_examples=num_train,
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        eval_results = self.bst.eval_set(
            evals=[(valid_dmatrix, "valid")],
            iteration=self.bst.num_boosted_rounds() - 1,
        )
        predictions = self.bst.predict(valid_dmatrix)
        auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)
        print(predictions)
        # num_train, num_val                train_dmatrix  valid_dmatrix
        # predictions = np.round(predictions)
        # print('XGBoost model accuracy score: {0:0.4f}'. format(accuracy_score(num_val, predictions)))
        # print(num_val)
        # print(self.bst.get_fscore(),"this is f1")
        print(self.bst.eval_set(evals=[(valid_dmatrix, "valid")],
            iteration=self.bst.num_boosted_rounds() - 1,))
        
        return EvaluateRes(
            status=Status(
                code=Code.OK,
                message="OK",
            ),
            loss=0.0,
            num_examples=num_val,
            metrics={"AUC": auc},
        )


# Start Flower client
fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient().to_client())

INFO flwr 2024-05-03 14:38:57,101 | grpc.py:52 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2024-05-03 14:38:57,114 | connection.py:55 | ChannelConnectivity.IDLE
DEBUG flwr 2024-05-03 14:38:57,116 | connection.py:55 | ChannelConnectivity.READY
INFO flwr 2024-05-03 14:38:57,119 | 894417073.py:33 | Start training at round 1


[0]	validate-error:0.39006	train-error:0.38980
[0.50720346 0.54078406 0.5326804  ... 0.5407039  0.5270872  0.54282004]
[0]	valid-error:0.39005714285714288
[0.47250575 0.56874794 0.54889095 ... 0.5443925  0.5141186  0.58131874]
[2]	valid-error:0.31319999999999998
[0.4390707  0.59106904 0.568166   ... 0.5377155  0.50377566 0.62653947]
[4]	valid-error:0.29881428571428570
[0.40488988 0.58968246 0.5863348  ... 0.5465034  0.49662665 0.65756816]
[6]	valid-error:0.29345714285714286


DEBUG flwr 2024-05-03 14:38:58,104 | connection.py:220 | gRPC channel closed
INFO flwr 2024-05-03 14:38:58,106 | app.py:398 | Disconnect and shut down


[0.38830832 0.5818097  0.6075606  ... 0.5509021  0.4885172  0.68229896]
[8]	valid-error:0.29018571428571427


In [25]:
train_dmatrix.feature_names
print(train_dmatrix)

<xgboost.core.DMatrix object at 0x000001EA59EE6FD0>


In [33]:
train_data['inputs']

array([[-0.13093917,  1.7058353 ,  0.5175378 , ..., -0.47649968,
        -0.649102  , -0.74050415],
       [-0.87608576, -1.4434521 ,  0.05769784, ..., -1.2853564 ,
        -0.08515155,  0.39095634],
       [ 0.12607476,  0.23739064,  1.0722127 , ...,  0.5466734 ,
         3.48156   ,  2.5582342 ],
       ...,
       [-0.8540744 ,  0.3049719 , -0.40996313, ...,  0.06947917,
        -0.17963345, -0.20951043],
       [-0.50254136,  0.86879283, -0.22084437, ..., -0.7747355 ,
         2.6738822 ,  3.0330305 ],
       [-0.45625293, -1.129682  , -0.00470868, ...,  0.53176695,
         1.586368  ,  1.44395   ]], dtype=float32)

In [34]:
train_data['label']

array([0., 0., 0., ..., 1., 0., 0.], dtype=float32)

In [31]:
num_train

280000