Skip to content

Commit

Permalink
Merge pull request #137 from Quansight-Labs/conversion-invert
Browse files Browse the repository at this point in the history
Invert the conversion of values
  • Loading branch information
hameerabbasi committed Apr 26, 2019
2 parents 7c8a060 + 24107e6 commit f5b7d39
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 85 deletions.
6 changes: 0 additions & 6 deletions docs/generated/uarray.Backend.register_convertor.rst

This file was deleted.

1 change: 0 additions & 1 deletion docs/generated/uarray.Backend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ uarray.Backend
.. autosummary::
:toctree:

Backend.register_convertor
Backend.register_implementation
Backend.replace_dispatchables
Backend.try_backend
Expand Down
6 changes: 6 additions & 0 deletions docs/generated/uarray.DispatchableInstance.__init__.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DispatchableInstance.\_\_init\_\_
=================================

.. currentmodule:: uarray

.. automethod:: DispatchableInstance.__init__
6 changes: 6 additions & 0 deletions docs/generated/uarray.DispatchableInstance.convert.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DispatchableInstance.convert
============================

.. currentmodule:: uarray

.. automethod:: DispatchableInstance.convert
6 changes: 6 additions & 0 deletions docs/generated/uarray.DispatchableInstance.convertors.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DispatchableInstance.convertors
===============================

.. currentmodule:: uarray

.. autoattribute:: DispatchableInstance.convertors
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DispatchableInstance.register\_convertor
========================================

.. currentmodule:: uarray

.. automethod:: DispatchableInstance.register_convertor
32 changes: 32 additions & 0 deletions docs/generated/uarray.DispatchableInstance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
DispatchableInstance
====================

.. currentmodule:: uarray

.. autoclass:: DispatchableInstance



.. rubric:: Attributes
.. autosummary::
:toctree:

DispatchableInstance.convertors






.. rubric:: Methods
.. autosummary::
:toctree:

DispatchableInstance.__init__

DispatchableInstance.convert

DispatchableInstance.register_convertor



2 changes: 2 additions & 0 deletions docs/generated/uarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

BackendOptions

DispatchableInstance

.. rubric:: Functions
.. autosummary::
:toctree:
Expand Down
145 changes: 70 additions & 75 deletions uarray/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ class Backend:

def __init__(self):
self._implementations: MethodLookupType = {}
self._convertors: TypeLookupType = {}

def register_implementation(self, method: MultiMethod, implementation: ImplementationType):
"""
Expand Down Expand Up @@ -267,58 +266,6 @@ def register_implementation(self, method: MultiMethod, implementation: Implement

self._implementations[method] = implementation

def register_convertor(self, dispatch_type: Type["DispatchableInstance"], convertor: ConvertorType):
"""
Registers a convertor for a given type.
The convertor takes in a single value and converts it to a form suitable for consumption by
the backend's implementations. It's called when the user coerces the type and the backend
has also registered the type of the dispatchable.
Parameters
----------
dispatch_type : Type["DispatchableInstance"]
The type of dispatchable to register the convertor for. The convertor will convert the
instance if coercion is enabled.
implementation : ImplementationType
The implementation of this method. It takes in a single value and converts it.
Raises
------
ValueError
If there is already a convertor for this type.
Examples
--------
>>> import uarray as ua
>>> class DispatchableInt(ua.DispatchableInstance):
... pass
>>> be = ua.Backend()
>>> # All ints piped to -2
>>> be.register_convertor(DispatchableInt, lambda x: -2)
>>> def potato_rd(args, kwargs, dispatch_args):
... # This replaces a within the args/kwargs
... return dispatch_args + args[1:], kwargs
>>> @ua.create_multimethod(potato_rd)
... def potato(a, b):
... # Here, we register a as dispatchable and mark it as an int
... return (DispatchableInt(a),)
>>> @ua.register_implementation(potato, be)
... def potato_impl(a, b):
... return a, b
>>> with ua.set_backend(be, coerce=True):
... potato(1, '2')
(-2, '2')
>>> be.register_convertor(DispatchableInt, lambda x: -2)
Traceback (most recent call last):
...
ValueError: ...
"""
if dispatch_type in self._convertors:
raise ValueError('Cannot register a different convertor once one is already registered.')

self._convertors[dispatch_type] = convertor

def replace_dispatchables(self, method: MultiMethod, args, kwargs, coerce: Optional[bool] = False):
"""
Replace dispatchables for a this method, using the convertor, if coercion is used.
Expand All @@ -341,34 +288,17 @@ def replace_dispatchables(self, method: MultiMethod, args, kwargs, coerce: Optio
replaced_args: List = []
filtered_dispatchable_args: List = []
for arg in dispatchable_args:
replaced_arg = self._replace_single(arg, coerce=coerce)
replaced_arg = arg.convert(self, coerce) if isinstance(arg, DispatchableInstance) else arg
replaced_args.append(replaced_arg)

if not isinstance(arg, DispatchableInstance):
filtered_dispatchable_args.append(replaced_arg)
elif type(arg) in self._convertors:
filtered_dispatchable_args.append(arg)
elif self in type(arg).convertors:
filtered_dispatchable_args.append(type(arg)(replaced_arg))

args, kwargs = method.argument_replacer(args, kwargs, tuple(replaced_args))
return args, kwargs, filtered_dispatchable_args

def _replace_single(self, arg: Union["DispatchableInstance", Any],
coerce: Optional[bool] = False):
if not isinstance(arg, DispatchableInstance):
return arg

arg_type = type(arg)

if coerce:
if arg.value is None:
return None

for try_type in arg_type.__mro__:
if try_type in self._convertors:
return self._convertors[try_type](arg.value)

return arg.value

def try_backend(self, method: MultiMethod, args: Tuple, kwargs: Dict, coerce: bool):
"""
Try this backend for a given args and kwargs. Returns either a
Expand Down Expand Up @@ -579,7 +509,7 @@ class DispatchableInstance:
... pass
>>> be = ua.Backend()
>>> # All ints piped to -2
>>> be.register_convertor(DispatchableInt, lambda x: -2)
>>> DispatchableInt.register_convertor(be, lambda x: -2)
>>> def potato_rd(args, kwargs, dispatch_args):
... # This replaces a within the args/kwargs
... return dispatch_args + args[1:], kwargs
Expand All @@ -594,6 +524,11 @@ class DispatchableInstance:
... potato(1, '2')
(-2, '2')
"""
convertors: Dict[Backend, ConvertorType] = {}

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.convertors = cls.convertors.copy()

def __init__(self, value: Any):
if type(self) is DispatchableInstance:
Expand All @@ -602,6 +537,66 @@ def __init__(self, value: Any):

self.value = value

@classmethod
def register_convertor(cls, backend: Backend, convertor: ConvertorType):
"""
Registers a convertor for a given type.
The convertor takes in a single value and converts it to a form suitable for consumption by
the backend's implementations. It's called when the user coerces the type and the backend
has also registered the type of the dispatchable.
Parameters
----------
dispatch_type : Type["DispatchableInstance"]
The type of dispatchable to register the convertor for. The convertor will convert the
instance if coercion is enabled.
implementation : ImplementationType
The implementation of this method. It takes in a single value and converts it.
Raises
------
ValueError
If there is already a convertor for this type.
Examples
--------
>>> import uarray as ua
>>> class DispatchableInt(ua.DispatchableInstance):
... pass
>>> be = ua.Backend()
>>> DispatchableInt.register_convertor(be, lambda x: -2)
>>> DispatchableInt(2).convert(be, coerce=True)
-2
>>> be2 = ua.Backend()
>>> DispatchableInt.register_convertor(be2, lambda x: 3)
>>> DispatchableInt(2).convert(be2, coerce=True)
3
>>> DispatchableInt.register_convertor(be, lambda x: -2)
Traceback (most recent call last):
...
ValueError: ...
"""
if backend in cls.convertors:
raise ValueError('Cannot register a different convertor once one is already registered.')

cls.convertors[backend] = convertor

def convert(self, backend: Backend, coerce: Optional[bool] = False):
"""
Convert a single argument using the given backend.
"""
cls = type(self)

if coerce:
if self.value is None:
return None

if backend in cls.convertors:
return cls.convertors[backend](self.value)

return self.value


def all_of_type(arg_type: Type[DispatchableInstance]):
"""
Expand All @@ -614,7 +609,7 @@ def all_of_type(arg_type: Type[DispatchableInstance]):
... pass
>>> be = ua.Backend()
>>> # All ints piped to -2
>>> be.register_convertor(DispatchableInt, lambda x: -2)
>>> DispatchableInt.register_convertor(be, lambda x: -2)
>>> def potato_rd(args, kwargs, dispatch_args):
... # This replaces a within the args/kwargs
... return dispatch_args + args[1:], kwargs
Expand Down
2 changes: 1 addition & 1 deletion unumpy/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def inner(self, *args, **kwargs):
register_numpy(multimethods.argmin)(np.argmin)
register_numpy(multimethods.argmax)(np.argmax)

NumpyBackend.register_convertor(ndarray, np.asarray)
ndarray.register_convertor(NumpyBackend, np.asarray)
2 changes: 1 addition & 1 deletion unumpy/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,4 @@ def argmin(a, axis=None, out=None):
return reduce(getattr(multimethods, 'min'), a, axis=axis, out=out, arg=True)


TorchBackend.register_convertor(ndarray, asarray)
ndarray.register_convertor(TorchBackend, asarray)
2 changes: 1 addition & 1 deletion unumpy/xnd_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def asarray(a):
return xnd.array(a)


XndBackend.register_convertor(ndarray, asarray)
ndarray.register_convertor(XndBackend, asarray)

0 comments on commit f5b7d39

Please sign in to comment.