Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Implement multimethod call in the Python C API #170

Merged
merged 29 commits into from Jul 6, 2019
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9a25e33
ENH: Implement multimethod call in the Python C API
peterbell10 Jun 24, 2019
fba2d39
Add extension to setup.py
peterbell10 Jun 24, 2019
822f0ee
language=c++11
peterbell10 Jun 25, 2019
bfcb7f7
Missing header
peterbell10 Jun 25, 2019
22e639f
Use member functions and canonicalize replaced kwargs
peterbell10 Jun 25, 2019
399430c
Propagate errors from replace_dispatchables
peterbell10 Jun 26, 2019
c7bac6f
Remove debug printing
peterbell10 Jun 26, 2019
5613a92
Tell mypy to ignore _uarray
peterbell10 Jun 26, 2019
5d9d033
Implement descriptor protocol
peterbell10 Jun 27, 2019
ccfb461
Fix descriptor get refcounting
peterbell10 Jun 27, 2019
8e13be9
Implement cyclic GC interface
peterbell10 Jun 27, 2019
f7e1b38
Fix module init code
peterbell10 Jun 28, 2019
539ece5
Better debug output from azure pipelines
peterbell10 Jun 28, 2019
6b54e04
Downgrade dask
peterbell10 Jun 28, 2019
b44fe57
Revert "Better debug output from azure pipelines"
peterbell10 Jun 29, 2019
02a7fdc
Allow argument_replacer to be None
peterbell10 Jul 1, 2019
2e8e9de
Add wrapper around PyContextVar
peterbell10 Jul 2, 2019
8c7d58e
WIP: Store current backends in C++
peterbell10 Jul 3, 2019
b843ddd
STY: Purge tabs from c++ code
peterbell10 Jul 3, 2019
bc32409
Improve error propagation
peterbell10 Jul 4, 2019
1da4d45
Fix skipping global backend
peterbell10 Jul 4, 2019
bdc71ce
Mark XND as xfail
peterbell10 Jul 4, 2019
7be7d78
Revert "Mark XND as xfail"
peterbell10 Jul 4, 2019
da4fa1c
Fix backend order of preferrence
peterbell10 Jul 4, 2019
2389d3b
Decref globals before reaching shutdown
peterbell10 Jul 4, 2019
291668c
Better handling for defaults in terms of other multimethods.
peterbell10 Jul 4, 2019
a71ab7d
Fix iterator invalidation bug
peterbell10 Jul 4, 2019
e737290
Remove contextvar wrapper
peterbell10 Jul 5, 2019
8d8d1ea
Intern protocol identifiers at module init time, avoid GetAttrString
peterbell10 Jul 5, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .conda/environment.yml
Expand Up @@ -15,7 +15,7 @@ dependencies:
- pytorch-cpu
- scipy
- gumath
- dask
- dask=1.2
- sparse
- doc8
- black
Expand Down
27 changes: 25 additions & 2 deletions setup.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python

from setuptools import setup, find_packages
from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
import versioneer
from pathlib import Path
import sys
Expand Down Expand Up @@ -43,10 +44,31 @@ def parse_requires():
with open("README.md") as f:
long_desc = f.read()

class build_cpp11_ext(build_ext):
def build_extension(self, ext):
cc = self.compiler
if cc.compiler_type == "unix":
ext.extra_compile_args.append("--std=c++11")
build_ext.build_extension(self, ext)


cmdclass = {"build_ext": build_cpp11_ext}
cmdclass.update(versioneer.get_cmdclass())


extensions = [
Extension(
"uarray._uarray",
sources=["uarray/_uarray_dispatch.cxx"],
depends=["uarray/_python_support.h"],
language="c++",
)
]

setup(
name="uarray",
version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass(),
cmdclass=cmdclass,
description="Array interface object for Python with pluggable backends and a multiple-dispatch"
"mechanism for defining down-stream functions",
url="https://github.com/Quansight-Labs/uarray/",
Expand Down Expand Up @@ -79,4 +101,5 @@ def parse_requires():
"Tracker": "https://github.com/Quansight-Labs/uarray/issues",
},
python_requires=">=3.5, <4",
ext_modules=extensions
)
199 changes: 16 additions & 183 deletions uarray/_backend.py
Expand Up @@ -14,15 +14,12 @@
from contextvars import ContextVar
import functools
import contextlib
from . import _uarray # type: ignore

ArgumentExtractorType = Callable[..., Tuple["Dispatchable", ...]]
ArgumentReplacerType = Callable[[Tuple, Dict, Tuple], Tuple[Tuple, Dict]]


class BackendNotImplementedError(NotImplementedError):
"""
An exception that is thrown when no compatible backend is found for a method.
"""
from ._uarray import BackendNotImplementedError


def create_multimethod(*args, **kwargs):
Expand Down Expand Up @@ -108,166 +105,18 @@ def generate_multimethod(
See the module documentation for how to override the method by creating backends.
"""
kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
ua_func = _uarray.Function(
argument_extractor,
argument_replacer,
domain,
arg_defaults,
kw_defaults,
default,
)

@functools.wraps(argument_extractor)
def inner(*args, **kwargs):
dispatchable_args = argument_extractor(*args, **kwargs)
errors = []

args = canonicalize_args(args, kwargs)
result = NotImplemented

for options in _backend_order(domain):
res = (
replace_dispatchables(
options.backend,
args,
kwargs,
dispatchable_args,
coerce=options.coerce,
)
if hasattr(options.backend, "__ua_convert__")
else (args, kwargs)
)

if res is NotImplemented:
continue

a, kw = res

for k, v in kw_defaults.items():
if k in kw and kw[k] is v:
del kw[k]

result = options.backend.__ua_function__(inner, a, kw)

if result is NotImplemented:
result = try_default(a, kw, options, errors)

if result is not NotImplemented:
break
else:
result = try_default(args, kwargs, None, errors)

if result is NotImplemented:
raise BackendNotImplementedError(
"No selected backends had an implementation for this function.", errors
)

return result

def try_default(args, kwargs, options, errors):
if default is not None:
try:
if options is not None:
with set_backend(options.backend, only=True, coerce=options.coerce):
return default(*args, **kwargs)
else:
return default(*args, **kwargs)
except BackendNotImplementedError as e:
errors.append(e)

return NotImplemented

def replace_dispatchables(
backend, args, kwargs, dispatchable_args, coerce: Optional[bool] = False
):
replaced_args: Iterable = backend.__ua_convert__(dispatchable_args, coerce)

if replaced_args is NotImplemented:
return NotImplemented

return argument_replacer(args, kwargs, tuple(replaced_args))

def canonicalize_args(args, kwargs):
if len(args) > len(arg_defaults):
return args

match = 0
for a, d in zip(args[::-1], arg_defaults[len(args) - 1 :: -1]):
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
if a is d:
match += 1
else:
break

args = args[:-match] if match > 0 else args
return args

inner._coerce_args = replace_dispatchables # type: ignore

return inner
return functools.update_wrapper(ua_func, argument_extractor)


class _BackendOptions:
def __init__(self, backend, coerce: bool = False, only: bool = False):
"""
The backend plus any additonal options associated with it.

Parameters
----------
backend : Backend
The associated backend.
coerce: bool, optional
Whether or not the backend is being coerced. Implies ``only``.
only: bool, optional
Whether or not this is the only backend to try.
"""
self.backend = backend
self.coerce = coerce
self.only = only or coerce


_backends: Dict[str, ContextVar] = {}


def _backend_order(domain: str) -> Iterable[_BackendOptions]:
skip = _get_skipped_backends(domain).get()
pref = _get_preferred_backends(domain).get()

for options in pref:
if options.backend not in skip:
yield options

if options.only:
return

if domain in _backends and _backends[domain] not in skip:
yield _BackendOptions(_backends[domain])

if domain in _registered_backend:
for backend in _registered_backend[domain]:
if backend not in skip:
yield _BackendOptions(backend)


def _get_preferred_backends(domain: str) -> ContextVar[Tuple[_BackendOptions, ...]]:
if domain not in _preferred_backend:
_preferred_backend[domain] = ContextVar(
f"_preferred_backend[{domain}]", default=()
)
return _preferred_backend[domain]


def _get_registered_backends(domain: str) -> Set[_BackendOptions]:
if domain not in _registered_backend:
_registered_backend[domain] = set()
return _registered_backend[domain]


def _get_skipped_backends(domain: str) -> ContextVar[Set]:
if domain not in _skipped_backend:
_skipped_backend[domain] = ContextVar(
f"_skipped_backend[{domain}]", default=set()
)
return _skipped_backend[domain]


_preferred_backend: Dict[str, ContextVar[Tuple[_BackendOptions, ...]]] = {}
_registered_backend: Dict[str, Set[_BackendOptions]] = {}
_skipped_backend: Dict[str, ContextVar[Set]] = {}


@contextlib.contextmanager
def set_backend(backend, *args, **kwargs):
"""
A context manager that sets the preferred backend. Uses :obj:`BackendOptions` to create
Expand All @@ -283,17 +132,9 @@ def set_backend(backend, *args, **kwargs):
BackendOptions: The backend plus options.
skip_backend: A context manager that allows skipping of backends.
"""
options = _BackendOptions(backend, *args, **kwargs)
pref = _get_preferred_backends(backend.__ua_domain__)
token = pref.set((options,) + pref.get())
return _uarray.SetBackendContext(backend, *args, **kwargs)

try:
yield
finally:
pref.reset(token)


@contextlib.contextmanager
def skip_backend(backend):
"""
A context manager that allows one to skip a given backend from processing
Expand All @@ -309,15 +150,7 @@ def skip_backend(backend):
--------
set_backend: A context manager that allows setting of backends.
"""
skip = _get_skipped_backends(backend.__ua_domain__)
new = set(skip.get())
new.add(backend)
token = skip.set(new)

try:
yield
finally:
skip.reset(token)
return _uarray.SkipBackendContext(backend)


def get_defaults(f):
Expand All @@ -335,7 +168,7 @@ def get_defaults(f):
arg_defaults.append(v.default)
opts.add(k)

return kw_defaults, arg_defaults, opts
return kw_defaults, tuple(arg_defaults), opts


def set_global_backend(backend):
Expand All @@ -358,7 +191,7 @@ def set_global_backend(backend):
backend
The backend to register.
"""
_backends[backend.__ua_domain__] = backend
_uarray.set_global_backend(backend)


def register_backend(backend):
Expand All @@ -374,7 +207,7 @@ def register_backend(backend):
backend
The backend to register.
"""
_get_registered_backends(backend.__ua_domain__).add(backend)
_uarray.register_backend(backend)


class Dispatchable:
Expand Down