Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/aspire/operators/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,19 @@ class PowerFilter(Filter):
A Filter object that is composed of a regular `Filter` object, but evaluates it to a specified power.
"""

def __init__(self, filter, power=1):
def __init__(self, filter, power=1, epsilon=None):
"""
Initialize PowerFilter instance.

:param filter: A Filter instance.
:param power: Exponent to raise filter values.
:param epsilon: Threshold on filter values that get raised to a negative power.
`filter` values below this threshold will be set to zero during evaluation.
Default uses machine epsilon for filter.dtype.
"""
self._filter = filter
self._power = power
self._epsilon = epsilon
super().__init__(dim=filter.dim, radial=filter.radial)

def _evaluate(self, omega):
Expand All @@ -204,7 +214,9 @@ def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs):

# Place safeguard on values below machine epsilon for negative powers.
if self._power < 0:
eps = np.finfo(filter_vals.dtype).eps
eps = self._epsilon
if eps is None:
eps = np.finfo(filter_vals.dtype).eps
condition = abs(filter_vals) < eps
num_less_eps = np.count_nonzero(condition)
if num_less_eps > 0:
Expand Down
10 changes: 8 additions & 2 deletions src/aspire/source/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def downsample(self, L):
self.L = L

@_as_copy
def whiten(self, noise_estimate=None):
def whiten(self, noise_estimate=None, epsilon=None):
"""
Modify the `ImageSource` in-place by appending a whitening filter to the generation pipeline.

Expand All @@ -810,6 +810,9 @@ def whiten(self, noise_estimate=None):
passed a `NoiseEstimator` the `filter` attribute will be
queried. Alternatively, the noise PSD may be passed
directly as a `Filter` object.
:param epsilon: Threshold used to determine which frequencies to whiten
and which to set to zero. By default all PSD values in the `noise_estimate`
less than eps(self.dtype) are zeroed out in the whitening filter.
:return: On return, the `ImageSource` object has been modified in place.
"""

Expand All @@ -827,8 +830,11 @@ def whiten(self, noise_estimate=None):
" instead of `NoiseEstimator` or `Filter`."
)

if epsilon is None:
epsilon = np.finfo(self.dtype).eps

logger.info("Whitening source object")
whiten_filter = PowerFilter(noise_filter, power=-0.5)
whiten_filter = PowerFilter(noise_filter, power=-0.5, epsilon=epsilon)

logger.info("Transforming all CTF Filters into Multiplicative Filters")
self.unique_filters = [
Expand Down
12 changes: 6 additions & 6 deletions tests/test_FLEbasis2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def testMatchFBEvaluate(basis):
fb_images = fb_basis.evaluate(coefs)
fle_images = basis.evaluate(coefs)

assert np.allclose(fb_images._data, fle_images._data, atol=1e-4)
np.testing.assert_allclose(fb_images._data, fle_images._data, atol=1e-4)


@pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params)
Expand All @@ -159,8 +159,8 @@ def testMatchFBDenseEvaluate(basis):
fle_images = Image(fle_out.T.reshape(-1, basis.nres, basis.nres)).asnumpy()

# Matrix column reording in match_fb mode flips signs of some of the basis functions
assert np.allclose(np.abs(fb_images), np.abs(fle_images), atol=1e-3)
assert np.allclose(fb_images, fle_images, atol=1e-3)
np.testing.assert_allclose(np.abs(fb_images), np.abs(fle_images), atol=1e-3)
np.testing.assert_allclose(fb_images, fle_images, atol=1e-3)


@pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params)
Expand All @@ -177,7 +177,7 @@ def testMatchFBEvaluate_t(basis):
fb_coefs = fb_basis.evaluate_t(images)
fle_coefs = basis.evaluate_t(images)

assert np.allclose(fb_coefs, fle_coefs, atol=1e-4)
np.testing.assert_allclose(fb_coefs, fle_coefs, atol=1e-4)


@pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params)
Expand All @@ -197,7 +197,7 @@ def testMatchFBDenseEvaluate_t(basis):
fle_coefs = basis._create_dense_matrix().T @ vec.T

# Matrix column reording in match_fb mode flips signs of some of the basis coefficients
assert np.allclose(np.abs(fb_coefs), np.abs(fle_coefs), atol=1e-4)
np.testing.assert_allclose(np.abs(fb_coefs), np.abs(fle_coefs), atol=1e-4)


def testLowPass():
Expand Down Expand Up @@ -265,4 +265,4 @@ def testRadialConvolution():
convolution_fft_pad[L // 2 : L // 2 + L, L // 2 : L // 2 + L]
)

assert np.allclose(imgs_convolved_fle, imgs_convolved_slow, atol=1e-5)
np.testing.assert_allclose(imgs_convolved_fle, imgs_convolved_slow, atol=1e-5)
39 changes: 36 additions & 3 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,20 +332,36 @@ def testFilterSigns(self):
self.assertTrue(np.allclose(sign_filter.evaluate(self.omega), signs))


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_power_filter_safeguard(dtype, caplog):
DTYPES = [np.float32, np.float64]
EPS = [None, 0.01]


@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module")
def dtype(request):
return request.param


@pytest.fixture(params=EPS, ids=lambda x: f"epsilon={x}", scope="module")
def epsilon(request):
return request.param


def test_power_filter_safeguard(dtype, epsilon, caplog):
L = 25
arr = np.ones((L, L), dtype=dtype)

# Set a few values below machine epsilon.
num_eps = 3
eps = np.finfo(dtype).eps
eps = epsilon
if eps is None:
eps = np.finfo(dtype).eps
arr[L // 2, L // 2 : L // 2 + num_eps] = eps / 2

# For negative powers, values below machine eps will be set to zero.
filt = PowerFilter(
filter=ArrayFilter(arr),
power=-0.5,
epsilon=epsilon,
)

caplog.clear()
Expand All @@ -361,3 +377,20 @@ def test_power_filter_safeguard(dtype, caplog):
# Check caplog for warning.
msg = f"setting {num_eps} extremal filter value(s) to zero."
assert msg in caplog.text


def test_array_filter_dtype_passthrough(dtype):
"""
We upcast to use scipy's fast interpolator. We do not recast
on exit, so this is an expected fail for singles.
"""
if dtype == np.float32:
pytest.xfail(reason="ArrayFilter currently upcasts singles.")

L = 8
arr = np.ones((L, L), dtype=dtype)

filt = ArrayFilter(arr)
filt_vals = filt.evaluate_grid(L, dtype=dtype)

assert filt_vals.dtype == dtype
37 changes: 34 additions & 3 deletions tests/test_preprocess_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def testWhiten(dtype):
corr_coef = np.corrcoef(imgs_wt[:, L - 1, L - 1], imgs_wt[:, L - 2, L - 1])

# correlation matrix should be close to identity
assert np.allclose(np.eye(2), corr_coef, atol=1e-1)
np.testing.assert_allclose(np.eye(2), corr_coef, atol=1e-1)
# dtype of returned images should be the same
assert dtype == imgs_wt.dtype

Expand All @@ -123,7 +123,36 @@ def testWhiten2(dtype):
corr_coef = np.corrcoef(imgs_wt[:, L - 1, L - 1], imgs_wt[:, L - 2, L - 1])

# Correlation matrix should be close to identity
assert np.allclose(np.eye(2), corr_coef, atol=2e-1)
np.testing.assert_allclose(np.eye(2), corr_coef, atol=2e-1)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_whiten_safeguard(dtype):
"""Test that whitening safeguard works as expected."""
L = 25
epsilon = 0.02
sim = get_sim_object(L, dtype)
noise_estimator = AnisotropicNoiseEstimator(sim)
sim = sim.whiten(noise_estimator.filter, epsilon=epsilon)

# Get whitening_filter from generation pipeline.
whiten_filt = sim.generation_pipeline.xforms[0].filter.evaluate_grid(sim.L)

# Generate whitening_filter without safeguard directly from noise_estimator.
filt_vals = noise_estimator.filter.xfer_fn_array
whiten_filt_unsafe = filt_vals**-0.5

# Get indices where safeguard should be applied
# and assert that they are not empty.
ind = np.where(filt_vals < epsilon)
np.testing.assert_array_less(0, len(ind[0]))

# Check that whiten_filt and whiten_filt_unsafe agree up to safeguard indices.
disagree = np.where(whiten_filt != whiten_filt_unsafe)
np.testing.assert_array_equal(ind, disagree)

# Check that whiten_filt is zero at safeguard indices.
np.testing.assert_allclose(whiten_filt[ind], 0.0)


@pytest.mark.parametrize("L, dtype", params)
Expand All @@ -138,7 +167,9 @@ def testInvertContrast(L, dtype):
imgs2_rc = sim2.images[:num_images]

# all images should be the same after inverting contrast
assert np.allclose(imgs1_rc.asnumpy(), imgs2_rc.asnumpy())
np.testing.assert_allclose(
imgs1_rc.asnumpy(), imgs2_rc.asnumpy(), rtol=1e-05, atol=1e-08
)
# dtype of returned images should be the same
assert dtype == imgs1_rc.dtype
assert dtype == imgs2_rc.dtype