# Distributed Training of an XGBoost Model on Anyscale


<div align="left">
<a target="_blank" href="https://console.anyscale.com/"><img src="https://img.shields.io/badge/🚀 Run_on-Anyscale-9hf"></a>&nbsp;
<a href="https://github.com/anyscale/e2e-xgboost" role="button"><img src="https://img.shields.io/static/v1?label=&amp;message=View%20On%20GitHub&amp;color=586069&amp;logo=github&amp;labelColor=2f363d"></a>&nbsp;
</div>

This tutorial demonstrates how to execute a distributed training workload, connecting the following heterogeneous components:
- Preprocessing the dataset with Ray Data
- Distributed training of an XGBoost model with Ray Train
- Saving model artifacts to a model registry, such as MLflow

**Note**: This tutorial does not cover model tuning. Refer to [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) for experiment execution and hyperparameter tuning at any scale.

<img src="https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/distributed_training.png" width=800>


Before you start, follow the instructions in README.md to install the dependencies.

In [None]:
%load_ext autoreload
%autoreload all

In [None]:
# enable importing from dist_xgboost module
import os
import sys

sys.path.append(os.path.abspath(".."))

In [None]:
# Enable Ray Train v2. This will be the default in an upcoming release.
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"
# now it's safe to import from ray.train

In [None]:
import ray

from dist_xgboost.constants import local_storage_path, preprocessor_path

ray.init(runtime_env={"env_vars": {"TRAIN_ENABLE_SHARE_CUDA_VISIBLE_DEVICES": "1"}})

In [None]:
# make ray data less verbose
ray.data.DataContext.get_current().enable_progress_bars = False
ray.data.DataContext.get_current().print_on_execution_start = False

## Dataset Preparation

This tutorial uses the ["Breast Cancer Wisconsin (Diagnostic)"](https://archive.ics.uci.edu/dataset/17/breast+cancer+wisconsin+diagnostic) dataset, which contains features computed from digitized images of breast mass cell nuclei.

The data will be split as follows:
- 70% for training
- 15% for validation
- 15% for testing

In [None]:
from ray.data import Dataset


def prepare_data() -> tuple[Dataset, Dataset, Dataset]:
    """Load and split the dataset into train, validation, and test sets."""
    # Load the dataset from S3
    dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")
    seed = 42

    # Split 70% for training
    train_dataset, rest = dataset.train_test_split(test_size=0.3, shuffle=True, seed=seed)
    # Split the remaining 70% into 15% validation and 15% testing
    valid_dataset, test_dataset = rest.train_test_split(test_size=0.5, shuffle=True, seed=seed)
    return train_dataset, valid_dataset, test_dataset

In [None]:
# Load and split the dataset
train_dataset, valid_dataset, _test_dataset = prepare_data()
train_dataset.take(1)

Looking at the output, you can see the dataset contains features characterizing cell nuclei in breast mass, such as radius, texture, concavity, and symmetry.

## Data Preprocessing

Notice that the features have different magnitudes and ranges. While tree-based models like XGBoost aren't as sensitive to this, feature scaling can still improve numerical stability in some cases.

Ray Data offers built-in preprocessors that simplify common feature preprocessing tasks, especially for tabular data. These can be seamlessly integrated with Ray Datasets, allowing you to preprocess your data in a fault-tolerant and distributed way.

This example uses Ray's built-in `StandardScaler` to zero-center and normalize the features:

In [None]:
from ray.data.preprocessors import StandardScaler


def train_preprocessor(train_dataset: ray.data.Dataset) -> StandardScaler:
    # pick some dataset columns to scale
    columns_to_scale = [c for c in train_dataset.columns() if c != "target"]

    # Initialize the preprocessor
    preprocessor = StandardScaler(columns=columns_to_scale)
    # train the preprocessor on the training set
    preprocessor.fit(train_dataset)

    return preprocessor


preprocessor = train_preprocessor(train_dataset)

After fitting the preprocessor, the next step is to save it to a file. Later, this artifact can be registered in MLflow, allowing for reuse in downstream pipelines:

In [None]:
import pickle

with open(preprocessor_path, "wb") as f:
    pickle.dump(preprocessor, f)

### Checkpointing Configuration

Checkpointing is a powerful feature that enables you to resume training from the last checkpoint in case of interruptions. This is particularly useful for long-running training sessions.

[`XGBoostTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.xgboost.XGBoostTrainer.html) implements checkpointing out of the box. We just need to configure [`RayTrainReportCallback`](https://docs.ray.io/en/latest/train/api/doc/ray.train.xgboost.RayTrainReportCallback.html) to set the checkpointing frequency.

> **Note**: Once you enable checkpointing, you can follow [this guide](https://docs.ray.io/en/latest/train/user-guides/fault-tolerance.html) to enable fault tolerance.

Next, the datasets are transformed using the fitted preprocessor. It's important to note that the `transform()` operation is lazy; it's only applied to the data when required by the train workers:

In [None]:
train_dataset = preprocessor.transform(train_dataset)
valid_dataset = preprocessor.transform(valid_dataset)
train_dataset.take(1)

Using `take()`, you can observe that the values have been transformed to be zero-centered and rescaled roughly between -1 and 1.

> **Data Processing Note**:  
> For more advanced data loading and preprocessing techniques, check out the [comprehensive guide](https://docs.ray.io/en/latest/train/user-guides/data-loading-preprocessing.html). Ray Data also supports performant joins, filters, aggregations, and other operations for more structured data processing your workloads may require.

## Model Training with XGBoost

Now we can configure the training run using [`RunConfig`](https://docs.ray.io/en/latest/train/api/doc/ray.train.RunConfig.html#ray.train.RunConfig). Note that for XGBoost, we will configure checkpointing using `RayTrainReportCallback` instead of a `CheckpointConfig`.

In [None]:
from ray.train import Result, RunConfig, ScalingConfig

run_config = RunConfig(
    ## For multi-node clusters, configure storage that is accessible
    ## across all worker nodes with `storage_path="s3://..."`
    storage_path=local_storage_path,
)

### Training with XGBoost

The training parameters are passed as a dictionary, similar to the original [`xgboost.train()`](https://xgboost.readthedocs.io/en/stable/parameter.html) function:

In [None]:
USE_GPU = True

config = {
    "model_config": {"objective": "binary:logistic", "eval_metric": ["logloss", "error"], "max_depth": 2, "verbosity": 3},
    "checkpoint_frequency": 10,
}

### Checkpointing Configuration

Checkpointing is a powerful feature that enables you to resume training from the last checkpoint in case of interruptions. This is particularly useful for long-running training sessions.

[`XGBoostTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.xgboost.XGBoostTrainer.html) implements checkpointing out of the box. You just need to configure [`RayTrainReportCallback`](https://docs.ray.io/en/latest/train/api/doc/ray.train.xgboost.RayTrainReportCallback.html) to set the checkpointing frequency.

> **Note**: Once you enable checkpointing, you can follow [this guide](https://docs.ray.io/en/latest/train/user-guides/fault-tolerance.html) to enable fault tolerance.

In [None]:
import xgboost
from ray.train.xgboost import RayTrainReportCallback, XGBoostTrainer

NUM_WORKERS = 4


def train_fn_per_worker(config: dict):
    """Training function that runs on each worker.

    This function:
    1. Gets the dataset shard for this worker
    2. Converts to pandas for XGBoost
    3. Separates features and labels
    4. Creates DMatrix objects
    5. Trains the model using distributed communication
    """
    # Get this worker's dataset shard
    train_ds, val_ds = (
        ray.train.get_dataset_shard("train"),
        ray.train.get_dataset_shard("validation"),
    )
    if USE_GPU:
        config["device"] = f"cuda:{ray.train.get_context().get_local_rank()}"

    # Materialize the data and convert to pandas
    train_ds = train_ds.materialize().to_pandas()
    val_ds = val_ds.materialize().to_pandas()

    # Separate the labels from the features
    train_X, train_y = train_ds.drop("target", axis=1), train_ds["target"]
    eval_X, eval_y = val_ds.drop("target", axis=1), val_ds["target"]

    # Convert the data into DMatrix format for XGBoost
    dtrain = xgboost.DMatrix(train_X, label=train_y)
    deval = xgboost.DMatrix(eval_X, label=eval_y)

    # Do distributed data-parallel training
    # Ray Train sets up the necessary coordinator processes and
    # environment variables for workers to communicate with each other
    _booster = xgboost.train(
        config["model_config"],
        dtrain=dtrain,
        evals=[(dtrain, "train"), (deval, "validation")],
        num_boost_round=10,
        # Handles metric logging and checkpointing
        callbacks=[
            RayTrainReportCallback(
                frequency=config["checkpoint_frequency"],
                checkpoint_at_end=True,
                metrics=config["model_config"]["eval_metric"],
            )
        ],
    )


trainer = XGBoostTrainer(
    train_fn_per_worker,
    train_loop_config=config,
    # Register the data subsets
    datasets={"train": train_dataset, "validation": valid_dataset},
    # see "How to scale out training?" for more details
    scaling_config=ScalingConfig(
        # Number of workers for data parallelism.
        num_workers=NUM_WORKERS,
        # Set to True to use GPU acceleration
        use_gpu=USE_GPU,
    ),
    run_config=run_config,
)

In [None]:
# # attempt to run xgboost with GPU without ray train  FIXME

# train_ds = train_dataset.materialize().to_pandas()
# val_ds = valid_dataset.materialize().to_pandas()

# train_X, train_y = train_ds.drop("target", axis=1), train_ds["target"]
# eval_X, eval_y = val_ds.drop("target", axis=1), val_ds["target"]

# dtrain = xgboost.DMatrix(train_X, label=train_y)
# deval = xgboost.DMatrix(eval_X, label=eval_y)


# _booster = xgboost.train(
#         config["model_config"],
#         dtrain=dtrain,
#         evals=[(dtrain, "train"), (deval, "validation")],
#         num_boost_round=10,
#     )

> **Ray Train Benefits**:
> 
> - **Multi-node orchestration**: Automatically handles multi-node, multi-GPU setup without manual SSH or hostfile configurations
> - **Built-in fault tolerance**: Supports automatic retry of failed workers and can continue from the last checkpoint
> - **Flexible training strategies**: Supports various parallelism strategies beyond just data parallel training
> - **Heterogeneous cluster support**: Define per-worker resource requirements and run on mixed hardware
> 
> Ray Train integrates with popular frameworks like PyTorch, TensorFlow, XGBoost, and more. For enterprise needs, [RayTurbo Train](https://docs.anyscale.com/rayturbo/rayturbo-train) offers additional features like elastic training, advanced monitoring, and performance optimization.
>
> <img src="https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/train_integrations.png" width=500>

Now it's time to train the model:

In [None]:
result: Result = trainer.fit()
result

Observe that at the beginning of the training job, Ray started requesting GPU nodes. This happens automatically to satisfy the training job's requirement of four GPU workers.

Ray Train returns a [`ray.train.Result`](https://docs.ray.io/en/latest/train/api/doc/ray.train.Result.html) object, which contains important properties such as metrics, checkpoint info, and error details:

In [None]:
metrics = result.metrics
metrics

Expected output (your values may differ):

```python
OrderedDict([('train-logloss', 0.05463397157248817),
             ('train-error', 0.00506329113924051),
             ('validation-logloss', 0.06741214815308066),
             ('validation-error', 0.01176470588235294)])
```

As shown in the output, Ray Train logged metrics based on the values configured in `eval_metric` and `evals`.

It's also possible to reconstruct the trained model directly from the checkpoint directory:

In [None]:
booster = RayTrainReportCallback.get_model(result.checkpoint)
booster

## Model Registry

Now that the model is trained, the next step is to save it to a model registry for future use. This tutorial uses MLflow for this purpose, storing the model artifacts in [Anyscale user storage](https://docs.anyscale.com/configuration/storage/#user-storage). It's worth noting that Ray also integrates with [other experiment trackers](https://docs.ray.io/en/latest/train/user-guides/experiment-tracking.html).

In [None]:
import shutil
from tempfile import TemporaryDirectory

import mlflow

from dist_xgboost.constants import (
    experiment_name,
    model_fname,
    model_registry,
    preprocessor_fname,
)


def clean_up_old_runs():
    # clean up old MLFlow runs
    os.path.isdir(model_registry) and shutil.rmtree(model_registry)
    # mlflow.delete_experiment(experiment_name)
    os.makedirs(model_registry, exist_ok=True)


def log_run_to_mlflow(model_config, result, preprocessor_path):
    # create a model registry in our user storage
    mlflow.set_tracking_uri(f"file:{model_registry}")

    # create a new experiment and log metrics and artifacts
    mlflow.set_experiment(experiment_name)
    with mlflow.start_run(description="xgboost breast cancer classifier on all features"):
        mlflow.log_params(model_config)
        mlflow.log_metrics(result.metrics)

        # Selectively log just the preprocessor and model weights
        with TemporaryDirectory() as tmp_dir:
            shutil.copy(
                os.path.join(result.checkpoint.path, model_fname),
                os.path.join(tmp_dir, model_fname),
            )
            shutil.copy(
                preprocessor_path,
                os.path.join(tmp_dir, preprocessor_fname),
            )

            mlflow.log_artifacts(tmp_dir)


clean_up_old_runs()
log_run_to_mlflow(config["model_config"], result, preprocessor_path)

You can start the MLflow server to view the logged experiments:

`mlflow server -h 0.0.0.0 -p 8080 --backend-store-uri {model_registry}`

To view the dashboard, go to the **Overview tab** → **Open Ports** → `8080`.

<img src="https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/mlflow.png" width=685>

You can also view the Ray Dashboard and Train workload dashboards:

<img src="https://raw.githubusercontent.com/anyscale/e2e-xgboost/refs/heads/main/images/train_metrics.png" width=700>

The best model can then be retrieved from the registry:

In [None]:
from dist_xgboost.data import get_best_model_from_registry

best_model, artifacts_dir = get_best_model_from_registry()
artifacts_dir

### Production Deployment with Anyscale Jobs

We can wrap our training workload as a production-grade [Anyscale Job](https://docs.anyscale.com/platform/jobs/) ([API ref](https://docs.anyscale.com/reference/job-api/)):
For production scenarios, the training workload can be wrapped as an [Anyscale Job](https://docs.anyscale.com/platform/jobs/) ([API ref](https://docs.anyscale.com/reference/job-api/)):

In [None]:
from dist_xgboost.constants import root_dir

os.environ["WORKING_DIR"] = root_dir

In [None]:
%%bash

# Production batch job
anyscale job submit --name=train-xboost-breast-cancer-model \
  --containerfile="${WORKING_DIR}/containerfile" \
  --working-dir="${WORKING_DIR}" \
  --exclude="" \
  --max-retries=0 \
  -- python dist_xgboost/train.py

> **Note**: 
> - This example uses a `containerfile` to define dependencies, but using a pre-built image is also an option.
> - You can specify compute requirements as a [compute config](https://docs.anyscale.com/configuration/compute-configuration/) or inline in a [job config](https://docs.anyscale.com/reference/job-api#job-cli)
> - Note that when launched from a workspace without specifying compute, the job defaults to using the workspace's compute configuration.

## Scaling Strategies

One of the key advantages of Ray Train is its ability to effortlessly scale your training workloads. By adjusting the [`ScalingConfig`](https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html), you can optimize resource utilization and reduce training time.

### Scaling Examples

**Multi-node CPU Example** (4 nodes with 8 CPUs each):

```python
scaling_config = ScalingConfig(
    num_workers=4,
    resources_per_worker={"CPU": 8},
)
```

**Single-node multi-GPU Example** (1 node with 8 CPUs and 4 GPUs):

```python
scaling_config = ScalingConfig(
    num_workers=4,
    use_gpu=True,
)
```

**Multi-node multi-GPU Example** (4 nodes with 8 CPUs and 4 GPUs each):

```python
scaling_config = ScalingConfig(
    num_workers=16,
    use_gpu=True,
)
```

> **Important**: Keep in mind that for multi-node clusters, specifying a shared storage location (like cloud storage or NFS) in the `run_config` is necessary. Using a local path will cause an error during checkpointing.
>
> ```python
> trainer = XGBoostTrainer(
>     ..., run_config=ray.train.RunConfig(storage_path="s3://...")
> )
> ```

### Worker Configuration Guidelines

The optimal number of workers depends on your workload and cluster setup:

- For **CPU-only training**, generally use one worker per node (XGBoost can leverage multiple CPUs with threading)
- For **multi-GPU training**, use one worker per GPU
- For **heterogeneous clusters**, consider the greatest common divisor of CPU counts

### GPU Acceleration

To use GPUs for training:

1. Start one actor per GPU with `use_gpu=True`
2. Set GPU-compatible parameters
3. Divide CPUs evenly across actors on each machine

Example:

```python
trainer = XGBoostTrainer(
    scaling_config=ScalingConfig(
        # Number of workers to use for data parallelism.
        num_workers=2,
        # Whether to use GPU acceleration.
        use_gpu=True,
    ),
    params={
        # XGBoost specific params
        "eval_metric": ["logloss", "error"],
    },
    ...
)
```

For more advanced topics, explore:
- [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) for hyperparameter optimization
- [Ray Serve](https://docs.ray.io/en/latest/serve/index.html) for model deployment
- [Ray Data](https://docs.ray.io/en/latest/data/data.html) for more advanced data processing