Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions pylops/signalprocessing/shift.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
__all__ = ["Shift"]

from typing import Tuple, Union
from typing import TYPE_CHECKING, Tuple, Union

import numpy as np
import numpy.typing as npt

from pylops.basicoperators import Diagonal
from pylops.signalprocessing import FFT
from pylops.utils._internal import _value_or_sized_to_array
from pylops.utils.backend import get_normalize_axis_index
from pylops.utils.typing import DTypeLike
from pylops.utils.typing import DTypeLike, NDArray

if TYPE_CHECKING:
from pylops.linearoperator import LinearOperator


def Shift(
dims: Tuple,
shift: Union[float, npt.ArrayLike],
shift: Union[float, NDArray],
axis: int = -1,
nfft: int = None,
sampling: float = 1.0,
real: bool = False,
engine: str = "numpy",
dtype: DTypeLike = "complex128",
name: str = "S",
**kwargs_fftw
):
**kwargs_fft,
) -> "LinearOperator":
r"""Shift operator

Apply fractional shift in the frequency domain along an ``axis``
Expand Down Expand Up @@ -58,9 +60,8 @@ def Shift(
.. versionadded:: 2.0.0

Name of operator (to be used by :func:`pylops.utils.describe.describe`)
**kwargs_fftw
Arbitrary keyword arguments
for :py:class:`pyfftw.FTTW`
**kwargs_fft
Arbitrary keyword arguments to be passed to the selected fft method

Attributes
----------
Expand Down Expand Up @@ -98,7 +99,7 @@ def Shift(
real=real,
engine=engine,
dtype=dtype,
**kwargs_fftw
**kwargs_fft,
)
if isinstance(dims, int):
dimsdiag = None
Expand Down
8 changes: 7 additions & 1 deletion pylops/waveeqprocessing/marchenko.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ["Marchenko"]

import logging
from typing import Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from scipy.signal import filtfilt
Expand Down Expand Up @@ -246,6 +246,7 @@ def __init__(
prescaled: bool = False,
fftengine: str = "numpy",
dtype: DTypeLike = "float64",
kwargs_fft: Optional[Dict[str, Any]] = None,
) -> None:
# Save inputs into class
self.dt = dt
Expand All @@ -257,6 +258,7 @@ def __init__(
self.prescaled = prescaled
self.fftengine = fftengine
self.dtype = dtype
self.kwargs_fft = {} if kwargs_fft is None else kwargs_fft
self.explicit = False
self.ncp = get_array_module(R)

Expand Down Expand Up @@ -384,6 +386,7 @@ def apply_onepoint(
saveGt=self.saveRt,
prescaled=self.prescaled,
usematmul=usematmul,
**self.kwargs_fft,
)
R1op = MDC(
self.Rtwosided_fft,
Expand All @@ -397,6 +400,7 @@ def apply_onepoint(
saveGt=self.saveRt,
prescaled=self.prescaled,
usematmul=usematmul,
**self.kwargs_fft,
)
Rollop = Roll(
(self.nt2, self.ns),
Expand Down Expand Up @@ -592,6 +596,7 @@ def apply_multiplepoints(
fftengine=self.fftengine,
prescaled=self.prescaled,
usematmul=usematmul,
**self.kwargs_fft,
)
R1op = MDC(
self.Rtwosided_fft,
Expand All @@ -604,6 +609,7 @@ def apply_multiplepoints(
fftengine=self.fftengine,
prescaled=self.prescaled,
usematmul=usematmul,
**self.kwargs_fft,
)
Rollop = Roll(
(self.nt2, self.ns, nvs),
Expand Down
19 changes: 12 additions & 7 deletions pylops/waveeqprocessing/mdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,15 @@ def _MDC(
real=True,
ifftshift_before=twosided,
dtype=rdtype,
**args_FFT
**args_FFT,
)
F1op = _FFT(
dims=(nt, ns, nv),
axis=0,
real=True,
ifftshift_before=False,
dtype=rdtype,
**args_FFT1
**args_FFT1,
)

# create Identity operator to extract only relevant frequencies
Expand Down Expand Up @@ -140,6 +140,7 @@ def MDC(
usematmul: bool = False,
prescaled: bool = False,
name: str = "M",
**kwargs_fft,
) -> LinearOperator:
r"""Multi-dimensional convolution.

Expand Down Expand Up @@ -188,6 +189,10 @@ def MDC(
.. versionadded:: 2.0.0

Name of operator (to be used by :func:`pylops.utils.describe.describe`)
**kwargs_fft
.. versionadded:: 2.6.0

Arbitrary keyword arguments to be passed to the selected fft method

Raises
------
Expand Down Expand Up @@ -243,8 +248,8 @@ def MDC(
saveGt=saveGt,
conj=conj,
prescaled=prescaled,
args_FFT={"engine": fftengine},
args_FFT1={"engine": fftengine},
args_FFT={**{"engine": fftengine}, **kwargs_fft},
args_FFT1={**{"engine": fftengine}, **kwargs_fft},
args_Fredholm1={"usematmul": usematmul},
)
MOp.name = name
Expand All @@ -267,7 +272,7 @@ def MDD(
add_negative: bool = True,
smooth_precond: int = 0,
fftengine: str = "numpy",
**kwargs_solver
**kwargs_solver,
) -> Union[
Tuple[NDArray, NDArray],
Tuple[NDArray, NDArray, NDArray],
Expand Down Expand Up @@ -483,7 +488,7 @@ def MDD(
MDCop,
d.ravel(),
ncp.zeros(int(MDCop.shape[1]), dtype=MDCop.dtype),
**kwargs_solver
**kwargs_solver,
)[0]
minv = ncp.squeeze(minv.reshape(nt2, nr, nv))
minv = ncp.moveaxis(minv, 0, -1)
Expand All @@ -502,7 +507,7 @@ def MDD(
PSFop,
G.ravel(),
ncp.zeros(int(PSFop.shape[1]), dtype=PSFop.dtype),
**kwargs_solver
**kwargs_solver,
)[0]
psfinv = ncp.squeeze(psfinv.reshape(nt2, nr, nr))
psfinv = ncp.moveaxis(psfinv, 0, -1)
Expand Down
15 changes: 14 additions & 1 deletion pylops/waveeqprocessing/oneway.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def PhaseShift(
ky: Optional[NDArray] = None,
dtype: DTypeLike = "float64",
name: str = "P",
fftengine: str = "numpy",
**kwargs_fft,
) -> LinearOperator:
r"""Phase shift operator

Expand Down Expand Up @@ -110,6 +112,15 @@ def PhaseShift(
.. versionadded:: 2.0.0

Name of operator (to be used by :func:`pylops.utils.describe.describe`)
fftengine : :obj:`str`, optional
.. versionadded:: 2.6.0

Engine used for fft computation (``numpy``, ``scipy``, or ``fftw``). Choose
``numpy`` when working with CuPy arrays.
**kwargs_fft
.. versionadded:: 2.6.0

Arbitrary keyword arguments to be passed to the selected fft method

Returns
-------
Expand Down Expand Up @@ -170,7 +181,9 @@ def PhaseShift(
nfft=ky.size,
real=False,
ifftshift_before=True,
engine=fftengine,
dtype=dtypefft,
**kwargs_fft,
)
Pop = _PhaseShift(vel, dz, freq, kx, ky, dtypefft)
if ky is None:
Expand Down Expand Up @@ -204,7 +217,7 @@ def Deghosting(
solver: Callable = lsqr,
dottest: bool = False,
dtype: DTypeLike = "complex128",
**kwargs_solver
**kwargs_solver,
) -> Tuple[NDArray, NDArray]:
r"""Wavefield deghosting.

Expand Down
18 changes: 11 additions & 7 deletions pytests/test_marchenko.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

if int(os.environ.get("TEST_CUPY_PYLOPS", 0)):
import cupy as np
from cupy.testing import assert_array_almost_equal, assert_array_equal

backend = "cupy"
else:
import numpy as np
from numpy.testing import assert_array_almost_equal, assert_array_equal

backend = "numpy"
import numpy as npp
Expand Down Expand Up @@ -86,11 +84,16 @@
R1twosided_fft = npp.fft.rfft(R1twosided, 2 * nt - 1, axis=-1) / npp.sqrt(2 * nt - 1)
R1twosided_fft = R1twosided_fft[..., :nfmax]


par1 = {"niter": 10, "prescaled": False, "fftengine": "numpy"}
par2 = {"niter": 10, "prescaled": True, "fftengine": "numpy"}
par3 = {"niter": 10, "prescaled": False, "fftengine": "scipy"}
par4 = {"niter": 10, "prescaled": False, "fftengine": "fftw"}
# Test parameters
par1 = {"niter": 10, "prescaled": False, "fftengine": "numpy", "kwargs_fft": None}
par2 = {"niter": 10, "prescaled": True, "fftengine": "numpy", "kwargs_fft": None}
par3 = {
"niter": 10,
"prescaled": False,
"fftengine": "scipy",
"kwargs_fft": dict(workers=4),
}
par4 = {"niter": 10, "prescaled": False, "fftengine": "fftw", "kwargs_fft": None}


@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)])
Expand All @@ -111,6 +114,7 @@ def test_Marchenko_freq(par):
nsmooth=nsmooth,
prescaled=par["prescaled"],
fftengine=par["fftengine"] if backend == "numpy" else "numpy",
kwargs_fft=par["kwargs_fft"] if backend == "numpy" else None,
)

solver_dict = (
Expand Down
Loading