Skip to content

Commit

Permalink
FEAT-modin-project#6990: Implement lazy execution for the Ray virtual…
Browse files Browse the repository at this point in the history
… partitions.
  • Loading branch information
AndreyPavlenko committed Mar 21, 2024
1 parent 95e704f commit 5978769
Show file tree
Hide file tree
Showing 10 changed files with 667 additions and 350 deletions.
188 changes: 131 additions & 57 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
from ray.util.client.common import ClientObjectRef

from modin.core.execution.ray.common import MaterializationHook, RayWrapper
from modin.core.execution.utils import remote_function
from modin.logging import get_logger
from modin.utils import _inherit_docstrings

ObjectRefType = Union[ray.ObjectRef, ClientObjectRef, None]
ObjectRefType = Union[ray.ObjectRef, ClientObjectRef]
ObjectRefOrListType = Union[ObjectRefType, List[ObjectRefType]]
ListOrTuple = (list, tuple)

Expand Down Expand Up @@ -68,16 +70,18 @@ class DeferredExecution:
Attributes
----------
data : ObjectRefType or DeferredExecution
data : object
The execution input.
func : callable or ObjectRefType
A function to be executed.
args : list or tuple
args : list or tuple, optional
Additional positional arguments to be passed in `func`.
kwargs : dict
kwargs : dict, optional
Additional keyword arguments to be passed in `func`.
num_returns : int
num_returns : int, default: 1
The number of the return values.
flat_data : bool
True means that the data is neither DeferredExecution nor list.
flat_args : bool
True means that there are no lists or DeferredExecution objects in `args`.
In this case, no arguments processing is performed and `args` is passed
Expand All @@ -88,26 +92,29 @@ class DeferredExecution:

def __init__(
self,
data: Union[
ObjectRefType,
"DeferredExecution",
List[Union[ObjectRefType, "DeferredExecution"]],
],
data: Any,
func: Union[Callable, ObjectRefType],
args: Union[List[Any], Tuple[Any]],
kwargs: Dict[str, Any],
args: Union[List[Any], Tuple[Any]] = None,
kwargs: Dict[str, Any] = None,
num_returns=1,
):
if isinstance(data, DeferredExecution):
data.subscribe()
self.flat_data = self._flat_args((data,))
self.data = data
self.func = func
self.args = args
self.kwargs = kwargs
self.num_returns = num_returns
self.flat_args = self._flat_args(args)
self.flat_kwargs = self._flat_args(kwargs.values())
self.subscribers = 0
if args is not None:
self.args = args
self.flat_args = self._flat_args(args)
else:
self.args = ()
self.flat_args = True
if kwargs is not None:
self.kwargs = kwargs
self.flat_kwargs = self._flat_args(kwargs.values())
else:
self.kwargs = {}
self.flat_kwargs = True

@classmethod
def _flat_args(cls, args: Iterable):
Expand All @@ -134,7 +141,7 @@ def _flat_args(cls, args: Iterable):

def exec(
self,
) -> Tuple[ObjectRefOrListType, Union["MetaList", List], Union[int, List[int]]]:
) -> Tuple[ObjectRefOrListType, "MetaList", Union[int, List[int]]]:
"""
Execute this task, if required.
Expand All @@ -150,7 +157,7 @@ def exec(
return self.data, self.meta, self.meta_offset

if (
not isinstance(self.data, DeferredExecution)
self.flat_data
and self.flat_args
and self.flat_kwargs
and self.num_returns == 1
Expand All @@ -166,19 +173,21 @@ def exec(
# it back. After the execution, the result is saved and the counter has no effect.
self.subscribers += 2
consumers, output = self._deconstruct()

# The last result is the MetaList, so adding +1 here.
num_returns = sum(c.num_returns for c in consumers) + 1
results = self._remote_exec_chain(num_returns, *output)
meta = MetaList(results.pop())
meta_offset = 0
results = iter(results)
for de in consumers:
if de.num_returns == 1:
num_returns = de.num_returns
if num_returns == 1:
de._set_result(next(results), meta, meta_offset)
meta_offset += 2
else:
res = list(islice(results, num_returns))
offsets = list(range(0, 2 * num_returns, 2))
offsets = list(range(meta_offset, meta_offset + 2 * num_returns, 2))
de._set_result(res, meta, offsets)
meta_offset += 2 * num_returns
return self.data, self.meta, self.meta_offset
Expand Down Expand Up @@ -303,7 +312,9 @@ def _deconstruct_chain(
out_extend = output.extend
while True:
de.unsubscribe()
if (out_pos := getattr(de, "out_pos", None)) and not de.has_result:
if not (has_result := de.has_result) and (
out_pos := getattr(de, "out_pos", None)
):
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
Expand All @@ -318,12 +329,13 @@ def _deconstruct_chain(
break
elif not isinstance(data := de.data, DeferredExecution):
if isinstance(data, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(
data, output, stack, result_consumers, out_append
)
else:
out_append(data)
if not de.has_result:
if not has_result:
stack.append(de)
break
else:
Expand Down Expand Up @@ -391,22 +403,24 @@ def _deconstruct_list(
"""
for obj in lst:
if isinstance(obj, DeferredExecution):
if out_pos := getattr(obj, "out_pos", None):
if obj.has_result:
obj = obj.data
elif out_pos := getattr(obj, "out_pos", None):
obj.unsubscribe()
if obj.has_result:
out_append(obj.data)
else:
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
if obj.subscribers == 0:
output[out_pos + 1] = 0
result_consumers.remove(obj)
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
if obj.subscribers == 0:
output[out_pos + 1] = 0
result_consumers.remove(obj)
continue
else:
out_append(_Tag.CHAIN)
yield cls._deconstruct_chain(obj, output, stack, result_consumers)
out_append(_Tag.END)
elif isinstance(obj, ListOrTuple):
continue

if isinstance(obj, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(
obj, output, stack, result_consumers, out_append
Expand All @@ -432,13 +446,13 @@ def _remote_exec_chain(num_returns: int, *args: Tuple) -> List[Any]:
list
The execution results. The last element of this list is the ``MetaList``.
"""
# Prefer _remote_exec_single_chain(). It has fewer arguments and
# does not require the num_returns to be specified in options.
# Prefer _remote_exec_single_chain(). It does not require the num_returns
# to be specified in options.
if num_returns == 2:
return _remote_exec_single_chain.remote(*args)
else:
return _remote_exec_multi_chain.options(num_returns=num_returns).remote(
num_returns, *args
*args
)

def _set_result(
Expand All @@ -456,7 +470,7 @@ def _set_result(
meta : MetaList
meta_offset : int or list of int
"""
del self.func, self.args, self.kwargs, self.flat_args, self.flat_kwargs
del self.func, self.args, self.kwargs
self.data = result
self.meta = meta
self.meta_offset = meta_offset
Expand All @@ -466,6 +480,64 @@ def __reduce__(self):
raise NotImplementedError("DeferredExecution is not serializable!")


ObjectRefOrDeType = Union[ObjectRefType, DeferredExecution]


class DeferredGetItem(DeferredExecution):
"""
Deferred execution task that returns an item at the specified index.
Parameters
----------
data : ObjectRefOrDeType
The object to get the item from.
index : int
The item index.
"""

def __init__(self, data: ObjectRefOrDeType, index: int):
super().__init__(data, self._remote_fn, [index])
self.index = index

@property
@_inherit_docstrings(DeferredExecution.has_result)
def has_result(self):
if super().has_result:
return True

if (
isinstance(self.data, DeferredExecution)
and self.data.has_result
and self.data.num_returns != 1
):
# If `data` is a `DeferredExecution`, that returns multiple results, we
# don't need to execute `_remote_fn`, but can get the result by index instead.
self._set_result(
self.data.data[self.index],
self.data.meta,
self.data.meta_offset[self.index],
)
return True

return False

@remote_function
def _remote_fn(obj, index): # pragma: no cover
"""
Return the item by index.
Parameters
----------
obj : collection
index : int
Returns
-------
object
"""
return obj[index]


class MetaList:
"""
Meta information, containing the result lengths and the worker address.
Expand All @@ -478,6 +550,11 @@ class MetaList:
def __init__(self, obj: Union[ray.ObjectID, ClientObjectRef, List]):
self._obj = obj

def materialize(self):
"""Materialized the list, if required."""
if not isinstance(self._obj, list):
self._obj = RayWrapper.materialize(self._obj)

def __getitem__(self, index):
"""
Get item at the specified index.
Expand Down Expand Up @@ -508,21 +585,21 @@ def __setitem__(self, index, value):
obj[index] = value


class MetaListHook(MaterializationHook):
class MetaListHook(MaterializationHook, DeferredGetItem):
"""
Used by MetaList.__getitem__() for lazy materialization and getting a single value from the list.
Parameters
----------
meta : MetaList
Non-materialized list to get the value from.
idx : int
index : int
The value index in the list.
"""

def __init__(self, meta: MetaList, idx: int):
def __init__(self, meta: MetaList, index: int):
super().__init__(meta._obj, index)
self.meta = meta
self.idx = idx

def pre_materialize(self):
"""
Expand All @@ -533,7 +610,7 @@ def pre_materialize(self):
object
"""
obj = self.meta._obj
return obj[self.idx] if isinstance(obj, list) else obj
return obj[self.index] if isinstance(obj, list) else obj

def post_materialize(self, materialized):
"""
Expand All @@ -548,7 +625,7 @@ def post_materialize(self, materialized):
object
"""
self.meta._obj = materialized
return materialized[self.idx]
return materialized[self.index]


class _Tag(Enum): # noqa: PR01
Expand Down Expand Up @@ -605,7 +682,7 @@ def exec_func(fn: Callable, obj: Any, args: Tuple, kwargs: Dict) -> Any:
raise err

@classmethod
def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
def construct(cls, args: Tuple): # pragma: no cover
"""
Construct and execute the specified chain.
Expand All @@ -615,7 +692,6 @@ def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
Parameters
----------
num_returns : int
args : tuple
Yields
Expand Down Expand Up @@ -687,7 +763,7 @@ def construct_chain(

while chain:
fn = pop()
if fn == tg_e:
if fn is tg_e:
lst.append(obj)
break

Expand Down Expand Up @@ -717,10 +793,10 @@ def construct_chain(

itr = iter([obj] if num_returns == 1 else obj)
for _ in range(num_returns):
obj = next(itr)
meta.append(len(obj) if hasattr(obj, "__len__") else 0)
meta.append(len(obj.columns) if hasattr(obj, "columns") else 0)
yield obj
o = next(itr)
meta.append(len(o) if hasattr(o, "__len__") else 0)
meta.append(len(o.columns) if hasattr(o, "columns") else 0)
yield o

@classmethod
def construct_list(
Expand Down Expand Up @@ -834,20 +910,18 @@ def _remote_exec_single_chain(
-------
Generator
"""
return remote_executor.construct(num_returns=2, args=args)
return remote_executor.construct(args=args)


@ray.remote
def _remote_exec_multi_chain(
num_returns: int, *args: Tuple, remote_executor=_REMOTE_EXEC
*args: Tuple, remote_executor=_REMOTE_EXEC
) -> Generator: # pragma: no cover
"""
Execute the deconstructed chain with a multiple return values in a worker process.
Parameters
----------
num_returns : int
The number of return values.
*args : tuple
A deconstructed chain to be executed.
remote_executor : _RemoteExecutor, default: _REMOTE_EXEC
Expand All @@ -857,4 +931,4 @@ def _remote_exec_multi_chain(
-------
Generator
"""
return remote_executor.construct(num_returns, args)
return remote_executor.construct(args)
Loading

0 comments on commit 5978769

Please sign in to comment.