Skip to content

Commit bb2e701

Browse files
authored
[Dask] Sort and partition for ranking. (dmlc#11007)
- Implement automatic local sort. - Implement partitioning by query ID. - Document for distributed ranking.
1 parent 544a52e commit bb2e701

File tree

17 files changed

+699
-57
lines changed

17 files changed

+699
-57
lines changed

demo/dask/dask_learning_to_rank.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""
2+
Learning to rank with the Dask Interface
3+
========================================
4+
5+
.. versionadded:: 3.0.0
6+
7+
This is a demonstration of using XGBoost for learning to rank tasks using the
8+
MSLR_10k_letor dataset. For more infomation about the dataset, please visit its
9+
`description page <https://www.microsoft.com/en-us/research/project/mslr/>`_.
10+
11+
See :ref:`ltr-dist` for a general description for distributed learning to rank and
12+
:ref:`ltr-dask` for Dask-specific features.
13+
14+
"""
15+
16+
from __future__ import annotations
17+
18+
import argparse
19+
import os
20+
from contextlib import contextmanager
21+
from typing import Generator
22+
23+
import dask
24+
import numpy as np
25+
from dask import dataframe as dd
26+
from distributed import Client, LocalCluster, wait
27+
from sklearn.datasets import load_svmlight_file
28+
29+
from xgboost import dask as dxgb
30+
31+
32+
def load_mslr_10k(
33+
device: str, data_path: str, cache_path: str
34+
) -> tuple[dd.DataFrame, dd.DataFrame, dd.DataFrame]:
35+
"""Load the MSLR10k dataset from data_path and save parquet files in the cache_path."""
36+
root_path = os.path.expanduser(args.data)
37+
cache_path = os.path.expanduser(args.cache)
38+
39+
# Use only the Fold1 for demo:
40+
# Train, Valid, Test
41+
# {S1,S2,S3}, S4, S5
42+
fold = 1
43+
44+
if not os.path.exists(cache_path):
45+
os.mkdir(cache_path)
46+
fold_path = os.path.join(root_path, f"Fold{fold}")
47+
train_path = os.path.join(fold_path, "train.txt")
48+
valid_path = os.path.join(fold_path, "vali.txt")
49+
test_path = os.path.join(fold_path, "test.txt")
50+
51+
X_train, y_train, qid_train = load_svmlight_file(
52+
train_path, query_id=True, dtype=np.float32
53+
)
54+
columns = [f"f{i}" for i in range(X_train.shape[1])]
55+
X_train = dd.from_array(X_train.toarray(), columns=columns)
56+
y_train = y_train.astype(np.int32)
57+
qid_train = qid_train.astype(np.int32)
58+
59+
X_train["y"] = dd.from_array(y_train)
60+
X_train["qid"] = dd.from_array(qid_train)
61+
X_train.to_parquet(os.path.join(cache_path, "train"), engine="pyarrow")
62+
63+
X_valid, y_valid, qid_valid = load_svmlight_file(
64+
valid_path, query_id=True, dtype=np.float32
65+
)
66+
X_valid = dd.from_array(X_valid.toarray(), columns=columns)
67+
y_valid = y_valid.astype(np.int32)
68+
qid_valid = qid_valid.astype(np.int32)
69+
70+
X_valid["y"] = dd.from_array(y_valid)
71+
X_valid["qid"] = dd.from_array(qid_valid)
72+
X_valid.to_parquet(os.path.join(cache_path, "valid"), engine="pyarrow")
73+
74+
X_test, y_test, qid_test = load_svmlight_file(
75+
test_path, query_id=True, dtype=np.float32
76+
)
77+
78+
X_test = dd.from_array(X_test.toarray(), columns=columns)
79+
y_test = y_test.astype(np.int32)
80+
qid_test = qid_test.astype(np.int32)
81+
82+
X_test["y"] = dd.from_array(y_test)
83+
X_test["qid"] = dd.from_array(qid_test)
84+
X_test.to_parquet(os.path.join(cache_path, "test"), engine="pyarrow")
85+
86+
df_train = dd.read_parquet(
87+
os.path.join(cache_path, "train"), calculate_divisions=True
88+
)
89+
df_valid = dd.read_parquet(
90+
os.path.join(cache_path, "valid"), calculate_divisions=True
91+
)
92+
df_test = dd.read_parquet(
93+
os.path.join(cache_path, "test"), calculate_divisions=True
94+
)
95+
96+
return df_train, df_valid, df_test
97+
98+
99+
def ranking_demo(client: Client, args: argparse.Namespace) -> None:
100+
"""Learning to rank with data sorted locally."""
101+
df_tr, df_va, _ = load_mslr_10k(args.device, args.data, args.cache)
102+
103+
X_train: dd.DataFrame = df_tr[df_tr.columns.difference(["y", "qid"])]
104+
y_train = df_tr[["y", "qid"]]
105+
Xy_train = dxgb.DaskQuantileDMatrix(client, X_train, y_train.y, qid=y_train.qid)
106+
107+
X_valid: dd.DataFrame = df_va[df_va.columns.difference(["y", "qid"])]
108+
y_valid = df_va[["y", "qid"]]
109+
Xy_valid = dxgb.DaskQuantileDMatrix(
110+
client, X_valid, y_valid.y, qid=y_valid.qid, ref=Xy_train
111+
)
112+
# Upon training, you will see a performance warning about sorting data based on
113+
# query groups.
114+
dxgb.train(
115+
client,
116+
{"objective": "rank:ndcg", "device": args.device},
117+
Xy_train,
118+
evals=[(Xy_train, "Train"), (Xy_valid, "Valid")],
119+
num_boost_round=100,
120+
)
121+
122+
123+
def ranking_wo_split_demo(client: Client, args: argparse.Namespace) -> None:
124+
"""Learning to rank with data partitioned according to query groups."""
125+
df_tr, df_va, df_te = load_mslr_10k(args.device, args.data, args.cache)
126+
127+
X_tr = df_tr[df_tr.columns.difference(["y", "qid"])]
128+
X_va = df_va[df_va.columns.difference(["y", "qid"])]
129+
130+
# `allow_group_split=False` makes sure data is partitioned according to the query
131+
# groups.
132+
ltr = dxgb.DaskXGBRanker(allow_group_split=False, device=args.device)
133+
ltr.client = client
134+
ltr = ltr.fit(
135+
X_tr,
136+
df_tr.y,
137+
qid=df_tr.qid,
138+
eval_set=[(X_tr, df_tr.y), (X_va, df_va.y)],
139+
eval_qid=[df_tr.qid, df_va.qid],
140+
verbose=True,
141+
)
142+
143+
df_te = df_te.persist()
144+
wait([df_te])
145+
146+
X_te = df_te[df_te.columns.difference(["y", "qid"])]
147+
predt = ltr.predict(X_te)
148+
y = client.compute(df_te.y)
149+
wait([predt, y])
150+
151+
152+
@contextmanager
153+
def gen_client(device: str) -> Generator[Client, None, None]:
154+
match device:
155+
case "cuda":
156+
from dask_cuda import LocalCUDACluster
157+
158+
with LocalCUDACluster() as cluster:
159+
with Client(cluster) as client:
160+
with dask.config.set(
161+
{
162+
"array.backend": "cupy",
163+
"dataframe.backend": "cudf",
164+
}
165+
):
166+
yield client
167+
case "cpu":
168+
with LocalCluster() as cluster:
169+
with Client(cluster) as client:
170+
yield client
171+
172+
173+
if __name__ == "__main__":
174+
parser = argparse.ArgumentParser(
175+
description="Demonstration of learning to rank using XGBoost."
176+
)
177+
parser.add_argument(
178+
"--data",
179+
type=str,
180+
help="Root directory of the MSLR-WEB10K data.",
181+
required=True,
182+
)
183+
parser.add_argument(
184+
"--cache",
185+
type=str,
186+
help="Directory for caching processed data.",
187+
required=True,
188+
)
189+
parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
190+
parser.add_argument(
191+
"--no-split",
192+
action="store_true",
193+
help="Flag to indicate query groups should not be split.",
194+
)
195+
args = parser.parse_args()
196+
197+
with gen_client(args.device) as client:
198+
if args.no_split:
199+
ranking_wo_split_demo(client, args)
200+
else:
201+
ranking_demo(client, args)

demo/guide-python/learning_to_rank.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
train on relevance degree, and the second part simulates click data and enable the
1313
position debiasing training.
1414
15-
For an overview of learning to rank in XGBoost, please see
16-
:doc:`Learning to Rank </tutorials/learning_to_rank>`.
15+
For an overview of learning to rank in XGBoost, please see :doc:`Learning to Rank
16+
</tutorials/learning_to_rank>`.
1717
"""
1818

1919
from __future__ import annotations
@@ -31,7 +31,7 @@
3131
from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples
3232

3333

34-
def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV:
34+
def load_mslr_10k(data_path: str, cache_path: str) -> RelDataCV:
3535
"""Load the MSLR10k dataset from data_path and cache a pickle object in cache_path.
3636
3737
Returns
@@ -89,7 +89,7 @@ def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV:
8989

9090
def ranking_demo(args: argparse.Namespace) -> None:
9191
"""Demonstration for learning to rank with relevance degree."""
92-
data = load_mlsr_10k(args.data, args.cache)
92+
data = load_mslr_10k(args.data, args.cache)
9393

9494
# Sort data according to query index
9595
X_train, y_train, qid_train = data.train
@@ -123,7 +123,7 @@ def ranking_demo(args: argparse.Namespace) -> None:
123123

124124
def click_data_demo(args: argparse.Namespace) -> None:
125125
"""Demonstration for learning to rank with click data."""
126-
data = load_mlsr_10k(args.data, args.cache)
126+
data = load_mslr_10k(args.data, args.cache)
127127
train, test = simulate_clicks(data)
128128
assert test is not None
129129

doc/tutorials/dask.rst

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -355,15 +355,18 @@ Working with asyncio
355355

356356
.. versionadded:: 1.2.0
357357

358-
XGBoost's dask interface supports the new ``asyncio`` in Python and can be integrated into
359-
asynchronous workflows. For using dask with asynchronous operations, please refer to
360-
`this dask example <https://examples.dask.org/applications/async-await.html>`_ and document in
361-
`distributed <https://distributed.dask.org/en/latest/asynchronous.html>`_. To use XGBoost's
362-
dask interface asynchronously, the ``client`` which is passed as an argument for training and
363-
prediction must be operating in asynchronous mode by specifying ``asynchronous=True`` when the
364-
``client`` is created (example below). All functions (including ``DaskDMatrix``) provided
365-
by the functional interface will then return coroutines which can then be awaited to retrieve
366-
their result.
358+
XGBoost's dask interface supports the new :py:mod:`asyncio` in Python and can be
359+
integrated into asynchronous workflows. For using dask with asynchronous operations,
360+
please refer to `this dask example
361+
<https://examples.dask.org/applications/async-await.html>`_ and document in `distributed
362+
<https://distributed.dask.org/en/latest/asynchronous.html>`_. To use XGBoost's Dask
363+
interface asynchronously, the ``client`` which is passed as an argument for training and
364+
prediction must be operating in asynchronous mode by specifying ``asynchronous=True`` when
365+
the ``client`` is created (example below). All functions (including ``DaskDMatrix``)
366+
provided by the functional interface will then return coroutines which can then be awaited
367+
to retrieve their result. Please note that XGBoost is a compute-bounded application, where
368+
parallelism is more important than concurrency. The support for `asyncio` is more about
369+
compatibility instead of performance gain.
367370

368371
Functional interface:
369372

@@ -526,6 +529,47 @@ See https://github.com/coiled/dask-xgboost-nyctaxi for a set of examples of usin
526529
with dask and optuna.
527530

528531

532+
.. _ltr-dask:
533+
534+
****************
535+
Learning to Rank
536+
****************
537+
538+
.. versionadded:: 3.0.0
539+
540+
.. note::
541+
542+
Position debiasing is not yet supported.
543+
544+
There are two operation modes in the Dask learning to rank for performance reasons. The
545+
difference is whether a distributed global sort is needed. Please see :ref:`ltr-dist` for
546+
how ranking works with distributed training in general. Below we will discuss some of the
547+
Dask-specific features.
548+
549+
First, if you use the :py:class:`~xgboost.dask.DaskQuantileDMatrix` interface or the
550+
:py:class:`~xgboost.dask.DaskXGBRanker` with ``allow_group_split`` set to ``True``,
551+
XGBoost will try to sort and group the samples for each worker based on the query ID. This
552+
mode tries to skip the global sort and sort only worker-local data, and hence no
553+
inter-worker data shuffle. Please note that even worker-local sort is costly, particularly
554+
in terms of memory usage as there's no spilling when
555+
:py:meth:`~pandas.DataFrame.sort_values` is used, and we need to concatenate the
556+
data. XGBoost first checks whether the QID is already sorted before actually performing
557+
the sorting operation. One can choose this if the query groups are relatively consecutive,
558+
meaning most of the samples within a query group are close to each other and are likely to
559+
be resided to the same worker. Don't use this if you have performed a random shuffle on
560+
your data.
561+
562+
If the input data is random, then there's no way we can guarantee most of data within the
563+
same group being in the same worker. For large query groups, this might not be an
564+
issue. But for small query groups, it's possible that each worker gets only one or two
565+
samples from their group for all groups, which can lead to disastrous performance. In that
566+
case, we can partition the data according to query group, which is the default behavior of
567+
the :py:class:`~xgboost.dask.DaskXGBRanker` unless the ``allow_group_split`` is set to
568+
``True``. This mode performs a sort and a groupby on the entire dataset in addition to an
569+
encoding operation for the query group IDs. Along with partition fragmentation, this
570+
option can lead to slow performance. See
571+
:ref:`sphx_glr_python_dask-examples_dask_learning_to_rank.py` for a worked example.
572+
529573
.. _tracker-ip:
530574

531575
***************

doc/tutorials/learning_to_rank.rst

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,26 @@ On the other hand, if you have comparatively small amount of training data:
165165

166166
For any method chosen, you can modify ``lambdarank_num_pair_per_sample`` to control the amount of pairs generated.
167167

168+
.. _ltr-dist:
169+
168170
********************
169171
Distributed Training
170172
********************
171-
XGBoost implements distributed learning-to-rank with integration of multiple frameworks including Dask, Spark, and PySpark. The interface is similar to the single-node counterpart. Please refer to document of the respective XGBoost interface for details. Scattering a query group onto multiple workers is theoretically sound but can affect the model accuracy. For most of the use cases, the small discrepancy is not an issue, as the amount of training data is usually large when distributed training is used. As a result, users don't need to partition the data based on query groups. As long as each data partition is correctly sorted by query IDs, XGBoost can aggregate sample gradients accordingly.
173+
174+
XGBoost implements distributed learning-to-rank with integration of multiple frameworks
175+
including :doc:`Dask </tutorials/dask>`, :doc:`Spark </jvm/xgboost4j_spark_tutorial>`, and
176+
:doc:`PySpark </tutorials/spark_estimator>`. The interface is similar to the single-node
177+
counterpart. Please refer to document of the respective XGBoost interface for details.
178+
179+
.. warning::
180+
181+
Position-debiasing is not yet supported for existing distributed interfaces.
182+
183+
XGBoost works with collective operations, which means data is scattered to multiple workers. We can divide the data partitions by query group and ensure no query group is split among workers. However, this requires a costly sort and groupby operation and might only be necessary for selected use cases. Splitting and scattering a query group to multiple workers is theoretically sound but can affect the model's accuracy. If there are only a small number of groups sitting at the boundaries of workers, the small discrepancy is not an issue, as the amount of training data is usually large when distributed training is used.
184+
185+
For a longer explanation, assuming the pairwise ranking method is used, we calculate the gradient based on relevance degree by constructing pairs within a query group. If a single query group is split among workers and we use worker-local data for gradient calculation, then we are simply sampling pairs from a smaller group for each worker to calculate the gradient and the evaluation metric. The comparison between each pair doesn't change because a group is split into sub-groups, what changes is the number of total and effective pairs and normalizers like `IDCG`. One can generate more pairs from a large group than it's from two smaller subgroups. As a result, the obtained gradient is still valid from a theoretical standpoint but might not be optimal. As long as each data partitions within a worker are correctly sorted by query IDs, XGBoost can aggregate sample gradients accordingly. And both the (Py)Spark interface and the Dask interface can sort the data according to query ID, please see respected tutorials for more information.
186+
187+
However, it's possible that a distributed framework shuffles the data during map reduce and splits every query group into multiple workers. In that case, the performance would be disastrous. As a result, it depends on the data and the framework for whether a sorted groupby is needed.
172188

173189
*******************
174190
Reproducible Result

python-package/xgboost/core.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3215,11 +3215,7 @@ def trees_to_dataframe(self, fmap: PathLike = "") -> DataFrame:
32153215
}
32163216
)
32173217

3218-
if callable(getattr(df, "sort_values", None)):
3219-
# pylint: disable=no-member
3220-
return df.sort_values(["Tree", "Node"]).reset_index(drop=True)
3221-
# pylint: disable=no-member
3222-
return df.sort(["Tree", "Node"]).reset_index(drop=True)
3218+
return df.sort_values(["Tree", "Node"]).reset_index(drop=True)
32233219

32243220
def _assign_dmatrix_features(self, data: DMatrix) -> None:
32253221
if data.num_row() == 0:

0 commit comments

Comments
 (0)