diff --git a/src/aspire/covariance/covar.py b/src/aspire/covariance/covar.py index 3e9a64f545..21da816cee 100644 --- a/src/aspire/covariance/covar.py +++ b/src/aspire/covariance/covar.py @@ -57,7 +57,7 @@ def compute_kernel(self): weights[:, 0, :] = 0 # TODO: This is where this differs from MeanEstimator - pts_rot = np.moveaxis(pts_rot[::-1], 1, 0).reshape(-1, 3, L**2) + pts_rot = np.moveaxis(pts_rot, 1, 0).reshape(-1, 3, L**2) weights = weights.T.reshape((-1, L**2)) batch_n = weights.shape[0] @@ -67,7 +67,7 @@ def compute_kernel(self): factors[j] = anufft(weights[j], pts_rot[j], (_2L, _2L, _2L), real=True) factors = Volume(factors).to_vec() - kernel += vecmat_to_volmat(factors.T @ factors) / (n * L**8) + kernel += (factors.T @ factors).reshape((_2L,) * 6) / (n * L**8) # Ensure symmetric kernel kernel[0, :, :, :, :, :] = 0 @@ -90,6 +90,8 @@ def estimate(self, mean_vol, noise_variance, tol=1e-5, regularizer=0): b_coef = self.src_backward(mean_vol, noise_variance) est_coef = self.conj_grad(b_coef, tol=tol, regularizer=regularizer) covar_est = self.basis.mat_evaluate(est_coef) + # Note, notice these cancel out, but requires a lot of changes elsewhere in this file, + # basically totally removing all the `utils/matrix` hacks ... todo. covar_est = vecmat_to_volmat(make_symmat(volmat_to_vecmat(covar_est))) return covar_est @@ -180,7 +182,9 @@ def src_backward(self, mean_vol, noise_variance, shrink_method=None): im_centered_b[j] = self.src.im_backward(im_centered[j], i + j) im_centered_b = Volume(im_centered_b).to_vec() - covar_b += vecmat_to_volmat(im_centered_b.T @ im_centered_b) / self.src.n + covar_b += (im_centered_b.T @ im_centered_b).reshape( + (self.src.L,) * 6 + ) / self.src.n covar_b_coef = self.basis.mat_evaluate_t(covar_b) return self._shrink(covar_b_coef, noise_variance, shrink_method) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 9d2911ad7a..ff2353d333 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -537,7 +537,7 @@ def backproject(self, rot_matrices, symmetry_group=None): ) pts_rot = pts_rot.reshape((3, -1)) - vol += anufft(im_f, pts_rot[::-1], (L, L, L), real=True) + vol += anufft(im_f, pts_rot, (L, L, L), real=True) vol /= L diff --git a/src/aspire/reconstruction/estimator.py b/src/aspire/reconstruction/estimator.py index 679b3e1234..9d62a0b765 100644 --- a/src/aspire/reconstruction/estimator.py +++ b/src/aspire/reconstruction/estimator.py @@ -56,6 +56,13 @@ def __init__( self.basis = basis self.dtype = self.src.dtype self.batch_size = batch_size + if not preconditioner or preconditioner.lower() == "none": + # Resolve None and string nones to None + preconditioner = None + elif preconditioner not in ["circulant"]: + raise ValueError( + f"Supplied preconditioner {preconditioner} is not supported." + ) self.preconditioner = preconditioner self.boost = boost @@ -128,12 +135,12 @@ def __getattr__(self, name): def compute_kernel(self): raise NotImplementedError("Subclasses must implement the compute_kernel method") - def estimate(self, b_coef=None, tol=1e-5, regularizer=0): + def estimate(self, b_coef=None, x0=None, tol=1e-5, regularizer=0): """Return an estimate as a Volume instance.""" if b_coef is None: b_coef = self.src_backward() - est_coef = self.conj_grad(b_coef, tol=tol, regularizer=regularizer) - est = Coef(self.basis, est_coef).evaluate().T + est_coef = self.conj_grad(b_coef, x0=x0, tol=tol, regularizer=regularizer) + est = Coef(self.basis, est_coef).evaluate() return est diff --git a/src/aspire/reconstruction/mean.py b/src/aspire/reconstruction/mean.py index f7f5bf1a2c..d298ae3272 100644 --- a/src/aspire/reconstruction/mean.py +++ b/src/aspire/reconstruction/mean.py @@ -64,7 +64,9 @@ def __getattr__(self, name): 1.0 / self.kernel.circularize() ) else: - if self.preconditioner.lower() not in (None, "none"): + if self.preconditioner and ( + self.preconditioner.lower() not in ("none") + ): logger.warning( f"Preconditioner {self.preconditioner} is not implemented, resetting to default of None." ) @@ -126,7 +128,7 @@ def _compute_kernel(self): batch_kernel += ( 1 / (self.r * self.src.L**4) - * anufft(weights, pts_rot[::-1], (_2L, _2L, _2L), real=True) + * anufft(weights, pts_rot, (_2L, _2L, _2L), real=True) ) kernel[k, j] += batch_kernel @@ -189,7 +191,7 @@ def src_backward(self): return res - def conj_grad(self, b_coef, tol=1e-5, regularizer=0): + def conj_grad(self, b_coef, x0=None, tol=1e-5, regularizer=0): count = b_coef.shape[-1] # b_coef should be (r, basis.count) kernel = self.kernel @@ -240,12 +242,13 @@ def cb(xk): # Construct checkpoint path path = f"{self.checkpoint_prefix}_iter{self.i:04d}.npy" # Write out the current solution - np.save(path, xk.reshape(self.r, self.basis.count)) + np.save(path, xk) logger.info(f"Checkpoint saved to `{path}`") x, info = cg( operator, b_coef.flatten(), + x0=x0, M=M, callback=cb, tol=tol, diff --git a/tests/test_mean_estimator.py b/tests/test_mean_estimator.py index 4bbd64e6f3..2650839272 100644 --- a/tests/test_mean_estimator.py +++ b/tests/test_mean_estimator.py @@ -1,733 +1,198 @@ import os.path import tempfile -from unittest import TestCase import numpy as np +import pytest from pytest import raises -from aspire.basis import FBBasis3D +from aspire.basis import Coef, FBBasis3D +from aspire.image import Image from aspire.operators import RadialCTFFilter from aspire.reconstruction import MeanEstimator -from aspire.source.simulation import _LegacySimulation -from aspire.volume import LegacyVolume +from aspire.source.simulation import Simulation +from aspire.utils import grid_3d +from aspire.volume import Volume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") +# Params -class MeanEstimatorTestCase(TestCase): - def setUp(self): - self.dtype = np.float32 - self.resolution = 8 - self.vols = LegacyVolume(L=self.resolution, dtype=self.dtype).generate() - self.sim = _LegacySimulation( - n=1024, - vols=self.vols, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], - dtype=self.dtype, - ) - self.basis = FBBasis3D((self.resolution,) * 3, dtype=self.dtype) +SEED = 1616 - self.estimator = MeanEstimator( - self.sim, basis=self.basis, preconditioner="none" - ) +DTYPE = [np.float32, np.float64] +L = [ + 8, + 9, +] - self.estimator_with_preconditioner = MeanEstimator( - self.sim, basis=self.basis, preconditioner="circulant" - ) +PRECONDITIONERS = [ + None, + "circulant", + pytest.param("none", marks=pytest.mark.expensive), + pytest.param("", marks=pytest.mark.expensive), +] - def tearDown(self): - pass - - def testEstimateResolutionError(self): - """ - Test mismatched resolutions yields a relevant error message. - """ - - with raises(ValueError, match=r".*resolution.*"): - # This basis is intentionally the wrong resolution. - incorrect_basis = FBBasis3D((2 * self.resolution,) * 3, dtype=self.dtype) - - _ = MeanEstimator(self.sim, basis=incorrect_basis, preconditioner="none") - - def testEstimate(self): - estimate = self.estimator.estimate() - self.assertTrue( - np.allclose( - estimate.asnumpy()[0][:, :, 4], - [ - [ - +0.00000000, - +0.00000000, - +0.00000000, - +0.00000000, - -0.00000000, - +0.00000000, - +0.00000000, - +0.00000000, - ], - [ - +0.00000000, - +0.00000000, - +0.02446793, - +0.05363505, - +0.21988572, - +0.19513786, - +0.01174418, - +0.00000000, - ], - [ - +0.00000000, - -0.06168774, - +0.13178457, - +0.36011154, - +0.88632372, - +0.92307694, - +0.45524491, - +0.15142541, - ], - [ - +0.00000000, - -0.09108749, - +0.19564009, - +0.78325885, - +2.34527692, - +2.44817345, - +1.41268619, - +0.53634876, - ], - [ - +0.00000000, - +0.07150180, - +0.38347393, - +1.70868980, - +3.78134981, - +3.03582139, - +1.49942724, - +0.52104809, - ], - [ - +0.00000000, - +0.00736866, - +0.19239950, - +1.71596036, - +3.59823119, - +2.64081679, - +1.08514933, - +0.24995637, - ], - [ - +0.00000000, - +0.11075829, - +0.43197553, - +0.82667320, - +1.51163241, - +1.25342639, - +0.36478594, - -0.00464912, - ], - [ - +0.00000000, - +0.00000000, - +0.43422818, - +0.64440739, - +0.44137408, - +0.25311494, - +0.00011242, - +0.00000000, - ], - ], - atol=1e-5, - ) - ) +# Fixtures. - def testAdjoint(self): - mean_b_coef = self.estimator.src_backward().squeeze() - self.assertTrue( - np.allclose( - mean_b_coef, - [ - 1.07338590e-01, - 1.23690941e-01, - 6.44482039e-03, - -5.40484306e-02, - -4.85304586e-02, - 1.09852144e-02, - 3.87838396e-02, - 3.43796455e-02, - -6.43284705e-03, - -2.86677145e-02, - -1.42313328e-02, - -2.25684091e-03, - -3.31840727e-02, - -2.59706174e-03, - -5.91919887e-04, - -9.97433028e-03, - 9.19123928e-04, - 1.19891589e-03, - 7.49154982e-03, - 6.18865229e-03, - -8.13265715e-04, - -1.30715655e-02, - -1.44160603e-02, - 2.90379956e-03, - 2.37066082e-02, - 4.88805735e-03, - 1.47870707e-03, - 7.63376018e-03, - -5.60619559e-03, - 1.05165081e-02, - 3.30510143e-03, - -3.48652120e-03, - -4.23228797e-04, - 1.40484061e-02, - 1.42914291e-03, - -1.28129504e-02, - 2.19868825e-03, - -6.30835037e-03, - 1.18524223e-03, - -2.97855052e-02, - 1.15491057e-03, - -8.27947006e-03, - 3.45442781e-03, - -4.72868856e-03, - 2.66615329e-03, - -7.87929790e-03, - 8.84126590e-04, - 1.59402808e-03, - -9.06854048e-05, - -8.79119004e-03, - 1.76449039e-03, - -1.36414673e-02, - 1.56793855e-03, - 1.44708445e-02, - -2.55974802e-03, - 5.38506357e-03, - -3.24188673e-03, - 4.81582945e-04, - 7.74260101e-05, - 5.48772082e-03, - 1.92058500e-03, - -4.63538896e-03, - -2.02735133e-03, - 3.67592386e-03, - 7.23486969e-04, - 1.81838422e-03, - 1.78793284e-03, - -8.01474060e-03, - -8.54007528e-03, - 1.96353845e-03, - -2.16254252e-03, - -3.64243996e-05, - -2.27329863e-03, - 1.11424393e-03, - -1.39389189e-03, - 2.57787159e-04, - 3.66918811e-03, - 1.31477774e-03, - 6.82220128e-04, - 1.41822851e-03, - -1.89476924e-03, - -6.43966255e-05, - -7.87888465e-04, - -6.99459279e-04, - 1.08918981e-03, - 2.25264584e-03, - -1.43651015e-04, - 7.68377620e-04, - 5.05955256e-04, - 2.66936132e-06, - 2.24934884e-03, - 6.70529439e-04, - 4.81121742e-04, - -6.40789745e-05, - -3.35915672e-04, - -7.98651783e-04, - -9.82705453e-04, - 6.46337066e-05, - ], - atol=1e-6, - ) - ) - def testOptimize1(self): - mean_b_coef = np.array( - [ - [ - 1.07338590e-01, - 1.23690941e-01, - 6.44482039e-03, - -5.40484306e-02, - -4.85304586e-02, - 1.09852144e-02, - 3.87838396e-02, - 3.43796455e-02, - -6.43284705e-03, - -2.86677145e-02, - -1.42313328e-02, - -2.25684091e-03, - -3.31840727e-02, - -2.59706174e-03, - -5.91919887e-04, - -9.97433028e-03, - 9.19123928e-04, - 1.19891589e-03, - 7.49154982e-03, - 6.18865229e-03, - -8.13265715e-04, - -1.30715655e-02, - -1.44160603e-02, - 2.90379956e-03, - 2.37066082e-02, - 4.88805735e-03, - 1.47870707e-03, - 7.63376018e-03, - -5.60619559e-03, - 1.05165081e-02, - 3.30510143e-03, - -3.48652120e-03, - -4.23228797e-04, - 1.40484061e-02, - 1.42914291e-03, - -1.28129504e-02, - 2.19868825e-03, - -6.30835037e-03, - 1.18524223e-03, - -2.97855052e-02, - 1.15491057e-03, - -8.27947006e-03, - 3.45442781e-03, - -4.72868856e-03, - 2.66615329e-03, - -7.87929790e-03, - 8.84126590e-04, - 1.59402808e-03, - -9.06854048e-05, - -8.79119004e-03, - 1.76449039e-03, - -1.36414673e-02, - 1.56793855e-03, - 1.44708445e-02, - -2.55974802e-03, - 5.38506357e-03, - -3.24188673e-03, - 4.81582945e-04, - 7.74260101e-05, - 5.48772082e-03, - 1.92058500e-03, - -4.63538896e-03, - -2.02735133e-03, - 3.67592386e-03, - 7.23486969e-04, - 1.81838422e-03, - 1.78793284e-03, - -8.01474060e-03, - -8.54007528e-03, - 1.96353845e-03, - -2.16254252e-03, - -3.64243996e-05, - -2.27329863e-03, - 1.11424393e-03, - -1.39389189e-03, - 2.57787159e-04, - 3.66918811e-03, - 1.31477774e-03, - 6.82220128e-04, - 1.41822851e-03, - -1.89476924e-03, - -6.43966255e-05, - -7.87888465e-04, - -6.99459279e-04, - 1.08918981e-03, - 2.25264584e-03, - -1.43651015e-04, - 7.68377620e-04, - 5.05955256e-04, - 2.66936132e-06, - 2.24934884e-03, - 6.70529439e-04, - 4.81121742e-04, - -6.40789745e-05, - -3.35915672e-04, - -7.98651783e-04, - -9.82705453e-04, - 6.46337066e-05, - ] - ], - dtype=self.dtype, - ) +@pytest.fixture(params=L, ids=lambda x: f"L={x}", scope="module") +def L(request): + return request.param - x = self.estimator.conj_grad(mean_b_coef) - self.assertTrue( - np.allclose( - x, - [ - 1.24325149e01, - 4.06481998e00, - 1.19149607e00, - -3.31414200e00, - -1.23897783e00, - 1.53987018e-01, - 2.50221093e00, - 9.18131863e-01, - 4.09624945e-02, - -1.81129255e00, - -2.58832135e-01, - -7.21149988e-01, - -1.00909836e00, - 5.72232366e-02, - -3.90701966e-01, - -3.65655187e-01, - 2.33601017e-01, - 1.75039197e-01, - 2.52945224e-01, - 3.29783105e-01, - 7.85601834e-02, - -3.96439884e-01, - -8.56255814e-01, - 7.35131473e-03, - 1.10704423e00, - 7.35615877e-02, - 5.61772211e-01, - 2.60428522e-01, - -5.41932165e-01, - 4.29851425e-01, - 3.86300956e-01, - -8.90168838e-02, - -1.02959264e-01, - 6.03104058e-01, - 1.85286462e-01, - -4.16434930e-01, - 2.11092135e-01, - -1.85514653e-01, - 9.80712710e-02, - -8.98429489e-01, - -9.54759574e-02, - -1.17952459e-01, - 1.41721715e-01, - -1.36184702e-01, - 3.23733962e-01, - -2.68721792e-01, - -1.42064052e-01, - 1.41909797e-01, - -2.24251300e-03, - -4.27538724e-01, - 1.28441757e-01, - -5.57623000e-01, - -1.54801935e-01, - 6.51729903e-01, - -2.15567768e-01, - 1.95041528e-01, - -4.18334680e-01, - 3.26735913e-02, - 6.35474331e-02, - 3.06828631e-01, - 1.43149180e-01, - -2.34377520e-01, - -1.54299735e-01, - 2.82627560e-01, - 9.60630473e-02, - 1.47687304e-01, - 1.38157247e-01, - -4.25581692e-01, - -5.62236939e-01, - 2.09287213e-01, - -1.14280315e-01, - 2.70617650e-02, - -1.19705716e-01, - 1.68350236e-02, - -1.20459065e-01, - 6.03971532e-02, - 3.21465643e-01, - 1.82032297e-01, - -2.95991444e-02, - 1.53711400e-01, - -1.30594319e-01, - -4.71412485e-02, - -1.35301477e-01, - -2.36292616e-01, - 1.95728111e-01, - 2.54618329e-01, - -1.81663289e-03, - 2.77960420e-02, - 3.58816749e-02, - -2.50138365e-02, - 2.54103161e-01, - 9.82534014e-02, - 9.00807559e-02, - 3.71458771e-02, - -7.86838200e-02, - -1.03837231e-01, - -1.26116949e-01, - 9.82006976e-02, - ], - atol=1e-4, - ) - ) - def testOptimize2(self): - mean_b_coef = np.array( - [ - [ - 1.07338590e-01, - 1.23690941e-01, - 6.44482039e-03, - -5.40484306e-02, - -4.85304586e-02, - 1.09852144e-02, - 3.87838396e-02, - 3.43796455e-02, - -6.43284705e-03, - -2.86677145e-02, - -1.42313328e-02, - -2.25684091e-03, - -3.31840727e-02, - -2.59706174e-03, - -5.91919887e-04, - -9.97433028e-03, - 9.19123928e-04, - 1.19891589e-03, - 7.49154982e-03, - 6.18865229e-03, - -8.13265715e-04, - -1.30715655e-02, - -1.44160603e-02, - 2.90379956e-03, - 2.37066082e-02, - 4.88805735e-03, - 1.47870707e-03, - 7.63376018e-03, - -5.60619559e-03, - 1.05165081e-02, - 3.30510143e-03, - -3.48652120e-03, - -4.23228797e-04, - 1.40484061e-02, - 1.42914291e-03, - -1.28129504e-02, - 2.19868825e-03, - -6.30835037e-03, - 1.18524223e-03, - -2.97855052e-02, - 1.15491057e-03, - -8.27947006e-03, - 3.45442781e-03, - -4.72868856e-03, - 2.66615329e-03, - -7.87929790e-03, - 8.84126590e-04, - 1.59402808e-03, - -9.06854048e-05, - -8.79119004e-03, - 1.76449039e-03, - -1.36414673e-02, - 1.56793855e-03, - 1.44708445e-02, - -2.55974802e-03, - 5.38506357e-03, - -3.24188673e-03, - 4.81582945e-04, - 7.74260101e-05, - 5.48772082e-03, - 1.92058500e-03, - -4.63538896e-03, - -2.02735133e-03, - 3.67592386e-03, - 7.23486969e-04, - 1.81838422e-03, - 1.78793284e-03, - -8.01474060e-03, - -8.54007528e-03, - 1.96353845e-03, - -2.16254252e-03, - -3.64243996e-05, - -2.27329863e-03, - 1.11424393e-03, - -1.39389189e-03, - 2.57787159e-04, - 3.66918811e-03, - 1.31477774e-03, - 6.82220128e-04, - 1.41822851e-03, - -1.89476924e-03, - -6.43966255e-05, - -7.87888465e-04, - -6.99459279e-04, - 1.08918981e-03, - 2.25264584e-03, - -1.43651015e-04, - 7.68377620e-04, - 5.05955256e-04, - 2.66936132e-06, - 2.24934884e-03, - 6.70529439e-04, - 4.81121742e-04, - -6.40789745e-05, - -3.35915672e-04, - -7.98651783e-04, - -9.82705453e-04, - 6.46337066e-05, - ] - ] - ) +@pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def sim(L, dtype): + sim = Simulation( + L=L, + n=256, + C=1, # single volume + unique_filters=[ + RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + ], + dtype=dtype, + seed=SEED, + ) + + sim = sim.cache() # precompute images + + return sim + + +@pytest.fixture(scope="module") +def basis(L, dtype): + return FBBasis3D(L, dtype=dtype) + + +@pytest.fixture( + params=PRECONDITIONERS, ids=lambda x: f"preconditioner={x}", scope="module" +) +def estimator(request, sim, basis): + preconditioner = request.param + return MeanEstimator(sim, basis=basis, preconditioner=preconditioner) + + +@pytest.fixture(scope="module") +def mask(L): + return grid_3d(L)["r"] < 1 + + +# Tests +def test_resolution_error(sim, basis): + """ + Test mismatched resolutions yields a relevant error message. + """ + + with raises(ValueError, match=r".*resolution.*"): + # This basis is intentionally the wrong resolution. + incorrect_basis = FBBasis3D(sim.L + 1, dtype=sim.dtype) - x = self.estimator_with_preconditioner.conj_grad(mean_b_coef) - self.assertTrue( - np.allclose( - x, - [ - 1.24325149e01, - 4.06481998e00, - 1.19149607e00, - -3.31414200e00, - -1.23897783e00, - 1.53987018e-01, - 2.50221093e00, - 9.18131863e-01, - 4.09624945e-02, - -1.81129255e00, - -2.58832135e-01, - -7.21149988e-01, - -1.00909836e00, - 5.72232366e-02, - -3.90701966e-01, - -3.65655187e-01, - 2.33601017e-01, - 1.75039197e-01, - 2.52945224e-01, - 3.29783105e-01, - 7.85601834e-02, - -3.96439884e-01, - -8.56255814e-01, - 7.35131473e-03, - 1.10704423e00, - 7.35615877e-02, - 5.61772211e-01, - 2.60428522e-01, - -5.41932165e-01, - 4.29851425e-01, - 3.86300956e-01, - -8.90168838e-02, - -1.02959264e-01, - 6.03104058e-01, - 1.85286462e-01, - -4.16434930e-01, - 2.11092135e-01, - -1.85514653e-01, - 9.80712710e-02, - -8.98429489e-01, - -9.54759574e-02, - -1.17952459e-01, - 1.41721715e-01, - -1.36184702e-01, - 3.23733962e-01, - -2.68721792e-01, - -1.42064052e-01, - 1.41909797e-01, - -2.24251300e-03, - -4.27538724e-01, - 1.28441757e-01, - -5.57623000e-01, - -1.54801935e-01, - 6.51729903e-01, - -2.15567768e-01, - 1.95041528e-01, - -4.18334680e-01, - 3.26735913e-02, - 6.35474331e-02, - 3.06828631e-01, - 1.43149180e-01, - -2.34377520e-01, - -1.54299735e-01, - 2.82627560e-01, - 9.60630473e-02, - 1.47687304e-01, - 1.38157247e-01, - -4.25581692e-01, - -5.62236939e-01, - 2.09287213e-01, - -1.14280315e-01, - 2.70617650e-02, - -1.19705716e-01, - 1.68350236e-02, - -1.20459065e-01, - 6.03971532e-02, - 3.21465643e-01, - 1.82032297e-01, - -2.95991444e-02, - 1.53711400e-01, - -1.30594319e-01, - -4.71412485e-02, - -1.35301477e-01, - -2.36292616e-01, - 1.95728111e-01, - 2.54618329e-01, - -1.81663289e-03, - 2.77960420e-02, - 3.58816749e-02, - -2.50138365e-02, - 2.54103161e-01, - 9.82534014e-02, - 9.00807559e-02, - 3.71458771e-02, - -7.86838200e-02, - -1.03837231e-01, - -1.26116949e-01, - 9.82006976e-02, - ], - atol=1e-4, - ) + _ = MeanEstimator(sim, basis=incorrect_basis, preconditioner="none") + + +def test_estimate(sim, estimator, mask): + estimate = estimator.estimate() + + est = estimate * mask + vol = sim.vols * mask + + np.testing.assert_allclose( + est / np.linalg.norm(est), vol / np.linalg.norm(vol), atol=0.1 + ) + + +def test_adjoint(sim, basis, estimator): + """ + Test = + for random volume `v` and random images `u`. + """ + rots = sim.rotations + + L = sim.L + n = sim.n + + u = np.random.rand(n, L, L).astype(sim.dtype, copy=False) + v = np.random.rand(L, L, L).astype(sim.dtype, copy=False) + + proj = Volume(v).project(rots) + backproj = Image(u).backproject(rots) + + lhs = np.dot(proj.asnumpy().flatten(), u.flatten()) + rhs = np.dot(backproj.asnumpy().flatten(), v.flatten()) + + np.testing.assert_allclose(lhs, rhs, rtol=1e-6) + + +def test_src_adjoint(sim, basis, estimator): + """ + Test the built-in source based estimator's `src_backward` has + adjoint like relationship with simulated image generation. + """ + + v = sim.vols.asnumpy()[0] # random volume + proj = sim.images[:] # projections of v + u = proj.asnumpy() # u = proj + + # `src_backward` scales by 1/n + backproj = Coef(basis, estimator.src_backward() * sim.n).evaluate() + + lhs = np.dot(proj.asnumpy().flatten(), u.flatten()) + rhs = np.dot(backproj.asnumpy().flatten(), v.flatten()) + + np.testing.assert_allclose(lhs, rhs, rtol=0.02) + + +def test_checkpoint(sim, basis, estimator): + """Exercise the checkpointing and max iterations branches.""" + test_iter = 2 + with tempfile.TemporaryDirectory() as tmp_input_dir: + prefix = os.path.join(tmp_input_dir, "new", "dirs", "chk") + _estimator = MeanEstimator( + sim, + basis=basis, + preconditioner="none", + checkpoint_iterations=test_iter, + maxiter=test_iter + 1, + checkpoint_prefix=prefix, ) - def testCheckpoint(self): - """Exercise the checkpointing and max iterations branches.""" - test_iter = 2 - with tempfile.TemporaryDirectory() as tmp_input_dir: - prefix = os.path.join(tmp_input_dir, "new", "dirs", "chk") - estimator = MeanEstimator( - self.sim, - basis=self.basis, - preconditioner="none", - checkpoint_iterations=test_iter, - maxiter=test_iter + 1, - checkpoint_prefix=prefix, - ) - - # Assert we raise when reading `maxiter`. - with raises(RuntimeError, match="Unable to converge!"): - _ = estimator.estimate() - - # Load the checkpoint coefficients while tmp_input_dir exists. - b_chk = np.load(f"{prefix}_iter{test_iter:04d}.npy") + # Assert we raise when reading `maxiter`. + with raises(RuntimeError, match="Unable to converge!"): + _ = _estimator.estimate() + + # Load the checkpoint coefficients while tmp_input_dir exists. + x_chk = np.load(f"{prefix}_iter{test_iter:04d}.npy") # Restart estimate from checkpoint - _ = self.estimator.estimate(b_coef=b_chk) - - def testCheckpointArgs(self): - with tempfile.TemporaryDirectory() as tmp_input_dir: - prefix = os.path.join(tmp_input_dir, "chk") - - for junk in [-1, 0, "abc"]: - # Junk `checkpoint_iterations` values - with raises( - ValueError, match=r".*iterations.*should be a positive integer.*" - ): - _ = MeanEstimator( - self.sim, - basis=self.basis, - preconditioner="none", - checkpoint_iterations=junk, - checkpoint_prefix=prefix, - ) - # Junk `maxiter` values - with raises( - ValueError, match=r".*maxiter.*should be a positive integer.*" - ): - _ = MeanEstimator( - self.sim, - basis=self.basis, - preconditioner="none", - maxiter=junk, - checkpoint_prefix=prefix, - ) + _ = estimator.estimate(x0=x_chk) + + +def test_checkpoint_args(sim, basis): + with tempfile.TemporaryDirectory() as tmp_input_dir: + prefix = os.path.join(tmp_input_dir, "chk") + + for junk in [-1, 0, "abc"]: + # Junk `checkpoint_iterations` values + with raises( + ValueError, match=r".*iterations.*should be a positive integer.*" + ): + _ = MeanEstimator( + sim, + basis=basis, + preconditioner="none", + checkpoint_iterations=junk, + checkpoint_prefix=prefix, + ) + # Junk `maxiter` values + with raises(ValueError, match=r".*maxiter.*should be a positive integer.*"): + _ = MeanEstimator( + sim, + basis=basis, + preconditioner="none", + maxiter=junk, + checkpoint_prefix=prefix, + ) diff --git a/tests/test_weighted_mean_estimator.py b/tests/test_weighted_mean_estimator.py index 23496428c1..5d622c3865 100644 --- a/tests/test_weighted_mean_estimator.py +++ b/tests/test_weighted_mean_estimator.py @@ -1,787 +1,195 @@ -import logging import os.path -from unittest import TestCase +import tempfile import numpy as np +import pytest +from pytest import raises -from aspire.basis import FBBasis3D +from aspire.basis import Coef, FBBasis3D from aspire.operators import RadialCTFFilter from aspire.reconstruction import WeightedVolumesEstimator -from aspire.source import _LegacySimulation -from aspire.volume import LegacyVolume - -logger = logging.getLogger(__name__) +from aspire.source.simulation import Simulation +from aspire.utils import grid_3d DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") +# Params -class WeightedVolumesEstimatorTestCase(TestCase): - def setUp(self): - self.dtype = np.float32 - self.n = 1024 - self.r = 2 - self.L = L = 8 - self.sim = _LegacySimulation( - vols=LegacyVolume(L, dtype=self.dtype).generate(), - n=self.n, - unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) - ], - dtype=self.dtype, - ) - self.basis = FBBasis3D((L, L, L), dtype=self.dtype) - self.weights = np.ones((self.n, self.r)) / np.sqrt(self.n) - self.estimator = WeightedVolumesEstimator( - self.weights, self.sim, basis=self.basis, preconditioner="none" - ) - self.estimator_with_preconditioner = WeightedVolumesEstimator( - self.weights, self.sim, basis=self.basis, preconditioner="circulant" - ) +SEED = 1617 - def tearDown(self): - pass - - def testPositiveWeightedEstimates(self): - estimate = self.estimator.estimate() - - a = (estimate.asnumpy()[0][:, :, 4],) - - b = np.array( - [ - [ - +0.00000000, - +0.00000000, - +0.00000000, - +0.00000000, - -0.00000000, - +0.00000000, - +0.00000000, - +0.00000000, - ], - [ - +0.00000000, - +0.00000000, - +0.02446793, - +0.05363505, - +0.21988572, - +0.19513786, - +0.01174418, - +0.00000000, - ], - [ - +0.00000000, - -0.06168774, - +0.13178457, - +0.36011154, - +0.88632372, - +0.92307694, - +0.45524491, - +0.15142541, - ], - [ - +0.00000000, - -0.09108749, - +0.19564009, - +0.78325885, - +2.34527692, - +2.44817345, - +1.41268619, - +0.53634876, - ], - [ - +0.00000000, - +0.07150180, - +0.38347393, - +1.70868980, - +3.78134981, - +3.03582139, - +1.49942724, - +0.52104809, - ], - [ - +0.00000000, - +0.00736866, - +0.19239950, - +1.71596036, - +3.59823119, - +2.64081679, - +1.08514933, - +0.24995637, - ], - [ - +0.00000000, - +0.11075829, - +0.43197553, - +0.82667320, - +1.51163241, - +1.25342639, - +0.36478594, - -0.00464912, - ], - [ - +0.00000000, - +0.00000000, - +0.43422818, - +0.64440739, - +0.44137408, - +0.25311494, - +0.00011242, - +0.00000000, - ], - ] - ) +DTYPE = [np.float32, np.float64] +L = [ + 8, + 9, +] - logger.info(f"max abs diff: {np.max(np.abs(a - b))}") - self.assertTrue(np.allclose(a, b, atol=1e-5)) - - def testAdjoint(self): - mean_b_coef = self.estimator.src_backward().squeeze() - self.assertTrue( - np.allclose( - mean_b_coef, - [ - 1.07338590e-01, - 1.23690941e-01, - 6.44482039e-03, - -5.40484306e-02, - -4.85304586e-02, - 1.09852144e-02, - 3.87838396e-02, - 3.43796455e-02, - -6.43284705e-03, - -2.86677145e-02, - -1.42313328e-02, - -2.25684091e-03, - -3.31840727e-02, - -2.59706174e-03, - -5.91919887e-04, - -9.97433028e-03, - 9.19123928e-04, - 1.19891589e-03, - 7.49154982e-03, - 6.18865229e-03, - -8.13265715e-04, - -1.30715655e-02, - -1.44160603e-02, - 2.90379956e-03, - 2.37066082e-02, - 4.88805735e-03, - 1.47870707e-03, - 7.63376018e-03, - -5.60619559e-03, - 1.05165081e-02, - 3.30510143e-03, - -3.48652120e-03, - -4.23228797e-04, - 1.40484061e-02, - 1.42914291e-03, - -1.28129504e-02, - 2.19868825e-03, - -6.30835037e-03, - 1.18524223e-03, - -2.97855052e-02, - 1.15491057e-03, - -8.27947006e-03, - 3.45442781e-03, - -4.72868856e-03, - 2.66615329e-03, - -7.87929790e-03, - 8.84126590e-04, - 1.59402808e-03, - -9.06854048e-05, - -8.79119004e-03, - 1.76449039e-03, - -1.36414673e-02, - 1.56793855e-03, - 1.44708445e-02, - -2.55974802e-03, - 5.38506357e-03, - -3.24188673e-03, - 4.81582945e-04, - 7.74260101e-05, - 5.48772082e-03, - 1.92058500e-03, - -4.63538896e-03, - -2.02735133e-03, - 3.67592386e-03, - 7.23486969e-04, - 1.81838422e-03, - 1.78793284e-03, - -8.01474060e-03, - -8.54007528e-03, - 1.96353845e-03, - -2.16254252e-03, - -3.64243996e-05, - -2.27329863e-03, - 1.11424393e-03, - -1.39389189e-03, - 2.57787159e-04, - 3.66918811e-03, - 1.31477774e-03, - 6.82220128e-04, - 1.41822851e-03, - -1.89476924e-03, - -6.43966255e-05, - -7.87888465e-04, - -6.99459279e-04, - 1.08918981e-03, - 2.25264584e-03, - -1.43651015e-04, - 7.68377620e-04, - 5.05955256e-04, - 2.66936132e-06, - 2.24934884e-03, - 6.70529439e-04, - 4.81121742e-04, - -6.40789745e-05, - -3.35915672e-04, - -7.98651783e-04, - -9.82705453e-04, - 6.46337066e-05, - ], - atol=1e-6, - ) - ) +PRECONDITIONERS = [ + None, + "circulant", + pytest.param("none", marks=pytest.mark.expensive), + pytest.param("", marks=pytest.mark.expensive), +] - def testOptimize1(self): - mean_b_coef = np.array( - [ - [ - 1.07338590e-01, - 1.23690941e-01, - 6.44482039e-03, - -5.40484306e-02, - -4.85304586e-02, - 1.09852144e-02, - 3.87838396e-02, - 3.43796455e-02, - -6.43284705e-03, - -2.86677145e-02, - -1.42313328e-02, - -2.25684091e-03, - -3.31840727e-02, - -2.59706174e-03, - -5.91919887e-04, - -9.97433028e-03, - 9.19123928e-04, - 1.19891589e-03, - 7.49154982e-03, - 6.18865229e-03, - -8.13265715e-04, - -1.30715655e-02, - -1.44160603e-02, - 2.90379956e-03, - 2.37066082e-02, - 4.88805735e-03, - 1.47870707e-03, - 7.63376018e-03, - -5.60619559e-03, - 1.05165081e-02, - 3.30510143e-03, - -3.48652120e-03, - -4.23228797e-04, - 1.40484061e-02, - 1.42914291e-03, - -1.28129504e-02, - 2.19868825e-03, - -6.30835037e-03, - 1.18524223e-03, - -2.97855052e-02, - 1.15491057e-03, - -8.27947006e-03, - 3.45442781e-03, - -4.72868856e-03, - 2.66615329e-03, - -7.87929790e-03, - 8.84126590e-04, - 1.59402808e-03, - -9.06854048e-05, - -8.79119004e-03, - 1.76449039e-03, - -1.36414673e-02, - 1.56793855e-03, - 1.44708445e-02, - -2.55974802e-03, - 5.38506357e-03, - -3.24188673e-03, - 4.81582945e-04, - 7.74260101e-05, - 5.48772082e-03, - 1.92058500e-03, - -4.63538896e-03, - -2.02735133e-03, - 3.67592386e-03, - 7.23486969e-04, - 1.81838422e-03, - 1.78793284e-03, - -8.01474060e-03, - -8.54007528e-03, - 1.96353845e-03, - -2.16254252e-03, - -3.64243996e-05, - -2.27329863e-03, - 1.11424393e-03, - -1.39389189e-03, - 2.57787159e-04, - 3.66918811e-03, - 1.31477774e-03, - 6.82220128e-04, - 1.41822851e-03, - -1.89476924e-03, - -6.43966255e-05, - -7.87888465e-04, - -6.99459279e-04, - 1.08918981e-03, - 2.25264584e-03, - -1.43651015e-04, - 7.68377620e-04, - 5.05955256e-04, - 2.66936132e-06, - 2.24934884e-03, - 6.70529439e-04, - 4.81121742e-04, - -6.40789745e-05, - -3.35915672e-04, - -7.98651783e-04, - -9.82705453e-04, - 6.46337066e-05, - ] - ] - * self.r, - dtype=self.dtype, - ) +# Fixtures. - # Given equal weighting we should get the same result for all self.r volumes. - x = self.estimator.conj_grad(mean_b_coef) - - ref = np.array( - [ - 1.24325149e01, - 4.06481998e00, - 1.19149607e00, - -3.31414200e00, - -1.23897783e00, - 1.53987018e-01, - 2.50221093e00, - 9.18131863e-01, - 4.09624945e-02, - -1.81129255e00, - -2.58832135e-01, - -7.21149988e-01, - -1.00909836e00, - 5.72232366e-02, - -3.90701966e-01, - -3.65655187e-01, - 2.33601017e-01, - 1.75039197e-01, - 2.52945224e-01, - 3.29783105e-01, - 7.85601834e-02, - -3.96439884e-01, - -8.56255814e-01, - 7.35131473e-03, - 1.10704423e00, - 7.35615877e-02, - 5.61772211e-01, - 2.60428522e-01, - -5.41932165e-01, - 4.29851425e-01, - 3.86300956e-01, - -8.90168838e-02, - -1.02959264e-01, - 6.03104058e-01, - 1.85286462e-01, - -4.16434930e-01, - 2.11092135e-01, - -1.85514653e-01, - 9.80712710e-02, - -8.98429489e-01, - -9.54759574e-02, - -1.17952459e-01, - 1.41721715e-01, - -1.36184702e-01, - 3.23733962e-01, - -2.68721792e-01, - -1.42064052e-01, - 1.41909797e-01, - -2.24251300e-03, - -4.27538724e-01, - 1.28441757e-01, - -5.57623000e-01, - -1.54801935e-01, - 6.51729903e-01, - -2.15567768e-01, - 1.95041528e-01, - -4.18334680e-01, - 3.26735913e-02, - 6.35474331e-02, - 3.06828631e-01, - 1.43149180e-01, - -2.34377520e-01, - -1.54299735e-01, - 2.82627560e-01, - 9.60630473e-02, - 1.47687304e-01, - 1.38157247e-01, - -4.25581692e-01, - -5.62236939e-01, - 2.09287213e-01, - -1.14280315e-01, - 2.70617650e-02, - -1.19705716e-01, - 1.68350236e-02, - -1.20459065e-01, - 6.03971532e-02, - 3.21465643e-01, - 1.82032297e-01, - -2.95991444e-02, - 1.53711400e-01, - -1.30594319e-01, - -4.71412485e-02, - -1.35301477e-01, - -2.36292616e-01, - 1.95728111e-01, - 2.54618329e-01, - -1.81663289e-03, - 2.77960420e-02, - 3.58816749e-02, - -2.50138365e-02, - 2.54103161e-01, - 9.82534014e-02, - 9.00807559e-02, - 3.71458771e-02, - -7.86838200e-02, - -1.03837231e-01, - -1.26116949e-01, - 9.82006976e-02, - ] - * self.r - ) - logger.info(f"max abs diff: {np.max(np.abs(x.flatten() - ref))}") - self.assertTrue(np.allclose(x.flatten(), ref, atol=1e-4)) - - def testOptimize2(self): - mean_b_coef = np.array( - [ - [ - 1.07338590e-01, - 1.23690941e-01, - 6.44482039e-03, - -5.40484306e-02, - -4.85304586e-02, - 1.09852144e-02, - 3.87838396e-02, - 3.43796455e-02, - -6.43284705e-03, - -2.86677145e-02, - -1.42313328e-02, - -2.25684091e-03, - -3.31840727e-02, - -2.59706174e-03, - -5.91919887e-04, - -9.97433028e-03, - 9.19123928e-04, - 1.19891589e-03, - 7.49154982e-03, - 6.18865229e-03, - -8.13265715e-04, - -1.30715655e-02, - -1.44160603e-02, - 2.90379956e-03, - 2.37066082e-02, - 4.88805735e-03, - 1.47870707e-03, - 7.63376018e-03, - -5.60619559e-03, - 1.05165081e-02, - 3.30510143e-03, - -3.48652120e-03, - -4.23228797e-04, - 1.40484061e-02, - 1.42914291e-03, - -1.28129504e-02, - 2.19868825e-03, - -6.30835037e-03, - 1.18524223e-03, - -2.97855052e-02, - 1.15491057e-03, - -8.27947006e-03, - 3.45442781e-03, - -4.72868856e-03, - 2.66615329e-03, - -7.87929790e-03, - 8.84126590e-04, - 1.59402808e-03, - -9.06854048e-05, - -8.79119004e-03, - 1.76449039e-03, - -1.36414673e-02, - 1.56793855e-03, - 1.44708445e-02, - -2.55974802e-03, - 5.38506357e-03, - -3.24188673e-03, - 4.81582945e-04, - 7.74260101e-05, - 5.48772082e-03, - 1.92058500e-03, - -4.63538896e-03, - -2.02735133e-03, - 3.67592386e-03, - 7.23486969e-04, - 1.81838422e-03, - 1.78793284e-03, - -8.01474060e-03, - -8.54007528e-03, - 1.96353845e-03, - -2.16254252e-03, - -3.64243996e-05, - -2.27329863e-03, - 1.11424393e-03, - -1.39389189e-03, - 2.57787159e-04, - 3.66918811e-03, - 1.31477774e-03, - 6.82220128e-04, - 1.41822851e-03, - -1.89476924e-03, - -6.43966255e-05, - -7.87888465e-04, - -6.99459279e-04, - 1.08918981e-03, - 2.25264584e-03, - -1.43651015e-04, - 7.68377620e-04, - 5.05955256e-04, - 2.66936132e-06, - 2.24934884e-03, - 6.70529439e-04, - 4.81121742e-04, - -6.40789745e-05, - -3.35915672e-04, - -7.98651783e-04, - -9.82705453e-04, - 6.46337066e-05, - ] - ] - * self.r - ) +@pytest.fixture(params=L, ids=lambda x: f"L={x}", scope="module") +def L(request): + return request.param + + +@pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def sim(L, dtype): + sim = Simulation( + L=L, + n=256, + C=1, # single volume + unique_filters=[ + RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + ], + dtype=dtype, + seed=SEED, + ) + + sim = sim.cache() # precompute images + + return sim + + +@pytest.fixture(scope="module") +def basis(L, dtype): + return FBBasis3D(L, dtype=dtype) + + +@pytest.fixture(scope="module") +def weights(sim): + # Construct simple test weights; + # one uniform positive and negative weighted volume respectively. + r = 2 # Number of weighted volumes + weights = np.ones((sim.n, r)) / np.sqrt(sim.n) + weights[:, 1] *= -1 # negate second weight vector - x = self.estimator_with_preconditioner.conj_grad(mean_b_coef) - self.assertTrue( - np.allclose( - x, - [ - [ - 1.24325149e01, - 4.06481998e00, - 1.19149607e00, - -3.31414200e00, - -1.23897783e00, - 1.53987018e-01, - 2.50221093e00, - 9.18131863e-01, - 4.09624945e-02, - -1.81129255e00, - -2.58832135e-01, - -7.21149988e-01, - -1.00909836e00, - 5.72232366e-02, - -3.90701966e-01, - -3.65655187e-01, - 2.33601017e-01, - 1.75039197e-01, - 2.52945224e-01, - 3.29783105e-01, - 7.85601834e-02, - -3.96439884e-01, - -8.56255814e-01, - 7.35131473e-03, - 1.10704423e00, - 7.35615877e-02, - 5.61772211e-01, - 2.60428522e-01, - -5.41932165e-01, - 4.29851425e-01, - 3.86300956e-01, - -8.90168838e-02, - -1.02959264e-01, - 6.03104058e-01, - 1.85286462e-01, - -4.16434930e-01, - 2.11092135e-01, - -1.85514653e-01, - 9.80712710e-02, - -8.98429489e-01, - -9.54759574e-02, - -1.17952459e-01, - 1.41721715e-01, - -1.36184702e-01, - 3.23733962e-01, - -2.68721792e-01, - -1.42064052e-01, - 1.41909797e-01, - -2.24251300e-03, - -4.27538724e-01, - 1.28441757e-01, - -5.57623000e-01, - -1.54801935e-01, - 6.51729903e-01, - -2.15567768e-01, - 1.95041528e-01, - -4.18334680e-01, - 3.26735913e-02, - 6.35474331e-02, - 3.06828631e-01, - 1.43149180e-01, - -2.34377520e-01, - -1.54299735e-01, - 2.82627560e-01, - 9.60630473e-02, - 1.47687304e-01, - 1.38157247e-01, - -4.25581692e-01, - -5.62236939e-01, - 2.09287213e-01, - -1.14280315e-01, - 2.70617650e-02, - -1.19705716e-01, - 1.68350236e-02, - -1.20459065e-01, - 6.03971532e-02, - 3.21465643e-01, - 1.82032297e-01, - -2.95991444e-02, - 1.53711400e-01, - -1.30594319e-01, - -4.71412485e-02, - -1.35301477e-01, - -2.36292616e-01, - 1.95728111e-01, - 2.54618329e-01, - -1.81663289e-03, - 2.77960420e-02, - 3.58816749e-02, - -2.50138365e-02, - 2.54103161e-01, - 9.82534014e-02, - 9.00807559e-02, - 3.71458771e-02, - -7.86838200e-02, - -1.03837231e-01, - -1.26116949e-01, - 9.82006976e-02, - ] - ] - * self.r, - atol=1e-4, - ) + return weights + + +@pytest.fixture( + params=PRECONDITIONERS, ids=lambda x: f"preconditioner={x}", scope="module" +) +def estimator(request, sim, basis, weights): + preconditioner = request.param + + return WeightedVolumesEstimator( + weights, sim, basis=basis, preconditioner=preconditioner + ) + + +@pytest.fixture(scope="module") +def mask(L): + return grid_3d(L)["r"] < 1 + + +# Tests +def test_resolution_error(sim, basis, weights): + """ + Test mismatched resolutions yields a relevant error message. + """ + + with raises(ValueError, match=r".*resolution.*"): + # This basis is intentionally the wrong resolution. + incorrect_basis = FBBasis3D(sim.L + 1, dtype=sim.dtype) + + _ = WeightedVolumesEstimator( + weights, sim, basis=incorrect_basis, preconditioner="none" ) - def testNegativeWeightedEstimates(self): - """ - Here we'll test createing two volumes. - One with positive and another with negative weights. - """ - weights = np.ones((self.n, self.r)) / np.sqrt(self.n) - weights[:, 1] *= -1 # negate second set of weights - estimator = WeightedVolumesEstimator( - weights, self.sim, basis=self.basis, preconditioner="none" +def test_estimate(sim, estimator, mask): + estimate = estimator.estimate() + + est = estimate * mask + vol = sim.vols * mask + + for i, w in enumerate([1, -1]): + np.testing.assert_allclose( + w * est[i] / np.linalg.norm(est[i]), vol / np.linalg.norm(vol), atol=0.1 ) - estimate = estimator.estimate() - - a0 = estimate.asnumpy()[0][:, :, 4] - a1 = estimate.asnumpy()[1][:, :, 4] - - b = np.array( - [ - [ - +0.00000000, - +0.00000000, - +0.00000000, - +0.00000000, - -0.00000000, - +0.00000000, - +0.00000000, - +0.00000000, - ], - [ - +0.00000000, - +0.00000000, - +0.02446793, - +0.05363505, - +0.21988572, - +0.19513786, - +0.01174418, - +0.00000000, - ], - [ - +0.00000000, - -0.06168774, - +0.13178457, - +0.36011154, - +0.88632372, - +0.92307694, - +0.45524491, - +0.15142541, - ], - [ - +0.00000000, - -0.09108749, - +0.19564009, - +0.78325885, - +2.34527692, - +2.44817345, - +1.41268619, - +0.53634876, - ], - [ - +0.00000000, - +0.07150180, - +0.38347393, - +1.70868980, - +3.78134981, - +3.03582139, - +1.49942724, - +0.52104809, - ], - [ - +0.00000000, - +0.00736866, - +0.19239950, - +1.71596036, - +3.59823119, - +2.64081679, - +1.08514933, - +0.24995637, - ], - [ - +0.00000000, - +0.11075829, - +0.43197553, - +0.82667320, - +1.51163241, - +1.25342639, - +0.36478594, - -0.00464912, - ], - [ - +0.00000000, - +0.00000000, - +0.43422818, - +0.64440739, - +0.44137408, - +0.25311494, - +0.00011242, - +0.00000000, - ], - ] + +def test_src_adjoint(sim, basis, estimator): + """ + Test the built-in source based estimator's `src_backward` has + adjoint like relationship with simulated image generation. + """ + + v = sim.vols.asnumpy()[0] # random volume + proj = sim.images[:] # projections of v + u = proj.asnumpy() # u = proj + + # `src_backward` scales by 1/n + backproj = Coef(basis, estimator.src_backward() * sim.n).evaluate() + + lhs = np.dot(proj.asnumpy().flatten(), u.flatten()) + + for i, w in enumerate([1, -1]): + rhs = np.dot(backproj[i].asnumpy().flatten(), w * v.flatten()) + np.testing.assert_allclose(lhs, rhs, rtol=0.02) + + +def test_checkpoint(sim, basis, estimator, weights): + """Exercise the checkpointing and max iterations branches.""" + test_iter = 2 + with tempfile.TemporaryDirectory() as tmp_input_dir: + prefix = os.path.join(tmp_input_dir, "new", "dirs", "chk") + _estimator = WeightedVolumesEstimator( + weights, + sim, + basis=basis, + preconditioner="none", + checkpoint_iterations=test_iter, + maxiter=test_iter + 1, + checkpoint_prefix=prefix, ) - logger.info(f"max abs diff: {np.max(np.abs(a0 - b))}") - self.assertTrue(np.allclose(a0, b, atol=1e-5)) - logger.info(f"max abs diff: {np.max(np.abs(-a1 - b))}") - self.assertTrue(np.allclose(-a1, b, atol=1e-5)) # negative weights + # Assert we raise when reading `maxiter`. + with raises(RuntimeError, match="Unable to converge!"): + _ = _estimator.estimate() + + # Load the checkpoint coefficients while tmp_input_dir exists. + x_chk = np.load(f"{prefix}_iter{test_iter:04d}.npy") + + # Restart estimate from checkpoint + _ = estimator.estimate(x0=x_chk) + + +def test_checkpoint_args(sim, basis, weights): + with tempfile.TemporaryDirectory() as tmp_input_dir: + prefix = os.path.join(tmp_input_dir, "chk") + + for junk in [-1, 0, "abc"]: + # Junk `checkpoint_iterations` values + with raises( + ValueError, match=r".*iterations.*should be a positive integer.*" + ): + _ = WeightedVolumesEstimator( + weights, + sim, + basis=basis, + preconditioner="none", + checkpoint_iterations=junk, + checkpoint_prefix=prefix, + ) + # Junk `maxiter` values + with raises(ValueError, match=r".*maxiter.*should be a positive integer.*"): + _ = WeightedVolumesEstimator( + weights, + sim, + basis=basis, + preconditioner="none", + maxiter=junk, + checkpoint_prefix=prefix, + )