From fb338427bbe5d4a52e02fba1e064597b176880a0 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi Date: Sat, 28 Sep 2019 17:08:20 +0200 Subject: [PATCH] Move to rich comparison of backends. --- uarray/_uarray_dispatch.cxx | 30 ++++++++++++++++++++++++----- uarray/tests/test_uarray.py | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/uarray/_uarray_dispatch.cxx b/uarray/_uarray_dispatch.cxx index 79bef17d..bda40e84 100644 --- a/uarray/_uarray_dispatch.cxx +++ b/uarray/_uarray_dispatch.cxx @@ -496,11 +496,22 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call) auto & pref = locals->preferred; auto should_skip = - [&](PyObject * backend) + [&](PyObject * backend) -> int { + bool success = true; auto it = std::find_if( skip.begin(), skip.end(), - [&](const py_ref & be) { return be.get() == backend; }); + [&](const py_ref & be) + { + auto result = PyObject_RichCompareBool(be.get(), backend, Py_EQ); + success = (result >= 0); + return (result != 0); + }); + + if (!success) + { + return -1; + } return (it != skip.end()); }; @@ -509,7 +520,10 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call) for (int i = pref.size()-1; i >= 0; --i) { auto options = pref[i]; - if (should_skip(options.backend)) + int skip_current = should_skip(options.backend); + if (skip_current < 0) + return LoopReturn::Error; + if (skip_current) continue; ret = call(options.backend.get(), options.coerce); @@ -522,7 +536,10 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call) auto& globals = global_domain_map[domain_key]; auto& global_options = globals.global; - if (global_options.backend && !should_skip(global_options.backend)) + int skip_current = global_options.backend.get() != nullptr ? should_skip(global_options.backend) : 1; + if (skip_current < 0) + return LoopReturn::Error; + if (global_options.backend && !skip_current) { ret = call(global_options.backend.get(), global_options.coerce); if (ret != LoopReturn::Continue) @@ -535,7 +552,10 @@ LoopReturn for_each_backend(const std::string & domain_key, Callback call) for (size_t i = 0; i < globals.registered.size(); ++i) { py_ref backend = globals.registered[i]; - if (should_skip(backend)) + int skip_current = should_skip(backend); + if (skip_current < 0) + return LoopReturn::Error; + if (skip_current) continue; ret = call(backend.get(), false); diff --git a/uarray/tests/test_uarray.py b/uarray/tests/test_uarray.py index 861d2d1d..12789319 100644 --- a/uarray/tests/test_uarray.py +++ b/uarray/tests/test_uarray.py @@ -217,3 +217,41 @@ def test_invalid(): ctx1.__exit__(None, None, None) finally: ctx2.__exit__(None, None, None) + + +def test_skip_comparison(nullary_mm): + be1 = Backend() + be1.__ua_function__ = lambda f, a, kw: None + + class Backend2(Backend): + @staticmethod + def __ua_function__(f, a, kw): + pass + + def __eq__(self, other): + return other is self or other is be1 + + with pytest.raises(ua.BackendNotImplementedError): + with ua.set_backend(be1), ua.skip_backend(Backend2()): + nullary_mm() + + +def test_skip_raises(nullary_mm): + be1 = Backend() + be1.__ua_function__ = lambda f, a, kw: None + + foo = Exception("Foo") + + class Backend2(Backend): + @staticmethod + def __ua_function__(f, a, kw): + pass + + def __eq__(self, other): + raise foo + + with pytest.raises(Exception) as e: + with ua.set_backend(be1), ua.skip_backend(Backend2()): + nullary_mm() + + assert e.value is foo