Skip to content

Commit 6891f51

Browse files
authored
[coll] Expose configuration. (dmlc#10983)
1 parent b835917 commit 6891f51

File tree

20 files changed

+437
-189
lines changed

20 files changed

+437
-189
lines changed

doc/python/python_api.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,17 @@ PySpark API
192192
:members:
193193
:inherited-members:
194194
:show-inheritance:
195+
196+
197+
Collective
198+
----------
199+
200+
.. automodule:: xgboost.collective
201+
202+
.. autoclass:: xgboost.collective.Config
203+
204+
.. autofunction:: xgboost.collective.init
205+
206+
.. automodule:: xgboost.tracker
207+
208+
.. autoclass:: xgboost.tracker.RabitTracker

doc/tutorials/dask.rst

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -536,25 +536,22 @@ Troubleshooting
536536
- In some environments XGBoost might fail to resolve the IP address of the scheduler, a
537537
symptom is user receiving ``OSError: [Errno 99] Cannot assign requested address`` error
538538
during training. A quick workaround is to specify the address explicitly. To do that
539-
dask config is used:
539+
the collective :py:class:`~xgboost.collective.Config` is used:
540540

541-
.. versionadded:: 1.6.0
541+
.. versionadded:: 3.0.0
542542

543543
.. code-block:: python
544544
545545
import dask
546546
from distributed import Client
547547
from xgboost import dask as dxgb
548+
from xgboost.collective import Config
549+
548550
# let xgboost know the scheduler address
549-
dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
551+
coll_cfg = Config(retry=1, timeout=20, tracker_host_ip="10.23.170.98", tracker_port=0)
550552
551553
with Client(scheduler_file="sched.json") as client:
552-
reg = dxgb.DaskXGBRegressor()
553-
554-
# We can specify the port for XGBoost as well
555-
with dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"}):
556-
reg = dxgb.DaskXGBRegressor()
557-
554+
reg = dxgb.DaskXGBRegressor(coll_cfg=coll_cfg)
558555
559556
- Please note that XGBoost requires a different port than dask. By default, on a unix-like
560557
system XGBoost uses the port 0 to find available ports, which may fail if a user is

python-package/xgboost/collective.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import logging
55
import os
66
import pickle
7+
from dataclasses import dataclass
78
from enum import IntEnum, unique
8-
from typing import Any, Dict, List, Optional
9+
from typing import Any, Dict, Optional, TypeAlias, Union
910

1011
import numpy as np
1112

@@ -15,7 +16,53 @@
1516
LOGGER = logging.getLogger("[xgboost.collective]")
1617

1718

18-
def init(**args: Any) -> None:
19+
_ArgVals: TypeAlias = Optional[Union[int, str]]
20+
_Args: TypeAlias = Dict[str, _ArgVals]
21+
22+
23+
@dataclass
24+
class Config:
25+
"""User configuration for the communicator context. This is used for easier
26+
integration with distributed frameworks. Users of the collective module can pass the
27+
parameters directly into tracker and the communicator.
28+
29+
.. versionadded:: 3.0
30+
31+
Attributes
32+
----------
33+
retry : See `dmlc_retry` in :py:meth:`init`.
34+
35+
timeout :
36+
See `dmlc_timeout` in :py:meth:`init`. This is only used for communicators, not
37+
the tracker. They are different parameters since the timeout for tracker limits
38+
only the time for starting and finalizing the communication group, whereas the
39+
timeout for communicators limits the time used for collective operations.
40+
41+
tracker_host_ip : See :py:class:`~xgboost.tracker.RabitTracker`.
42+
43+
tracker_port : See :py:class:`~xgboost.tracker.RabitTracker`.
44+
45+
tracker_timeout : See :py:class:`~xgboost.tracker.RabitTracker`.
46+
47+
"""
48+
49+
retry: Optional[int] = None
50+
timeout: Optional[int] = None
51+
52+
tracker_host_ip: Optional[str] = None
53+
tracker_port: Optional[int] = None
54+
tracker_timeout: Optional[int] = None
55+
56+
def get_comm_config(self, args: _Args) -> _Args:
57+
"""Update the arguments for the communicator."""
58+
if self.retry is not None:
59+
args["dmlc_retry"] = self.retry
60+
if self.timeout is not None:
61+
args["dmlc_timeout"] = self.timeout
62+
return args
63+
64+
65+
def init(**args: _ArgVals) -> None:
1966
"""Initialize the collective library with arguments.
2067
2168
Parameters
@@ -36,9 +83,7 @@ def init(**args: Any) -> None:
3683
- dmlc_timeout: Timeout in seconds.
3784
- dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication.
3885
39-
Only applicable to the Federated communicator (use upper case for environment
40-
variables, use lower case for runtime configuration):
41-
86+
Only applicable to the Federated communicator:
4287
- federated_server_address: Address of the federated server.
4388
- federated_world_size: Number of federated workers.
4489
- federated_rank: Rank of the current worker.
@@ -47,6 +92,9 @@ def init(**args: Any) -> None:
4792
- federated_client_key: Client key file path. Only needed for the SSL mode.
4893
- federated_client_cert: Client certificate file path. Only needed for the SSL
4994
mode.
95+
96+
Use upper case for environment variables, use lower case for runtime configuration.
97+
5098
"""
5199
_check_call(_LIB.XGCommunicatorInit(make_jcargs(**args)))
52100

@@ -117,7 +165,6 @@ def get_processor_name() -> str:
117165
name_str = ctypes.c_char_p()
118166
_check_call(_LIB.XGCommunicatorGetProcessorName(ctypes.byref(name_str)))
119167
value = name_str.value
120-
assert value
121168
return py_str(value)
122169

123170

@@ -247,7 +294,7 @@ def signal_error() -> None:
247294
class CommunicatorContext:
248295
"""A context controlling collective communicator initialization and finalization."""
249296

250-
def __init__(self, **args: Any) -> None:
297+
def __init__(self, **args: _ArgVals) -> None:
251298
self.args = args
252299
key = "dmlc_nccl_path"
253300
if args.get(key, None) is not None:
@@ -275,12 +322,12 @@ def __init__(self, **args: Any) -> None:
275322
except ImportError:
276323
pass
277324

278-
def __enter__(self) -> Dict[str, Any]:
325+
def __enter__(self) -> _Args:
279326
init(**self.args)
280327
assert is_distributed()
281328
LOGGER.debug("-------------- communicator say hello ------------------")
282329
return self.args
283330

284-
def __exit__(self, *args: List) -> None:
331+
def __exit__(self, *args: Any) -> None:
285332
finalize()
286333
LOGGER.debug("--------------- communicator say bye ------------------")

python-package/xgboost/compat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
# pylint: disable= invalid-name, unused-import
1+
# pylint: disable=invalid-name,unused-import
22
"""For compatibility and optional dependencies."""
33
import importlib.util
44
import logging
55
import sys
66
import types
7-
from typing import Any, Dict, List, Optional, Sequence, cast
7+
from typing import Any, Sequence, cast
88

99
import numpy as np
1010

@@ -13,8 +13,9 @@
1313
assert sys.version_info[0] == 3, "Python 2 is no longer supported."
1414

1515

16-
def py_str(x: bytes) -> str:
16+
def py_str(x: bytes | None) -> str:
1717
"""convert c string back to python string"""
18+
assert x is not None # ctypes might return None
1819
return x.decode("utf-8") # type: ignore
1920

2021

0 commit comments

Comments
 (0)