Skip to content

Commit

Permalink
Merge pull request #236 from Quansight-Labs/try_last
Browse files Browse the repository at this point in the history
Add try_last option to global backends
  • Loading branch information
hameerabbasi committed Apr 13, 2020
2 parents 8b7fdce + 354afc4 commit 606c369
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 20 deletions.
11 changes: 9 additions & 2 deletions uarray/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def get_defaults(f):
return kw_defaults, tuple(arg_defaults), opts


def set_global_backend(backend, coerce=False, only=False):
def set_global_backend(backend, coerce=False, only=False, *, try_last=False):
"""
This utility method replaces the default backend for permanent use. It
will be tried in the list of backends automatically, unless the
Expand All @@ -339,13 +339,20 @@ def set_global_backend(backend, coerce=False, only=False):
----------
backend
The backend to register.
coerce : bool
Whether to coerce input types when trying this backend.
only : bool
If ``True``, no more backends will be tried if this fails.
Implied by ``coerce=True``.
try_last : bool
If ``True``, the global backend is tried after registered backends.
See Also
--------
set_backend: A context manager that allows setting of backends.
skip_backend: A context manager that allows skipping of backends.
"""
_uarray.set_global_backend(backend, coerce, only)
_uarray.set_global_backend(backend, coerce, only, try_last)


def register_backend(backend):
Expand Down
57 changes: 39 additions & 18 deletions uarray/_uarray_dispatch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,16 @@ struct backend_options {
struct global_backends {
backend_options global;
std::vector<py_ref> registered;
bool try_global_backend_last = false;
};

struct local_backends {
std::vector<py_ref> skipped;
std::vector<backend_options> preferred;
};

typedef std::unordered_map<std::string, global_backends> global_state_t;
typedef std::unordered_map<std::string, local_backends> local_state_t;
using global_state_t = std::unordered_map<std::string, global_backends>;
using local_state_t = std::unordered_map<std::string, local_backends>;

static py_ref BackendNotImplementedError;
static global_state_t global_domain_map;
Expand Down Expand Up @@ -337,13 +338,16 @@ struct BackendState {

static global_backends convert_global_backends(PyObject * input) {
PyObject *py_global, *py_registered;
if (!PyArg_ParseTuple(input, "OO", &py_global, &py_registered))
int try_global_backend_last;
if (!PyArg_ParseTuple(
input, "OOp", &py_global, &py_registered, &try_global_backend_last))
throw std::invalid_argument("");

global_backends output;
output.global = BackendState::convert_backend_options(py_global);
output.registered =
convert_iter<py_ref>(py_registered, BackendState::convert_backend);
output.try_global_backend_last = try_global_backend_last;

return output;
}
Expand Down Expand Up @@ -412,7 +416,8 @@ struct BackendState {
static py_ref convert_py(const global_backends & input) {
py_ref py_globals = BackendState::convert_py(input.global);
py_ref py_registered = BackendState::convert_py(input.registered);
py_ref output = py_make_tuple(py_globals, py_registered);
py_ref output = py_make_tuple(
py_globals, py_registered, py_bool(input.try_global_backend_last));

if (!output)
throw std::runtime_error("");
Expand Down Expand Up @@ -454,8 +459,8 @@ PyObject * clear_all_globals(PyObject * /* self */, PyObject * /* args */) {

PyObject * set_global_backend(PyObject * /* self */, PyObject * args) {
PyObject * backend;
int only = false, coerce = false;
if (!PyArg_ParseTuple(args, "O|pp", &backend, &coerce, &only))
int only = false, coerce = false, try_last = false;
if (!PyArg_ParseTuple(args, "O|ppp", &backend, &coerce, &only, &try_last))
return nullptr;

auto domain = backend_to_domain_string(backend);
Expand All @@ -467,7 +472,10 @@ PyObject * set_global_backend(PyObject * /* self */, PyObject * args) {
options.coerce = coerce;
options.only = only;

(*current_global_state)[domain].global = options;
auto & domain_globals = (*current_global_state)[domain];
domain_globals.global = options;
domain_globals.try_global_backend_last = try_last;

Py_RETURN_NONE;
}

Expand Down Expand Up @@ -500,6 +508,7 @@ void clear_single(const std::string & domain, bool registered, bool global) {

if (global) {
domain_globals->second.global.backend.reset();
domain_globals->second.try_global_backend_last = false;
}
}

Expand Down Expand Up @@ -758,17 +767,25 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call) {
}

auto & globals = (*current_global_state)[domain_key];
auto & global_options = globals.global;
int skip_current =
global_options.backend ? should_skip(global_options.backend.get()) : 1;
if (skip_current < 0)
return LoopReturn::Error;
if (!skip_current) {
ret = call(global_options.backend.get(), global_options.coerce);
if (ret != LoopReturn::Continue)
return ret;
auto try_global_backend = [&] {
auto & options = globals.global;
if (!options.backend)
return LoopReturn::Continue;

if (global_options.only || global_options.coerce)
int skip_current = should_skip(options.backend.get());
if (skip_current < 0)
return LoopReturn::Error;
if (skip_current > 0)
return LoopReturn::Continue;

return call(options.backend.get(), options.coerce);
};

if (!globals.try_global_backend_last) {
ret = try_global_backend();

bool is_last = globals.global.coerce || globals.global.only;
if (ret != LoopReturn::Continue || is_last)
return ret;
}

Expand All @@ -784,7 +801,11 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call) {
if (ret != LoopReturn::Continue)
return ret;
}
return ret;

if (!globals.try_global_backend_last) {
return ret;
}
return try_global_backend();
}

struct py_func_args {
Expand Down
14 changes: 14 additions & 0 deletions uarray/tests/test_uarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ def test_global_before_registered(nullary_mm):
assert nullary_mm() is obj


def test_global_try_last(nullary_mm):
obj = object()
obj2 = object()
be = Backend()
be.__ua_function__ = lambda f, a, kw: obj

be2 = Backend()
be2.__ua_function__ = lambda f, a, kw: obj2

ua.set_global_backend(be, try_last=True)
ua.register_backend(be2)
assert nullary_mm() is obj2


def test_global_only(nullary_mm):
obj = object()
be = Backend()
Expand Down

0 comments on commit 606c369

Please sign in to comment.