Skip to content

Commit

Permalink
FEAT-modin-project#7001: Do not force materialization in MetaList.__g…
Browse files Browse the repository at this point in the history
…etitem__() (modin-project#7006)

Co-authored-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
Signed-off-by: Andrey Pavlenko <andrey.a.pavlenko@gmail.com>
  • Loading branch information
2 people authored and anmyachev committed Mar 6, 2024
1 parent 999ce44 commit 47a886c
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 30 deletions.
3 changes: 2 additions & 1 deletion modin/core/execution/ray/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

"""Common utilities for Ray execution engine."""

from .engine_wrapper import RayWrapper, SignalActor
from .engine_wrapper import MaterializationHook, RayWrapper, SignalActor
from .utils import initialize_ray

__all__ = [
"initialize_ray",
"RayWrapper",
"MaterializationHook",
"SignalActor",
]
49 changes: 45 additions & 4 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ray._private.services import get_node_ip_address
from ray.util.client.common import ClientObjectRef

from modin.core.execution.ray.common import RayWrapper
from modin.core.execution.ray.common import MaterializationHook, RayWrapper
from modin.logging import get_logger

ObjectRefType = Union[ray.ObjectRef, ClientObjectRef, None]
Expand Down Expand Up @@ -491,9 +491,7 @@ def __getitem__(self, index):
Any
"""
obj = self._obj
if not isinstance(obj, list):
self._obj = obj = RayWrapper.materialize(obj)
return obj[index]
return obj[index] if isinstance(obj, list) else MetaListHook(self, index)

def __setitem__(self, index, value):
"""
Expand All @@ -510,6 +508,49 @@ def __setitem__(self, index, value):
obj[index] = value


class MetaListHook(MaterializationHook):
"""
Used by MetaList.__getitem__() for lazy materialization and getting a single value from the list.
Parameters
----------
meta : MetaList
Non-materialized list to get the value from.
idx : int
The value index in the list.
"""

def __init__(self, meta: MetaList, idx: int):
self.meta = meta
self.idx = idx

def pre_materialize(self):
"""
Get item at self.idx or object ref if not materialized.
Returns
-------
object
"""
obj = self.meta._obj
return obj[self.idx] if isinstance(obj, list) else obj

def post_materialize(self, materialized):
"""
Save the materialized list in self.meta and get the item at self.idx.
Parameters
----------
materialized : list
Returns
-------
object
"""
self.meta._obj = materialized
return materialized[self.idx]


class _Tag(Enum): # noqa: PR01
"""
A set of special values used for the method arguments de/construction.
Expand Down
110 changes: 101 additions & 9 deletions modin/core/execution/ray/common/engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import asyncio
import os
from types import FunctionType
from typing import Sequence

import ray
from ray.util.client.common import ClientObjectRef
Expand Down Expand Up @@ -96,8 +97,7 @@ def is_future(cls, item):
boolean
If the value is a future.
"""
ObjectIDType = (ray.ObjectRef, ClientObjectRef)
return isinstance(item, ObjectIDType)
return isinstance(item, ObjectRefTypes)

@classmethod
def materialize(cls, obj_id):
Expand All @@ -114,7 +114,59 @@ def materialize(cls, obj_id):
object
Whatever was identified by `obj_id`.
"""
return ray.get(obj_id)
if isinstance(obj_id, MaterializationHook):
obj = obj_id.pre_materialize()
return (
obj_id.post_materialize(ray.get(obj))
if isinstance(obj, RayObjectRefTypes)
else obj
)

if not isinstance(obj_id, Sequence):
return ray.get(obj_id) if isinstance(obj_id, RayObjectRefTypes) else obj_id

if all(isinstance(obj, RayObjectRefTypes) for obj in obj_id):
return ray.get(obj_id)

ids = {}
result = []
for obj in obj_id:
if not isinstance(obj, ObjectRefTypes):
result.append(obj)
continue
if isinstance(obj, MaterializationHook):
oid = obj.pre_materialize()
if isinstance(oid, RayObjectRefTypes):
hook = obj
obj = oid
else:
result.append(oid)
continue
else:
hook = None

idx = ids.get(obj, None)
if idx is None:
ids[obj] = idx = len(ids)
if hook is None:
result.append(obj)
else:
hook._materialized_idx = idx
result.append(hook)

if len(ids) == 0:
return result

materialized = ray.get(list(ids.keys()))
for i in range(len(result)):
if isinstance((obj := result[i]), ObjectRefTypes):
if isinstance(obj, MaterializationHook):
result[i] = obj.post_materialize(
materialized[obj._materialized_idx]
)
else:
result[i] = materialized[ids[obj]]
return result

@classmethod
def put(cls, data, **kwargs):
Expand Down Expand Up @@ -161,12 +213,18 @@ def wait(cls, obj_ids, num_returns=None):
obj_ids : list, scalar
num_returns : int, optional
"""
if not isinstance(obj_ids, list):
obj_ids = [obj_ids]
unique_ids = list(set(obj_ids))
if num_returns is None:
num_returns = len(unique_ids)
ray.wait(unique_ids, num_returns=num_returns)
if not isinstance(obj_ids, Sequence):
obj_ids = list(obj_ids)

ids = set()
for obj in obj_ids:
if isinstance(obj, MaterializationHook):
obj = obj.pre_materialize()
if isinstance(obj, RayObjectRefTypes):
ids.add(obj)

if num_ids := len(ids):
ray.wait(list(ids), num_returns=num_returns or num_ids)


@ray.remote
Expand Down Expand Up @@ -218,3 +276,37 @@ def is_set(self, event_idx: int) -> bool:
bool
"""
return self.events[event_idx].is_set()


class MaterializationHook:
"""The Hook is called during the materialization and allows performing pre/post computations."""

def pre_materialize(self):
"""
Get an object reference to be materialized or a pre-computed value.
Returns
-------
ray.ObjectRef or object
"""
raise NotImplementedError()

def post_materialize(self, materialized):
"""
Perform computations on the materialized object.
Parameters
----------
materialized : object
The materialized object to be post-computed.
Returns
-------
object
The post-computed object.
"""
raise NotImplementedError()


RayObjectRefTypes = (ray.ObjectRef, ClientObjectRef)
ObjectRefTypes = (*RayObjectRefTypes, MaterializationHook)
5 changes: 2 additions & 3 deletions modin/core/execution/ray/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import psutil
import ray
from packaging import version
from ray.util.client.common import ClientObjectRef

from modin.config import (
CIAWSAccessKeyID,
Expand All @@ -40,7 +39,7 @@
from modin.core.execution.utils import set_env
from modin.error_message import ErrorMessage

from .engine_wrapper import RayWrapper
from .engine_wrapper import ObjectRefTypes, RayWrapper

_OBJECT_STORE_TO_SYSTEM_MEMORY_RATIO = 0.6
# This constant should be in sync with the limit in ray, which is private,
Expand All @@ -50,7 +49,7 @@

_RAY_IGNORE_UNHANDLED_ERRORS_VAR = "RAY_IGNORE_UNHANDLED_ERRORS"

ObjectIDType = (ray.ObjectRef, ClientObjectRef)
ObjectIDType = ObjectRefTypes


def initialize_ray(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@

from modin.config import LazyExecution
from modin.core.dataframe.pandas.partitioning.partition import PandasDataframePartition
from modin.core.execution.ray.common import RayWrapper
from modin.core.execution.ray.common import MaterializationHook, RayWrapper
from modin.core.execution.ray.common.deferred_execution import (
DeferredExecution,
MetaList,
MetaListHook,
)
from modin.core.execution.ray.common.utils import ObjectIDType
from modin.logging import get_logger
from modin.pandas.indexing import compute_sliced_len
from modin.utils import _inherit_docstrings

compute_sliced_len = ray.remote(compute_sliced_len)


class PandasOnRayDataframePartition(PandasDataframePartition):
"""
Expand Down Expand Up @@ -199,25 +198,21 @@ def mask(self, row_labels, col_labels):
self._is_debug(log) and log.debug(f"ENTER::Partition.mask::{self._identity}")
new_obj = super().mask(row_labels, col_labels)
if isinstance(row_labels, slice) and isinstance(
self._length_cache, ObjectIDType
(len_cache := self._length_cache), ObjectIDType
):
if row_labels == slice(None):
# fast path - full axis take
new_obj._length_cache = self._length_cache
new_obj._length_cache = len_cache
else:
new_obj._length_cache = compute_sliced_len.remote(
row_labels, self._length_cache
)
new_obj._length_cache = SlicerHook(len_cache, row_labels)
if isinstance(col_labels, slice) and isinstance(
self._width_cache, ObjectIDType
(width_cache := self._width_cache), ObjectIDType
):
if col_labels == slice(None):
# fast path - full axis take
new_obj._width_cache = self._width_cache
new_obj._width_cache = width_cache
else:
new_obj._width_cache = compute_sliced_len.remote(
col_labels, self._width_cache
)
new_obj._width_cache = SlicerHook(width_cache, col_labels)
self._is_debug(log) and log.debug(f"EXIT::Partition.mask::{self._identity}")
return new_obj

Expand Down Expand Up @@ -421,3 +416,53 @@ def eager_exec(self, func, *args, length=None, width=None, **kwargs):


LazyExecution.subscribe(_configure_lazy_exec)


class SlicerHook(MaterializationHook):
"""
Used by mask() for the slilced length computation.
Parameters
----------
ref : ObjectIDType
Non-materialized length to be sliced.
slc : slice
The slice to be applied.
"""

def __init__(self, ref: ObjectIDType, slc: slice):
self.ref = ref
self.slc = slc

def pre_materialize(self):
"""
Get the sliced length or object ref if not materialized.
Returns
-------
int or ObjectIDType
"""
if isinstance(self.ref, MetaListHook):
len_or_ref = self.ref.pre_materialize()
return (
compute_sliced_len(self.slc, len_or_ref)
if isinstance(len_or_ref, int)
else len_or_ref
)
return self.ref

def post_materialize(self, materialized):
"""
Get the sliced length.
Parameters
----------
materialized : list or int
Returns
-------
int
"""
if isinstance(self.ref, MetaListHook):
materialized = self.ref.post_materialize(materialized)
return compute_sliced_len(self.slc, materialized)
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class PandasOnRayDataframePartitionManager(GenericRayDataframePartitionManager):
_column_partitions_class = PandasOnRayDataframeColumnPartition
_row_partition_class = PandasOnRayDataframeRowPartition
_execution_wrapper = RayWrapper
materialize_futures = RayWrapper.materialize

@classmethod
def wait_partitions(cls, partitions):
Expand Down

0 comments on commit 47a886c

Please sign in to comment.