Skip to content

Commit

Permalink
Add context managers for backend determination (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed May 15, 2020
1 parent a1c98d6 commit ff57432
Show file tree
Hide file tree
Showing 7 changed files with 393 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/generated/uarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ uarray
get_state
set_state
reset_state
determine_backend
determine_backend_multi



Expand Down
2 changes: 1 addition & 1 deletion uarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
... overridden_me(1, "2")
Traceback (most recent call last):
...
uarray.backend.BackendNotImplementedError: ...
uarray.BackendNotImplementedError: ...
The last possibility is if we don't have ``__ua_convert__``, in which case the job is left
up to ``__ua_function__``, but putting things back into arrays after conversion will not be
Expand Down
168 changes: 167 additions & 1 deletion uarray/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"set_global_backend",
"skip_backend",
"register_backend",
"determine_backend",
"determine_backend_multi",
"clear_backends",
"create_multimethod",
"generate_multimethod",
Expand Down Expand Up @@ -220,7 +222,8 @@ def generate_multimethod(
>>> overridden_me(1, "a")
Traceback (most recent call last):
...
uarray.backend.BackendNotImplementedError: ...
uarray.BackendNotImplementedError: ...
>>> overridden_me2 = generate_multimethod(
... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
... )
Expand Down Expand Up @@ -537,3 +540,166 @@ def __ua_convert__(self, dispatchables, coerce):
return converted

return __ua_convert__


def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False):
"""Set the backend to the first active backend that supports ``value``
This is useful for functions that call multimethods without any dispatchable
arguments. You can use :func:`determine_backend` to ensure the same backend
is used everywhere in a block of multimethod calls.
Parameters
----------
value
The value being tested
dispatch_type
The dispatch type associated with ``value``, aka
":ref:`marking <MarkingGlossary>`".
domain: string
The domain to query for backends and set.
coerce: bool
Whether or not to allow coercion to the backend's types. Implies ``only``.
only: bool
Whether or not this should be the last backend to try.
See Also
--------
set_backend: For when you know which backend to set
Notes
-----
Support is determined by the ``__ua_convert__`` protocol. Backends not
supporting the type must return ``NotImplemented`` from their
``__ua_convert__`` if they don't support input of that type.
Examples
--------
Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting
different types, ``TypeA`` and ``TypeB``. Neither supporting the other type:
>>> with ua.set_backend(ex.BackendA):
... ex.call_multimethod(ex.TypeB(), ex.TypeB())
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
Now consider a multimethod that creates a new object of ``TypeA``, or
``TypeB`` depending on the active backend.
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, ex.TypeA())
Traceback (most recent call last):
...
uarray.BackendNotImplementedError: ...
``res`` is an object of ``TypeB`` because ``BackendB`` is set in the
innermost with statement. So, ``call_multimethod`` fails since the types
don't match.
Instead, we need to first find a backend suitable for all of our objects.
>>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
... x = ex.TypeA()
... with ua.determine_backend(x, "mark", domain="ua_examples"):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, x)
TypeA
"""
dispatchables = (Dispatchable(value, dispatch_type, coerce),)
backend = _uarray.determine_backend(domain, dispatchables, coerce)

return set_backend(backend, coerce=coerce, only=only)


def determine_backend_multi(
dispatchables, *, domain, only=True, coerce=False, **kwargs
):
"""Set a backend supporting all ``dispatchables``
This is useful for functions that call multimethods without any dispatchable
arguments. You can use :func:`determine_backend_multi` to ensure the same
backend is used everywhere in a block of multimethod calls involving
multiple arrays.
Parameters
----------
dispatchables: Sequence[Union[uarray.Dispatchable, Any]]
The dispatchables that must be supported
domain: string
The domain to query for backends and set.
coerce: bool
Whether or not to allow coercion to the backend's types. Implies ``only``.
only: bool
Whether or not this should be the last backend to try.
dispatch_type: Optional[Any]
The default dispatch type associated with ``dispatchables``, aka
":ref:`marking <MarkingGlossary>`".
See Also
--------
determine_backend: For a single dispatch value
set_backend: For when you know which backend to set
Notes
-----
Support is determined by the ``__ua_convert__`` protocol. Backends not
supporting the type must return ``NotImplemented`` from their
``__ua_convert__`` if they don't support input of that type.
Examples
--------
:func:`determine_backend` allows the backend to be set from a single
object. :func:`determine_backend_multi` allows multiple objects to be
checked simultaneously for support in the backend. Suppose we have a
``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call,
and a ``BackendBC`` that doesn't support ``TypeA``.
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
... a, b = ex.TypeA(), ex.TypeB()
... with ua.determine_backend_multi(
... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")],
... domain="ua_examples"
... ):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, a, b)
TypeA
This won't call ``BackendBC`` because it doesn't support ``TypeA``.
We can also use leave out the ``ua.Dispatchable`` if we specify the
default ``dispatch_type`` for the ``dispatchables`` argument.
>>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
... a, b = ex.TypeA(), ex.TypeB()
... with ua.determine_backend_multi(
... [a, b], dispatch_type="mark", domain="ua_examples"
... ):
... res = ex.creation_multimethod()
... ex.call_multimethod(res, a, b)
TypeA
"""
if "dispatch_type" in kwargs:
disp_type = kwargs.pop("dispatch_type")
dispatchables = tuple(
d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type)
for d in dispatchables
)
else:
dispatchables = tuple(dispatchables)
if not all(isinstance(d, Dispatchable) for d in dispatchables):
raise TypeError("dispatchables must be instances of uarray.Dispatchable")

if len(kwargs) != 0:
raise TypeError("Received unexpected keyword arguments: {}".format(kwargs))

backend = _uarray.determine_backend(domain, dispatchables, coerce)

return set_backend(backend, coerce=coerce, only=only)
59 changes: 59 additions & 0 deletions uarray/_uarray_dispatch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,64 @@ PyObject * set_state(PyObject * /* self */, PyObject * args) {
Py_RETURN_NONE;
}

PyObject * determine_backend(PyObject * /*self*/, PyObject * args) {
PyObject *domain_object, *dispatchables;
int coerce;
if (!PyArg_ParseTuple(
args, "OOp:determine_backend", &domain_object, &dispatchables,
&coerce))
return nullptr;

auto domain = domain_to_string(domain_object);
if (domain.empty())
return nullptr;

auto dispatchables_tuple = py_ref::steal(PySequence_Tuple(dispatchables));
if (!dispatchables_tuple)
return nullptr;

py_ref selected_backend;
auto result = for_each_backend_in_domain(
domain, [&](PyObject * backend, bool coerce_backend) {
auto ua_convert = py_ref::steal(
PyObject_GetAttr(backend, identifiers.ua_convert.get()));

if (!ua_convert) {
// If no __ua_convert__, assume it won't accept the type
PyErr_Clear();
return LoopReturn::Continue;
}

auto convert_args = py_make_tuple(
dispatchables_tuple, py_bool(coerce && coerce_backend));
if (!convert_args)
return LoopReturn::Error;

auto res = py_ref::steal(
PyObject_Call(ua_convert.get(), convert_args.get(), nullptr));
if (!res) {
return LoopReturn::Error;
}

if (res == Py_NotImplemented) {
return LoopReturn::Continue;
}

// __ua_convert__ succeeded, so select this backend
selected_backend = py_ref::ref(backend);
return LoopReturn::Break;
});

if (result != LoopReturn::Continue)
return selected_backend.get();

// All backends failed, raise an error
PyErr_SetString(
BackendNotImplementedError.get(),
"No backends could accept input of this type.");
return nullptr;
}


// getset takes mutable char * in python < 3.7
static char dict__[] = "__dict__";
Expand Down Expand Up @@ -1683,6 +1741,7 @@ PyMethodDef method_defs[] = {
{"register_backend", register_backend, METH_VARARGS, nullptr},
{"clear_all_globals", clear_all_globals, METH_NOARGS, nullptr},
{"clear_backends", clear_backends, METH_VARARGS, nullptr},
{"determine_backend", determine_backend, METH_VARARGS, nullptr},
{"get_state", get_state, METH_NOARGS, nullptr},
{"set_state", set_state, METH_VARARGS, nullptr},
{NULL} /* Sentinel */
Expand Down
10 changes: 10 additions & 0 deletions uarray/tests/conftest.py → uarray/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import sys
import uarray
import pytest # type: ignore

from .tests import example_helpers


def pytest_cmdline_preparse(args):
Expand All @@ -17,3 +21,9 @@ def pytest_cmdline_preparse(args):
else:
args.append("--mypy")
print("uarray: Enabling pytest-mypy")


@pytest.fixture(autouse=True)
def add_namespaces(doctest_namespace):
doctest_namespace["ua"] = uarray
doctest_namespace["ex"] = example_helpers
46 changes: 46 additions & 0 deletions uarray/tests/example_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import uarray as ua


class _TypedBackend:
__ua_domain__ = "ua_examples"

def __init__(self, *my_types):
self.my_types = my_types

def __ua_convert__(self, dispatchables, coerce):
if not all(type(d.value) in self.my_types for d in dispatchables):
return NotImplemented
return tuple(d.value for d in dispatchables)

def __ua_function__(self, func, args, kwargs):
return self.my_types[0]()


class TypeA:
@classmethod
def __repr__(cls):
return cls.__name__


class TypeB(TypeA):
pass


class TypeC(TypeA):
pass


BackendA = _TypedBackend(TypeA)
BackendB = _TypedBackend(TypeB)
BackendC = _TypedBackend(TypeC)
BackendAB = _TypedBackend(TypeA, TypeB)
BackendBC = _TypedBackend(TypeB, TypeC)

creation_multimethod = ua.generate_multimethod(
lambda: (), lambda a, kw, d: (a, kw), "ua_examples"
)
call_multimethod = ua.generate_multimethod(
lambda *a: tuple(ua.Dispatchable(x, "mark") for x in a),
lambda a, kw, d: (a, kw),
"ua_examples",
)

0 comments on commit ff57432

Please sign in to comment.