Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: use numpy-config and fix support for numpy 2.0 #723

Merged
merged 2 commits into from Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion pywt/_cwt.py
Expand Up @@ -7,6 +7,7 @@
_check_dtype,
)
from ._functions import integrate_wavelet, scale2frequency
from ._utils import AxisError

__all__ = ["cwt"]

Expand Down Expand Up @@ -124,7 +125,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
raise ValueError("`scales` must only include positive values")

if not np.isscalar(axis):
raise np.AxisError("axis must be a scalar.")
raise AxisError("axis must be a scalar.")

dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
Expand Down
6 changes: 3 additions & 3 deletions pywt/_dwt.py
Expand Up @@ -9,7 +9,7 @@
from ._extensions._dwt import dwt_max_level as _dwt_max_level
from ._extensions._dwt import upcoef as _upcoef
from ._extensions._pywt import Modes, Wavelet, _check_dtype, wavelist
from ._utils import _as_wavelet
from ._utils import AxisError, _as_wavelet

__all__ = ["dwt", "idwt", "downcoef", "upcoef", "dwt_max_level",
"dwt_coeff_len", "pad"]
Expand Down Expand Up @@ -176,7 +176,7 @@ def dwt(data, wavelet, mode='symmetric', axis=-1):
if axis < 0:
axis = axis + data.ndim
if not 0 <= axis < data.ndim:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")

if data.ndim == 1:
cA, cD = dwt_single(data, wavelet, mode)
Expand Down Expand Up @@ -282,7 +282,7 @@ def idwt(cA, cD, wavelet, mode='symmetric', axis=-1):
if axis < 0:
axis = axis + ndim
if not 0 <= axis < ndim:
raise np.AxisError("Axis greater than coefficient dimensions")
raise AxisError("Axis greater than coefficient dimensions")

if ndim == 1:
rec = idwt_single(cA, cD, wavelet, mode)
Expand Down
63 changes: 31 additions & 32 deletions pywt/_extensions/meson.build
Expand Up @@ -3,45 +3,44 @@ if m_dep.found()
add_project_link_arguments('-lm', language : 'c')
endif

# For cross-compilation it is often not possible to run the Python interpreter
# in order to retrieve numpy's include directory. It can be specified in the
# cross file instead:
# [properties]
# numpy-include-dir = /abspath/to/host-pythons/site-packages/numpy/core/include
#
# This uses the path as is, and avoids running the interpreter.
incdir_numpy = meson.get_external_property('numpy-include-dir', 'not-given')
if incdir_numpy == 'not-given'
incdir_numpy = run_command(py,
[
'-c',
'''import os
# Don't use the deprecated NumPy C API. Define this to a fixed version instead of
# NPY_API_VERSION in order not to break compilation for released PyWavelets
# versions when NumPy introduces a new deprecation.
numpy_nodepr_api = ['-DNPY_NO_DEPRECATED_API=NPY_1_22_API_VERSION']

# Uses the `numpy-config` executable (or a user's numpy.pc pkg-config file),
# will work for numpy>=2.0.0b1 and meson>=1.4.0
_numpy_dep = dependency('numpy', required: false)
if _numpy_dep.found()
np_dep = declare_dependency(dependencies: _numpy_dep, compile_args: numpy_nodepr_api)
else
# For cross-compilation it is often not possible to run the Python interpreter
# in order to retrieve numpy's include directory. It can be specified in the
# cross file instead:
# [properties]
# numpy-include-dir = /abspath/to/host-pythons/site-packages/numpy/core/include
#
# This uses the path as is, and avoids running the interpreter.
incdir_numpy = meson.get_external_property('numpy-include-dir', 'not-given')
if incdir_numpy == 'not-given'
incdir_numpy = run_command(py,
[
'-c',
'''import os
import numpy as np
try:
incdir = os.path.relpath(np.get_include())
except Exception:
incdir = np.get_include()
print(incdir)
'''
],
check: true
).stdout().strip()

# We do need an absolute path to feed to `cc.find_library` below
_incdir_numpy_abs = run_command(py,
['-c', 'import os; os.chdir(".."); import numpy; print(numpy.get_include())'],
check: true
).stdout().strip()
else
_incdir_numpy_abs = incdir_numpy
'''
],
check: true
).stdout().strip()
endif
inc_np = include_directories(incdir_numpy)
np_dep = declare_dependency(include_directories: inc_np, compile_args: numpy_nodepr_api)
endif
inc_np = include_directories(incdir_numpy)

# Don't use the deprecated NumPy C API. Define this to a fixed version instead of
# NPY_API_VERSION in order not to break compilation for released PyWavelets
# versions when NumPy introduces a new deprecation.
numpy_nodepr_api = ['-DNPY_NO_DEPRECATED_API=NPY_1_22_API_VERSION']
np_dep = declare_dependency(include_directories: inc_np, compile_args: numpy_nodepr_api)

config_pxi = configure_file(
input: 'config.pxi.in',
Expand Down
4 changes: 2 additions & 2 deletions pywt/_multidim.py
Expand Up @@ -14,7 +14,7 @@

from ._c99_config import _have_c99_complex
from ._extensions._dwt import dwt_axis, idwt_axis
from ._utils import _modes_per_axis, _wavelets_per_axis
from ._utils import AxisError, _modes_per_axis, _wavelets_per_axis

__all__ = ['dwt2', 'idwt2', 'dwtn', 'idwtn']

Expand Down Expand Up @@ -288,7 +288,7 @@ def idwtn(coeffs, wavelet, mode='symmetric', axes=None):
for key_length, (axis, wav, mode) in reversed(
list(enumerate(zip(axes, wavelets, modes)))):
if axis < 0 or axis >= ndim:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")

new_coeffs = {}
new_keys = [''.join(coef) for coef in product('ad', repeat=key_length)]
Expand Down
14 changes: 7 additions & 7 deletions pywt/_multilevel.py
Expand Up @@ -20,7 +20,7 @@
from ._extensions._dwt import dwt_max_level
from ._extensions._pywt import Modes, Wavelet
from ._multidim import _fix_coeffs, dwt2, dwtn, idwt2, idwtn
from ._utils import _as_wavelet, _modes_per_axis, _wavelets_per_axis
from ._utils import AxisError, _as_wavelet, _modes_per_axis, _wavelets_per_axis

__all__ = ['wavedec', 'waverec', 'wavedec2', 'waverec2', 'wavedecn',
'waverecn', 'coeffs_to_array', 'array_to_coeffs', 'ravel_coeffs',
Expand Down Expand Up @@ -93,7 +93,7 @@ def wavedec(data, wavelet, mode='symmetric', level=None, axis=-1):
try:
axes_shape = data.shape[axis]
except IndexError:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")
level = _check_level(axes_shape, wavelet.dec_len, level)

coeffs_list = []
Expand Down Expand Up @@ -170,7 +170,7 @@ def waverec(coeffs, wavelet, mode='symmetric', axis=-1):
elif a.shape[axis] != d.shape[axis]:
raise ValueError("coefficient shape mismatch")
except IndexError:
raise np.AxisError("Axis greater than coefficient dimensions")
raise AxisError("Axis greater than coefficient dimensions")
a = idwt(a, d, wavelet, mode, axis)

return a
Expand Down Expand Up @@ -233,7 +233,7 @@ def wavedec2(data, wavelet, mode='symmetric', level=None, axes=(-2, -1)):
try:
axes_sizes = [data.shape[ax] for ax in axes]
except IndexError:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")

wavelets = _wavelets_per_axis(wavelet, axes)
dec_lengths = [w.dec_len for w in wavelets]
Expand Down Expand Up @@ -352,7 +352,7 @@ def _prep_axes_wavedecn(shape, axes):
try:
axes_shapes = [shape[ax] for ax in axes]
except IndexError:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")
ndim_transform = len(axes)
return axes, axes_shapes, ndim_transform

Expand Down Expand Up @@ -1194,11 +1194,11 @@ def unravel_coeffs(arr, coeff_slices, coeff_shapes, output_format='wavedecn'):
def _check_fswavedecn_axes(data, axes):
"""Axes checks common to fswavedecn, fswaverecn."""
if len(axes) != len(set(axes)):
raise np.AxisError("The axes passed to fswavedecn must be unique.")
raise AxisError("The axes passed to fswavedecn must be unique.")
try:
[data.shape[ax] for ax in axes]
except IndexError:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")


class FswavedecnResult:
Expand Down
8 changes: 4 additions & 4 deletions pywt/_swt.py
Expand Up @@ -10,7 +10,7 @@
from ._extensions._swt import swt_axis as _swt_axis
from ._extensions._swt import swt_max_level
from ._multidim import idwt2, idwtn
from ._utils import _as_wavelet, _wavelets_per_axis
from ._utils import AxisError, _as_wavelet, _wavelets_per_axis

__all__ = ["swt", "swt_max_level", 'iswt', 'swt2', 'iswt2', 'swtn', 'iswtn']

Expand Down Expand Up @@ -141,7 +141,7 @@ def swt(data, wavelet, level=None, start_level=0, axis=-1,
if axis < 0:
axis = axis + data.ndim
if not 0 <= axis < data.ndim:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")

if level is None:
level = swt_max_level(data.shape[axis])
Expand Down Expand Up @@ -196,7 +196,7 @@ def iswt(coeffs, wavelet, norm=False, axis=-1):
coeffs_nd = [{'a': a, 'd': d} for a, d in coeffs]
return iswtn(coeffs_nd, wavelet, axes=(axis,), norm=norm)
elif axis != 0 and axis != -1:
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")
if not _have_c99_complex and np.iscomplexobj(cA):
if trim_approx:
coeffs_real = [c.real for c in coeffs]
Expand Down Expand Up @@ -639,7 +639,7 @@ def swtn(data, wavelet, level, start_level=0, axes=None, trim_approx=False,
axes = range(data.ndim)
axes = [a + data.ndim if a < 0 else a for a in axes]
if any(a < 0 or a >= data.ndim for a in axes):
raise np.AxisError("Axis greater than data dimensions")
raise AxisError("Axis greater than data dimensions")
if len(axes) != len(set(axes)):
raise ValueError("The axes passed to swtn must be unique.")
num_axes = len(axes)
Expand Down
6 changes: 6 additions & 0 deletions pywt/_utils.py
Expand Up @@ -13,6 +13,12 @@
Wavelet,
)

AxisError: type[Exception]
if np.lib.NumpyVersion(np.__version__) >= '1.25.0':
from numpy.exceptions import AxisError
else:
from numpy import AxisError


def _as_wavelet(wavelet):
"""Convert wavelet name to a Wavelet object."""
Expand Down
7 changes: 4 additions & 3 deletions pywt/tests/test_mra.py
Expand Up @@ -6,6 +6,7 @@

import pywt
from pywt import data
from pywt._utils import AxisError

# tolerances used in accuracy comparisons
tol_single = 1e-6
Expand Down Expand Up @@ -84,7 +85,7 @@ def test_mra_axis(transform, ndim, axis, dtype):

# out of range axis
if axis < -x.ndim or axis >= x.ndim:
with pytest.raises(np.AxisError):
with pytest.raises(AxisError):
pywt.mra(x, 'db1', transform=transform, axis=axis)
return

Expand Down Expand Up @@ -160,7 +161,7 @@ def test_mra2_axes(transform, axes, ndim, dtype):

# out of range axis
if any(axis < -x.ndim or axis >= x.ndim for axis in axes):
with pytest.raises(np.AxisError):
with pytest.raises(AxisError):
pywt.mra2(x, 'db1', transform=transform, axes=axes)
return

Expand Down Expand Up @@ -246,7 +247,7 @@ def test_mran_axes(axes, transform):

# out of range axis
if any(axis < -x.ndim or axis >= x.ndim for axis in axes):
with pytest.raises(np.AxisError):
with pytest.raises(AxisError):
pywt.mran(x, 'db1', transform='dwtn', axes=axes)
return

Expand Down