44import logging
55import os
66import pickle
7+ from dataclasses import dataclass
78from enum import IntEnum , unique
8- from typing import Any , Dict , List , Optional
9+ from typing import Any , Dict , Optional , TypeAlias , Union
910
1011import numpy as np
1112
1516LOGGER = logging .getLogger ("[xgboost.collective]" )
1617
1718
18- def init (** args : Any ) -> None :
19+ _ArgVals : TypeAlias = Optional [Union [int , str ]]
20+ _Args : TypeAlias = Dict [str , _ArgVals ]
21+
22+
23+ @dataclass
24+ class Config :
25+ """User configuration for the communicator context. This is used for easier
26+ integration with distributed frameworks. Users of the collective module can pass the
27+ parameters directly into tracker and the communicator.
28+
29+ .. versionadded:: 3.0
30+
31+ Attributes
32+ ----------
33+ retry : See `dmlc_retry` in :py:meth:`init`.
34+
35+ timeout :
36+ See `dmlc_timeout` in :py:meth:`init`. This is only used for communicators, not
37+ the tracker. They are different parameters since the timeout for tracker limits
38+ only the time for starting and finalizing the communication group, whereas the
39+ timeout for communicators limits the time used for collective operations.
40+
41+ tracker_host_ip : See :py:class:`~xgboost.tracker.RabitTracker`.
42+
43+ tracker_port : See :py:class:`~xgboost.tracker.RabitTracker`.
44+
45+ tracker_timeout : See :py:class:`~xgboost.tracker.RabitTracker`.
46+
47+ """
48+
49+ retry : Optional [int ] = None
50+ timeout : Optional [int ] = None
51+
52+ tracker_host_ip : Optional [str ] = None
53+ tracker_port : Optional [int ] = None
54+ tracker_timeout : Optional [int ] = None
55+
56+ def get_comm_config (self , args : _Args ) -> _Args :
57+ """Update the arguments for the communicator."""
58+ if self .retry is not None :
59+ args ["dmlc_retry" ] = self .retry
60+ if self .timeout is not None :
61+ args ["dmlc_timeout" ] = self .timeout
62+ return args
63+
64+
65+ def init (** args : _ArgVals ) -> None :
1966 """Initialize the collective library with arguments.
2067
2168 Parameters
@@ -36,9 +83,7 @@ def init(**args: Any) -> None:
3683 - dmlc_timeout: Timeout in seconds.
3784 - dmlc_nccl_path: Path to load (dlopen) nccl for GPU-based communication.
3885
39- Only applicable to the Federated communicator (use upper case for environment
40- variables, use lower case for runtime configuration):
41-
86+ Only applicable to the Federated communicator:
4287 - federated_server_address: Address of the federated server.
4388 - federated_world_size: Number of federated workers.
4489 - federated_rank: Rank of the current worker.
@@ -47,6 +92,9 @@ def init(**args: Any) -> None:
4792 - federated_client_key: Client key file path. Only needed for the SSL mode.
4893 - federated_client_cert: Client certificate file path. Only needed for the SSL
4994 mode.
95+
96+ Use upper case for environment variables, use lower case for runtime configuration.
97+
5098 """
5199 _check_call (_LIB .XGCommunicatorInit (make_jcargs (** args )))
52100
@@ -117,7 +165,6 @@ def get_processor_name() -> str:
117165 name_str = ctypes .c_char_p ()
118166 _check_call (_LIB .XGCommunicatorGetProcessorName (ctypes .byref (name_str )))
119167 value = name_str .value
120- assert value
121168 return py_str (value )
122169
123170
@@ -247,7 +294,7 @@ def signal_error() -> None:
247294class CommunicatorContext :
248295 """A context controlling collective communicator initialization and finalization."""
249296
250- def __init__ (self , ** args : Any ) -> None :
297+ def __init__ (self , ** args : _ArgVals ) -> None :
251298 self .args = args
252299 key = "dmlc_nccl_path"
253300 if args .get (key , None ) is not None :
@@ -275,12 +322,12 @@ def __init__(self, **args: Any) -> None:
275322 except ImportError :
276323 pass
277324
278- def __enter__ (self ) -> Dict [ str , Any ] :
325+ def __enter__ (self ) -> _Args :
279326 init (** self .args )
280327 assert is_distributed ()
281328 LOGGER .debug ("-------------- communicator say hello ------------------" )
282329 return self .args
283330
284- def __exit__ (self , * args : List ) -> None :
331+ def __exit__ (self , * args : Any ) -> None :
285332 finalize ()
286333 LOGGER .debug ("--------------- communicator say bye ------------------" )
0 commit comments