Skip to content

Commit

Permalink
FEAT-modin-project#6990: Implement lazy execution for the Ray virtual…
Browse files Browse the repository at this point in the history
… partitions.
  • Loading branch information
AndreyPavlenko committed Mar 15, 2024
1 parent c753436 commit b81a433
Show file tree
Hide file tree
Showing 7 changed files with 511 additions and 316 deletions.
165 changes: 128 additions & 37 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from modin.core.execution.ray.common import MaterializationHook, RayWrapper
from modin.logging import get_logger
from modin.utils import _inherit_docstrings

ObjectRefType = Union[ray.ObjectRef, ClientObjectRef, None]
ObjectRefOrListType = Union[ObjectRefType, List[ObjectRefType]]
Expand Down Expand Up @@ -72,12 +73,14 @@ class DeferredExecution:
The execution input.
func : callable or ObjectRefType
A function to be executed.
args : list or tuple
args : list or tuple, optional
Additional positional arguments to be passed in `func`.
kwargs : dict
kwargs : dict, optional
Additional keyword arguments to be passed in `func`.
num_returns : int
num_returns : int, default: 1
The number of the return values.
flat_data : bool
True means that the data is neither DeferredExecution nor list.
flat_args : bool
True means that there are no lists or DeferredExecution objects in `args`.
In this case, no arguments processing is performed and `args` is passed
Expand All @@ -88,26 +91,29 @@ class DeferredExecution:

def __init__(
self,
data: Union[
ObjectRefType,
"DeferredExecution",
List[Union[ObjectRefType, "DeferredExecution"]],
],
data: Any,
func: Union[Callable, ObjectRefType],
args: Union[List[Any], Tuple[Any]],
kwargs: Dict[str, Any],
args: Union[List[Any], Tuple[Any]] = None,
kwargs: Dict[str, Any] = None,
num_returns=1,
):
if isinstance(data, DeferredExecution):
data.subscribe()
self.flat_data = self._flat_args((data,))
self.data = data
self.func = func
self.args = args
self.kwargs = kwargs
self.num_returns = num_returns
self.flat_args = self._flat_args(args)
self.flat_kwargs = self._flat_args(kwargs.values())
self.subscribers = 0
if args is not None:
self.args = args
self.flat_args = self._flat_args(args)
else:
self.args = ()
self.flat_args = True
if kwargs is not None:
self.kwargs = kwargs
self.flat_kwargs = self._flat_args(kwargs.values())
else:
self.kwargs = {}
self.flat_kwargs = True

@classmethod
def _flat_args(cls, args: Iterable):
Expand All @@ -134,7 +140,7 @@ def _flat_args(cls, args: Iterable):

def exec(
self,
) -> Tuple[ObjectRefOrListType, Union["MetaList", List], Union[int, List[int]]]:
) -> Tuple[ObjectRefOrListType, "MetaList", Union[int, List[int]]]:
"""
Execute this task, if required.
Expand All @@ -150,11 +156,29 @@ def exec(
return self.data, self.meta, self.meta_offset

if (
not isinstance(self.data, DeferredExecution)
self.flat_data
and self.flat_args
and self.flat_kwargs
and self.num_returns == 1
):
# self.data = RayWrapper.materialize(self.data)
# self.args = [
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for o in self.args
# ]
# self.kwargs = {
# k: RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for k, o in self.kwargs.items()
# }
# obj = _REMOTE_EXEC.exec_func(
# RayWrapper.materialize(self.func), self.data, self.args, self.kwargs
# )
# result, length, width, ip = (
# obj,
# len(obj) if hasattr(obj, "__len__") else 0,
# len(obj.columns) if hasattr(obj, "columns") else 0,
# "",
# )
result, length, width, ip = remote_exec_func.remote(
self.func, self.data, *self.args, **self.kwargs
)
Expand All @@ -166,14 +190,23 @@ def exec(
# it back. After the execution, the result is saved and the counter has no effect.
self.subscribers += 2
consumers, output = self._deconstruct()

# assert not any(isinstance(o, ListOrTuple) for o in output)
# tmp = [
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for o in output
# ]
# list(_REMOTE_EXEC.construct(tmp))

# The last result is the MetaList, so adding +1 here.
num_returns = sum(c.num_returns for c in consumers) + 1
results = self._remote_exec_chain(num_returns, *output)
meta = MetaList(results.pop())
meta_offset = 0
results = iter(results)
for de in consumers:
if de.num_returns == 1:
num_returns = de.num_returns
if num_returns == 1:
de._set_result(next(results), meta, meta_offset)
meta_offset += 2
else:
Expand Down Expand Up @@ -318,6 +351,7 @@ def _deconstruct_chain(
break
elif not isinstance(data := de.data, DeferredExecution):
if isinstance(data, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(
data, output, stack, result_consumers, out_append
)
Expand Down Expand Up @@ -394,7 +428,13 @@ def _deconstruct_list(
if out_pos := getattr(obj, "out_pos", None):
obj.unsubscribe()
if obj.has_result:
out_append(obj.data)
if isinstance(obj.data, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(
obj.data, output, stack, result_consumers, out_append
)
else:
out_append(obj.data)
else:
out_append(_Tag.REF)
out_append(out_pos)
Expand Down Expand Up @@ -432,13 +472,13 @@ def _remote_exec_chain(num_returns: int, *args: Tuple) -> List[Any]:
list
The execution results. The last element of this list is the ``MetaList``.
"""
# Prefer _remote_exec_single_chain(). It has fewer arguments and
# does not require the num_returns to be specified in options.
# Prefer _remote_exec_single_chain(). It does not require the num_returns
# to be specified in options.
if num_returns == 2:
return _remote_exec_single_chain.remote(*args)
else:
return _remote_exec_multi_chain.options(num_returns=num_returns).remote(
num_returns, *args
*args
)

def _set_result(
Expand All @@ -456,7 +496,7 @@ def _set_result(
meta : MetaList
meta_offset : int or list of int
"""
del self.func, self.args, self.kwargs, self.flat_args, self.flat_kwargs
del self.func, self.args, self.kwargs
self.data = result
self.meta = meta
self.meta_offset = meta_offset
Expand All @@ -466,6 +506,55 @@ def __reduce__(self):
raise NotImplementedError("DeferredExecution is not serializable!")


class DeferredGetItem(DeferredExecution):
"""
Deferred execution task that returns an item at the specified index.
Parameters
----------
data : ObjectRefType or DeferredExecution
The object to get the item from.
idx : int
The item index.
"""

def __init__(self, data: Union[ObjectRefType, DeferredExecution], idx: int):
super().__init__(data, self._remote_fn(), [idx])

@_inherit_docstrings(DeferredExecution.exec)
def exec(self) -> Tuple[ObjectRefType, "MetaList", int]:
if (
self.has_result
or not isinstance(self.data, DeferredExecution)
or self.num_returns == 1
):
return super().exec()

# If `data` is a `DeferredExecution`, that returns multiple results,
# it's not required to execute `_remote_fn()`. We can only execute
# `data` and get the result by index.
obj, meta, offsets = self.data.exec()
self._set_result(obj[self.args[0]], meta, offsets[self.args[0]])
return self.data, self.meta, self.meta_offset

@classmethod
def _remote_fn(cls) -> ObjectRefType:
"""
Return the remote function reference.
Returns
-------
ObjectRefType
"""
if (fn := getattr(cls, "_GET_ITEM", None)) is None:

def get_item(obj, index):
return obj[index]

cls._GET_ITEM = fn = RayWrapper.put(get_item)
return fn


class MetaList:
"""
Meta information, containing the result lengths and the worker address.
Expand All @@ -478,6 +567,10 @@ class MetaList:
def __init__(self, obj: Union[ray.ObjectID, ClientObjectRef, List]):
self._obj = obj

def materialize(self):
"""Materialized the list, if required."""
self._obj = RayWrapper.materialize(self._obj)

def __getitem__(self, index):
"""
Get item at the specified index.
Expand Down Expand Up @@ -508,7 +601,7 @@ def __setitem__(self, index, value):
obj[index] = value


class MetaListHook(MaterializationHook):
class MetaListHook(MaterializationHook, DeferredGetItem):
"""
Used by MetaList.__getitem__() for lazy materialization and getting a single value from the list.
Expand All @@ -521,6 +614,7 @@ class MetaListHook(MaterializationHook):
"""

def __init__(self, meta: MetaList, idx: int):
super().__init__(meta._obj, idx)
self.meta = meta
self.idx = idx

Expand Down Expand Up @@ -605,7 +699,7 @@ def exec_func(fn: Callable, obj: Any, args: Tuple, kwargs: Dict) -> Any:
raise err

@classmethod
def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
def construct(cls, args: Tuple): # pragma: no cover
"""
Construct and execute the specified chain.
Expand All @@ -615,7 +709,6 @@ def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
Parameters
----------
num_returns : int
args : tuple
Yields
Expand Down Expand Up @@ -687,7 +780,7 @@ def construct_chain(

while chain:
fn = pop()
if fn == tg_e:
if fn is tg_e:
lst.append(obj)
break

Expand Down Expand Up @@ -717,10 +810,10 @@ def construct_chain(

itr = iter([obj] if num_returns == 1 else obj)
for _ in range(num_returns):
obj = next(itr)
meta.append(len(obj) if hasattr(obj, "__len__") else 0)
meta.append(len(obj.columns) if hasattr(obj, "columns") else 0)
yield obj
o = next(itr)
meta.append(len(o) if hasattr(o, "__len__") else 0)
meta.append(len(o.columns) if hasattr(o, "columns") else 0)
yield o

@classmethod
def construct_list(
Expand Down Expand Up @@ -834,20 +927,18 @@ def _remote_exec_single_chain(
-------
Generator
"""
return remote_executor.construct(num_returns=2, args=args)
return remote_executor.construct(args=args)


@ray.remote
def _remote_exec_multi_chain(
num_returns: int, *args: Tuple, remote_executor=_REMOTE_EXEC
*args: Tuple, remote_executor=_REMOTE_EXEC
) -> Generator: # pragma: no cover
"""
Execute the deconstructed chain with a multiple return values in a worker process.
Parameters
----------
num_returns : int
The number of return values.
*args : tuple
A deconstructed chain to be executed.
remote_executor : _RemoteExecutor, default: _REMOTE_EXEC
Expand All @@ -857,4 +948,4 @@ def _remote_exec_multi_chain(
-------
Generator
"""
return remote_executor.construct(num_returns, args)
return remote_executor.construct(args)
4 changes: 2 additions & 2 deletions modin/core/execution/ray/common/engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import asyncio
import os
from types import FunctionType
from typing import Sequence
from typing import Iterable, Sequence

import ray
from ray.util.client.common import ClientObjectRef
Expand Down Expand Up @@ -214,7 +214,7 @@ def wait(cls, obj_ids, num_returns=None):
num_returns : int, optional
"""
if not isinstance(obj_ids, Sequence):
obj_ids = list(obj_ids)
obj_ids = list(obj_ids) if isinstance(obj_ids, Iterable) else [obj_ids]

ids = set()
for obj in obj_ids:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def add_to_apply_calls(
def drain_call_queue(self):
data = self._data_ref
if not isinstance(data, DeferredExecution):
return data
return

log = get_logger()
self._is_debug(log) and log.debug(
Expand Down Expand Up @@ -419,7 +419,7 @@ def eager_exec(self, func, *args, length=None, width=None, **kwargs):
LazyExecution.subscribe(_configure_lazy_exec)


class SlicerHook(MaterializationHook):
class SlicerHook(MaterializationHook, DeferredExecution):
"""
Used by mask() for the slilced length computation.
Expand All @@ -432,6 +432,7 @@ class SlicerHook(MaterializationHook):
"""

def __init__(self, ref: ObjectIDType, slc: slice):
super().__init__(slc, compute_sliced_len, [ref])
self.ref = ref
self.slc = slc

Expand Down
Loading

0 comments on commit b81a433

Please sign in to comment.