# Graph - Custom Collections
https://dask.pydata.org/en/latest/custom-collections.html

In [1]:
import dask

## Internals of the Core Dask Methods
Dask has a few core functions (and corresponding methods) that implement common operations:

- `compute`: convert one or more dask collections into their in-memory counterparts
- `persist`: convert one or more dask collections into equivalent dask collections with their results already computed and cached in memory.
- `optimize`: convert one or more dask collections into equivalent dask collections sharing one large optimized graph.
- `visualize`: given one or more dask collections, draw out the graph that would be passed to the scheduler during a call to compute or persist

### Compute
In pseudocode this process looks like:

In [2]:
def compute(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    if kwargs.pop('optimize_graph', True):
        # If optimization is turned on, group the collections by
        # optimization method, and apply each method only once to the merged
        # sub-graphs.
        optimization_groups = groupby_optimization_methods(collections)
        graphs = []
        for optimize_method, cols in optimization_groups:
            # Merge the graphs and keys for the subset of collections that
            # share this optimization method
            sub_graph = merge_graphs([x.__dask_graph__() for x in cols])
            sub_keys = [x.__dask_keys__() for x in cols]
            # kwargs are forwarded to ``__dask_optimize__`` from compute
            optimized_graph = optimize_method(sub_graph, sub_keys, **kwargs)
            graphs.append(optimized_graph)
        graph = merge_graphs(graphs)
    else:
        graph = merge_graphs([x.__dask_graph__() for x in collections])
    # Keys are always the same
    keys = [x.__dask_keys__() for x in collections]

    # 2. Computation
    # --------------
    # Determine appropriate get function based on collections, global
    # settings, and keyword arguments
    get = determine_get_function(collections, **kwargs)
    # Pass the merged graph, keys, and kwargs to ``get``
    results = get(graph, keys, **kwargs)

    # 3. Postcompute
    # --------------
    output = []
    # Iterate over the results and collections
    for res, collection in zip(results, collections):
        finalize, extra_args = collection.__dask_postcompute__()
        out = finalize(res, **extra_args)
        output.append(out)

    # `dask.compute` always returns tuples
    return tuple(output)

### Persist 
In pseudocode this looks like:

In [3]:
def persist(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    # **Same as in compute**
    graph = ...
    keys = ...

    # 2. Computation
    # --------------
    # **Same as in compute**
    results = ...

    # 3. Postpersist
    # --------------
    output = []
    # Iterate over the results and collections
    for res, collection in zip(results, collections):
        # res has the same structure as keys
        keys = collection.__dask_keys__()
        # Get the computed graph for this collection.
        # Here flatten converts a nested list into a single list
        subgraph = {k: r for (k, r) in zip(flatten(keys), flatten(res))}

        # Rebuild the output dask collection with the computed graph
        rebuild, extra_args = collection.__dask_postpersist__()
        out = rebuild(subgraph, *extra_args)

        output.append(out)

    # dask.persist always returns tuples
    return tuple(output)

### Optimize
In pseudocode this looks like:

In [4]:
def optimize(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    # **Same as in compute**
    graph = ...

    # 2. Rebuilding
    # -------------
    # Rebuild each dask collection using the same large optimized graph
    output = []
    for collection in collections:
        rebuild, extra_args = collection.__dask_postpersist__()
        out = rebuild(graph, *extra_args)
        output.append(out)

    # dask.optimize always returns tuples
    return tuple(output)

### Visualize n using graphviz and output to the specified file.
In pseudocode this looks like:

In [5]:
def visualize(*collections, **kwargs):
    # 1. Graph Merging & Optimization
    # -------------------------------
    # **Same as in compute**
    graph = ...

    # 2. Graph Drawing
    # ----------------
    # Draw the graph with graphviz's `dot` tool and return the result.
    return dot_graph(graph, **kwargs)

## Adding the Core Dask Methods to Your Class
Defining the above interface will allow your object to used by the core dask functions (`dask.compute`, `dask.persist`, `dask.visualize`, etc…). To add corresponding method versions of these subclass from `dask.base.DaskMethodsMixin`, which adds implementations of `compute`, `persist` and `visualize` based on the interface above.

## Example Dask Collection
Here we create a dask collection representing a tuple. Every element in the tuple is represented as a task in the graph. Note that this is for illustration purposes only - the same user experience could be done using normal tuples with elements of dask.delayed.

In [6]:
# Saved as dask_tuple.py
from dask.base import DaskMethodsMixin
from dask.optimization import cull

# We subclass from DaskMethodsMixin to add common dask methods to our
# class. This is nice but not necessary for creating a dask collection.
class Tuple(DaskMethodsMixin):
    def __init__(self, dsk, keys):
        # The init method takes in a dask graph and a set of keys to use
        # as outputs.
        self._dsk = dsk
        self._keys = keys

    def __dask_graph__(self):
        return self._dsk

    def __dask_keys__(self):
        return self._keys

    @staticmethod
    def __dask_optimize__(dsk, keys, **kwargs):
        # We cull unnecessary tasks here. Note that this isn't necessary,
        # dask will do this automatically, this just shows one optimization
        # you could do.
        dsk2, _ = cull(dsk, keys)
        return dsk2

    # Use the threaded scheduler by default.
    __dask_scheduler__ = staticmethod(dask.threaded.get)

    def __dask_postcompute__(self):
        # We want to return the results as a tuple, so our finalize
        # function is `tuple`. There are no extra arguments, so we also
        # return an empty tuple.
        return tuple, ()

    def __dask_postpersist__(self):
        # Since our __init__ takes a graph as its first argument, our
        # rebuild function can just be the class itself. For extra
        # arguments we also return a tuple containing just the keys.
        return Tuple, (self._keys,)

    def __dask_tokenize__(self):
        # For tokenize to work we want to return a value that fully
        # represents this object. In this case it's the list of keys
        # to be computed.
        return tuple(self._keys)

Demonstrating this class:

In [7]:
# >>> from dask_tuple import Tuple
>>> from operator import add, mul

# Define a dask graph
>>> dsk = {'a': 1,
...        'b': 2,
...        'c': (add, 'a', 'b'),
...        'd': (mul, 'b', 2),
...        'e': (add, 'b', 'c')}

# The output keys for this graph
>>> keys = ['b', 'c', 'd', 'e']

>>> x = Tuple(dsk, keys)

# Compute turns Tuple into a tuple
>>> x.compute()

(2, 3, 4, 5)

In [8]:
# Persist turns Tuple into a Tuple, with each task already computed
>>> x2 = x.persist()
>>> isinstance(x2, Tuple)

True

In [9]:
>>> x2.__dask_graph__()

{'b': 2, 'c': 3, 'd': 4, 'e': 5}

In [10]:
x2.compute()

(2, 3, 4, 5)

## Checking if an object is a dask collection
To check if an object is a dask collection, use `dask.base.is_dask_collection`:

In [11]:
>>> from dask.base import is_dask_collection
>>> from dask import delayed

>>> x = delayed(sum)([1, 2, 3])
>>> is_dask_collection(x)

True

In [12]:
is_dask_collection(1)

False

## Implementing Deterministic Hashing
Dask implements its own deterministic hash function to generate keys based on the value of arguments. This function is available as `dask.base.tokenize`. Many common types already have implementations of `tokenize`, which can be found in `dask/base.py`.

When creating your own custom classes you may need to register a `tokenize` implementation. There are two ways to do this:

In [13]:
>>> from dask.base import tokenize, normalize_token

# Define a tokenize implementation using a method.
>>> class Foo(object):
...     def __init__(self, a, b):
...         self.a = a
...         self.b = b
...
...     def __dask_tokenize__(self):
...         # This tuple fully represents self
...         return (Foo, self.a, self.b)

>>> x = Foo(1, 2)
>>> tokenize(x)

'5988362b6e07087db2bc8e7c1c8cc560'

In [14]:
tokenize(x) == tokenize(x)  # token is deterministic

True

In [15]:
# Register an implementation with normalize_token
>>> class Bar(object):
...     def __init__(self, x, y):
...         self.x = x
...         self.y = y

>>> @normalize_token.register(Bar)
... def tokenize_bar(x):
...     return (Bar, x.x, x.x)

>>> y = Bar(1, 2)
>>> tokenize(y)

'5a7e9c3645aa44cf13d021c14452152e'

In [16]:
>>> tokenize(y) == tokenize(y)

True

In [17]:
>>> tokenize(y) == tokenize(x)  # tokens for different objects aren't equal

False