Skip to content

Commit

Permalink
Add support for NamedTuples, support NumPy 1.25
Browse files Browse the repository at this point in the history
The functions in numpy.linalg which returned tuples in version <= 1.24
have switched to using a different NamedTuple type for each function. To
support this I've added the necessary Box and VSpace registrations.
  • Loading branch information
j-towns committed Jun 22, 2023
1 parent c2b8ab5 commit 7c2a7ff
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
6 changes: 6 additions & 0 deletions autograd/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,9 @@ def _subval(self, xs, idx, x):
ListVSpace.register(list_)
TupleVSpace.register(tuple_)
DictVSpace.register(dict_)

class NamedTupleVSpace(SequenceVSpace):
def _map(self, f, *args):
return self.seq_type(*map(f, self.shape, *args))
def _subval(self, xs, idx, x):
return self.seq_type(*subvals(xs, [(idx, x)]))
8 changes: 8 additions & 0 deletions autograd/numpy/numpy_boxes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
import numpy as np
from autograd.extend import Box, primitive
from autograd.builtins import SequenceBox
from . import numpy_wrapper as anp

Box.__array_priority__ = 90.0
Expand Down Expand Up @@ -64,3 +65,10 @@ def __hash__(self): return id(self)

# Flatten has no function, only a method.
setattr(ArrayBox, 'flatten', anp.__dict__['ravel'])

if np.__version__ >= '1.25':
SequenceBox.register(np.linalg.linalg.EigResult)
SequenceBox.register(np.linalg.linalg.EighResult)
SequenceBox.register(np.linalg.linalg.QRResult)
SequenceBox.register(np.linalg.linalg.SlogdetResult)
SequenceBox.register(np.linalg.linalg.SVDResult)
15 changes: 15 additions & 0 deletions autograd/numpy/numpy_vspaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from autograd.extend import VSpace
from autograd.builtins import NamedTupleVSpace

class ArrayVSpace(VSpace):
def __init__(self, value):
Expand Down Expand Up @@ -63,3 +64,17 @@ def _covector(self, x):

for type_ in [complex, np.complex64, np.complex128, np.complex256]:
ComplexArrayVSpace.register(type_)


if np.__version__ >= '1.25':
class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EigResult
class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EighResult
class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.QRResult
class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.SlogdetResult
class SVDResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.SVDResult

EigResultVSpace.register(np.linalg.linalg.EigResult)
EighResultVSpace.register(np.linalg.linalg.EighResult)
QRResultVSpace.register(np.linalg.linalg.QRResult)
SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult)
SVDResultVSpace.register(np.linalg.linalg.SVDResult)

0 comments on commit 7c2a7ff

Please sign in to comment.