Skip to content

Commit

Permalink
complex array exception handling in quantile filter
Browse files Browse the repository at this point in the history
  • Loading branch information
aromanielloNTIA committed May 8, 2023
1 parent d2e2db3 commit 0b07535
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 6 additions & 3 deletions scos_actions/signal_processing/power_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,16 @@ def filter_quantiles(x: np.ndarray, q_lo: float, q_hi: float) -> np.ndarray:
"""
Replace values outside specified quantiles with NaN.
:param x: Input N-dimensional data array.
:param x: Input N-dimensional data array. Complex valued arrays
are not supported.
:param q_lo: Lower quantile, 0 <= q_lo < q_hi.
:param q_hi: Upper quantile, q_lo < q_hi <= 1.
:return: The input data array, with values outside the
specified quantile replaced with NaN (numpy.nan).
:raises ValueError: If either ``q_lo`` or ``q_hi`` is not
within its valid range (listed above).
:raises TypeError: If ``x`` is not a NumPy array with a size
greater than 1.
:raises TypeError: If ``x`` is not a real-valued NumPy array
with a size greater than 1.
"""
if q_lo < 0 or q_lo >= q_hi:
raise ValueError("q_lo must satistfy 0 <= q_lo < q_hi")
Expand All @@ -210,6 +211,8 @@ def filter_quantiles(x: np.ndarray, q_lo: float, q_hi: float) -> np.ndarray:
raise TypeError("Input data must be a NumPy array")
if x.size <= 1:
raise TypeError("Input data must have length greater than 1")
if np.iscomplexobj(x):
raise TypeError("Input data must be real, not complex")
lo, hi = np.quantile(x, [q_lo, q_hi]) # Works on flattened array
if x.size < NUMEXPR_THRESHOLD:
x = np.where((x <= lo) | (x > hi), np.nan, x)
Expand Down
4 changes: 4 additions & 0 deletions scos_actions/signal_processing/tests/test_power_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,7 @@ def test_filter_quantiles():
for q in bad_hi_q:
with pytest.raises(ValueError):
_ = pa.filter_quantiles(test_data, lo_q, q)
# Complex input should raise TypeError
test_complex_data = test_data + 1j * test_data
with pytest.raises(TypeError):
_ = pa.filter_quantiles(test_complex_data, lo_q, hi_q)

0 comments on commit 0b07535

Please sign in to comment.