diff --git a/horovod/ray/elastic_v2.py b/horovod/ray/elastic_v2.py new file mode 100644 index 0000000000..623408ea4c --- /dev/null +++ b/horovod/ray/elastic_v2.py @@ -0,0 +1,304 @@ +from typing import Dict, Callable, Any, Optional, List +import logging +import ray.exceptions +import socket + +import time +import threading + +from horovod.runner.http.http_server import RendezvousServer +from horovod.ray.utils import detect_nics +from horovod.runner.elastic.rendezvous import create_rendezvous_handler +from horovod.runner.gloo_run import (create_slot_env_vars, create_run_env_vars, + _get_min_start_hosts) +from horovod.ray.worker import BaseHorovodWorker +from horovod.ray.elastic import RayHostDiscovery +from horovod.runner.elastic.driver import ElasticDriver + +logger = logging.getLogger(__name__) + +if hasattr(ray.exceptions, "GetTimeoutError"): + GetTimeoutError = ray.exceptions.GetTimeoutError +elif hasattr(ray.exceptions, "RayTimeoutError"): + GetTimeoutError = ray.exceptions.RayTimeoutError +else: + raise ImportError("Unable to find Ray Timeout Error class " + "(GetTimeoutError, RayTimeoutError). " + "This is likely due to the Ray version not " + "compatible with Horovod-Ray.") + +class ElasticAdapter: + """Adapter for executing Ray calls for elastic Horovod jobs.""" + def __init__(self, + settings, + min_np: int, + max_np: Optional[int] = None, + use_gpu: bool = False, + cpus_per_worker: int = 1, + gpus_per_worker: Optional[int] = None, + env_vars: dict = None, + override_discovery: bool=True, + reset_limit: int = None, + elastic_timeout: int = 600, + **kwargs: Optional[Dict]): + self.settings = settings + if override_discovery: + settings.discovery = RayHostDiscovery( + use_gpu=use_gpu, + cpus_per_slot=cpus_per_worker, + gpus_per_slot=gpus_per_worker) + self.cpus_per_worker = cpus_per_worker + self.gpus_per_worker = gpus_per_worker + self.use_gpu = use_gpu + # moved from settings + self.min_np = min_np + self.max_np = max_np + self.np = min_np + self.reset_limit = reset_limit + self.elastic_timeout = elastic_timeout + self.driver = None + self.rendezvous = None + self.env_vars = env_vars or {} + + def start(self, + executable_cls: type = None, + executable_args: Optional[List] = None, + executable_kwargs: Optional[Dict] = None): + + self.rendezvous = RendezvousServer(self.settings.verbose) + self.driver = ElasticDriver( + rendezvous=self.rendezvous, + discovery=self.settings.discovery, + min_np=self.min_np, + max_np=self.max_np, + timeout=self.elastic_timeout, + reset_limit=self.reset_limit, + verbose=self.settings.verbose) + handler = create_rendezvous_handler(self.driver) + logger.debug("[ray] starting rendezvous") + global_rendezv_port = self.rendezvous.start(handler) + + logger.debug(f"[ray] waiting for {self.np} to start.") + self.driver.wait_for_available_slots(self.np) + + # Host-to-host common interface detection + # requires at least 2 hosts in an elastic job. + min_hosts = _get_min_start_hosts(self.settings) + current_hosts = self.driver.wait_for_available_slots( + self.np, min_hosts=min_hosts) + logger.debug("[ray] getting common interfaces") + nics = detect_nics( + self.settings, + all_host_names=current_hosts.host_assignment_order, + ) + logger.debug("[ray] getting driver IP") + server_ip = socket.gethostbyname(socket.gethostname()) + self.run_env_vars = create_run_env_vars( + server_ip, nics, global_rendezv_port, elastic=True) + + self.executable_cls = executable_cls + self.executable_args = executable_args + self.executable_kwargs = executable_kwargs + + + def _create_resources(self, hostname: str): + resources = dict( + num_cpus=self.cpus_per_worker, + num_gpus=int(self.use_gpu) * self.gpus_per_worker, + resources={f"node:{hostname}": 0.01}) + return resources + + def _create_remote_worker(self, slot_info, worker_env_vars): + hostname = slot_info.hostname + loaded_worker_cls = self.remote_worker_cls.options( + **self._create_resources(hostname)) + + worker = loaded_worker_cls.remote() + worker.update_env_vars.remote(worker_env_vars) + worker.update_env_vars.remote(create_slot_env_vars(slot_info)) + if self.use_gpu: + visible_devices = ",".join( + [str(i) for i in range(slot_info.local_size)]) + worker.update_env_vars.remote({ + "CUDA_VISIBLE_DEVICES": + visible_devices + }) + return worker + + def _create_spawn_worker_fn(self, return_results: List, + worker_fn: Callable, + queue: "ray.util.Queue") -> Callable: + self.remote_worker_cls = ray.remote(BaseHorovodWorker) + # event = register_shutdown_event() + worker_env_vars = {} + worker_env_vars.update(self.run_env_vars.copy()) + worker_env_vars.update(self.env_vars.copy()) + worker_env_vars.update({"PYTHONUNBUFFERED": "1"}) + + def worker_loop(slot_info, events): + def ping_worker(worker): + # There is an odd edge case where a node can be removed + # before the remote worker is started, leading to a failure + # in trying to create the horovod mesh. + try: + ping = worker.execute.remote(lambda _: 1) + ray.get(ping, timeout=10) + except Exception as e: + logger.error(f"{slot_info.hostname}: Ping failed - {e}") + return False + return True + + worker = self._create_remote_worker(slot_info, worker_env_vars) + if not ping_worker(worker): + return 1, time.time() + + ray.get(worker.set_queue.remote(queue)) + future = worker.execute.remote(worker_fn) + + result = None + while result is None: + try: + # TODO: make this event driven at some point. + retval = ray.get(future, timeout=0.1) + return_results.append((slot_info.rank, retval)) + # Success + result = 0, time.time() + except GetTimeoutError: + # Timeout + if any(e.is_set() for e in events): + ray.kill(worker) + result = 1, time.time() + except Exception as e: + logger.error(f"{slot_info.hostname}[{slot_info.rank}]:{e}") + ray.kill(worker) + result = 1, time.time() + logger.debug(f"Worker ({slot_info}) routine is done!") + return result + + return worker_loop + + + def run(self, + fn: Callable[[Any], Any], + args: Optional[List] = None, + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function that can be executed with arbitrary + args and keyword arguments. + args: List of arguments to be passed into the target function. + kwargs: Dictionary of keyword arguments to be + passed into the target function. + + Returns: + Deserialized return values from the target function. + """ + args = args or [] + kwargs = kwargs or {} + f = lambda _: fn(*args, **kwargs) + return self._run_remote(f, callbacks=callbacks) + + def _run_remote(self, + worker_fn: Callable, + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + worker_fn: Target elastic function that can be executed. + callbacks: List of callables. Each callback must either + be a callable function or a class that implements __call__. + Every callback will be invoked on every value logged + by the rank 0 worker. + + Returns: + List of return values from every completed worker. + """ + return_values = [] + from ray.util.queue import Queue + import inspect + args = inspect.getfullargspec(Queue).args + if "actor_options" not in args: + # Ray 1.1 and less + _queue = Queue() + else: + _queue = Queue(actor_options={ + "num_cpus": 0, + "resources": { + ray.state.current_node_id(): 0.001 + } + }) + self.driver.start( + self.np, + self._create_spawn_worker_fn(return_values, worker_fn, _queue)) + + def _process_calls(queue, callbacks, event): + if not callbacks: + return + while queue.actor: + if not queue.empty(): + result = queue.get_nowait() + for c in callbacks: + c(result) + # avoid slamming the CI + elif event.is_set(): + break + time.sleep(0.1) + + try: + event = threading.Event() + _callback_thread = threading.Thread( + target=_process_calls, + args=(_queue, callbacks, event), + daemon=True) + _callback_thread.start() + res = self.driver.get_results() + event.set() + if _callback_thread: + _callback_thread.join(timeout=60) + finally: + if hasattr(_queue, "shutdown"): + _queue.shutdown() + else: + done_ref = _queue.actor.__ray_terminate__.remote() + done, not_done = ray.wait([done_ref], timeout=5) + if not_done: + ray.kill(_queue.actor) + self.driver.stop() + + if res.error_message is not None: + raise RuntimeError(res.error_message) + + for name, value in sorted( + res.worker_results.items(), key=lambda item: item[1][1]): + exit_code, timestamp = value + if exit_code != 0: + raise RuntimeError( + 'Horovod detected that one or more processes ' + 'exited with non-zero ' + 'status, thus causing the job to be terminated. ' + 'The first process ' + 'to do so was:\nProcess name: {name}\nExit code: {code}\n' + .format(name=name, code=exit_code)) + + return_values = [ + value for k, value in sorted(return_values, key=lambda kv: kv[0]) + ] + return return_values + + def run_remote(self, + fn: Callable[[Any], Any]) -> List[Any]: + raise NotImplementedError("ObjectRefs cannot be returned from Elastic runs as the workers are ephemeral") + + def execute(self, fn: Callable[["executable_cls"], Any], + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function to be invoked on every object. + + Returns: + Deserialized return values from the target function. + """ + return ray.get(self._run_remote(fn, callbacks=callbacks)) \ No newline at end of file diff --git a/horovod/ray/runner_v2.py b/horovod/ray/runner_v2.py new file mode 100644 index 0000000000..13883f4eeb --- /dev/null +++ b/horovod/ray/runner_v2.py @@ -0,0 +1,559 @@ +import ray +from ray.util.placement_group import get_current_placement_group + +from collections import defaultdict +from dataclasses import dataclass, asdict +import os +from typing import Dict, Callable, Any, Optional, List, Union +import logging +import ray.exceptions + +from horovod.runner.common.util import secret, timeout, hosts +from horovod.runner.http.http_server import RendezvousServer +from horovod.ray.utils import detect_nics, nics_to_env_var, map_blocking +from horovod.ray.strategy import ColocatedStrategy, PGStrategy +from horovod.ray.elastic_v2 import ElasticAdapter + +logger = logging.getLogger(__name__) + +@dataclass +class MiniSettings: + """Minimal settings necessary for Ray to work. + + Can be replaced with a proper Horovod Settings object. + """ + nics: set = None + verbose: int = 1 + key: str = secret.make_secret_key() if secret else None + ssh_port: int = None + ssh_identity_file: str = None + timeout_s: int = 300 + placement_group_timeout_s: int = 100 + elastic: bool = False + + @property + def start_timeout(self): + return timeout.Timeout( + self.timeout_s, + message="Timed out waiting for {activity}. Please " + "check connectivity between servers. You " + "may need to increase the --start-timeout " + "parameter if you have too many servers.") + + +class Coordinator: + """Responsible for instantiating the Rendezvous server. + + Args: + settings: Horovod Settings object.""" + rendezvous = None + global_rendezv_port = None + nics = None + node_id_by_rank = None + + def __init__( + self, + settings, + ): + self.settings = settings + self.node_id_by_rank = defaultdict(list) + self._hostnames = set() + + @property + def world_size(self) -> int: + return sum(len(ranks) for ranks in self.node_id_by_rank.values()) + + @property + def hostnames(self): + return self._hostnames + + @property + def node_id_string(self) -> str: + return ",".join([ + f"{node_id}:{len(ranks)}" + for node_id, ranks in self.node_id_by_rank.items() + ]) + + def register(self, hostname: str, node_id: str, world_rank: int): + self._hostnames.add(hostname) + self.node_id_by_rank[node_id].append(world_rank) + + def finalize_registration(self) -> dict: + """Return a dictionary for all ranks.""" + rank_to_info = {} + + cross_sizes = defaultdict(int) + cross_ranks = {} + for rank_list in self.node_id_by_rank.values(): + for local_rank, world_rank in enumerate(rank_list): + cross_ranks[world_rank] = cross_sizes[local_rank] + cross_sizes[local_rank] += 1 + + for node_world_rank, (node_id, ranks) in enumerate( + self.node_id_by_rank.items()): + for local_rank, world_rank in enumerate(ranks): + rank_to_info[world_rank] = dict( + HOROVOD_CROSS_RANK=cross_ranks[world_rank], + HOROVOD_CROSS_SIZE=cross_sizes[local_rank], + HOROVOD_LOCAL_RANK=local_rank, + HOROVOD_LOCAL_SIZE=len(ranks)) + return rank_to_info + + def establish_rendezvous(self) -> Dict[str, str]: + """Creates the rendezvous server and identifies the nics to be used. + + Returns: + Environment variables for each worker. + """ + + # start global rendezvous server and get port that it is listening on + self.rendezvous = RendezvousServer(self.settings.verbose) + + # allocate processes into slots + # hosts = parse_hosts(hosts_string="10.11.11.11:4,10.11.11.12:4") + parsed_node_ids = hosts.parse_hosts(hosts_string=self.node_id_string) + host_alloc_plan = hosts.get_host_assignments(parsed_node_ids, + self.world_size) + + # start global rendezvous server and get port that it is listening on + self.global_rendezv_port = self.rendezvous.start() + self.rendezvous.init(host_alloc_plan) + + return { + # needs to be real address + "HOROVOD_GLOO_RENDEZVOUS_ADDR": ray.util.get_node_ip_address(), + "HOROVOD_GLOO_RENDEZVOUS_PORT": str(self.global_rendezv_port), + "HOROVOD_CONTROLLER": "gloo", + "HOROVOD_CPU_OPERATIONS": "gloo", + } + +@dataclass +class BaseParams: + np: Optional[int] = None + cpus_per_worker: int = 1 + use_gpu: bool = False + gpus_per_worker: Optional[int] = None + use_current_placement_group: bool = True + env_vars: dict = None + def __post_init__(self): + if self.gpus_per_worker and not self.use_gpu: + raise ValueError("gpus_per_worker is set, but use_gpu is False. " + "use_gpu must be True if gpus_per_worker is " + "set. ") + if self.use_gpu and isinstance(self.gpus_per_worker, + int) and self.gpus_per_worker < 1: + raise ValueError( + f"gpus_per_worker must be >= 1: Got {self.gpus_per_worker}.") + self.gpus_per_worker = self.gpus_per_worker or int(self.use_gpu) + +@dataclass +class ElasticParams(BaseParams): + min_np: int = 1 + max_np: int = None + reset_limit: int = None + elastic_timeout: int = 600 + override_discovery: bool = True + def __post_init__(self): + super().__post_init__() + self.np = self.min_np + +@dataclass +class NoneElasticParams(BaseParams): + num_hosts: Optional[int] = None + np_per_host: int = 1 + +class RayExecutorV2: + """Job class for Horovod + Ray integration. + + Args: + settings (horovod.Settings): Configuration for job setup. You can + use a standard Horovod Settings object or create one directly + from RayExecutorV2.create_settings. + num_workers (int): Number of workers to use for training. + cpus_per_worker (int): Number of CPU resources to allocate to + each worker. + use_gpu (bool): Whether to use GPU for allocation. TODO: this + can be removed. + gpus_per_worker (int): Number of GPU resources to allocate to + each worker. + num_hosts (int): Alternative API to ``num_workers``. Number of + machines to execute the job on. Used to enforce equal number of + workers on each machine. + num_workers_per_host (int): Alternative API to + ``num_workers``. Number of workers to be placed on each machine. + Used to enforce equal number of workers on each machine. Only + used in conjunction with `num_hosts`. + use_current_placement_group (bool): Whether to use the current + placement group instead of creating a new one. Defaults to True. + + """ + + @classmethod + def create_settings(cls, + timeout_s=30, + ssh_identity_file=None, + ssh_str=None, + placement_group_timeout_s=100, + nics=None): + """Create a mini setting object. + + Args: + timeout_s (int): Timeout parameter for Gloo rendezvous. + ssh_identity_file (str): Path to the identity file to + ssh into different hosts on the cluster. + ssh_str (str): CAUTION WHEN USING THIS. Private key + file contents. Writes the private key to ssh_identity_file. + placement_group_timeout_s (int): Timeout parameter for Ray + Placement Group creation. + + Returns: + MiniSettings object. + """ + if ssh_str and not os.path.exists(ssh_identity_file): + with open(ssh_identity_file, "w") as f: + os.chmod(ssh_identity_file, 0o600) + f.write(ssh_str) + return MiniSettings( + ssh_identity_file=ssh_identity_file, + timeout_s=timeout_s, + placement_group_timeout_s=placement_group_timeout_s, + nics=nics) + + def __init__( + self, + settings, + params: Union[ElasticParams, NoneElasticParams]): + self.params = params + self.settings = settings + self.elastic = type(params) is ElasticParams + self.settings.elastic = self.elastic + + def start(self, + executable_cls: type = None, + executable_args: Optional[List] = None, + executable_kwargs: Optional[Dict] = None): + """Starts the workers and colocates them on all machines. + + We implement a node grouping because it seems like + our implementation doesn't quite work for imbalanced nodes. + Also, colocation performance is typically much better than + non-colocated workers. + + Args: + executable_cls (type): The class that will be created within + an actor (BaseHorovodWorker). This will allow Horovod + to establish its connections and set env vars. + executable_args (List): Arguments to be passed into the + worker class upon initialization. + executable_kwargs (Dict): Keyword arguments to be passed into the + worker class upon initialization. + extra_env_vars (Dict): Environment variables to be set + on the actors (worker processes) before initialization. + + """ + # Initialize adapter + # NonElasticAdapter if non elastic + # ElasticAdapter if elastic + self._initialize_adapter() + + kwargs_ = dict( + executable_cls=executable_cls, + executable_args=executable_args, + executable_kwargs=executable_kwargs) + return self._maybe_call_ray(self.adapter.start, **kwargs_) + + def _initialize_adapter(self): + kwargs = asdict(self.params) + logger.debug(f"Kwargs: {kwargs}") + Adapter = ElasticAdapter if self.elastic else NonElasticAdapter + self._is_remote = False + if ray.util.client.ray.is_connected(): + RemoteAdapter = ray.remote(Adapter) + self.adapter = RemoteAdapter.remote(self.settings, **kwargs) + self._is_remote = True + else: + self.adapter= Adapter(self.settings, **kwargs) + + def execute(self, fn: Callable[["executable_cls"], Any], + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function to be invoked on every object. + + Returns: + Deserialized return values from the target function. + """ + kwargs_ = dict(fn=fn, callbacks=callbacks) + # invoke run_remote + return self._maybe_call_ray(self.adapter.execute, **kwargs_) + + def run(self, + fn: Callable[[Any], Any], + args: Optional[List] = None, + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function that can be executed with arbitrary + args and keyword arguments. + args: List of arguments to be passed into the target function. + kwargs: Dictionary of keyword arguments to be + passed into the target function. + + Returns: + Deserialized return values from the target function. + """ + kwargs_ = dict(fn=fn, args=args, kwargs=kwargs, callbacks=callbacks) + return self._maybe_call_ray(self.adapter.run, **kwargs_) + + def run_remote(self, + fn: Callable[[Any], Any], + args: Optional[List] = None, + kwargs: Optional[Dict] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function that can be executed with arbitrary + args and keyword arguments. + args: List of arguments to be passed into the target function. + kwargs: Dictionary of keyword arguments to be + passed into the target function. + + Returns: + list: List of ObjectRefs that you can run `ray.get` on to + retrieve values. + """ + kwargs_ = dict(fn=fn, args=args, kwargs=kwargs) + return self._maybe_call_ray(self.adapter.run_remote, **kwargs_) + + def execute_single(self, + fn: Callable[["executable_cls"], Any]) -> List[Any]: + """Executes the provided function on the rank 0 worker (chief). + + Args: + fn: Target function to be invoked on the chief object. + + Returns: + Deserialized return values from the target function. + """ + kwargs = dict(fn=fn) + return self._maybe_call_ray(self.adapter.execute_single, **kwargs) + + def shutdown(self): + """Destroys the provided workers.""" + result = self._maybe_call_ray(self.adapter.shutdown) + del self.adapter + return result + + def _maybe_call_ray(self, driver_func, *args, **kwargs): + if self._is_remote: + return ray.get(driver_func.remote(*args, **kwargs)) + else: + return driver_func(**kwargs) + + +class NonElasticAdapter: + """Adapter for executing Ray calls for non-elastic Horovod jobs.""" + + def __init__(self, + settings, + np: Optional[int] = None, + num_hosts: Optional[int] = None, + np_per_host: int = 1, + cpus_per_worker: int = 1, + use_gpu: bool = False, + gpus_per_worker: Optional[int] = None, + use_current_placement_group: bool = True, + env_vars: Optional[Dict] = None, + **kwargs: Optional[Dict]): + + self.settings = settings + self.np = np + self.num_hosts = num_hosts + self.np_per_host = np_per_host + self.cpus_per_worker = cpus_per_worker + self.use_gpu = use_gpu + self.gpus_per_worker = gpus_per_worker or 1 + self.use_current_placement_group = use_current_placement_group + + self.workers = [] + self.strategy = None + self.env_vars = env_vars or {} + + def _start_executables(self, executable_cls, executable_args, + executable_kwargs): + def _start_exec(worker): + return worker.start_executable.remote( + executable_cls, executable_args, executable_kwargs) + + map_blocking(_start_exec, self.workers) + + def _create_strategy(self): + assert self.np is None or self.num_hosts is None + use_pg = self.use_current_placement_group and get_current_placement_group() + if self.np or use_pg: + if use_pg: + logger.info( + "Found an existing placement group, inheriting. " + "You can disable this behavior by setting " + "`use_current_placement_group=False`." + ) + num_workers = self.np or self.num_workers_per_host * self.num_hosts + return PGStrategy( + settings=self.settings, + num_workers=num_workers, + use_gpu=self.use_gpu, + cpus_per_worker=self.cpus_per_worker, + gpus_per_worker=self.gpus_per_worker, + force_create_placement_group=not self.use_current_placement_group) + else: + return ColocatedStrategy( + settings=self.settings, + num_hosts=self.num_hosts, + num_workers_per_host=self.np_per_host, + use_gpu=self.use_gpu, + cpus_per_worker=self.cpus_per_worker, + gpus_per_worker=self.gpus_per_worker) + + def start(self, + executable_cls: type = None, + executable_args: Optional[List] = None, + executable_kwargs: Optional[Dict] = None): + """Starts the workers and colocates them on all machines. + + We implement a node grouping because it seems like + our implementation doesn't quite work for imbalanced nodes. + Also, colocation performance is typically much better than + non-colocated workers. + + Args: + executable_cls (type): The class that will be created within + an actor (BaseHorovodWorker). This will allow Horovod + to establish its connections and set env vars. + executable_args (List): Arguments to be passed into the + worker class upon initialization. + executable_kwargs (Dict): Keyword arguments to be passed into the + worker class upon initialization. + extra_env_vars (Dict): Environment variables to be set + on the actors (worker processes) before initialization. + + """ + #if elastic, do the elastic driver setup(not start) + + self.strategy = self._create_strategy() + self.coordinator = Coordinator(self.settings) + executable_args = executable_args or [] + self.workers, node_workers = self.strategy.create_workers() + # Get all the hostnames of all workers + node_ids = map_blocking(lambda w: w.node_id.remote(), self.workers) + hostnames = map_blocking(lambda w: w.hostname.remote(), self.workers) + # Register each hostname to the coordinator. assumes the hostname + # ordering is the same. + for rank, (hostname, node_id) in enumerate(zip(hostnames, node_ids)): + self.coordinator.register(hostname, node_id, rank) + all_info = self.coordinator.finalize_registration() + + indexed_runners = dict(enumerate(self.workers)) + for rank, local_cross_env_var in all_info.items(): + indexed_runners[rank].update_env_vars.remote(local_cross_env_var) + + coordinator_envs = self.coordinator.establish_rendezvous() + coordinator_envs.update(self.env_vars) + nics = detect_nics( + self.settings, + all_host_names=list(self.coordinator.hostnames), + node_workers=node_workers) + coordinator_envs.update(nics_to_env_var(nics)) + + map_blocking(lambda w: w.update_env_vars.remote(coordinator_envs), + self.workers) + + self._start_executables(executable_cls, executable_args, + executable_kwargs) + + def execute(self, fn: Callable[["executable_cls"], Any], + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function to be invoked on every object. + + Returns: + Deserialized return values from the target function. + """ + return ray.get(self._run_remote(fn)) + + def run(self, + fn: Callable[[Any], Any], + args: Optional[List] = None, + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function that can be executed with arbitrary + args and keyword arguments. + args: List of arguments to be passed into the target function. + kwargs: Dictionary of keyword arguments to be + passed into the target function. + + Returns: + Deserialized return values from the target function. + """ + args = args or [] + kwargs = kwargs or {} + f = lambda w: fn(*args, **kwargs) + return ray.get(self._run_remote(fn=f)) + + def run_remote(self, + fn: Callable[[Any], Any], + args: Optional[List] = None, + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None): + args = args or [] + kwargs = kwargs or {} + f = lambda w: fn(*args, **kwargs) + return self._run_remote(fn=f) + + def _run_remote(self, + fn: Callable[[Any], Any]) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function that can be executed with arbitrary + args and keyword arguments. + args: List of arguments to be passed into the target function. + kwargs: Dictionary of keyword arguments to be + passed into the target function. + + Returns: + list: List of ObjectRefs that you can run `ray.get` on to + retrieve values. + """ + # Use run_remote for all calls + # for elastic, start the driver and launch the job + return [ + worker.execute.remote(fn) for worker in self.workers + ] + + def execute_single(self, + fn: Callable[["executable_cls"], Any]) -> List[Any]: + """Executes the provided function on the rank 0 worker (chief). + + Args: + fn: Target function to be invoked on the chief object. + + Returns: + Deserialized return values from the target function. + """ + return ray.get(self.workers[0].execute.remote(fn)) + + def shutdown(self): + """Destroys the provided workers.""" + for worker in self.workers: + del worker + + if self.strategy: + self.strategy.shutdown() diff --git a/test/single/test_ray_elastic_v2.py b/test/single/test_ray_elastic_v2.py new file mode 100644 index 0000000000..f92f3a0f0a --- /dev/null +++ b/test/single/test_ray_elastic_v2.py @@ -0,0 +1,361 @@ +"""Ray-Horovod Elastic training unit tests. + +This is currently not run on the Ray CI. +""" +from contextlib import contextmanager +import psutil +import os +import socket + +import mock +import pytest +import ray + +from horovod.common.util import gloo_built +from horovod.runner.elastic.discovery import HostDiscovery +from horovod.ray.elastic import RayHostDiscovery +from horovod.ray.runner_v2 import RayExecutorV2, ElasticParams + + + +@pytest.fixture +def ray_shutdown(): + yield + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_8_cpus(): + ray.init(num_cpus=8, resources={ + f"node:host-{i}": 1 for i in range(10)}) + yield + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_8_cpus_gpus(): + if "CUDA_VISIBLE_DEVICES" in os.environ: + if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) < 8: + pytest.skip("Avoiding mismatched GPU machine.") + ray.init(num_cpus=8, num_gpus=8, resources={ + f"node:host-{i}": 1 for i in range(10)}) + try: + yield + finally: + # The code after the yield will run as teardown code. + ray.shutdown() + + +class TestRayDiscoverySuite: + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_cpu_discovery(self, ray_shutdown): + ray.init(num_cpus=4, num_gpus=1) + discovery = RayHostDiscovery(cpus_per_slot=1) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 1 + assert list(mapping.values()) == [4] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_gpu_discovery(self, ray_shutdown): + ray.init(num_cpus=4, num_gpus=1) + discovery = RayHostDiscovery(use_gpu=True, cpus_per_slot=1) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 1 + assert list(mapping.values()) == [1] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_gpu_slot_discovery(self, ray_shutdown): + ray.init(num_cpus=4, num_gpus=4) + discovery = RayHostDiscovery( + use_gpu=True, cpus_per_slot=1, gpus_per_slot=2) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 1 + assert list(mapping.values()) == [2] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_multinode(self, monkeypatch): + def create_multi_node_mock(): + host_names = ["host-1", "host-2", "host-3"] + resources = {"GPU": 2, "CPU": 8} + + def create_node_entry(hostname): + return { + "NodeManagerAddress": hostname, + "Resources": resources.copy(), + "alive": True + } + + return map(create_node_entry, host_names) + + monkeypatch.setattr(ray, "nodes", create_multi_node_mock) + discovery = RayHostDiscovery(use_gpu=True, cpus_per_slot=1) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 3 + assert list(mapping.values()) == [2, 2, 2] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_multinode_gpus_per_slot(self, monkeypatch): + def create_multi_node_mock(): + host_names = ["host-1", "host-2", "host-3"] + resources = {"GPU": 2, "CPU": 8} + + def create_node_entry(hostname): + return { + "NodeManagerAddress": hostname, + "Resources": resources.copy(), + "alive": True + } + + return map(create_node_entry, host_names) + + monkeypatch.setattr(ray, "nodes", create_multi_node_mock) + discovery = RayHostDiscovery(use_gpu=True, gpus_per_slot=2) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 3 + assert list(mapping.values()) == [1, 1, 1] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_multinode_mismatch(self, monkeypatch): + def create_multi_node_mock(): + host_names = ["host-1", "host-2", "host-3"] + resources = {"CPU": 8} + + def create_node_entry(hostname): + return { + "NodeManagerAddress": hostname, + "Resources": resources.copy(), + "alive": True + } + + return map(create_node_entry, host_names) + + monkeypatch.setattr(ray, "nodes", create_multi_node_mock) + discovery = RayHostDiscovery(use_gpu=True, cpus_per_slot=1) + mapping = discovery.find_available_hosts_and_slots() + assert sum(mapping.values()) == 0 + + +class SimpleTestDiscovery(HostDiscovery): + def __init__(self, schedule): + self._schedule = schedule + self._generator = self.host_generator() + + def host_generator(self): + for iters, hosts in self._schedule: + iters = iters or 500 # max + for i in range(iters): + yield hosts + + def find_available_hosts_and_slots(self): + hostlist = next(self._generator) + hosts = {} + for item in hostlist: + host, slots = item.split(":") + slots = int(slots) + hosts[host] = slots + return hosts + + +class StatusCallback: + def __init__(self): + self._journal = [] + + def __call__(self, info_dict): + self._journal.append(info_dict) + + def fetch(self): + return self._journal.copy() + + +def _create_training_function(iterations): + def training_fn(): + import time + import torch + import horovod.torch as hvd + from horovod.ray import ray_logger + + hvd.init() + + model = torch.nn.Sequential(torch.nn.Linear(2, 2)) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + ray_logger.log({"started": True, "pid": os.getpid()}) + + @hvd.elastic.run + def train(state): + for state.epoch in range(state.epoch, iterations): + ray_logger.log({"training": True, "pid": os.getpid()}) + time.sleep(0.1) + state.commit() # triggers scale-up, scale-down + ray_logger.log({"finished": True, "pid": os.getpid()}) + + state = hvd.elastic.TorchState( + model, optimizer, batch=0, epoch=0, commits=0, rendezvous=0) + train(state) + return True + + return training_fn + + +@contextmanager +def fault_tolerance_patches(): + with mock.patch( + 'horovod.runner.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', + 0.1): + with mock.patch( + "horovod.runner.util.network.get_driver_ip", + return_value=socket.gethostbyname(socket.gethostname())): + yield + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +@pytest.mark.skip(reason='https://github.com/horovod/horovod/issues/3197') +def test_fault_tolerance_hosts_added_and_removed(ray_8_cpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:2']), + (30, ['host-1:2', 'host-2:1', 'host-3:1']), + (None, ['host-2:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutorV2.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutorV2( + settings, ElasticParams( + min_np=1, + cpus_per_slot=1, override_discovery=False)) + + training_fn = _create_training_function(iterations=50) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 1 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 4, events + assert sum(int("finished" in e) for e in events) == 1, events + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +@pytest.mark.skip(reason='https://github.com/horovod/horovod/issues/3197') +def test_fault_tolerance_hosts_remove_and_add(ray_8_cpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:2', 'host-2:1', 'host-3:2']), + (10, ['host-1:2']), + (None, ['host-1:2', 'host-4:1', 'host-5:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutorV2.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutorV2(settings, ElasticParams( + min_np=1, cpus_per_worker=1, override_discovery=False)) + + training_fn = _create_training_function(iterations=30) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 4 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 7, events + assert sum(int("finished" in e) for e in events) == 4, events + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +def test_max_np(ray_8_cpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:2']), + (None, ['host-1:2', 'host-4:1', 'host-5:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutorV2.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutorV2(settings, + ElasticParams(min_np=1, max_np=2, cpus_per_worker=1, override_discovery=False)) + + training_fn = _create_training_function(iterations=20) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 2 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 2, events + assert sum(int("finished" in e) for e in events) == 2, events + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +def test_min_np(ray_8_cpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:1']), + (10, ['host-1:1', 'host-4:1', 'host-5:1']), + (None, ['host-1:1', 'host-4:1', 'host-5:1', 'host-6:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutorV2.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutorV2(settings, ElasticParams( + min_np=4, + max_np=4, + override_discovery=False + )) + + training_fn = _create_training_function(iterations=30) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 4 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 4, events + assert sum(int("finished" in e) for e in events) == 4, events + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +def test_gpu_e2e(ray_8_cpus_gpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:1']), + (10, ['host-1:1', 'host-4:1', 'host-5:1']), + (None, ['host-1:1', 'host-4:1', 'host-5:1', 'host-6:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutorV2.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutorV2(settings, + ElasticParams(min_np=4, max_np=4, gpus_per_worker=1, use_gpu=True, override_discovery=False)) + + training_fn = _create_training_function(iterations=30) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 4 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 4, events + assert sum(int("finished" in e) for e in events) == 4, events + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(sys.argv[1:] + ["-v", "-x", __file__])) diff --git a/test/single/test_ray_v2.py b/test/single/test_ray_v2.py new file mode 100644 index 0000000000..ba2c461d49 --- /dev/null +++ b/test/single/test_ray_v2.py @@ -0,0 +1,506 @@ +"""Ray-Horovod Job unit tests. + +This is currently not run on the Ray CI. +""" +import os +import sys + +import socket +import pytest +import ray +from ray.util.client.ray_client_helpers import ray_start_client_server +import torch + +from horovod.common.util import gloo_built +from horovod.ray.runner_v2 import (Coordinator, MiniSettings, NoneElasticParams, RayExecutorV2) +from horovod.ray.worker import BaseHorovodWorker +from horovod.ray.strategy import create_placement_group + +sys.path.append(os.path.dirname(__file__)) + + +@pytest.fixture +def ray_start_2_cpus(): + address_info = ray.init(num_cpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_start_4_cpus(): + address_info = ray.init(num_cpus=4, _redis_max_memory=1024 * 1024 * 1024) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_start_6_cpus(): + address_info = ray.init(num_cpus=6) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_start_4_cpus_4_gpus(): + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + address_info = ray.init(num_cpus=4, num_gpus=4) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + del os.environ["CUDA_VISIBLE_DEVICES"] + + +@pytest.fixture +def ray_start_client(): + def ray_connect_handler(job_config=None): + # Ray client will disconnect from ray when + # num_clients == 0. + if ray.is_initialized(): + return + else: + return ray.init(job_config=job_config, num_cpus=4) + + assert not ray.util.client.ray.is_connected() + with ray_start_client_server(ray_connect_handler=ray_connect_handler): + yield + + +def check_resources(original_resources): + for i in reversed(range(10)): + if original_resources == ray.available_resources(): + return True + else: + print(ray.available_resources()) + import time + time.sleep(0.5) + return False + + +def test_coordinator_registration(): + settings = MiniSettings() + coord = Coordinator(settings) + assert coord.world_size == 0 + assert coord.node_id_string == "" + ranks = list(range(12)) + + for i, hostname in enumerate(["a", "b", "c"]): + for r in ranks: + if r % 3 == i: + coord.register(hostname, node_id=str(i), world_rank=r) + + rank_to_info = coord.finalize_registration() + assert len(rank_to_info) == len(ranks) + assert all( + info["HOROVOD_CROSS_SIZE"] == 3 for info in rank_to_info.values()) + assert {info["HOROVOD_CROSS_RANK"] + for info in rank_to_info.values()} == {0, 1, 2} + assert all( + info["HOROVOD_LOCAL_SIZE"] == 4 for info in rank_to_info.values()) + assert {info["HOROVOD_LOCAL_RANK"] + for info in rank_to_info.values()} == {0, 1, 2, 3} + + +@pytest.mark.parametrize("use_same_host", [True, False]) +def test_cross_rank(use_same_host): + settings = MiniSettings() + coord = Coordinator(settings) + assert coord.world_size == 0 + assert coord.node_id_string == "" + ranks = list(range(12)) + + for r in ranks: + if r < 5: + coord.register("host1", node_id="host1", world_rank=r) + elif r < 9: + coord.register( + "host1" if use_same_host else "host2", + node_id="host2", + world_rank=r) + else: + coord.register( + "host1" if use_same_host else "host3", + node_id="host3", + world_rank=r) + + rank_to_info = coord.finalize_registration() + assert len(rank_to_info) == len(ranks) + # check that there is only 1 rank with cross_size == 1, cross_rank == 0 + cross_size_1 = list(info for info in rank_to_info.values() + if info["HOROVOD_CROSS_SIZE"] == 1) + assert len(cross_size_1) == 1 + assert cross_size_1[0]["HOROVOD_CROSS_RANK"] == 0 + # check that there is only 2 rank with cross_size == 2 + cross_size_2 = list(info for info in rank_to_info.values() + if info["HOROVOD_CROSS_SIZE"] == 2) + assert len(cross_size_2) == 2 + + # check that if cross_size == 2, set(cross_rank) == 0,1 + assert set(d["HOROVOD_CROSS_RANK"] for d in cross_size_2) == {0, 1} + + # check that there is 9 rank with cross_size = 3 + cross_size_3 = list(info for info in rank_to_info.values() + if info["HOROVOD_CROSS_SIZE"] == 3) + assert len(cross_size_3) == 9 + + +# Used for Pytest parametrization. +parameter_str = "num_workers,num_hosts,num_workers_per_host" +ray_executor_parametrized = [(4, None, None), (None, 1, 4)] + + +@pytest.mark.parametrize(parameter_str, ray_executor_parametrized) +def test_infeasible_placement(ray_start_2_cpus, num_workers, num_hosts, + num_workers_per_host): + setting = RayExecutorV2.create_settings( + timeout_s=30, placement_group_timeout_s=5) + np = NoneElasticParams( + np=num_workers, + num_hosts=num_hosts, + np_per_host=num_workers_per_host + ) + hjob = RayExecutorV2(setting, np) + with pytest.raises(TimeoutError): + hjob.start() + hjob.shutdown() + + +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="GPU test requires 4 GPUs") +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="GPU test requires CUDA.") +def test_gpu_ids(ray_start_4_cpus_4_gpus): + original_resources = ray.available_resources() + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + num_hosts=1, + use_gpu=True + ) + hjob = RayExecutorV2( + setting, np) + hjob.start() + all_envs = hjob.execute(lambda _: os.environ.copy()) + all_cudas = {ev["CUDA_VISIBLE_DEVICES"] for ev in all_envs} + assert len(all_cudas) == 1, all_cudas + assert len(all_envs[0]["CUDA_VISIBLE_DEVICES"].split(",")) == 4 + hjob.shutdown() + assert check_resources(original_resources) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="GPU test requires 4 GPUs") +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="GPU test requires CUDA.") +def test_gpu_ids_num_workers(ray_start_4_cpus_4_gpus): + original_resources = ray.available_resources() + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + np=4, + use_gpu=True + ) + hjob = RayExecutorV2(setting, np) + hjob.start() + all_envs = hjob.execute(lambda _: os.environ.copy()) + all_cudas = {ev["CUDA_VISIBLE_DEVICES"] for ev in all_envs} + + assert len(all_cudas) == 1, all_cudas + assert len(all_envs[0]["CUDA_VISIBLE_DEVICES"].split(",")) == 4 + + def _test(worker): + import horovod.torch as hvd + hvd.init() + local_rank = str(hvd.local_rank()) + return local_rank in os.environ["CUDA_VISIBLE_DEVICES"] + + all_valid_local_rank = hjob.execute(_test) + assert all(all_valid_local_rank) + hjob.shutdown() + assert check_resources(original_resources) + + +def test_horovod_mixin(ray_start_2_cpus): + class Test(BaseHorovodWorker): + pass + + assert Test().hostname() == socket.gethostname() + actor = ray.remote(BaseHorovodWorker).remote() + DUMMY_VALUE = 1123123 + actor.update_env_vars.remote({"TEST": DUMMY_VALUE}) + assert ray.get(actor.env_vars.remote())["TEST"] == str(DUMMY_VALUE) + + +@pytest.mark.parametrize(parameter_str, ray_executor_parametrized) +def test_local(ray_start_4_cpus, num_workers, num_hosts, num_workers_per_host): + original_resources = ray.available_resources() + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + np=num_workers, + num_hosts=num_hosts, + np_per_host=num_workers_per_host + ) + hjob = RayExecutorV2(setting,np) + hjob.start() + hostnames = hjob.execute(lambda _: socket.gethostname()) + assert len(set(hostnames)) == 1, hostnames + hjob.shutdown() + assert check_resources(original_resources) + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +@pytest.mark.parametrize(parameter_str, ray_executor_parametrized) +def test_ray_init(ray_start_4_cpus, num_workers, num_hosts, + num_workers_per_host): + original_resources = ray.available_resources() + + def simple_fn(worker): + import horovod.torch as hvd + hvd.init() + return hvd.rank() + + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + np=num_workers, + num_hosts=num_hosts, + np_per_host=num_workers_per_host, + use_gpu=torch.cuda.is_available() + ) + hjob = RayExecutorV2(setting,np) + hjob.start() + result = hjob.execute(simple_fn) + assert len(set(result)) == 4 + hjob.shutdown() + assert check_resources(original_resources) + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +@pytest.mark.parametrize(parameter_str, ray_executor_parametrized) +def test_ray_exec_func(ray_start_4_cpus, num_workers, num_hosts, + num_workers_per_host): + def simple_fn(num_epochs): + import horovod.torch as hvd + hvd.init() + return hvd.rank() * num_epochs + + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + np=num_workers, + num_hosts=num_hosts, + np_per_host=num_workers_per_host, + use_gpu=torch.cuda.is_available() + ) + hjob = RayExecutorV2(setting, np) + hjob.start() + result = hjob.run(simple_fn, args=[0]) + assert len(set(result)) == 1 + hjob.shutdown() + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +@pytest.mark.parametrize(parameter_str, ray_executor_parametrized) +def test_ray_exec_remote_func(ray_start_4_cpus, num_workers, num_hosts, + num_workers_per_host): + def simple_fn(num_epochs): + import horovod.torch as hvd + hvd.init() + return hvd.rank() * num_epochs + + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + np=num_workers, + num_hosts=num_hosts, + np_per_host=num_workers_per_host, + use_gpu=torch.cuda.is_available() + ) + hjob = RayExecutorV2(setting, np) + hjob.start() + object_refs = hjob.run_remote(simple_fn, args=[0]) + result = ray.get(object_refs) + assert len(set(result)) == 1 + hjob.shutdown() + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +@pytest.mark.parametrize(parameter_str, ray_executor_parametrized) +def test_ray_executable(ray_start_4_cpus, num_workers, num_hosts, + num_workers_per_host): + class Executable: + def __init__(self, epochs): + import horovod.torch as hvd + self.hvd = hvd + self.epochs = epochs + self.hvd.init() + + def rank_epoch(self): + return self.hvd.rank() * self.epochs + + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + np=num_workers, + num_hosts=num_hosts, + np_per_host=num_workers_per_host, + use_gpu=torch.cuda.is_available() + ) + hjob = RayExecutorV2(setting, np) + hjob.start(executable_cls=Executable, executable_args=[2]) + result = hjob.execute(lambda w: w.rank_epoch()) + assert set(result) == {0, 2, 4, 6} + hjob.shutdown() + + +# @pytest.mark.skipif( +# not gloo_built(), reason='Gloo is required for Ray integration') +# def test_ray_deprecation(ray_start_4_cpus): +# class Executable: +# def __init__(self, epochs): +# import horovod.torch as hvd +# self.hvd = hvd +# self.epochs = epochs +# self.hvd.init() + +# def rank_epoch(self): +# return self.hvd.rank() * self.epochs + +# setting = RayExecutorV2.create_settings(timeout_s=30) +# np = NoneElasticParams( +# num_hosts=1, +# np=2, +# np_per_host=2, +# use_gpu=torch.cuda.is_available() +# ) +# hjob = RayExecutorV2( +# setting, +# num_hosts=1, +# num_slots=2, +# cpus_per_slot=2, +# use_gpu=torch.cuda.is_available()) +# hjob.start(executable_cls=Executable, executable_args=[2]) +# result = hjob.execute(lambda w: w.rank_epoch()) +# assert set(result) == {0, 2} +# hjob.shutdown() + + +def _train(batch_size=32, batch_per_iter=10): + import torch.nn.functional as F + import torch.optim as optim + import torch.utils.data.distributed + import horovod.torch as hvd + import timeit + + hvd.init() + + # Set up fixed fake data + data = torch.randn(batch_size, 2) + target = torch.LongTensor(batch_size).random_() % 2 + + model = torch.nn.Sequential(torch.nn.Linear(2, 2)) + optimizer = optim.SGD(model.parameters(), lr=0.01) + + # Horovod: wrap optimizer with DistributedOptimizer. + optimizer = hvd.DistributedOptimizer( + optimizer, named_parameters=model.named_parameters()) + + # Horovod: broadcast parameters & optimizer state. + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + def benchmark_step(): + optimizer.zero_grad() + output = model(data) + loss = F.cross_entropy(output, target) + loss.backward() + optimizer.step() + + timeit.timeit(benchmark_step, number=batch_per_iter) + return hvd.local_rank() + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +@pytest.mark.parametrize(parameter_str, ray_executor_parametrized) +def test_horovod_train(ray_start_4_cpus, num_workers, num_hosts, + num_workers_per_host): + def simple_fn(worker): + local_rank = _train() + return local_rank + + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + np=num_workers, + num_hosts=num_hosts, + np_per_host=num_workers_per_host, + use_gpu=torch.cuda.is_available() + ) + hjob = RayExecutorV2(setting, np) + hjob.start() + result = hjob.execute(simple_fn) + assert set(result) == {0, 1, 2, 3} + hjob.shutdown() + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +def test_horovod_train_in_pg(ray_start_4_cpus): + pg, _ = create_placement_group( + {"CPU": 1, "GPU": int(torch.cuda.is_available())}, 4, 30, "PACK") + + @ray.remote + class _Actor(): + def run(self): + def simple_fn(worker): + local_rank = _train() + return local_rank + + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + np=4, + num_hosts=None, + np_per_host=None, + cpus_per_worker=1, + use_gpu=torch.cuda.is_available(), + gpus_per_worker=int(torch.cuda.is_available()) or None + ) + hjob = RayExecutorV2(setting, np) + hjob.start() + assert not hjob.adapter.strategy._created_placement_group + result = hjob.execute(simple_fn) + assert set(result) == {0, 1, 2, 3} + hjob.shutdown() + actor = _Actor.options( + num_cpus=0, num_gpus=0, placement_group_capture_child_tasks=True, placement_group=pg).remote() + ray.get(actor.run.remote()) + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +def test_remote_client_train(ray_start_client): + def simple_fn(worker): + local_rank = _train() + return local_rank + + assert ray.util.client.ray.is_connected() + + setting = RayExecutorV2.create_settings(timeout_s=30) + np = NoneElasticParams( + np=3, + use_gpu=torch.cuda.is_available() + ) + hjob = RayExecutorV2(setting, np) + hjob.start() + result = hjob.execute(simple_fn) + assert set(result) == {0, 1, 2} + result = ray.get(hjob.run_remote(simple_fn, args=[None])) + assert set(result) == {0, 1, 2} + hjob.shutdown() + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))