Skip to content

Commit

Permalink
Merge pull request #600 from j-towns/fix-597
Browse files Browse the repository at this point in the history
Add support for NamedTuples, support NumPy 1.25 (fixes #597)
  • Loading branch information
j-towns committed Jun 22, 2023
2 parents c2b8ab5 + 7c2a7ff commit 06e6e2d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
6 changes: 6 additions & 0 deletions autograd/builtins.py
Expand Up @@ -167,3 +167,9 @@ def _subval(self, xs, idx, x):
ListVSpace.register(list_)
TupleVSpace.register(tuple_)
DictVSpace.register(dict_)

class NamedTupleVSpace(SequenceVSpace):
def _map(self, f, *args):
return self.seq_type(*map(f, self.shape, *args))
def _subval(self, xs, idx, x):
return self.seq_type(*subvals(xs, [(idx, x)]))
8 changes: 8 additions & 0 deletions autograd/numpy/numpy_boxes.py
@@ -1,6 +1,7 @@
from __future__ import absolute_import
import numpy as np
from autograd.extend import Box, primitive
from autograd.builtins import SequenceBox
from . import numpy_wrapper as anp

Box.__array_priority__ = 90.0
Expand Down Expand Up @@ -64,3 +65,10 @@ def __hash__(self): return id(self)

# Flatten has no function, only a method.
setattr(ArrayBox, 'flatten', anp.__dict__['ravel'])

if np.__version__ >= '1.25':
SequenceBox.register(np.linalg.linalg.EigResult)
SequenceBox.register(np.linalg.linalg.EighResult)
SequenceBox.register(np.linalg.linalg.QRResult)
SequenceBox.register(np.linalg.linalg.SlogdetResult)
SequenceBox.register(np.linalg.linalg.SVDResult)
15 changes: 15 additions & 0 deletions autograd/numpy/numpy_vspaces.py
@@ -1,5 +1,6 @@
import numpy as np
from autograd.extend import VSpace
from autograd.builtins import NamedTupleVSpace

class ArrayVSpace(VSpace):
def __init__(self, value):
Expand Down Expand Up @@ -63,3 +64,17 @@ def _covector(self, x):

for type_ in [complex, np.complex64, np.complex128, np.complex256]:
ComplexArrayVSpace.register(type_)


if np.__version__ >= '1.25':
class EigResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EigResult
class EighResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.EighResult
class QRResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.QRResult
class SlogdetResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.SlogdetResult
class SVDResultVSpace(NamedTupleVSpace): seq_type = np.linalg.linalg.SVDResult

EigResultVSpace.register(np.linalg.linalg.EigResult)
EighResultVSpace.register(np.linalg.linalg.EighResult)
QRResultVSpace.register(np.linalg.linalg.QRResult)
SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult)
SVDResultVSpace.register(np.linalg.linalg.SVDResult)

0 comments on commit 06e6e2d

Please sign in to comment.