diff --git a/hamilton/plugins/h_ray.py b/hamilton/plugins/h_ray.py index c78b7918a..51c6bdf28 100644 --- a/hamilton/plugins/h_ray.py +++ b/hamilton/plugins/h_ray.py @@ -209,19 +209,33 @@ class RayTaskExecutor(executors.TaskExecutor): This is still experimental, so the API might change. """ - def __init__(self, num_cpus: int): + def __init__( + self, + num_cpus: int = None, + ray_init_config: typing.Dict[str, typing.Any] = None, + skip_init: bool = False, + ): """Creates a ray task executor. Note this will likely take in more parameters. This is experimental, so the API will likely change, although we will do our best to make it backwards compatible. - :param num_cpus: Number of cores to use for initialization, passed drirectly to ray.init + :param num_cpus: Number of cores to use for initialization, passed directly to ray.init. Defaults to all cores. + :param ray_init_config: General configuration to pass to ray.init. Defaults to None. + :param skip_init: Skips ray init if you already have Ray initialized. Default is False. """ self.num_cpus = num_cpus + self.ray_init_config = ray_init_config if ray_init_config else {} + self.skip_init = skip_init def init(self): - ray.init(num_cpus=self.num_cpus) + if self.skip_init: + return + ray.init(num_cpus=self.num_cpus, **self.ray_init_config) def finalize(self): + if self.skip_init: + # we assume that if we didn't init it, we don't need to shutdown either. + return ray.shutdown() def submit_task(self, task: TaskImplementation) -> TaskFuture: