diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 9b910a8fe0..bb7491c780 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -184,9 +184,19 @@ class PowerFilter(Filter): A Filter object that is composed of a regular `Filter` object, but evaluates it to a specified power. """ - def __init__(self, filter, power=1): + def __init__(self, filter, power=1, epsilon=None): + """ + Initialize PowerFilter instance. + + :param filter: A Filter instance. + :param power: Exponent to raise filter values. + :param epsilon: Threshold on filter values that get raised to a negative power. + `filter` values below this threshold will be set to zero during evaluation. + Default uses machine epsilon for filter.dtype. + """ self._filter = filter self._power = power + self._epsilon = epsilon super().__init__(dim=filter.dim, radial=filter.radial) def _evaluate(self, omega): @@ -204,7 +214,9 @@ def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): # Place safeguard on values below machine epsilon for negative powers. if self._power < 0: - eps = np.finfo(filter_vals.dtype).eps + eps = self._epsilon + if eps is None: + eps = np.finfo(filter_vals.dtype).eps condition = abs(filter_vals) < eps num_less_eps = np.count_nonzero(condition) if num_less_eps > 0: diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 473585acb9..fa5be1f7f7 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -798,7 +798,7 @@ def downsample(self, L): self.L = L @_as_copy - def whiten(self, noise_estimate=None): + def whiten(self, noise_estimate=None, epsilon=None): """ Modify the `ImageSource` in-place by appending a whitening filter to the generation pipeline. @@ -810,6 +810,9 @@ def whiten(self, noise_estimate=None): passed a `NoiseEstimator` the `filter` attribute will be queried. Alternatively, the noise PSD may be passed directly as a `Filter` object. + :param epsilon: Threshold used to determine which frequencies to whiten + and which to set to zero. By default all PSD values in the `noise_estimate` + less than eps(self.dtype) are zeroed out in the whitening filter. :return: On return, the `ImageSource` object has been modified in place. """ @@ -827,8 +830,11 @@ def whiten(self, noise_estimate=None): " instead of `NoiseEstimator` or `Filter`." ) + if epsilon is None: + epsilon = np.finfo(self.dtype).eps + logger.info("Whitening source object") - whiten_filter = PowerFilter(noise_filter, power=-0.5) + whiten_filter = PowerFilter(noise_filter, power=-0.5, epsilon=epsilon) logger.info("Transforming all CTF Filters into Multiplicative Filters") self.unique_filters = [ diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py index 7d6b3f5c47..ffb1f8f7d1 100644 --- a/tests/test_FLEbasis2D.py +++ b/tests/test_FLEbasis2D.py @@ -142,7 +142,7 @@ def testMatchFBEvaluate(basis): fb_images = fb_basis.evaluate(coefs) fle_images = basis.evaluate(coefs) - assert np.allclose(fb_images._data, fle_images._data, atol=1e-4) + np.testing.assert_allclose(fb_images._data, fle_images._data, atol=1e-4) @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) @@ -159,8 +159,8 @@ def testMatchFBDenseEvaluate(basis): fle_images = Image(fle_out.T.reshape(-1, basis.nres, basis.nres)).asnumpy() # Matrix column reording in match_fb mode flips signs of some of the basis functions - assert np.allclose(np.abs(fb_images), np.abs(fle_images), atol=1e-3) - assert np.allclose(fb_images, fle_images, atol=1e-3) + np.testing.assert_allclose(np.abs(fb_images), np.abs(fle_images), atol=1e-3) + np.testing.assert_allclose(fb_images, fle_images, atol=1e-3) @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) @@ -177,7 +177,7 @@ def testMatchFBEvaluate_t(basis): fb_coefs = fb_basis.evaluate_t(images) fle_coefs = basis.evaluate_t(images) - assert np.allclose(fb_coefs, fle_coefs, atol=1e-4) + np.testing.assert_allclose(fb_coefs, fle_coefs, atol=1e-4) @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) @@ -197,7 +197,7 @@ def testMatchFBDenseEvaluate_t(basis): fle_coefs = basis._create_dense_matrix().T @ vec.T # Matrix column reording in match_fb mode flips signs of some of the basis coefficients - assert np.allclose(np.abs(fb_coefs), np.abs(fle_coefs), atol=1e-4) + np.testing.assert_allclose(np.abs(fb_coefs), np.abs(fle_coefs), atol=1e-4) def testLowPass(): @@ -265,4 +265,4 @@ def testRadialConvolution(): convolution_fft_pad[L // 2 : L // 2 + L, L // 2 : L // 2 + L] ) - assert np.allclose(imgs_convolved_fle, imgs_convolved_slow, atol=1e-5) + np.testing.assert_allclose(imgs_convolved_fle, imgs_convolved_slow, atol=1e-5) diff --git a/tests/test_filters.py b/tests/test_filters.py index 35d7955a9e..911e3b347b 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -332,20 +332,36 @@ def testFilterSigns(self): self.assertTrue(np.allclose(sign_filter.evaluate(self.omega), signs)) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_power_filter_safeguard(dtype, caplog): +DTYPES = [np.float32, np.float64] +EPS = [None, 0.01] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(params=EPS, ids=lambda x: f"epsilon={x}", scope="module") +def epsilon(request): + return request.param + + +def test_power_filter_safeguard(dtype, epsilon, caplog): L = 25 arr = np.ones((L, L), dtype=dtype) # Set a few values below machine epsilon. num_eps = 3 - eps = np.finfo(dtype).eps + eps = epsilon + if eps is None: + eps = np.finfo(dtype).eps arr[L // 2, L // 2 : L // 2 + num_eps] = eps / 2 # For negative powers, values below machine eps will be set to zero. filt = PowerFilter( filter=ArrayFilter(arr), power=-0.5, + epsilon=epsilon, ) caplog.clear() @@ -361,3 +377,20 @@ def test_power_filter_safeguard(dtype, caplog): # Check caplog for warning. msg = f"setting {num_eps} extremal filter value(s) to zero." assert msg in caplog.text + + +def test_array_filter_dtype_passthrough(dtype): + """ + We upcast to use scipy's fast interpolator. We do not recast + on exit, so this is an expected fail for singles. + """ + if dtype == np.float32: + pytest.xfail(reason="ArrayFilter currently upcasts singles.") + + L = 8 + arr = np.ones((L, L), dtype=dtype) + + filt = ArrayFilter(arr) + filt_vals = filt.evaluate_grid(L, dtype=dtype) + + assert filt_vals.dtype == dtype diff --git a/tests/test_preprocess_pipeline.py b/tests/test_preprocess_pipeline.py index fb7d2427ec..d7416095b0 100644 --- a/tests/test_preprocess_pipeline.py +++ b/tests/test_preprocess_pipeline.py @@ -101,7 +101,7 @@ def testWhiten(dtype): corr_coef = np.corrcoef(imgs_wt[:, L - 1, L - 1], imgs_wt[:, L - 2, L - 1]) # correlation matrix should be close to identity - assert np.allclose(np.eye(2), corr_coef, atol=1e-1) + np.testing.assert_allclose(np.eye(2), corr_coef, atol=1e-1) # dtype of returned images should be the same assert dtype == imgs_wt.dtype @@ -123,7 +123,36 @@ def testWhiten2(dtype): corr_coef = np.corrcoef(imgs_wt[:, L - 1, L - 1], imgs_wt[:, L - 2, L - 1]) # Correlation matrix should be close to identity - assert np.allclose(np.eye(2), corr_coef, atol=2e-1) + np.testing.assert_allclose(np.eye(2), corr_coef, atol=2e-1) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_whiten_safeguard(dtype): + """Test that whitening safeguard works as expected.""" + L = 25 + epsilon = 0.02 + sim = get_sim_object(L, dtype) + noise_estimator = AnisotropicNoiseEstimator(sim) + sim = sim.whiten(noise_estimator.filter, epsilon=epsilon) + + # Get whitening_filter from generation pipeline. + whiten_filt = sim.generation_pipeline.xforms[0].filter.evaluate_grid(sim.L) + + # Generate whitening_filter without safeguard directly from noise_estimator. + filt_vals = noise_estimator.filter.xfer_fn_array + whiten_filt_unsafe = filt_vals**-0.5 + + # Get indices where safeguard should be applied + # and assert that they are not empty. + ind = np.where(filt_vals < epsilon) + np.testing.assert_array_less(0, len(ind[0])) + + # Check that whiten_filt and whiten_filt_unsafe agree up to safeguard indices. + disagree = np.where(whiten_filt != whiten_filt_unsafe) + np.testing.assert_array_equal(ind, disagree) + + # Check that whiten_filt is zero at safeguard indices. + np.testing.assert_allclose(whiten_filt[ind], 0.0) @pytest.mark.parametrize("L, dtype", params) @@ -138,7 +167,9 @@ def testInvertContrast(L, dtype): imgs2_rc = sim2.images[:num_images] # all images should be the same after inverting contrast - assert np.allclose(imgs1_rc.asnumpy(), imgs2_rc.asnumpy()) + np.testing.assert_allclose( + imgs1_rc.asnumpy(), imgs2_rc.asnumpy(), rtol=1e-05, atol=1e-08 + ) # dtype of returned images should be the same assert dtype == imgs1_rc.dtype assert dtype == imgs2_rc.dtype