diff --git a/.conda/environment.yml b/.conda/environment.yml index 19513888..eacfd1ce 100644 --- a/.conda/environment.yml +++ b/.conda/environment.yml @@ -15,7 +15,7 @@ dependencies: - pytorch-cpu - scipy - gumath - - dask + - dask=1.2 - sparse - doc8 - black diff --git a/setup.py b/setup.py index 683c22d7..aa9b851d 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ #!/usr/bin/env python -from setuptools import setup, find_packages +from setuptools import setup, find_packages, Extension +from setuptools.command.build_ext import build_ext import versioneer from pathlib import Path import sys @@ -43,10 +44,30 @@ def parse_requires(): with open("README.md") as f: long_desc = f.read() +class build_cpp11_ext(build_ext): + def build_extension(self, ext): + cc = self.compiler + if cc.compiler_type == "unix": + ext.extra_compile_args.append("--std=c++11") + build_ext.build_extension(self, ext) + + +cmdclass = {"build_ext": build_cpp11_ext} +cmdclass.update(versioneer.get_cmdclass()) + + +extensions = [ + Extension( + "uarray._uarray", + sources=["uarray/_uarray_dispatch.cxx"], + language="c++", + ) +] + setup( name="uarray", version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), + cmdclass=cmdclass, description="Array interface object for Python with pluggable backends and a multiple-dispatch" "mechanism for defining down-stream functions", url="https://github.com/Quansight-Labs/uarray/", @@ -79,4 +100,5 @@ def parse_requires(): "Tracker": "https://github.com/Quansight-Labs/uarray/issues", }, python_requires=">=3.5, <4", + ext_modules=extensions ) diff --git a/uarray/_backend.py b/uarray/_backend.py index 2f478519..046c7f32 100644 --- a/uarray/_backend.py +++ b/uarray/_backend.py @@ -14,15 +14,16 @@ from contextvars import ContextVar import functools import contextlib +from . import _uarray # type: ignore ArgumentExtractorType = Callable[..., Tuple["Dispatchable", ...]] ArgumentReplacerType = Callable[[Tuple, Dict, Tuple], Tuple[Tuple, Dict]] +from ._uarray import BackendNotImplementedError -class BackendNotImplementedError(NotImplementedError): - """ - An exception that is thrown when no compatible backend is found for a method. - """ +import atexit + +atexit.register(_uarray.clear_all_globals) def create_multimethod(*args, **kwargs): @@ -108,166 +109,18 @@ def generate_multimethod( See the module documentation for how to override the method by creating backends. """ kw_defaults, arg_defaults, opts = get_defaults(argument_extractor) + ua_func = _uarray.Function( + argument_extractor, + argument_replacer, + domain, + arg_defaults, + kw_defaults, + default, + ) - @functools.wraps(argument_extractor) - def inner(*args, **kwargs): - dispatchable_args = argument_extractor(*args, **kwargs) - errors = [] - - args = canonicalize_args(args, kwargs) - result = NotImplemented - - for options in _backend_order(domain): - res = ( - replace_dispatchables( - options.backend, - args, - kwargs, - dispatchable_args, - coerce=options.coerce, - ) - if hasattr(options.backend, "__ua_convert__") - else (args, kwargs) - ) - - if res is NotImplemented: - continue - - a, kw = res - - for k, v in kw_defaults.items(): - if k in kw and kw[k] is v: - del kw[k] - - result = options.backend.__ua_function__(inner, a, kw) - - if result is NotImplemented: - result = try_default(a, kw, options, errors) - - if result is not NotImplemented: - break - else: - result = try_default(args, kwargs, None, errors) - - if result is NotImplemented: - raise BackendNotImplementedError( - "No selected backends had an implementation for this function.", errors - ) - - return result - - def try_default(args, kwargs, options, errors): - if default is not None: - try: - if options is not None: - with set_backend(options.backend, only=True, coerce=options.coerce): - return default(*args, **kwargs) - else: - return default(*args, **kwargs) - except BackendNotImplementedError as e: - errors.append(e) - - return NotImplemented - - def replace_dispatchables( - backend, args, kwargs, dispatchable_args, coerce: Optional[bool] = False - ): - replaced_args: Iterable = backend.__ua_convert__(dispatchable_args, coerce) - - if replaced_args is NotImplemented: - return NotImplemented - - return argument_replacer(args, kwargs, tuple(replaced_args)) - - def canonicalize_args(args, kwargs): - if len(args) > len(arg_defaults): - return args - - match = 0 - for a, d in zip(args[::-1], arg_defaults[len(args) - 1 :: -1]): - if a is d: - match += 1 - else: - break - - args = args[:-match] if match > 0 else args - return args - - inner._coerce_args = replace_dispatchables # type: ignore - - return inner + return functools.update_wrapper(ua_func, argument_extractor) -class _BackendOptions: - def __init__(self, backend, coerce: bool = False, only: bool = False): - """ - The backend plus any additonal options associated with it. - - Parameters - ---------- - backend : Backend - The associated backend. - coerce: bool, optional - Whether or not the backend is being coerced. Implies ``only``. - only: bool, optional - Whether or not this is the only backend to try. - """ - self.backend = backend - self.coerce = coerce - self.only = only or coerce - - -_backends: Dict[str, ContextVar] = {} - - -def _backend_order(domain: str) -> Iterable[_BackendOptions]: - skip = _get_skipped_backends(domain).get() - pref = _get_preferred_backends(domain).get() - - for options in pref: - if options.backend not in skip: - yield options - - if options.only: - return - - if domain in _backends and _backends[domain] not in skip: - yield _BackendOptions(_backends[domain]) - - if domain in _registered_backend: - for backend in _registered_backend[domain]: - if backend not in skip: - yield _BackendOptions(backend) - - -def _get_preferred_backends(domain: str) -> ContextVar[Tuple[_BackendOptions, ...]]: - if domain not in _preferred_backend: - _preferred_backend[domain] = ContextVar( - f"_preferred_backend[{domain}]", default=() - ) - return _preferred_backend[domain] - - -def _get_registered_backends(domain: str) -> Set[_BackendOptions]: - if domain not in _registered_backend: - _registered_backend[domain] = set() - return _registered_backend[domain] - - -def _get_skipped_backends(domain: str) -> ContextVar[Set]: - if domain not in _skipped_backend: - _skipped_backend[domain] = ContextVar( - f"_skipped_backend[{domain}]", default=set() - ) - return _skipped_backend[domain] - - -_preferred_backend: Dict[str, ContextVar[Tuple[_BackendOptions, ...]]] = {} -_registered_backend: Dict[str, Set[_BackendOptions]] = {} -_skipped_backend: Dict[str, ContextVar[Set]] = {} - - -@contextlib.contextmanager def set_backend(backend, *args, **kwargs): """ A context manager that sets the preferred backend. Uses :obj:`BackendOptions` to create @@ -283,17 +136,9 @@ def set_backend(backend, *args, **kwargs): BackendOptions: The backend plus options. skip_backend: A context manager that allows skipping of backends. """ - options = _BackendOptions(backend, *args, **kwargs) - pref = _get_preferred_backends(backend.__ua_domain__) - token = pref.set((options,) + pref.get()) + return _uarray.SetBackendContext(backend, *args, **kwargs) - try: - yield - finally: - pref.reset(token) - -@contextlib.contextmanager def skip_backend(backend): """ A context manager that allows one to skip a given backend from processing @@ -309,15 +154,7 @@ def skip_backend(backend): -------- set_backend: A context manager that allows setting of backends. """ - skip = _get_skipped_backends(backend.__ua_domain__) - new = set(skip.get()) - new.add(backend) - token = skip.set(new) - - try: - yield - finally: - skip.reset(token) + return _uarray.SkipBackendContext(backend) def get_defaults(f): @@ -335,7 +172,7 @@ def get_defaults(f): arg_defaults.append(v.default) opts.add(k) - return kw_defaults, arg_defaults, opts + return kw_defaults, tuple(arg_defaults), opts def set_global_backend(backend): @@ -358,7 +195,7 @@ def set_global_backend(backend): backend The backend to register. """ - _backends[backend.__ua_domain__] = backend + _uarray.set_global_backend(backend) def register_backend(backend): @@ -374,7 +211,7 @@ def register_backend(backend): backend The backend to register. """ - _get_registered_backends(backend.__ua_domain__).add(backend) + _uarray.register_backend(backend) class Dispatchable: diff --git a/uarray/_uarray_dispatch.cxx b/uarray/_uarray_dispatch.cxx new file mode 100644 index 00000000..22bee396 --- /dev/null +++ b/uarray/_uarray_dispatch.cxx @@ -0,0 +1,1044 @@ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +/** Handle to a python object that automatically DECREFs */ +class py_ref +{ + explicit py_ref(PyObject * object): obj_(object) {} +public: + + py_ref() noexcept: obj_(nullptr) {} + py_ref(std::nullptr_t) noexcept: py_ref() {} + + py_ref(const py_ref & other) noexcept: obj_(other.obj_) { Py_XINCREF(obj_); } + py_ref(py_ref && other) noexcept: obj_(other.obj_) { other.obj_ = nullptr; } + + /** Construct from new reference (No INCREF) */ + static py_ref steal(PyObject * object) { return py_ref(object); } + + /** Construct from borrowed reference (and INCREF) */ + static py_ref ref(PyObject * object) + { + Py_XINCREF(object); + return py_ref(object); + } + + ~py_ref(){ Py_XDECREF(obj_); } + + py_ref & operator = (const py_ref & other) noexcept + { + py_ref(other).swap(*this); + return *this; + } + + py_ref & operator = (py_ref && other) noexcept + { + py_ref(std::move(other)).swap(*this); + return *this; + } + + void swap(py_ref & other) noexcept + { + std::swap(other.obj_, obj_); + } + + friend void swap(py_ref & a, py_ref & b) noexcept + { + a.swap(b); + } + + explicit operator bool () const { return obj_ != nullptr; } + + operator PyObject* () const { return get(); } + + PyObject * get() const { return obj_; } + PyObject * release() + { + PyObject * t = obj_; + obj_ = nullptr; + return t; + } + void reset() + { + Py_CLEAR(obj_); + } +private: + PyObject * obj_; +}; + +/** Make tuple from variadic set of PyObjects */ +template +py_ref py_make_tuple(Ts... args) +{ + using py_obj = PyObject *; + return py_ref::steal(PyTuple_Pack(sizeof...(args), py_obj{args}...)); +} + +struct global_backends +{ + py_ref global; + std::vector registered; +}; + + +struct backend_options +{ + py_ref backend; + bool coerce, only; + + bool operator == (const backend_options & other) const + { + return (backend == other.backend + && coerce == other.coerce + && only == other.only); + } +}; + +struct local_backends +{ + std::vector skipped; + std::vector preferred; +}; + + +static py_ref BackendNotImplementedError; +static std::unordered_map global_domain_map; +thread_local std::unordered_map< + std::string, local_backends> local_domain_map; + +/** Constant Python string identifiers + +Using these with PyObject_GetAttr is faster than PyObject_GetAttrString which +has to create a new python string internally. + */ +struct +{ + py_ref ua_convert; + py_ref ua_domain; + py_ref ua_function; + + bool init() + { + ua_convert = py_ref::steal(PyUnicode_InternFromString("__ua_convert__")); + if (!ua_convert) + return false; + + ua_domain = py_ref::steal(PyUnicode_InternFromString("__ua_domain__")); + if (!ua_domain) + return false; + + ua_function = py_ref::steal(PyUnicode_InternFromString("__ua_function__")); + if (!ua_function) + return false; + + return true; + } + + void clear() + { + ua_convert.reset(); + ua_domain.reset(); + ua_function.reset(); + } +} identifiers; + + +std::string domain_to_string(PyObject * domain) +{ + if (!PyUnicode_Check(domain)) + { + PyErr_SetString(PyExc_TypeError, "__ua_domain__ must be a string"); + return {}; + } + + Py_ssize_t size; + const char * str = PyUnicode_AsUTF8AndSize(domain, &size); + if (!str) + return {}; + + if (size == 0) + { + PyErr_SetString(PyExc_ValueError, "__ua_domain__ must be non-empty"); + return {}; + } + + return std::string(str, size); +} + +std::string backend_to_domain_string(PyObject * backend) +{ + auto domain = py_ref::steal( + PyObject_GetAttr(backend, identifiers.ua_domain)); + if (!domain) + return {}; + + return domain_to_string(domain); +} + + +/** Use to clean up python references before the interpreter is finalized. + * + * This must be installed in a python atexit handler. This prevents Py_DECREF + * being called after the interpreter has already shudown. + */ +PyObject * clear_all_globals(PyObject * /*self*/, PyObject * /*args*/) +{ + global_domain_map.clear(); + BackendNotImplementedError.reset(); + identifiers.clear(); + Py_RETURN_NONE; +} + + +PyObject * set_global_backend(PyObject * /*self*/, PyObject * args) +{ + PyObject * backend; + if (!PyArg_ParseTuple(args, "O", &backend)) + return nullptr; + + auto domain = backend_to_domain_string(backend); + if (domain.empty()) + return nullptr; + + global_domain_map[domain].global = py_ref::ref(backend); + Py_RETURN_NONE; +} + +PyObject * register_backend(PyObject * /*self*/, PyObject * args) +{ + PyObject * backend; + if (!PyArg_ParseTuple(args, "O", &backend)) + return nullptr; + + auto domain = backend_to_domain_string(backend); + if (domain.empty()) + return nullptr; + + global_domain_map[domain].registered.push_back(py_ref::ref(backend)); + Py_RETURN_NONE; +} + + +/** Common functionality of set_backend and skip_backend */ +template +class context_helper +{ + T new_backend_; + std::vector * backends_; + size_t enter_size_; +public: + + context_helper(): + backends_(nullptr), + enter_size_(size_t(-1)) + {} + + bool init(std::vector & backends, T new_backend) + { + backends_ = &backends; + new_backend_ = std::move(new_backend); + return true; + } + + bool enter() + { + enter_size_ = backends_->size(); + try { backends_->push_back(new_backend_); } + catch(std::bad_alloc&) + { + PyErr_NoMemory(); + return false; + } + return true; + } + + bool exit() + { + bool success = (enter_size_ + 1 == backends_->size() + && backends_->back() == new_backend_); + if (enter_size_ < backends_->size()) + backends_->resize(enter_size_); + + if (!success) + PyErr_SetString(PyExc_RuntimeError, + "Found invalid context state while in __exit__"); + return success; + } +}; + + +struct SetBackendContext +{ + PyObject_HEAD + + context_helper ctx_; + + static void dealloc(SetBackendContext * self) + { + self->~SetBackendContext(); + Py_TYPE(self)->tp_free(self); + } + + static PyObject * new_(PyTypeObject * type, PyObject * args, PyObject * kwargs) + { + auto self = reinterpret_cast(type->tp_alloc(type, 0)); + if (self == nullptr) + return nullptr; + + // Placement new + self = new (self) SetBackendContext; + return reinterpret_cast(self); + } + + static int init( + SetBackendContext * self, PyObject * args, PyObject * kwargs) + { + static const char * kwlist[] = {"backend", "coerce", "only", nullptr}; + PyObject * backend = nullptr; + PyObject * coerce = nullptr; + PyObject * only = nullptr; + + if (!PyArg_ParseTupleAndKeywords( + args, kwargs, + "O|O!O!", (char**)kwlist, + &backend, + &PyBool_Type, &coerce, + &PyBool_Type, &only)) + return -1; + + + auto domain = backend_to_domain_string(backend); + if (domain.empty()) + return -1; + backend_options opt; + opt.backend = py_ref::ref(backend); + opt.coerce = (coerce == Py_True); + opt.only = (only == Py_True); + + try + { + if (!self->ctx_.init(local_domain_map[domain].preferred, std::move(opt))) + return -1; + } + catch(std::bad_alloc&) + { + PyErr_NoMemory(); + return -1; + } + + return 0; + } + + static PyObject * enter__(SetBackendContext * self, PyObject * /*args*/) + { + if (!self->ctx_.enter()) + return nullptr; + Py_RETURN_NONE; + } + + static PyObject * exit__(SetBackendContext * self, PyObject * /*args*/) + { + if (!self->ctx_.exit()) + return nullptr; + Py_RETURN_NONE; + } +}; + + +struct SkipBackendContext +{ + PyObject_HEAD + + context_helper ctx_; + + static void dealloc(SkipBackendContext * self) + { + self->~SkipBackendContext(); + Py_TYPE(self)->tp_free(self); + } + + static PyObject * new_(PyTypeObject * type, PyObject * args, PyObject * kwargs) + { + auto self = reinterpret_cast(type->tp_alloc(type, 0)); + if (self == nullptr) + return nullptr; + + // Placement new + self = new (self) SkipBackendContext; + return reinterpret_cast(self); + } + + static int init( + SkipBackendContext * self, PyObject * args, PyObject * kwargs) + { + static const char *kwlist[] = {"backend", nullptr}; + PyObject * backend; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, + "O", (char**)kwlist, + &backend)) + return -1; + + auto domain = backend_to_domain_string(backend); + if (domain.empty()) + return -1; + + try + { + if (!self->ctx_.init( + local_domain_map[domain].skipped, py_ref::ref(backend))) + return -1; + } + catch(std::bad_alloc&) + { + PyErr_NoMemory(); + return -1; + } + + return 0; + } + + static PyObject * enter__(SkipBackendContext * self, PyObject * /*args*/) + { + if (!self->ctx_.enter()) + return nullptr; + Py_RETURN_NONE; + } + + static PyObject * exit__(SkipBackendContext * self, PyObject * /*args*/) + { + if (!self->ctx_.exit()) + return nullptr; + Py_RETURN_NONE; + } +}; + +enum class LoopReturn { Continue, Break, Error }; + +template +LoopReturn for_each_backend(const std::string & domain_key, Callback call) +{ + local_backends * locals = nullptr; + try + { + locals = &local_domain_map[domain_key]; + } + catch (std::bad_alloc&) + { + PyErr_NoMemory(); + return LoopReturn::Error; + } + + + auto & skip = locals->skipped; + auto & pref = locals->preferred; + + auto should_skip = + [&](PyObject * backend) + { + auto it = std::find_if( + skip.begin(), skip.end(), + [&](const py_ref & be) { return be.get() == backend; }); + + return (it != skip.end()); + }; + + LoopReturn ret = LoopReturn::Continue; + for (int i = pref.size()-1; i >= 0; --i) + { + auto options = pref[i]; + if (should_skip(options.backend)) + continue; + + ret = call(options.backend.get(), options.coerce); + if (ret != LoopReturn::Continue) + return ret; + + if (options.only || options.coerce) + return ret; + } + + auto & globals = global_domain_map[domain_key]; + + if (globals.global && !should_skip(globals.global)) + { + ret = call(globals.global.get(), false); + if (ret != LoopReturn::Continue) + return ret; + } + + for (size_t i = 0; i < globals.registered.size(); ++i) + { + py_ref backend = globals.registered[i]; + if (should_skip(backend)) + continue; + + ret = call(backend.get(), false); + if (ret != LoopReturn::Continue) + return ret; + } + return ret; +} + +struct py_func_args { py_ref args, kwargs; }; + +struct Function +{ + PyObject_HEAD + py_ref extractor_, replacer_; // functions to handle dispatchables + std::string domain_key_; // associated __ua_domain__ in UTF8 + py_ref def_args_, def_kwargs_; // default arguments + py_ref def_impl_; // default implementation + py_ref dict_; // __dict__ + + PyObject * call(PyObject * args, PyObject * kwargs); + + py_func_args replace_dispatchables( + PyObject * backend, PyObject * args, PyObject * kwargs, PyObject * coerce); + + py_ref canonicalize_args(PyObject * args); + py_ref canonicalize_kwargs(PyObject * kwargs); + + static void dealloc(Function * self) + { + PyObject_GC_UnTrack(self); + self->~Function(); + Py_TYPE(self)->tp_free(self); + } + + static PyObject * new_(PyTypeObject * type, PyObject * args, PyObject * kwargs) + { + auto self = reinterpret_cast(type->tp_alloc(type, 0)); + if (self == nullptr) + return nullptr; + + // Placement new + self = new (self) Function; + return reinterpret_cast(self); + } + + static int init(Function * self, PyObject * args, PyObject * /*kwargs*/) + { + PyObject * extractor, * replacer; + PyObject * domain; + PyObject * def_args, * def_kwargs; + PyObject * def_impl; + + if (!PyArg_ParseTuple( + args, "OOO!O!O!O", + &extractor, + &replacer, + &PyUnicode_Type, &domain, + &PyTuple_Type, &def_args, + &PyDict_Type, &def_kwargs, + &def_impl)) + { + return -1; + } + + if (!PyCallable_Check(extractor) + || (replacer != Py_None && !PyCallable_Check(replacer))) + { + PyErr_SetString(PyExc_TypeError, + "Argument extractor and replacer must be callable"); + return -1; + } + + if (def_impl != Py_None && !PyCallable_Check(def_impl)) + { + PyErr_SetString(PyExc_TypeError, + "Default implementation must be Callable or None"); + return -1; + } + + self->domain_key_ = domain_to_string(domain); + if (PyErr_Occurred()) + return -1; + + self->extractor_ = py_ref::ref(extractor); + self->replacer_ = py_ref::ref(replacer); + self->def_args_ = py_ref::ref(def_args); + self->def_kwargs_ = py_ref::ref(def_kwargs); + self->def_impl_ = py_ref::ref(def_impl); + + return 0; + } +}; + + +bool is_default(PyObject * value, PyObject * def) +{ + // TODO: richer comparison for builtin types? (if cheap) + return (value == def); +} + + +py_ref Function::canonicalize_args(PyObject * args) +{ + const auto arg_size = PyTuple_GET_SIZE(args); + const auto def_size = PyTuple_GET_SIZE(def_args_.get()); + + if (arg_size > def_size) + return py_ref::ref(args); + + Py_ssize_t mismatch = 0; + for (Py_ssize_t i = arg_size - 1; i >= 0; --i) + { + auto val = PyTuple_GET_ITEM(args, i); + auto def = PyTuple_GET_ITEM(def_args_.get(), i); + if (!is_default(val, def)) + { + mismatch = i + 1; + break; + } + } + + return py_ref::steal(PyTuple_GetSlice(args, 0, mismatch)); +} + + +py_ref Function::canonicalize_kwargs(PyObject * kwargs) +{ + if (kwargs == nullptr) + return py_ref::steal(PyDict_New()); + + PyObject * key, * def_value; + Py_ssize_t pos = 0; + while (PyDict_Next(def_kwargs_, &pos, &key, &def_value)) + { + auto val = PyDict_GetItem(kwargs, key); + if (val && is_default(val, def_value)) + { + PyDict_DelItem(kwargs, key); + } + } + return py_ref::ref(kwargs); +} + + +py_func_args Function::replace_dispatchables( + PyObject * backend, PyObject * args, PyObject * kwargs, PyObject * coerce) +{ + auto ua_convert = py_ref::steal( + PyObject_GetAttr(backend, identifiers.ua_convert)); + + if (!ua_convert) + { + PyErr_Clear(); + return {py_ref::ref(args), py_ref::ref(kwargs)}; + } + + auto dispatchables = py_ref::steal(PyObject_Call(extractor_, args, kwargs)); + if (!dispatchables) + return {}; + + auto convert_args = py_make_tuple(dispatchables, coerce); + auto res = py_ref::steal(PyObject_Call(ua_convert, convert_args, nullptr)); + if (!res) + { + return {}; + } + + if (res == Py_NotImplemented) + { + return {std::move(res), nullptr}; + } + + auto replaced_args = py_ref::steal(PySequence_Tuple(res)); + if (!replaced_args) + return {}; + + auto replacer_args = py_make_tuple(args, kwargs, replaced_args); + if (!replacer_args) + return {}; + + res = py_ref::steal(PyObject_Call(replacer_, replacer_args, nullptr)); + if (!res) + return {}; + + if (!PyTuple_Check(res) || PyTuple_Size(res) != 2) + { + PyErr_SetString(PyExc_TypeError, + "Argument replacer must return a 2-tuple (args, kwargs)"); + return {}; + } + + auto new_args = py_ref::ref(PyTuple_GET_ITEM(res.get(), 0)); + auto new_kwargs = py_ref::ref(PyTuple_GET_ITEM(res.get(), 1)); + + new_kwargs = canonicalize_kwargs(new_kwargs); + + if (!PyTuple_Check(new_args) || !PyDict_Check(new_kwargs)) + { + PyErr_SetString(PyExc_ValueError, "Invalid return from argument_replacer"); + return {}; + } + + return {std::move(new_args), std::move(new_kwargs)}; +} + + +PyObject * Function_call( + Function * self, PyObject * args, PyObject * kwargs) +{ + return self->call(args, kwargs); +} + + +PyObject * Function::call(PyObject * args_, PyObject * kwargs_) +{ + auto args = canonicalize_args(args_); + auto kwargs = canonicalize_kwargs(kwargs_); + + py_ref result; + + auto ret = for_each_backend( + domain_key_, + [&, this](PyObject * backend, bool coerce) + { + auto new_args = replace_dispatchables(backend, args, kwargs, + coerce ? Py_True : Py_False); + if (new_args.args == Py_NotImplemented) + return LoopReturn::Continue; + if (new_args.args == nullptr) + return LoopReturn::Error; + + auto ua_function = py_ref::steal( + PyObject_GetAttr(backend, identifiers.ua_function)); + if (!ua_function) + return LoopReturn::Error; + + auto ua_func_args = py_make_tuple( + reinterpret_cast(this), new_args.args, new_args.kwargs); + if (!ua_func_args) + return LoopReturn::Error; + + result = py_ref::steal( + PyObject_Call(ua_function, ua_func_args, nullptr)); + + // Try the default with this backend + if (result == Py_NotImplemented && def_impl_ != Py_None) + { + backend_options opt; + opt.backend = py_ref::ref(backend); + opt.coerce = coerce; + opt.only = true; + context_helper ctx; + try + { + if (!ctx.init( + local_domain_map[domain_key_].preferred, std::move(opt))) + return LoopReturn::Error; + } + catch(std::bad_alloc&) + { + PyErr_NoMemory(); + return LoopReturn::Error; + } + + if (!ctx.enter()) + return LoopReturn::Error; + + result = py_ref::steal( + PyObject_Call(def_impl_, new_args.args, new_args.kwargs)); + + if (PyErr_Occurred() && PyErr_ExceptionMatches(BackendNotImplementedError)) + { + PyErr_Clear(); // Suppress exception + result = py_ref::ref(Py_NotImplemented); + } + + if (!ctx.exit()) + return LoopReturn::Error; + } + + if (!result) + return LoopReturn::Error; + + if (result == Py_NotImplemented) + return LoopReturn::Continue; + + return LoopReturn::Break; // Backend called successfully + } + ); + + if (ret != LoopReturn::Continue) + return result.release(); + + if (def_impl_ == Py_None) + { + PyErr_SetString( + BackendNotImplementedError, + "No selected backends had an implementation for this function."); + return nullptr; + } + + return PyObject_Call(def_impl_, args, kwargs); +} + + +PyObject * Function_repr(Function * self) +{ + if (self->dict_) + if (auto name = PyDict_GetItemString(self->dict_, "__name__")) + return PyUnicode_FromFormat("", name); + + return PyUnicode_FromString(""); +} + + +/** Implements the descriptor protocol to allow binding to class instances */ +PyObject * Function_descr_get(PyObject * self, PyObject * obj, PyObject * type) +{ + if (!obj) + { + Py_INCREF(self); + return self; + } + + return PyMethod_New(self, obj); +} + + +/** Make members visible to the garbage collector */ +int Function_traverse(Function * self, visitproc visit, void * arg) +{ + Py_VISIT(self->extractor_); + Py_VISIT(self->replacer_); + Py_VISIT(self->def_args_); + Py_VISIT(self->def_kwargs_); + Py_VISIT(self->def_impl_); + Py_VISIT(self->dict_); + return 0; +} + + +PyGetSetDef Function_getset[] = +{ + {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict}, + {NULL} /* Sentinel */ +}; + +PyTypeObject FunctionType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_uarray.Function", /* tp_name */ + sizeof(Function), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)Function::dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + (reprfunc)Function_repr, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + (ternaryfunc)Function_call, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + PyObject_GenericSetAttr, /* tp_setattro */ + 0, /* tp_as_buffer */ + (Py_TPFLAGS_DEFAULT + | Py_TPFLAGS_HAVE_GC), /* tp_flags */ + 0, /* tp_doc */ + (traverseproc)Function_traverse,/* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + Function_getset, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + Function_descr_get, /* tp_descr_get */ + 0, /* tp_descr_set */ + offsetof(Function, dict_), /* tp_dictoffset */ + (initproc)Function::init, /* tp_init */ + 0, /* tp_alloc */ + Function::new_, /* tp_new */ +}; + + +PyMethodDef SetBackendContext_Methods[] = { + {"__enter__", (binaryfunc)SetBackendContext::enter__, METH_NOARGS, nullptr}, + {"__exit__", (binaryfunc)SetBackendContext::exit__, METH_VARARGS, nullptr}, + {NULL} /* Sentinel */ +}; + +PyTypeObject SetBackendContextType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_uarray.SetBackendContext", /* tp_name */ + sizeof(SetBackendContext), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)SetBackendContext::dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + 0, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + SetBackendContext_Methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)SetBackendContext::init, /* tp_init */ + 0, /* tp_alloc */ + SetBackendContext::new_, /* tp_new */ +}; + + +PyMethodDef SkipBackendContext_Methods[] = { + {"__enter__", (binaryfunc)SkipBackendContext::enter__, METH_NOARGS, nullptr}, + {"__exit__", (binaryfunc)SkipBackendContext::exit__, METH_VARARGS, nullptr}, + {NULL} /* Sentinel */ +}; + +PyTypeObject SkipBackendContextType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_uarray.SkipBackendContext", /* tp_name */ + sizeof(SkipBackendContext), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)SkipBackendContext::dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + 0, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + SkipBackendContext_Methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)SkipBackendContext::init, /* tp_init */ + 0, /* tp_alloc */ + SkipBackendContext::new_, /* tp_new */ +}; + + +PyObject * dummy(PyObject * /*self*/, PyObject * args) +{ + Py_RETURN_NONE; +} + + +PyMethodDef method_defs[] = +{ + {"set_global_backend", set_global_backend, METH_VARARGS, nullptr}, + {"register_backend", register_backend, METH_VARARGS, nullptr}, + {"clear_all_globals", clear_all_globals, METH_NOARGS, nullptr}, + {"dummy", dummy, METH_VARARGS, nullptr}, + {NULL} /* Sentinel */ +}; + +PyModuleDef uarray_module = +{ + PyModuleDef_HEAD_INIT, + "_uarray", + nullptr, + -1, + method_defs, +}; + +} // namespace (anonymous) + + +#if defined(WIN32) || defined(_WIN32) +# define MODULE_EXPORT __declspec(dllexport) +#else +# define MODULE_EXPORT __attribute__ ((visibility("default"))) +#endif + +extern "C" MODULE_EXPORT PyObject * +PyInit__uarray(void) +{ + + auto m = py_ref::steal(PyModule_Create(&uarray_module)); + if (!m) + return nullptr; + + if (PyType_Ready(&FunctionType) < 0) + return nullptr; + Py_INCREF(&FunctionType); + PyModule_AddObject(m, "Function", (PyObject *)&FunctionType); + + if (PyType_Ready(&SetBackendContextType) < 0) + return nullptr; + Py_INCREF(&SetBackendContextType); + PyModule_AddObject(m, "SetBackendContext", (PyObject*)&SetBackendContextType); + + if (PyType_Ready(&SkipBackendContextType) < 0) + return nullptr; + Py_INCREF(&SkipBackendContextType); + PyModule_AddObject( + m, "SkipBackendContext", (PyObject*)&SkipBackendContextType); + + BackendNotImplementedError = py_ref::steal( + PyErr_NewExceptionWithDoc( + "uarray.BackendNotImplementedError", + "An exception that is thrown when no compatible" + " backend is found for a method.", + PyExc_NotImplementedError, + nullptr)); + if (!BackendNotImplementedError) + return nullptr; + Py_INCREF(BackendNotImplementedError.get()); + PyModule_AddObject( + m, "BackendNotImplementedError", BackendNotImplementedError); + + if (!identifiers.init()) + return nullptr; + + return m.release(); +}