Skip to content

Commit

Permalink
ENH: unifying error type for gh scipy numpy#1339
Browse files Browse the repository at this point in the history
  • Loading branch information
WillTirone committed Jan 24, 2023
1 parent 9d0f6e5 commit 5c7c696
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion numpy/lib/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def p(x, y):
def test_poly1d_nan_roots(self):
# Ticket #396
p = np.poly1d([np.nan, np.nan, 1], r=False)
assert_raises(ValueError, getattr, p, "r")
assert_raises(np.linalg.LinAlgError, getattr, p, "r")

def test_mem_polymul(self):
# Ticket #448
Expand Down
14 changes: 9 additions & 5 deletions numpy/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from numpy.core.overrides import set_module
from numpy.core import overrides
from numpy.lib.twodim_base import triu, eye
from numpy.lib.function_base import asarray_chkfinite
from numpy.linalg import _umath_linalg


Expand All @@ -43,11 +42,11 @@


@set_module('numpy.linalg')
class LinAlgError(Exception):
class LinAlgError(ValueError):
"""
Generic Python-exception-derived object raised by linalg functions.
General purpose exception class, derived from Python's exception.Exception
General purpose exception class, derived from Python's ValueError
class, programmatically raised in linalg functions when a Linear
Algebra-related condition would prevent further correct execution of the
function.
Expand Down Expand Up @@ -204,6 +203,11 @@ def _assert_stacked_square(*arrays):
if m != n:
raise LinAlgError('Last 2 dimensions of the array must be square')

def _assert_finite(*arrays):
for a in arrays:
if not isfinite(a).all():
raise LinAlgError("Array must not contain infs or NaNs")

def _is_empty_2d(arr):
# check size first for efficiency
return arr.size == 0 and product(arr.shape[-2:]) == 0
Expand Down Expand Up @@ -1050,7 +1054,7 @@ def eigvals(a):
a, wrap = _makearray(a)
_assert_stacked_2d(a)
_assert_stacked_square(a)
asarray_chkfinite(a)
_assert_finite(a)
t, result_t = _commonType(a)

extobj = get_linalg_error_extobj(
Expand Down Expand Up @@ -1305,7 +1309,7 @@ def eig(a):
a, wrap = _makearray(a)
_assert_stacked_2d(a)
_assert_stacked_square(a)
asarray_chkfinite(a)
_assert_finite(a)
t, result_t = _commonType(a)

extobj = get_linalg_error_extobj(
Expand Down

0 comments on commit 5c7c696

Please sign in to comment.