Skip to content

Commit

Permalink
FEAT-modin-project#7001: Do not force materialization in MetaList.__g…
Browse files Browse the repository at this point in the history
…etitem__()

Signed-off-by: Andrey Pavlenko <andrey.a.pavlenko@gmail.com>
  • Loading branch information
AndreyPavlenko committed Mar 5, 2024
1 parent cd3d0c6 commit ae05a5a
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 30 deletions.
3 changes: 2 additions & 1 deletion modin/core/execution/ray/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

"""Common utilities for Ray execution engine."""

from .engine_wrapper import RayWrapper, SignalActor
from .engine_wrapper import ObjectRefMapper, RayWrapper, SignalActor
from .utils import initialize_ray

__all__ = [
"initialize_ray",
"RayWrapper",
"ObjectRefMapper",
"SignalActor",
]
47 changes: 43 additions & 4 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ray._private.services import get_node_ip_address
from ray.util.client.common import ClientObjectRef

from modin.core.execution.ray.common import RayWrapper
from modin.core.execution.ray.common import ObjectRefMapper, RayWrapper
from modin.logging import get_logger

ObjectRefType = Union[ray.ObjectRef, ClientObjectRef, None]
Expand Down Expand Up @@ -491,9 +491,7 @@ def __getitem__(self, index):
Any
"""
obj = self._obj
if not isinstance(obj, list):
self._obj = obj = RayWrapper.materialize(obj)
return obj[index]
return obj[index] if isinstance(obj, list) else MetaListMapper(self, index)

def __setitem__(self, index, value):
"""
Expand All @@ -510,6 +508,47 @@ def __setitem__(self, index, value):
obj[index] = value


class MetaListMapper(ObjectRefMapper):
"""
Used by MetaList.__getitem__() for lazy materialization.
Parameters
----------
meta : MetaList
idx : int
"""

def __init__(self, meta: MetaList, idx: int):
self.meta = meta
self.idx = idx

def get(self):
"""
Get item at self.idx or object ref if not materialized.
Returns
-------
object
"""
obj = self.meta._obj
return obj[self.idx] if isinstance(obj, list) else obj

def map(self, materialized):
"""
Save the materialized list in self.meta and get the item at self.idx.
Parameters
----------
materialized : list
Returns
-------
object
"""
self.meta._obj = materialized
return materialized[self.idx]


class _Tag(Enum): # noqa: PR01
"""
A set of special values used for the method arguments de/construction.
Expand Down
105 changes: 96 additions & 9 deletions modin/core/execution/ray/common/engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import asyncio
import os
from types import FunctionType
from typing import Sequence

import ray
from ray.util.client.common import ClientObjectRef
Expand Down Expand Up @@ -96,8 +97,7 @@ def is_future(cls, item):
boolean
If the value is a future.
"""
ObjectIDType = (ray.ObjectRef, ClientObjectRef)
return isinstance(item, ObjectIDType)
return isinstance(item, ObjectRefTypes)

@classmethod
def materialize(cls, obj_id):
Expand All @@ -114,7 +114,56 @@ def materialize(cls, obj_id):
object
Whatever was identified by `obj_id`.
"""
return ray.get(obj_id)
if isinstance(obj_id, ObjectRefMapper):
obj = obj_id.get()
return (
obj_id.map(ray.get(obj)) if isinstance(obj, RayObjectRefTypes) else obj
)

if not isinstance(obj_id, Sequence):
return ray.get(obj_id) if isinstance(obj_id, RayObjectRefTypes) else obj_id

if all(isinstance(obj, RayObjectRefTypes) for obj in obj_id):
return ray.get(obj_id)

ids = {}
result = []
for obj in obj_id:
if isinstance(obj, ObjectRefTypes):
if isinstance(obj, ObjectRefMapper):
oid = obj.get()
if isinstance(oid, RayObjectRefTypes):
mapper = obj
obj = oid
else:
result.append(oid)
continue
else:
mapper = None
else:
result.append(obj)
continue

idx = ids.get(obj, None)
if idx is None:
ids[obj] = idx = len(ids)
if mapper is None:
result.append(obj)
else:
mapper._materialized_idx = idx
result.append(mapper)

if len(ids) == 0:
return result

materialized = ray.get(list(ids.keys()))
for i in range(len(result)):
if isinstance((obj := result[i]), ObjectRefTypes):
if isinstance(obj, ObjectRefMapper):
result[i] = obj.map(materialized[obj._materialized_idx])
else:
result[i] = materialized[ids[obj]]
return result

@classmethod
def put(cls, data, **kwargs):
Expand Down Expand Up @@ -161,12 +210,18 @@ def wait(cls, obj_ids, num_returns=None):
obj_ids : list, scalar
num_returns : int, optional
"""
if not isinstance(obj_ids, list):
obj_ids = [obj_ids]
unique_ids = list(set(obj_ids))
if num_returns is None:
num_returns = len(unique_ids)
ray.wait(unique_ids, num_returns=num_returns)
if not isinstance(obj_ids, Sequence):
obj_ids = list(obj_ids)

ids = set()
for obj in obj_ids:
if isinstance(obj, ObjectRefMapper):
obj = obj.get()
if isinstance(obj, RayObjectRefTypes):
ids.add(obj)

if num_ids := len(ids):
ray.wait(list(ids), num_returns=num_returns or num_ids)


@ray.remote
Expand Down Expand Up @@ -218,3 +273,35 @@ def is_set(self, event_idx: int) -> bool:
bool
"""
return self.events[event_idx].is_set()


class ObjectRefMapper:
"""Map the materialized object to a different value."""

def get(self):
"""
Get an object reference or the cached, previously mapped value.
Returns
-------
ray.ObjectRef or object
"""
raise NotImplementedError()

def map(self, materialized):
"""
Map the materialized object.
Parameters
----------
materialized : object
Returns
-------
object
"""
raise NotImplementedError()


RayObjectRefTypes = (ray.ObjectRef, ClientObjectRef)
ObjectRefTypes = (*RayObjectRefTypes, ObjectRefMapper)
5 changes: 2 additions & 3 deletions modin/core/execution/ray/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import psutil
import ray
from packaging import version
from ray.util.client.common import ClientObjectRef

from modin.config import (
CIAWSAccessKeyID,
Expand All @@ -40,7 +39,7 @@
from modin.core.execution.utils import set_env
from modin.error_message import ErrorMessage

from .engine_wrapper import RayWrapper
from .engine_wrapper import ObjectRefTypes, RayWrapper

_OBJECT_STORE_TO_SYSTEM_MEMORY_RATIO = 0.6
# This constant should be in sync with the limit in ray, which is private,
Expand All @@ -50,7 +49,7 @@

_RAY_IGNORE_UNHANDLED_ERRORS_VAR = "RAY_IGNORE_UNHANDLED_ERRORS"

ObjectIDType = (ray.ObjectRef, ClientObjectRef)
ObjectIDType = ObjectRefTypes


def initialize_ray(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@

from modin.config import LazyExecution
from modin.core.dataframe.pandas.partitioning.partition import PandasDataframePartition
from modin.core.execution.ray.common import RayWrapper
from modin.core.execution.ray.common import ObjectRefMapper, RayWrapper
from modin.core.execution.ray.common.deferred_execution import (
DeferredExecution,
MetaList,
MetaListMapper,
)
from modin.core.execution.ray.common.utils import ObjectIDType
from modin.logging import get_logger
from modin.pandas.indexing import compute_sliced_len
from modin.utils import _inherit_docstrings

compute_sliced_len = ray.remote(compute_sliced_len)


class PandasOnRayDataframePartition(PandasDataframePartition):
"""
Expand Down Expand Up @@ -199,25 +198,21 @@ def mask(self, row_labels, col_labels):
self._is_debug(log) and log.debug(f"ENTER::Partition.mask::{self._identity}")
new_obj = super().mask(row_labels, col_labels)
if isinstance(row_labels, slice) and isinstance(
self._length_cache, ObjectIDType
(len_cache := self._length_cache), ObjectIDType
):
if row_labels == slice(None):
# fast path - full axis take
new_obj._length_cache = self._length_cache
new_obj._length_cache = len_cache
else:
new_obj._length_cache = compute_sliced_len.remote(
row_labels, self._length_cache
)
new_obj._length_cache = SlicedLenMapper(len_cache, row_labels)
if isinstance(col_labels, slice) and isinstance(
self._width_cache, ObjectIDType
(width_cache := self._width_cache), ObjectIDType
):
if col_labels == slice(None):
# fast path - full axis take
new_obj._width_cache = self._width_cache
new_obj._width_cache = width_cache
else:
new_obj._width_cache = compute_sliced_len.remote(
col_labels, self._width_cache
)
new_obj._width_cache = SlicedLenMapper(width_cache, col_labels)
self._is_debug(log) and log.debug(f"EXIT::Partition.mask::{self._identity}")
return new_obj

Expand Down Expand Up @@ -421,3 +416,51 @@ def eager_exec(self, func, *args, length=None, width=None, **kwargs):


LazyExecution.subscribe(_configure_lazy_exec)


class SlicedLenMapper(ObjectRefMapper):
"""
Used by mask() for the slilced length computation.
Parameters
----------
ref : ObjectIDType
slc : slice
"""

def __init__(self, ref: ObjectIDType, slc: slice):
self.ref = ref
self.slc = slc

def get(self):
"""
Get the sliced length or object ref if not materialized.
Returns
-------
int or ObjectIDType
"""
if isinstance(self.ref, MetaListMapper):
len_or_ref = self.ref.get()
return (
compute_sliced_len(self.slc, len_or_ref)
if isinstance(len_or_ref, int)
else len_or_ref
)
return self.ref

def map(self, materialized):
"""
Get the sliced length.
Parameters
----------
materialized : list or int
Returns
-------
int
"""
if isinstance(self.ref, MetaListMapper):
materialized = self.ref.map(materialized)
return compute_sliced_len(self.slc, materialized)
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class PandasOnRayDataframePartitionManager(GenericRayDataframePartitionManager):
_column_partitions_class = PandasOnRayDataframeColumnPartition
_row_partition_class = PandasOnRayDataframeRowPartition
_execution_wrapper = RayWrapper
materialize_futures = RayWrapper.materialize

@classmethod
def wait_partitions(cls, partitions):
Expand Down

0 comments on commit ae05a5a

Please sign in to comment.