# `ray` Executor

In [1]:
%run -m literary.notebook

import time

In [2]:
import asyncio
import weakref

import ray
import ray.exceptions

from .executor import AsyncExecutor



The Ray executor uses the globally initialised state.

In [3]:
class RayExecutor(AsyncExecutor):
    pass

To submit a task, we invoke the `distributed.Client.submit` method, and then wrap the result in an `asyncio.Future` handle. This future holds the status of the running task.

In [4]:
@patch(RayExecutor)
def _apply(self, func, /, *args, **kwargs) -> asyncio.Future:
    args, kwargs = self._process_args(args, kwargs)
    return _ray_call.remote(func, *args, **kwargs)

Ray requires that the `remote` object is a free-function. Here, we implement a proxy that calls the passed method.

In [5]:
@ray.remote
def _ray_call(func, *args, **kwargs):
    return func(*args, **kwargs)

Ray can also accept `ray.ObjectRef` objects as arguments to other calls to `submit`. This is useful to avoid a round-trip of the data for chained computations. In order to support this, we process the arguments passed to `submit`:

In [6]:
@patch(RayExecutor)
def _process_args(self, args, kwargs):
    # Unwrap any wrapped handles
    args = [self._unwrap_maybe(x) for x in args]
    kwargs = {k: self._unwrap_maybe(v) for k, v in kwargs.items()}
    return args, kwargs

Here, we unwrap the `asyncio.Future` handles in the arguments so that they are visible to Ray.

In [7]:
@patch(RayExecutor)
def _unwrap_maybe(self, obj):
    try:
        return self._unwrap_handle(obj)
    except ValueError:
        return obj

In order to connect the handle with the `ray.ObjectRef` result, we implement a routine to chain these with `asyncio.Future` objects:

In [8]:
@patch(RayExecutor)
def _register_handle(self, handle: asyncio.Future, ref: ray.ObjectRef):
    def on_ref_completed_threadsafe(result, handle_ref=weakref.ref(handle)):
        if not (handle := handle_ref()):
            return

        if handle.cancelled():
            return

        if isinstance(result, ray.exceptions.RayTaskError):
            handle.set_exception(result.as_instanceof_cause())
        elif isinstance(result, ray.exceptions.RayError):
            handle.set_exception(result)
        else:
            handle.set_result(True)

    loop = asyncio.get_running_loop()

    @ref._on_completed
    def on_ref_completed(result):
        loop.call_soon_threadsafe(on_ref_completed_threadsafe, result)

    @handle.add_done_callback
    def on_fut_done(fut):
        if fut.cancelled():
            ray.cancel(ref)

Finally, we implement a method to unwrap and return the result object for the Ray ref:

In [9]:
@patch(RayExecutor)
async def retrieve(self, handle: asyncio.Future):
    return await self._unwrap_handle(handle)

To demonstrate this, we can first create a local cluster:

In [10]:
ray.init()

2021-05-05 18:31:52,200	INFO services.py:1267 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8266[39m[22m


{'node_ip_address': '192.168.1.123',
 'raylet_ip_address': '192.168.1.123',
 'redis_address': '192.168.1.123:44142',
 'object_store_address': '/tmp/ray/session_2021-05-05_18-31-51_197946_1045065/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2021-05-05_18-31-51_197946_1045065/sockets/raylet',
 'webui_url': '127.0.0.1:8266',
 'session_dir': '/tmp/ray/session_2021-05-05_18-31-51_197946_1045065',
 'metrics_export_port': 65243,
 'node_id': '65d8770e13b5a19c407f5a815838426cd5a7f5111cbb360c5e3df694'}

Using this asychronous client, we can create an executor:

In [11]:
executor = RayExecutor()

To do some work, let's implement a sleep function that returns the delay

In [12]:
def slow_function(timeout):
    time.sleep(timeout)
    return timeout

Now we can chain a few of these tasks together:

In [13]:
a = executor.submit(
    slow_function,
    2,
)
b = executor.submit(slow_function, 5)
c = executor.submit(int.__add__, a, b)

We can wait for the result without retrieving its value:

In [14]:
await c

True

And when we're ready for the value, we invoke `executor.retrieve`.

In [15]:
assert await executor.retrieve(c) == 7