**iting custom array containers**

In [2]:
import numpy as np
class DiagonalArray:
    def __init__(self, N, value):
        self._N = N
        self._i = value
    def __repr__(self):
        return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
    def __array__(self, dtype=None):
        return self._i * np.eye(self._N, dtype=dtype)

In [3]:
arr = DiagonalArray(5, 1)
arr

DiagonalArray(N=5, value=1)

In [4]:
np.asarray(arr)

array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

In [5]:
np.multiply(arr, 2)

array([[2., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0.],
       [0., 0., 2., 0., 0.],
       [0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 2.]])

In [6]:
type(np.multiply(arr, 2))

numpy.ndarray

In [7]:
from numbers import Number
class DiagonalArray:
    def __init__(self, N, value):
        self._N = N
        self._i = value
    def __repr__(self):
        return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
    def __array__(self, dtype=None):
        return self._i * np.eye(self._N, dtype=dtype)
    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        if method == '__call__':
            N = None
            scalars = []
            for input in inputs:
                if isinstance(input, Number):
                    scalars.append(input)
                elif isinstance(input, self.__class__):
                    scalars.append(input._i)
                    if N is not None:
                        if N != self._N:
                            raise TypeError("inconsistent sizes")
                    else:
                        N = self._N
                else:
                    return NotImplemented
            return self.__class__(N, ufunc(*scalars, **kwargs))
        else:
            return NotImplemented

In [8]:
arr = DiagonalArray(5, 1)
np.multiply(arr, 3)

DiagonalArray(N=5, value=3)

In [9]:
np.add(arr, 3)

DiagonalArray(N=5, value=4)

In [10]:
np.sin(arr)

DiagonalArray(N=5, value=0.8414709848078965)

In [11]:
arr + 3

TypeError: unsupported operand type(s) for +: 'DiagonalArray' and 'int'

In [12]:
import numpy.lib.mixins
class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
    def __init__(self, N, value):
        self._N = N
        self._i = value
    def __repr__(self):
        return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
    def __array__(self, dtype=None):
        return self._i * np.eye(self._N, dtype=dtype)
    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        if method == '__call__':
            N = None
            scalars = []
            for input in inputs:
                if isinstance(input, Number):
                    scalars.append(input)
                elif isinstance(input, self.__class__):
                    scalars.append(input._i)
                    if N is not None:
                        if N != self._N:
                            raise TypeError("inconsistent sizes")
                    else:
                        N = self._N
                else:
                    return NotImplemented
            return self.__class__(N, ufunc(*scalars, **kwargs))
        else:
            return NotImplemented

In [13]:
arr = DiagonalArray(5, 1)
arr + 3

DiagonalArray(N=5, value=4)

In [14]:
arr > 0

DiagonalArray(N=5, value=True)

In [15]:
HANDLED_FUNCTIONS = {}
class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
    def __init__(self, N, value):
        self._N = N
        self._i = value
    def __repr__(self):
        return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
    def __array__(self, dtype=None):
        return self._i * np.eye(self._N, dtype=dtype)
    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        if method == '__call__':
            N = None
            scalars = []
            for input in inputs:
                # In this case we accept only scalar numbers or DiagonalArrays.
                if isinstance(input, Number):
                    scalars.append(input)
                elif isinstance(input, self.__class__):
                    scalars.append(input._i)
                    if N is not None:
                        if N != self._N:
                            raise TypeError("inconsistent sizes")
                    else:
                        N = self._N
                else:
                    return NotImplemented
            return self.__class__(N, ufunc(*scalars, **kwargs))
        else:
            return NotImplemented
    def __array_function__(self, func, types, args, kwargs):
        if func not in HANDLED_FUNCTIONS:
            return NotImplemented
        # Note: this allows subclasses that don't override
        # __array_function__ to handle DiagonalArray objects.
        if not all(issubclass(t, self.__class__) for t in types):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

In [16]:
def implements(np_function):
   "Register an __array_function__ implementation for DiagonalArray objects."
   def decorator(func):
       HANDLED_FUNCTIONS[np_function] = func
       return func
   return decorator

In [17]:
@implements(np.sum)
def sum(arr):
    "Implementation of np.sum for DiagonalArray objects"
    return arr._i * arr._N

In [18]:
@implements(np.mean)
def mean(arr):
    "Implementation of np.mean for DiagonalArray objects"
    return arr._i / arr._N
arr = DiagonalArray(5, 1)
np.sum(arr)

5

In [19]:
np.mean(arr)

0.2

In [20]:
np.concatenate([arr, arr])

TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]

In [21]:
np.sum(arr, axis=0)

TypeError: sum() got an unexpected keyword argument 'axis'

In [22]:
np.concatenate([np.asarray(arr), np.asarray(arr)])

array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])