Skip to content

Commit

Permalink
Move to rich comparison of backends.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Sep 28, 2019
1 parent ce82053 commit fb33842
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
30 changes: 25 additions & 5 deletions uarray/_uarray_dispatch.cxx
Expand Up @@ -496,11 +496,22 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call)
auto & pref = locals->preferred;

auto should_skip =
[&](PyObject * backend)
[&](PyObject * backend) -> int
{
bool success = true;
auto it = std::find_if(
skip.begin(), skip.end(),
[&](const py_ref & be) { return be.get() == backend; });
[&](const py_ref & be)
{
auto result = PyObject_RichCompareBool(be.get(), backend, Py_EQ);
success = (result >= 0);
return (result != 0);
});

if (!success)
{
return -1;
}

return (it != skip.end());
};
Expand All @@ -509,7 +520,10 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call)
for (int i = pref.size()-1; i >= 0; --i)
{
auto options = pref[i];
if (should_skip(options.backend))
int skip_current = should_skip(options.backend);
if (skip_current < 0)
return LoopReturn::Error;
if (skip_current)
continue;

ret = call(options.backend.get(), options.coerce);
Expand All @@ -522,7 +536,10 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call)

auto& globals = global_domain_map[domain_key];
auto& global_options = globals.global;
if (global_options.backend && !should_skip(global_options.backend))
int skip_current = global_options.backend.get() != nullptr ? should_skip(global_options.backend) : 1;
if (skip_current < 0)
return LoopReturn::Error;
if (global_options.backend && !skip_current)
{
ret = call(global_options.backend.get(), global_options.coerce);
if (ret != LoopReturn::Continue)
Expand All @@ -535,7 +552,10 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call)
for (size_t i = 0; i < globals.registered.size(); ++i)
{
py_ref backend = globals.registered[i];
if (should_skip(backend))
int skip_current = should_skip(backend);
if (skip_current < 0)
return LoopReturn::Error;
if (skip_current)
continue;

ret = call(backend.get(), false);
Expand Down
38 changes: 38 additions & 0 deletions uarray/tests/test_uarray.py
Expand Up @@ -217,3 +217,41 @@ def test_invalid():
ctx1.__exit__(None, None, None)
finally:
ctx2.__exit__(None, None, None)


def test_skip_comparison(nullary_mm):
be1 = Backend()
be1.__ua_function__ = lambda f, a, kw: None

class Backend2(Backend):
@staticmethod
def __ua_function__(f, a, kw):
pass

def __eq__(self, other):
return other is self or other is be1

with pytest.raises(ua.BackendNotImplementedError):
with ua.set_backend(be1), ua.skip_backend(Backend2()):
nullary_mm()


def test_skip_raises(nullary_mm):
be1 = Backend()
be1.__ua_function__ = lambda f, a, kw: None

foo = Exception("Foo")

class Backend2(Backend):
@staticmethod
def __ua_function__(f, a, kw):
pass

def __eq__(self, other):
raise foo

with pytest.raises(Exception) as e:
with ua.set_backend(be1), ua.skip_backend(Backend2()):
nullary_mm()

assert e.value is foo

0 comments on commit fb33842

Please sign in to comment.