Skip to content

Commit

Permalink
Merge pull request #238 from Quansight-Labs/heirarchical_domains
Browse files Browse the repository at this point in the history
Add hierarchical domain dispatch
  • Loading branch information
hameerabbasi committed Apr 17, 2020
2 parents 606c369 + 9a8d7f2 commit ed4307e
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 20 deletions.
10 changes: 7 additions & 3 deletions docs/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ or temporarily.
Domain
------

A domain is a collection or grouping of multimethods. A domain's string,
by convention (although not by force), is the name of the module that provides
the multimethods.
A domain defines the hierarchical grouping of multimethods. The domain string
is, by convention, the name of the module that provides the multimethods.

Sub-domains are denoted with a separating ``.``. For example, a multimethod in
``"numpy.fft"`` is also considered to be in the domain ``"numpy"``. When calling
a multimethod, the backends for the most specific sub-domain are always tried first,
followed by the next domain up the hierarchy.

Dispatching
-----------
Expand Down
72 changes: 55 additions & 17 deletions uarray/_uarray_dispatch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ py_ref py_bool(bool input) { return py_ref::ref(input ? Py_True : Py_False); }

struct backend_options {
py_ref backend;
bool coerce, only;
bool coerce = false;
bool only = false;

bool operator==(const backend_options & other) const {
return (
Expand Down Expand Up @@ -719,20 +720,35 @@ struct SkipBackendContext {
}
};

const local_backends & get_local_backends(const std::string & domain_key) {
static const local_backends null_local_backends;
auto itr = local_domain_map.find(domain_key);
if (itr == local_domain_map.end()) {
return null_local_backends;
}
return itr->second;
}


const global_backends & get_global_backends(const std::string & domain_key) {
static const global_backends null_global_backends;
const auto & cur_globals = *current_global_state;
auto itr = cur_globals.find(domain_key);
if (itr == cur_globals.end()) {
return null_global_backends;
}
return itr->second;
}

enum class LoopReturn { Continue, Break, Error };

template <typename Callback>
LoopReturn for_each_backend(const std::string & domain_key, Callback call) {
local_backends * locals = nullptr;
try {
locals = &local_domain_map[domain_key];
} catch (std::bad_alloc &) {
PyErr_NoMemory();
return LoopReturn::Error;
}
LoopReturn for_each_backend_in_domain(
const std::string & domain_key, Callback call) {
const local_backends & locals = get_local_backends(domain_key);

auto & skip = locals->skipped;
auto & pref = locals->preferred;
auto & skip = locals.skipped;
auto & pref = locals.preferred;

auto should_skip = [&](PyObject * backend) -> int {
bool success = true;
Expand Down Expand Up @@ -763,10 +779,10 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call) {
return ret;

if (options.only || options.coerce)
return ret;
return LoopReturn::Break;
}

auto & globals = (*current_global_state)[domain_key];
auto & globals = get_global_backends(domain_key);
auto try_global_backend = [&] {
auto & options = globals.global;
if (!options.backend)
Expand All @@ -783,10 +799,11 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call) {

if (!globals.try_global_backend_last) {
ret = try_global_backend();

bool is_last = globals.global.coerce || globals.global.only;
if (ret != LoopReturn::Continue || is_last)
if (ret != LoopReturn::Continue)
return ret;

if (globals.global.only || globals.global.coerce)
return LoopReturn::Break;
}

for (size_t i = 0; i < globals.registered.size(); ++i) {
Expand All @@ -808,6 +825,24 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call) {
return try_global_backend();
}

template <typename Callback>
LoopReturn for_each_backend(std::string domain, Callback call) {
do {
auto ret = for_each_backend_in_domain(domain, call);
if (ret != LoopReturn::Continue) {
return ret;
}

auto dot_pos = domain.rfind('.');
if (dot_pos == std::string::npos) {
return ret;
}

domain.resize(dot_pos);
} while (!domain.empty());
return LoopReturn::Continue;
}

struct py_func_args {
py_ref args, kwargs;
};
Expand Down Expand Up @@ -1115,7 +1150,10 @@ PyObject * Function::call(PyObject * args_, PyObject * kwargs_) {
return LoopReturn::Break; // Backend called successfully
});

if (ret != LoopReturn::Continue)
if (ret == LoopReturn::Error)
return nullptr;

if (result && result != Py_NotImplemented)
return result.release();

if (def_impl_ != Py_None) {
Expand Down
66 changes: 66 additions & 0 deletions uarray/tests/test_uarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,69 @@ def test_pickle_state():
state_loaded = pickle.loads(pickle.dumps(state))

assert state._pickle() == state_loaded._pickle()


def test_hierarchical_backends():
mm = ua.generate_multimethod(
lambda: (), lambda a, kw, d: (a, kw), "ua_tests.foo.bar"
)
subdomains = "ua_tests.foo.bar".split(".")
depth = len(subdomains)

mms = [
ua.generate_multimethod(
lambda: (), lambda a, kw, d: (a, kw), ".".join(subdomains[: i + 1])
)
for i in range(depth)
]

class DisableBackend:
def __init__(self, domain):
self.__ua_domain__ = domain
self.active = True
self.ret = object()

def __ua_function__(self, f, a, kw):
if self.active:
return self.ret

raise ua.BackendNotImplementedError(self.__ua_domain__)

be = [DisableBackend(".".join(subdomains[: i + 1])) for i in range(depth)]

ua.set_global_backend(be[1])
with pytest.raises(ua.BackendNotImplementedError):
mms[0]()

for i in range(1, depth):
assert mms[i]() is be[1].ret

ua.set_global_backend(be[0])

for i in range(depth):
assert mms[i]() is be[min(i, 1)].ret

ua.set_global_backend(be[2])

for i in range(depth):
assert mms[i]() is be[i].ret

be[2].active = False
for i in range(depth):
print(i)
assert mms[i]() is be[min(i, 1)].ret

be[1].active = False
for i in range(depth):
assert mms[i]() is be[0].ret

be[0].active = False
for i in range(depth):
with pytest.raises(ua.BackendNotImplementedError):
mms[i]()

# only=True prevents all further domain checking
be[0].active = True
be[1].active = True
with ua.set_backend(be[2], only=True), pytest.raises(ua.BackendNotImplementedError):
mms[2]()

0 comments on commit ed4307e

Please sign in to comment.