# `dask.distributed` Executor

In [1]:
%run -m literary.notebook

import time

In [2]:
import asyncio
import enum
import weakref

from dask import distributed

from .executor import Executor
from .futures import create_future
from .wrap import AttributeHandleWrapper

Dask uses raw strings for status codes. Here we define them as a Python enum to avoid errors later on.

In [3]:
class DaskStatus(str, enum.Enum):
    FINISHED = "finished"
    CANCELLED = "cancelled"
    LOST = "lost"
    PENDING = "pending"
    ERROR = "error"

The basic Dask executor wraps a `distributed.Client` object. 

In [4]:
class DaskExecutor(Executor):
    def __init__(self, client: distributed.Client):
        assert client.asynchronous

        self._client = client
        self._wrapper = AttributeHandleWrapper()

To submit a task, we invoke the `distributed.Client.submit` method, and then wrap the result in an `asyncio.Future` handle. This future holds the status of the running task.

In [5]:
@patch(DaskExecutor)
def submit(self, func, /, *args, **kwargs) -> asyncio.Future:
    args, kwargs = self._process_args(args, kwargs)
    dask_fut = self._client.submit(func, *args, **kwargs)

    handle = create_future()
    self._chain_futures(handle, dask_fut)
    self._wrapper.wrap(handle, dask_fut)
    return handle

Dask can also accept `distributed.Future` objects as arguments to other calls to `submit`. This is useful to avoid a round-trip of the data for chained computations. In order to support this, we process the arguments passed to `submit`:

In [6]:
@patch(DaskExecutor)
def _process_args(self, args, kwargs):
    # Unwrap any wrapped handles
    args = [self._unwrap_maybe(x) for x in args]
    kwargs = {k: self._unwrap_maybe(v) for k, v in kwargs.items()}
    return args, kwargs

Here, we unwrap the `asyncio.Future` handles in the arguments so that they are visible to Dask.

In [7]:
@patch(DaskExecutor)
def _unwrap_maybe(self, obj):
    try:
        return self._wrapper.unwrap(obj)
    except AttributeError:
        return obj

In order to connect the handle with the `distributed.Future` result, we implement a routine to chain these with `asyncio.Future` objects:

In [8]:
@patch(DaskExecutor)
def _chain_futures(self, fut: asyncio.Future, dask_fut: distributed.Future):
    @dask_fut.add_done_callback
    def on_dask_fut_done(dask_fut, fut_ref=weakref.ref(fut)):
        fut = fut_ref()
        if fut is None:
            return

        if fut.cancelled():
            return

        if dask_fut.status == DaskStatus.FINISHED:
            fut.set_result(True)
        elif dask_fut.status == DaskStatus.CANCELLED:
            fut.cancel()
        else:
            try:
                typ, exc, tb = dask_fut.result()
                raise exc.with_traceback(tb)
            except BaseException as err:
                fut.set_exception(err)

    @fut.add_done_callback
    def on_fut_done(fut):
        if fut.cancelled():
            asyncio.create_task(dask_fut.cancel(force=True))

Finally, we implement a method to unwrap and return the result object for the Dask future:

In [9]:
@patch(DaskExecutor)
async def retrieve(self, handle: asyncio.Future):
    dask_fut = self._wrapper.unwrap(handle)
    return await dask_fut

To demonstrate this, we can first create a local cluster:

In [28]:
client = await distributed.Client(
    scheduler_port=0, dashboard_address=":0", asynchronous=True
)
client

0,1
Client  Scheduler: tcp://127.0.0.1:43035  Dashboard: http://127.0.0.1:39005/status,Cluster  Workers: 4  Cores: 12  Memory: 15.56 GiB


Using this asychronous client, we can create an executor:

In [12]:
executor = DaskExecutor(client)

To do some work, let's implement a sleep function that returns the delay

In [22]:
def slow_function(timeout):
    time.sleep(timeout)
    return timeout

Now we can chain a few of these tasks together:

In [23]:
a = executor.submit(
    slow_function,
    2,
)
b = executor.submit(slow_function, 5)
c = executor.submit(int.__add__, a, b)

We can wait for the result without retrieving its value:

In [24]:
await c

True

And when we're ready for the value, we invoke `executor.retrieve`.

In [26]:
assert await executor.retrieve(c) == 7