diff --git a/.bumpversion.cfg b/.bumpversion.cfg index e8869dac22..3d6c2a6025 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.12.0 +current_version = 0.12.1 commit = True tag = True diff --git a/.github/workflows/long_workflow.yml b/.github/workflows/long_workflow.yml index 350e65e133..4346b93222 100644 --- a/.github/workflows/long_workflow.yml +++ b/.github/workflows/long_workflow.yml @@ -9,6 +9,7 @@ on: jobs: expensive_tests: runs-on: self-hosted + timeout-minutes: 360 steps: - uses: actions/checkout@v3 - name: Install dependencies diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index c211b5234c..44d799e4ae 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -155,7 +155,7 @@ jobs: - name: Cache Data run: | ASPIREDIR=${{ env.WORK_DIR }} python -c \ - "import aspire; print(aspire.config['common']['cache_dir']); aspire.downloader.emdb_2660()" + "import aspire; print(aspire.config['common']['cache_dir']); import aspire.downloader; aspire.downloader.emdb_2660()" - name: Cleanup run: rm -rf ${{ env.WORK_DIR }} diff --git a/README.md b/README.md index ffdeb55387..278ad11df5 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5657281.svg)](https://doi.org/10.5281/zenodo.5657281) [![Downloads](https://static.pepy.tech/badge/aspire/month)](https://pepy.tech/project/aspire) -# ASPIRE - Algorithms for Single Particle Reconstruction - v0.12.0 +# ASPIRE - Algorithms for Single Particle Reconstruction - v0.12.1 The ASPIRE-Python project supersedes [Matlab ASPIRE](https://github.com/PrincetonUniversity/aspire). @@ -20,7 +20,7 @@ For more information about the project, algorithms, and related publications ple Please cite using the following DOI. This DOI represents all versions, and will always resolve to the latest one. ``` -ComputationalCryoEM/ASPIRE-Python: v0.12.0 https://doi.org/10.5281/zenodo.5657281 +ComputationalCryoEM/ASPIRE-Python: v0.12.1 https://doi.org/10.5281/zenodo.5657281 ``` diff --git a/docs/check_docstrings.py b/docs/check_docstrings.py new file mode 100644 index 0000000000..68d2afe956 --- /dev/null +++ b/docs/check_docstrings.py @@ -0,0 +1,62 @@ +import logging +import os +import re +import sys +from glob import glob + +logger = logging.getLogger(__name__) + + +def check_blank_line_above_param_section(file_path): + """ + Check that every docstring with both a body section and a parameter + section separates the two sections with exactly one blank line. Log + errors and return count. + + :param file_path: File path to check for error. + :return: Per file error count. + """ + error_count = 0 + with open(file_path, "r") as file: + content = file.read() + + regex = ( + r" {4,}\"\"\"\n(?:^[^:]+?[^\n])+(\n|\n\n\n+) {4,}(:p|:r)(?:.*\n)+? {4,}\"\"\"" + ) + + bad_docstrings = re.finditer(regex, content, re.MULTILINE) + for docstring in bad_docstrings: + line_number = content.count("\n", 0, docstring.start()) + 1 + + # Log error message. + msg = "Must have exactly 1 blank line between docstring body and parameter sections." + logger.error(f"{file_path}: {line_number}: {msg}") + error_count += 1 + + return error_count + + +def process_directory(directory): + """ + Recursively walk through directories and check for docstring errors. + If any errors found, log error count and exit. + + :param directory: Directory path to walk. + """ + error_count = 0 + for file in glob(os.path.join(directory, "**/*.py"), recursive=True): + error_count += check_blank_line_above_param_section(file) + if error_count > 0: + logger.error(f"Found {error_count} docstring errors.") + sys.exit(1) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + logger.warning("Usage: python check_docstrings.py ") + sys.exit(1) + + target_directory = sys.argv[1] + if not os.path.isdir(target_directory): + raise RuntimeError(f"Invalid target directory path: {target_directory}") + process_directory(target_directory) diff --git a/docs/source/conf.py b/docs/source/conf.py index cd83c6a7ff..f9c69025c4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,7 +86,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -release = version = "0.12.0" +release = version = "0.12.1" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/source/index.rst b/docs/source/index.rst index 15a7335521..3802166549 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,4 +1,4 @@ -Aspire v0.12.0 +Aspire v0.12.1 ============== Algorithms for Single Particle Reconstruction diff --git a/gallery/experiments/cov2d_experiment.py.dontrun b/gallery/experiments/cov2d_experiment.py.dontrun deleted file mode 100755 index 24d2e22799..0000000000 --- a/gallery/experiments/cov2d_experiment.py.dontrun +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python -""" -This script illustrates denoising 2D images using batched Cov2D class -from experime dataset and outputing to mrcs file. -""" - -import logging - -from aspire.basis import FFBBasis2D -from aspire.denoising.denoiser_cov2d import DenoiserCov2D -from aspire.noise import AnisotropicNoiseEstimator -from aspire.source.relion import RelionSource - -logger = logging.getLogger(__name__) - - -# Set input path and files and initialize other parameters -DATA_FOLDER = "/path/to/untarred/empiar/dataset/" -STARFILE_IN = "/path/to/untarred/empiar/dataset/input.star" -STARFILE_OUT = "/path/to/output/ouput.star" -PIXEL_SIZE = 1.34 -MAX_ROWS = 1024 -MAX_RESOLUTION = 60 - -# Create a source object for 2D images -logger.info(f"Read in images from {STARFILE_IN} and preprocess the images.") -source = RelionSource( - STARFILE_IN, DATA_FOLDER, pixel_size=PIXEL_SIZE, max_rows=MAX_ROWS -) - -# Downsample the images -logger.info(f"Set the resolution to {MAX_RESOLUTION} X {MAX_RESOLUTION}") -if MAX_RESOLUTION < source.L: - source = source.downsample(MAX_RESOLUTION) -else: - logger.warn(f"Unable to downsample to {max_resolution}, using {source.L}") - - -# Specify the fast FB basis method for expending the 2D images -basis = FFBBasis2D((source.L, source.L)) - -# Estimate the noise of images -logger.info(f"Estimate the noise of images using anisotropic method") -noise_estimator = AnisotropicNoiseEstimator(source) -var_noise = noise_estimator.estimate() -logger.info(f"var_noise before whitening {var_noise}") - -# Whiten the noise of images -logger.info(f"Whiten the noise of images from the noise estimator") -source = source.whiten(noise_estimator) -# Note this changes the noise variance, -# flattening spectrum and converging towards 1. -# Noise variance will be recomputed in DenoiserCov2D by default. - -logger.info(f"Denoise the images using batched cov2D method.") -denoiser = DenoiserCov2D(source, basis) -denoised_src = denoiser.denoise(batch_size=512) -denoised_src.save(STARFILE_OUT, batch_size=512, save_mode="single", overwrite=False) diff --git a/gallery/experiments/cov3d_experiment.dontrun b/gallery/experiments/cov3d_experiment.dontrun deleted file mode 100644 index a907649480..0000000000 --- a/gallery/experiments/cov3d_experiment.dontrun +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python -""" -This script illustrates Cov3D analysis using experimental dataset -""" -import numpy as np - -from aspire.basis import FBBasis3D -from aspire.covariance import CovarianceEstimator -from aspire.noise import AnisotropicNoiseEstimator -from aspire.reconstruction import MeanEstimator -from aspire.source.relion import RelionSource -from aspire.utils import eigs -from aspire.volume import Volume - -# Set input path and files and initialize other parameters -DATA_FOLDER = "/path/to/untarred/empiar/dataset/" -STARFILE = "/path/to/untarred/empiar/dataset/input.star" -PIXEL_SIZE = 5.0 -MAX_ROWS = 1024 -MAX_RESOLUTION = 8 -CG_TOL = 1e-5 - -# Set number of eigen-vectors to keep -NUM_EIGS = 16 - -# Create a source object for experimental 2D images with estimated rotation angles -print(f"Read in images from {STARFILE} and preprocess the images.") -source = RelionSource( - STARFILE, data_folder=DATA_FOLDER, pixel_size=PIXEL_SIZE, max_rows=MAX_ROWS -) - -# Downsample the images -print(f"Set the resolution to {MAX_RESOLUTION} X {MAX_RESOLUTION}") -if MAX_RESOLUTION < source.L: - source = source.downsample(MAX_RESOLUTION) - -# Estimate the noise of images -print("Estimate the noise of images using anisotropic method") -noise_estimator = AnisotropicNoiseEstimator(source, batchSize=512) - -# Whiten the noise of images -print("Whiten the noise of images from the noise estimator") -source = source.whiten(noise_estimator) -# Estimate the noise variance. This is needed for the covariance estimation step below. -noise_variance = noise_estimator.estimate() -print(f"Noise Variance = {noise_variance}") - -# Specify the fast FB basis method for expanding the 2D images -basis = FBBasis3D((MAX_RESOLUTION, MAX_RESOLUTION, MAX_RESOLUTION), dtype=source.dtype) - -mean_estimator = MeanEstimator(source, basis, batch_size=512) -mean_est = mean_estimator.estimate() - -# Passing in a mean_kernel argument to the following constructor speeds up some calculations -covar_estimator = CovarianceEstimator(source, basis, mean_kernel=mean_estimator.kernel) -covar_est = covar_estimator.estimate(mean_est, noise_variance, tol=CG_TOL) - -# Extract the top eigenvectors and eigenvalues of the covariance estimate. -eigs_est, lambdas_est = eigs(covar_est, NUM_EIGS) -for i in range(NUM_EIGS): - print(f"Top {i}th eigen value: {lambdas_est[i, i]}") - -# Eigs should probably return a Volume, for now hack it. -# move the last axis to the first -eigs_est_c = np.moveaxis(eigs_est, -1, 0) -eigs_est = Volume(eigs_est_c) diff --git a/gallery/experiments/experimental_abinitio_pipeline_10028.py b/gallery/experiments/experimental_abinitio_pipeline_10028.py index ccd236605a..472c21a638 100644 --- a/gallery/experiments/experimental_abinitio_pipeline_10028.py +++ b/gallery/experiments/experimental_abinitio_pipeline_10028.py @@ -30,7 +30,7 @@ from aspire.abinitio import CLSyncVoting from aspire.basis import FFBBasis3D -from aspire.denoising import DefaultClassAvgSource, DenoiserCov2D +from aspire.denoising import DefaultClassAvgSource, DenoisedSource, DenoiserCov2D from aspire.noise import AnisotropicNoiseEstimator from aspire.reconstruction import MeanEstimator from aspire.source import OrientedSource, RelionSource @@ -121,7 +121,7 @@ # Use CWF denoising cwf_denoiser = DenoiserCov2D(src) # Use denoised src for classification - classification_src = cwf_denoiser.denoise() + classification_src = DenoisedSource(src, cwf_denoiser) # Cache for speedup. Avoids recomputing. classification_src = classification_src.cache() # Peek, what do the denoised images look like... diff --git a/gallery/experiments/orient3d_experiment.dontrun b/gallery/experiments/orient3d_experiment.dontrun deleted file mode 100644 index ec7a9eb079..0000000000 --- a/gallery/experiments/orient3d_experiment.dontrun +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python -""" -This script illustrates the estimation of orientation angles for experimental dataset -""" - -from aspire.abinitio import CLSyncVoting -from aspire.source.relion import RelionSource - -# Set input path and files and initialize other parameters -DATA_FOLDER = "/path/to/untarred/empiar/dataset/" -STARFILE_IN = "/path/to/untarred/empiar/dataset/input.star" -STARFILE_OUT = "/path/to/output/output.star" -PIXEL_SIZE = 1.34 -MAX_ROWS = 1024 - -# Create a source object for 2D images -print(f"Read in images from {STARFILE_IN}.") -source = RelionSource( - STARFILE_IN, DATA_FOLDER, pixel_size=PIXEL_SIZE, max_rows=MAX_ROWS -) - -# Estimate rotation matrices -print("Estimate rotation matrices.") -orient_est = CLSyncVoting(source) -orient_est.estimate_rotations() - -# Create new source object and save estimate rotation matrices -print("Save estimate rotation matrices.") -orient_est_src = orient_est.save_rotations() - -# Output orientational angles -print("Save orientational angles to STAR file.") -orient_est_src.save_metadata(STARFILE_OUT) diff --git a/gallery/experiments/preprocess_imgs_exp.py.dontrun b/gallery/experiments/preprocess_imgs_exp.py.dontrun deleted file mode 100644 index 9a02e43dd9..0000000000 --- a/gallery/experiments/preprocess_imgs_exp.py.dontrun +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python -""" -This script illustrates how to preprocess experimental cryo-EM images -before starting the pipeline of reconstructing 3D map. -""" - -import matplotlib.pyplot as plt - -from aspire.noise import WhiteNoiseEstimator -from aspire.source.relion import RelionSource - -# Set input path and files and initialize other parameters -DATA_FOLDER = '/path/to/untarred/empiar/dataset/' -STARFILE_IN = '/path/to/untarred/empiar/dataset/input.star' -PIXEL_SIZE = 1.34 -NUM_IMGS = 100 - -print('This script illustrates how to preprocess experimental cryo-EM images') -print(f'Read in images from {STARFILE_IN} and preprocess the images') -source = RelionSource( - STARFILE_IN, - DATA_FOLDER, - pixel_size=PIXEL_SIZE, - max_rows=NUM_IMGS -) - -# number of images to extract for plotting -nimgs_ext = 1 - -print('Obtain original images') -imgs_od = source.images[:nimgs_ext] - -print('Perform phase flip to input images') -source = source.phase_flip() -imgs_pf = source.images[:nimgs_ext] - -max_resolution = 60 -print(f'Downsample resolution to {max_resolution} X {max_resolution}') -source = source.downsample(max_resolution) -imgs_ds = source.images[:nimgs_ext] - -print('Normalize images to noise background') -source = source.normalize_background() -imgs_nb = source.images[:nimgs_ext] - -print('Whiten noise of images') -noise_estimator = WhiteNoiseEstimator(source) -source = source.whiten(noise_estimator) -imgs_wt = source.images[:nimgs_ext] - -print('Invert global density contrast') -source = source.invert_contrast() -imgs_rc = source.images[:nimgs_ext] - -# plot the first images -print('plot the first images') -idm = 0 -plt.subplot(2, 3, 1) -plt.imshow(imgs_od[idm], cmap='gray') -plt.colorbar(orientation='horizontal') -plt.title('original image') - -plt.subplot(2, 3, 2) -plt.imshow(imgs_pf[idm], cmap='gray') -plt.colorbar(orientation='horizontal') -plt.title('phase flip') - -plt.subplot(2, 3, 3) -plt.imshow(imgs_ds[idm], cmap='gray') -plt.colorbar(orientation='horizontal') -plt.title('downsample') - -plt.subplot(2, 3, 4) -plt.imshow(imgs_nb[idm], cmap='gray') -plt.colorbar(orientation='horizontal') -plt.title('normalize background') - -plt.subplot(2, 3, 5) -plt.imshow(imgs_wt[idm], cmap='gray') -plt.colorbar(orientation='horizontal') -plt.title('noise whitening') - -plt.subplot(2, 3, 6) -plt.imshow(imgs_rc[idm], cmap='gray') -plt.colorbar(orientation='horizontal') -plt.title('invert contrast') -plt.show() diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index ec775fc973..63dce7e62a 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -22,17 +22,13 @@ from aspire.abinitio import CLSyncVoting from aspire.basis import FFBBasis3D -from aspire.denoising import DefaultClassAvgSource, DenoiserCov2D +from aspire.denoising import DefaultClassAvgSource, DenoisedSource, DenoiserCov2D from aspire.downloader import emdb_2660 from aspire.noise import AnisotropicNoiseEstimator, CustomNoiseAdder from aspire.operators import FunctionFilter, RadialCTFFilter from aspire.reconstruction import MeanEstimator from aspire.source import OrientedSource, Simulation -from aspire.utils.coor_trans import ( - get_aligned_rotations, - get_rots_mse, - register_rotations, -) +from aspire.utils import mean_aligned_angular_distance logger = logging.getLogger(__name__) @@ -149,7 +145,7 @@ def noise_function(x, y): # Use CWF denoising cwf_denoiser = DenoiserCov2D(src) # Use denoised src for classification - classification_src = cwf_denoiser.denoise() + classification_src = DenoisedSource(src, cwf_denoiser) # Peek, what do the denoised images look like... if interactive: classification_src.images[:10].show() @@ -198,12 +194,11 @@ def noise_function(x, y): oriented_src = OrientedSource(avgs, orient_est) logger.info("Compare with known rotations") -# Compare with known true rotations -Q_mat, flag = register_rotations(oriented_src.rotations, true_rotations) -regrot = get_aligned_rotations(oriented_src.rotations, Q_mat, flag) -mse_reg = get_rots_mse(regrot, true_rotations) +# Compare with known true rotations. ``mean_aligned_angular_distance`` globally aligns the estimated +# rotations to the ground truth and finds the mean angular distance between them. +mean_ang_dist = mean_aligned_angular_distance(oriented_src.rotations, true_rotations) logger.info( - f"MSE deviation of the estimated rotations using register_rotations : {mse_reg}\n" + f"Mean angular distance between globally aligned estimates and ground truth rotations: {mean_ang_dist}\n" ) # %% diff --git a/gallery/tutorials/aspire_introduction.py b/gallery/tutorials/aspire_introduction.py index e12b3651be..f87afaee2f 100644 --- a/gallery/tutorials/aspire_introduction.py +++ b/gallery/tutorials/aspire_introduction.py @@ -245,7 +245,7 @@ # classDiagram # class Filter{ # +evaluate() -# +fb_mat() +# +basis_mat() # +scale() # +evaluate_grid() # +dual() diff --git a/gallery/tutorials/data/rln_proj_65.mrcs b/gallery/tutorials/data/rln_proj_65.mrcs new file mode 120000 index 0000000000..e722cc0de2 --- /dev/null +++ b/gallery/tutorials/data/rln_proj_65.mrcs @@ -0,0 +1 @@ +../../../tests/saved_test_data/rln_proj_65.mrcs \ No newline at end of file diff --git a/gallery/tutorials/data/rln_proj_65.star b/gallery/tutorials/data/rln_proj_65.star new file mode 120000 index 0000000000..5965d5d1cf --- /dev/null +++ b/gallery/tutorials/data/rln_proj_65.star @@ -0,0 +1 @@ +../../../tests/saved_test_data/rln_proj_65.star \ No newline at end of file diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index 3119b8e1bd..f2547d2c00 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -9,7 +9,7 @@ # %% # Download an Example Volume -# ----------------- +# -------------------------- # We begin by downloading a high resolution volume map of the 80S # Ribosome, sourced from EMDB: https://www.ebi.ac.uk/emdb/EMD-2660. # This is one of several volume maps that can be downloaded with @@ -213,23 +213,17 @@ # %% -# Mean Squared Error -# ------------------ -# ASPIRE has some built-in utility functions for globally aligning the -# estimated rotations to the true rotations and computing the mean -# squared error. - -from aspire.utils.coor_trans import ( - get_aligned_rotations, - get_rots_mse, - register_rotations, -) +# Mean Error of Estimated Rotations +# --------------------------------- +# ASPIRE has the built-in utility function, ``mean_aligned_angular_distance``, which globally +# aligns the estimated rotations to the true rotations and computes the mean +# angular distance (in degrees). + +from aspire.utils import mean_aligned_angular_distance # Compare with known true rotations -Q_mat, flag = register_rotations(oriented_src.rotations, true_rotations) -regrot = get_aligned_rotations(oriented_src.rotations, Q_mat, flag) -mse_reg = get_rots_mse(regrot, true_rotations) -mse_reg +mean_ang_dist = mean_aligned_angular_distance(oriented_src.rotations, true_rotations) +print(f"Mean aligned angular distance: {mean_ang_dist} degrees") # %% diff --git a/gallery/tutorials/tutorials/cov2d_simulation.py b/gallery/tutorials/tutorials/cov2d_simulation.py index 162eda4e80..e0ed8328e0 100644 --- a/gallery/tutorials/tutorials/cov2d_simulation.py +++ b/gallery/tutorials/tutorials/cov2d_simulation.py @@ -116,7 +116,7 @@ h_idx = sim.filter_indices # Evaluate CTF in the 8X8 FB basis -h_ctf_fb = [filt.fb_mat(ffbbasis) for filt in ctf_filters] +h_ctf_fb = [ffbbasis.filter_to_basis_mat(filt) for filt in ctf_filters] # Get clean images from projections of 3D map. logger.info("Apply CTF filters to clean images.") @@ -140,8 +140,8 @@ # ``basis.evaluate_t``. logger.info("Get coefficients of clean and noisy images in FFB basis.") -coeff_clean = ffbbasis.evaluate_t(imgs_clean) -coeff_noise = ffbbasis.evaluate_t(imgs_noise) +coef_clean = ffbbasis.evaluate_t(imgs_clean) +coef_noise = ffbbasis.evaluate_t(imgs_noise) # %% # Create Cov2D Object and Calculate Mean and Variance for Clean Images @@ -161,8 +161,8 @@ "Get 2D covariance matrices of clean and noisy images using FB coefficients." ) cov2d = RotCov2D(ffbbasis) -mean_coeff = cov2d.get_mean(coeff_clean) -covar_coeff = cov2d.get_covar(coeff_clean, mean_coeff, noise_var=0) +mean_coef = cov2d.get_mean(coef_clean) +covar_coef = cov2d.get_covar(coef_clean, mean_coef, noise_var=0) # %% # Estimate mean and covariance for noisy images with CTF and shrink method @@ -184,12 +184,12 @@ "precision": "float64", "preconditioner": "identity", } -mean_coeff_est = cov2d.get_mean(coeff_noise, h_ctf_fb, h_idx) -covar_coeff_est = cov2d.get_covar( - coeff_noise, +mean_coef_est = cov2d.get_mean(coef_noise, h_ctf_fb, h_idx) +covar_coef_est = cov2d.get_covar( + coef_noise, h_ctf_fb, h_idx, - mean_coeff_est, + mean_coef_est, noise_var=noise_var, covar_est_opt=covar_opt, ) @@ -203,17 +203,17 @@ # the lowest expected mean square error out of all linear estimators. logger.info("Get the CWF coefficients of noising images.") -coeff_est = cov2d.get_cwf_coeffs( - coeff_noise, +coef_est = cov2d.get_cwf_coefs( + coef_noise, h_ctf_fb, h_idx, - mean_coeff=mean_coeff_est, - covar_coeff=covar_coeff_est, + mean_coef=mean_coef_est, + covar_coef=covar_coef_est, noise_var=noise_var, ) # Convert Fourier-Bessel coefficients back into 2D images -imgs_est = ffbbasis.evaluate(coeff_est) +imgs_est = ffbbasis.evaluate(coef_est) # %% # Evaluate the Results @@ -221,12 +221,12 @@ # Calculate the difference between the estimated covariance and the "true" # covariance estimated from the clean Fourier-Bessel coefficients. -covar_coeff_diff = covar_coeff - covar_coeff_est +covar_coef_diff = covar_coef - covar_coef_est # Calculate the deviation between the clean estimates and those obtained from # the noisy, filtered images. -diff_mean = anorm(mean_coeff_est - mean_coeff) / anorm(mean_coeff) -diff_covar = covar_coeff_diff.norm() / covar_coeff.norm() +diff_mean = anorm(mean_coef_est - mean_coef) / anorm(mean_coef) +diff_covar = covar_coef_diff.norm() / covar_coef.norm() # Calculate the normalized RMSE of the estimated images. nrmse_ims = (imgs_est - imgs_clean).norm() / imgs_clean.norm() diff --git a/gallery/tutorials/tutorials/cov3d_simulation.py b/gallery/tutorials/tutorials/cov3d_simulation.py index c30653d5d9..da75f998c8 100644 --- a/gallery/tutorials/tutorials/cov3d_simulation.py +++ b/gallery/tutorials/tutorials/cov3d_simulation.py @@ -33,10 +33,11 @@ num_eigs = 16 # number of eigen-vectors to keep dtype = np.float32 -# Generate a ``Volume`` object for use in the simulation. Here we use a ``LegacyVolume`` which -# by default generates 2 unique random volumes. +# Generate a ``Volume`` object for use in the simulation. Here we use a ``LegacyVolume`` and +# set C = 3 to generate 3 unique random volumes. vols = LegacyVolume( L=img_size, + C=3, dtype=dtype, ).generate() @@ -49,7 +50,7 @@ dtype=dtype, ) -# The Simulation object was created using 2 volumes. +# The Simulation object was created using 3 volumes. num_vols = sim.C # Specify the normal FB basis method for expending the 2D images @@ -159,6 +160,6 @@ logger.info(f'Coordinates (mean correlation) = {np.mean(coords_perf["corr"])}') # Basic Check -assert covar_perf["rel_err"] <= 0.60 -assert np.mean(coords_perf["corr"]) >= 0.98 +assert covar_perf["rel_err"] <= 0.80 +assert np.mean(coords_perf["corr"]) >= 0.97 assert clustering_accuracy >= 0.99 diff --git a/gallery/tutorials/tutorials/image_class.py b/gallery/tutorials/tutorials/image_class.py index 36d689ddc0..ca4df61961 100644 --- a/gallery/tutorials/tutorials/image_class.py +++ b/gallery/tutorials/tutorials/image_class.py @@ -14,7 +14,7 @@ file_path = os.path.join(os.path.dirname(os.getcwd()), "data", "monuments.npy") img_data = np.load(file_path) -img_data.shape, img_data.dtype +print(img_data.shape, img_data.dtype) # %% # Create an Image Instance diff --git a/gallery/tutorials/tutorials/image_expansion.py b/gallery/tutorials/tutorials/image_expansion.py index 93adaec129..4c3f36aac4 100644 --- a/gallery/tutorials/tutorials/image_expansion.py +++ b/gallery/tutorials/tutorials/image_expansion.py @@ -51,13 +51,13 @@ # Get the expansion coefficients based on FB basis logger.info("Start normal FB expansion of original images.") tstart = timeit.default_timer() -fb_coeffs = fb_basis.evaluate_t(org_images) +fb_coefs = fb_basis.evaluate_t(org_images) tstop = timeit.default_timer() dtime = tstop - tstart logger.info(f"Finish normal FB expansion of original images in {dtime:.4f} seconds.") # Reconstruct images from the expansion coefficients based on FB basis -fb_images = fb_basis.evaluate(fb_coeffs).asnumpy() +fb_images = fb_basis.evaluate(fb_coefs).asnumpy() logger.info("Finish reconstruction of images from normal FB expansion coefficients.") # Calculate the mean value of maximum differences between the FB estimated images and the original images @@ -94,13 +94,13 @@ # Get the expansion coefficients based on fast FB basis logger.info("start fast FB expansion of original images.") tstart = timeit.default_timer() -ffb_coeffs = ffb_basis.evaluate_t(org_images) +ffb_coefs = ffb_basis.evaluate_t(org_images) tstop = timeit.default_timer() dtime = tstop - tstart logger.info(f"Finish fast FB expansion of original images in {dtime:.4f} seconds.") # Reconstruct images from the expansion coefficients based on fast FB basis -ffb_images = ffb_basis.evaluate(ffb_coeffs).asnumpy() +ffb_images = ffb_basis.evaluate(ffb_coefs).asnumpy() logger.info("Finish reconstruction of images from fast FB expansion coefficients.") # Calculate the mean value of maximum differences between the fast FB estimated images to the original images @@ -138,13 +138,13 @@ # Get the expansion coefficients based on direct PSWF basis logger.info("Start direct PSWF expansion of original images.") tstart = timeit.default_timer() -pswf_coeffs = pswf_basis.evaluate_t(org_images) +pswf_coefs = pswf_basis.evaluate_t(org_images) tstop = timeit.default_timer() dtime = tstop - tstart logger.info(f"Finish direct PSWF expansion of original images in {dtime:.4f} seconds.") # Reconstruct images from the expansion coefficients based on direct PSWF basis -pswf_images = pswf_basis.evaluate(pswf_coeffs).asnumpy() +pswf_images = pswf_basis.evaluate(pswf_coefs).asnumpy() logger.info("Finish reconstruction of images from direct PSWF expansion coefficients.") # Calculate the mean value of maximum differences between direct PSWF estimated images and original images @@ -182,13 +182,13 @@ # Get the expansion coefficients based on fast PSWF basis logger.info("Start fast PSWF expansion of original images.") tstart = timeit.default_timer() -fpswf_coeffs = fpswf_basis.evaluate_t(org_images) +fpswf_coefs = fpswf_basis.evaluate_t(org_images) tstop = timeit.default_timer() dtime = tstop - tstart logger.info(f"Finish fast PSWF expansion of original images in {dtime:.4f} seconds.") # Reconstruct images from the expansion coefficients based on direct PSWF basis -fpswf_images = fpswf_basis.evaluate(fpswf_coeffs).asnumpy() +fpswf_images = fpswf_basis.evaluate(fpswf_coefs).asnumpy() logger.info("Finish reconstruction of images from fast PSWF expansion coefficients.") # Calculate mean value of maximum differences between the fast PSWF estimated images and the original images diff --git a/gallery/tutorials/tutorials/orient3d_simulation.py b/gallery/tutorials/tutorials/orient3d_simulation.py index 433223e6d2..5934eec030 100644 --- a/gallery/tutorials/tutorials/orient3d_simulation.py +++ b/gallery/tutorials/tutorials/orient3d_simulation.py @@ -14,7 +14,7 @@ from aspire.abinitio import CLSyncVoting from aspire.operators import RadialCTFFilter from aspire.source import OrientedSource, Simulation -from aspire.utils import get_aligned_rotations, get_rots_mse, register_rotations +from aspire.utils import mean_aligned_angular_distance from aspire.volume import Volume logger = logging.getLogger(__name__) @@ -87,23 +87,26 @@ # ---------------------------------------- # Initialize an orientation estimation object and create an ``OrientedSource`` object -# to perform viewing angle estimation +# to perform viewing angle estimation. Here, because of the small image size of the +# ``Simulation``, we customize the ``CLSyncVoting`` method to use fewer theta values +# when searching for common-lines between pairs of images. Additionally, since we are +# processing images with no noise, we opt not to use a ``fuzzy_mask``, an option that +# improves common-line detection in higher noise regimes. logger.info("Estimate rotation angles using synchronization matrix and voting method.") -orient_est = CLSyncVoting(sim, n_theta=36) +orient_est = CLSyncVoting(sim, n_theta=36, mask=False) oriented_src = OrientedSource(sim, orient_est) rots_est = oriented_src.rotations # %% -# Mean Squared Error -# ------------------ +# Mean Angular Distance +# --------------------- -# Get register rotations after performing global alignment -Q_mat, flag = register_rotations(rots_est, rots_true) -regrot = get_aligned_rotations(rots_est, Q_mat, flag) -mse_reg = get_rots_mse(regrot, rots_true) +# ``mean_aligned_angular_distance`` will perform global alignment of the estimated rotations +# to the ground truth and find the mean angular distance between them (in degrees). +mean_ang_dist = mean_aligned_angular_distance(rots_est, rots_true) logger.info( - f"MSE deviation of the estimated rotations using register_rotations : {mse_reg}" + f"Mean angular distance between estimates and ground truth: {mean_ang_dist} degrees" ) # Basic Check -assert mse_reg < 0.06 +assert mean_ang_dist < 10 diff --git a/gallery/tutorials/tutorials/relion_projection_interop.py b/gallery/tutorials/tutorials/relion_projection_interop.py new file mode 100644 index 0000000000..96122282db --- /dev/null +++ b/gallery/tutorials/tutorials/relion_projection_interop.py @@ -0,0 +1,103 @@ +""" +================================== +Relion Projection Interoperability +================================== + +In this tutorial we compare projections generated by Relion +with projections generated by ASPIRE's ``Simulation`` class. +Both sets of projections are generated using a downsampled +volume map of a 70S Ribosome, absent of noise and CTF corruption. +""" + +import os + +import numpy as np + +from aspire.source import RelionSource, Simulation +from aspire.volume import Volume + +# %% +# Load Relion Projections +# ----------------------- +# We load the Relion projections as a ``RelionSource`` and view the images. + +starfile = os.path.join(os.path.dirname(os.getcwd()), "data", "rln_proj_65.star") +rln_src = RelionSource(starfile) +rln_src.images[:].show(colorbar=False) + +# %% +# .. note:: +# The projections above were generated in Relion using the following command:: +# +# relion_project --i clean70SRibosome_vol_65p.mrc --nr_uniform 3000 --angpix 5 +# +# For this tutorial we take a subset of these projections consisting of the first 5 images. + +# %% +# Generate Projections using ``Simulation`` +# ----------------------------------------- +# Using the metadata associated with the ``RelionSource`` and the same volume +# we generate an analogous set of projections with ASPIRE's ``Simulation`` class. + +# Load the volume from file as a ``Volume`` object. +filepath = os.path.join( + os.path.dirname(os.getcwd()), "data", "clean70SRibosome_vol_65p.mrc" +) +vol = Volume.load(filepath, dtype=rln_src.dtype) + +# Create a ``Simulation`` source using metadata from the RelionSource projections. +# Note, for odd resolution Relion projections are shifted from ASPIRE projections +# by 1 pixel in x and y. +sim_src = Simulation( + n=rln_src.n, + vols=vol, + offsets=-np.ones((rln_src.n, 2), dtype=rln_src.dtype), + amplitudes=rln_src.amplitudes, + angles=rln_src.angles, + dtype=rln_src.dtype, +) + +sim_src.images[:].show(colorbar=False) + +# %% +# Comparing the Projections +# ------------------------- +# We will take a few different approaches to comparing the two sets of projection images. + +# %% +# Visual Comparison +# ^^^^^^^^^^^^^^^^^ +# We'll first look at a side-by-side of the two sets of images to confirm visually that +# the projections are taken from the same viewing angles. + +rln_src.images[:].show(colorbar=False) +sim_src.images[:].show(colorbar=False) + +# %% +# Fourier Ring Correlation +# ^^^^^^^^^^^^^^^^^^^^^^^^ +# Additionally, we can compare the two sets of images using the FRC. Note that the images +# are tightly correlated up to a high resolution of 2 pixels. +rln_src.images[:].frc(sim_src.images[:], cutoff=0.143, plot=True) + +# %% +# Relative Error +# ^^^^^^^^^^^^^^ +# As Relion and ASPIRE differ in methods of generating projections, the pixel intensity of +# the images may not correspond perfectly. So we begin by first normalizing the two sets of projections. +# We then check that the relative error with respect to the frobenius norm is less than 3%. + +# Work with numpy arrays. +rln_np = rln_src.images[:].asnumpy() +sim_np = sim_src.images[:].asnumpy() + +# Normalize images. +rln_np = (rln_np - np.mean(rln_np)) / np.std(rln_np) +sim_np = (sim_np - np.mean(sim_np)) / np.std(sim_np) + +# Assert that error is less than 3%. +error = np.linalg.norm(rln_np - sim_np, axis=(1, 2)) / np.linalg.norm( + rln_np, axis=(1, 2) +) +assert all(error < 0.03) +print(f"Relative per-image error: {error}") diff --git a/gallery/tutorials/tutorials/weighted_volume_estimation.py b/gallery/tutorials/tutorials/weighted_volume_estimation.py new file mode 100644 index 0000000000..a8e7893292 --- /dev/null +++ b/gallery/tutorials/tutorials/weighted_volume_estimation.py @@ -0,0 +1,115 @@ +""" +Weighted Volume Reconstruction +============================== + +This tutorial demonstrates a weighted volume reconstruction, +using a published reference dataset. +""" + +# %% +# Download an Example Dataset +# --------------------------- +# ASPIRE's downloader will download, cache, +# and unpack the reference dataset. +# More information about the dataset can be found on +# `Zenodo `_ +# and in this `paper `_ + +from aspire import downloader + +sim_data = downloader.simulated_channelspin() + +# This data contains a `Volume` stack, an `Image` stack, weights and +# corresponding parameters that were used to derive images +# from the volumes. For example, the rotations below are the known +# true simulation projection rotations. In practice these would be +# derived from an orientation estimation component. + +imgs = sim_data["images"] # Simulated image stack (`Image` object) +rots = sim_data["rots"] # True projection rotations (`Rotation` object) +weights = sim_data["weights"] # Volume weights (`Numpy` array) +vols = sim_data["vols"] # True reference volumes (`Volume` object) + +# %% +# Create a ``ImageSource`` +# ------------------------ +# The image stack and projection rotation (Euler) angles can be +# associated together during instantiation of an ``ImageSource``. +# Because this example starts with a dense array of images, +# an ``ArrayImageSource`` is used. + +from aspire.source import ArrayImageSource + +src = ArrayImageSource(imgs, angles=rots.angles) + +# The images are downsampled for the sake of a quicker tutorial. +# This line can be commented out to achieve the reference size (54 pixels). +src = src.downsample(24) + +# %% +# .. note:: +# This tutorial demonstrates bringing reference data. +# It is also possible to just create a ``Simulation`` or use other +# ``ImageSource`` objects, so long as the rotations required +# for backprojecting are assigned. + +# %% +# Volume Reconstruction +# --------------------- +# Performing a weighted volume reconstruction requires defining an +# appropriate 3D basis and supplying an associated image to volume +# weight mapping as an array. + +from aspire.basis import FFBBasis3D +from aspire.reconstruction import WeightedVolumesEstimator + +# Create a reasonable Basis +basis = FFBBasis3D(src.L, dtype=src.dtype) + +# Set up an estimator to perform the backprojections and volume estimation. +# In this case, the `weights` array comes from the reference data set, +# and is shaped to map images to spectral volumes. +# Note that we can have many more actual/reference volumes generating +# the image stack than spectral volumes. In this case the input +# images were generated from 54 volumes, but are described by 16 +# spectral volumes. +print("`weights shape:`", weights.shape) +estimator = WeightedVolumesEstimator(weights, src, basis, preconditioner="none") + +# Perform the estimation, returning a `Volume` stack. +estimated_volume = estimator.estimate() + +# %% +# .. note:: +# The ``estimate()`` method requires a fair amount of compute time, +# but there should be regularly logged progress towards convergence. + +# %% +# Comparison of Estimated Volume with Source Volume +# ------------------------------------------------- +# Generate several random projections rotations, then compare these +# projections between the estimated spectral volumes and a known volume. +# If ``src`` was downsampled above, the resulting estimated volumes +# and projections will be of similar downsampled quality. +# +# Note that the estimated spectral volumes are treated as `Volume` +# objects purely for convienience and are not expected to correspond +# exactly to any particular reference volume. The spectral volumes +# collectively describe motion features derived from the input data. +# However, basic visual comparison is useful as a sanity check to +# demonstrate that we are in fact generating spectral volumes that +# appear reasonably similar to the input volumes. + +from aspire.utils import Rotation, uniform_random_angles + +reference_v = 0 # Actual volume under comparison +spectral_v = 0 # Estimated spectral volume +m = 3 # Number of projections + +random_rotations = Rotation.from_euler(uniform_random_angles(m, dtype=src.dtype)) + +# Estimated volume projections +estimated_volume[spectral_v].project(random_rotations).show() + +# Source volume projections +vols[reference_v].project(random_rotations).show() diff --git a/pyproject.toml b/pyproject.toml index 70cfc85960..3ddeb7ac1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "aspire" -version = "0.12.0" +version = "0.12.1" description = "Algorithms for Single Particle Reconstruction" readme = "README.md" # Optional requires-python = ">=3.8" @@ -30,6 +30,7 @@ classifiers = [ dependencies = [ "click", "confuse >= 2.0.0", + "cvxpy", "finufft", "gemmi >= 0.4.8", "grpcio >= 1.54.2", @@ -39,11 +40,13 @@ dependencies = [ "numpy>=1.21.5", "packaging", "pooch>=1.7.0", + "pillow", "psutil", + "pydantic<2", # Workaround for Ray<2.9 "pyfftw", "pymanopt", + "pyshtools", "PyWavelets", - "pillow", "ray", "scipy >= 1.10.0", "scikit-learn", @@ -72,7 +75,6 @@ dev = [ "pooch", "pyflakes", "pydocstyle", - "parameterized", "pytest", "pytest-cov", "pytest-random-order", diff --git a/src/aspire/__init__.py b/src/aspire/__init__.py index 654112e47d..8e412fbab8 100644 --- a/src/aspire/__init__.py +++ b/src/aspire/__init__.py @@ -12,7 +12,7 @@ from aspire.exceptions import handle_exception # version in maj.min.bld format -__version__ = "0.12.0" +__version__ = "0.12.1" # Setup `confuse` config @@ -69,8 +69,18 @@ sys.excepthook = handle_exception +# Collect set of all module names in package +_modules = set(item[1] for item in pkgutil.iter_modules(aspire.__path__)) +# Automatically add modules __all__ = [] -for _, modname, _ in pkgutil.iter_modules(aspire.__path__): +for modname in _modules: __all__.append(modname) # Add module to __all_ - importlib.import_module(f"aspire.{modname}") # Import the module + + +# Dynamically load and return attributes +def __getattr__(attr): + if attr in _modules: + return importlib.import_module(f"aspire.{attr}") + else: + raise AttributeError(f"module `{__name__}` has no attribute `{attr}`.") diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index 69483c9800..ff14cc2d45 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -1,4 +1,5 @@ from .commonline_base import CLOrient3D +from .commonline_sdp import CommonlineSDP from .sync_voting import SyncVotingMixin # isort: off diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index a7a044b60e..d1831c8177 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -5,7 +5,7 @@ import scipy.sparse as sparse from aspire.operators import PolarFT -from aspire.utils import common_line_from_rots +from aspire.utils import common_line_from_rots, fuzzy_mask from aspire.utils.random import choice logger = logging.getLogger(__name__) @@ -17,12 +17,19 @@ class CLOrient3D: """ def __init__( - self, src, n_rad=None, n_theta=360, n_check=None, max_shift=0.15, shift_step=1 + self, + src, + n_rad=None, + n_theta=360, + n_check=None, + max_shift=0.15, + shift_step=1, + mask=True, ): """ - Initialize an object for estimating 3D orientations using common lines + Initialize an object for estimating 3D orientations using common lines. - :param src: The source object of 2D denoised or class-averaged imag + :param src: The source object of 2D denoised or class-averaged images. :param n_rad: The number of points in the radial direction. If None, n_rad will default to the ceiling of half the resolution of the source. :param n_theta: The number of points in the theta direction. This value must be even. @@ -34,6 +41,8 @@ def __init__( of the resolution. Default is 0.15. :param shift_step: Resolution of shift estimation in pixels. Default is 1 pixel. + :param mask: Option to mask `src.images` with a fuzzy mask (boolean). + Default, `True`, applies a mask. """ self.src = src # Note dtype is inferred from self.src @@ -46,6 +55,7 @@ def __init__( self.clmatrix = None self.max_shift = math.ceil(max_shift * self.n_res) self.shift_step = shift_step + self.mask = mask self.rotations = None self._build() @@ -69,6 +79,10 @@ def _build(self): imgs = self.src.images[:] + if self.mask: + fuzz_mask = fuzzy_mask((self.n_res, self.n_res), self.dtype) + imgs = imgs * fuzz_mask + # Obtain coefficients of polar Fourier transform for input 2D images self.pft = PolarFT( (self.n_res, self.n_res), self.n_rad, self.n_theta, dtype=self.dtype @@ -356,7 +370,7 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): shift_b[shift_eq_idx] = dx # Compute the coefficients of the current equation - coeffs = np.array( + coefs = np.array( [ np.sin(shift_alpha), np.cos(shift_alpha), @@ -364,7 +378,7 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): -np.cos(shift_beta), ] ) - shift_eq[idx] = -1 * coeffs if is_pf_j_flipped else coeffs + shift_eq[idx] = -1 * coefs if is_pf_j_flipped else coefs # create sparse matrix object only containing non-zero elements shift_equations = sparse.csr_matrix( @@ -381,6 +395,7 @@ def _estimate_num_shift_equations(self, n_img, equations_factor=1, max_memory=40 The function computes total number of shift equations based on number of images and preselected memory factor. + :param n_img: The total number of input images :param equations_factor: The factor to rescale the number of shift equations (=1 in default) @@ -428,6 +443,7 @@ def _generate_shift_phase_and_filter(self, r_max, max_shift, shift_step): The shift phases are pre-defined in a range of max_shift that can be applied to maximize the common line calculation. The common-line filter is also applied to the radial direction for easier detection. + :param r_max: Maximum index for common line detection :param max_shift: Maximum value of 1D shift (in pixels) to search :param shift_step: Resolution of shift estimation in pixels @@ -497,6 +513,7 @@ def _apply_filter_and_norm(self, subscripts, pf, r_max, h): :subscripts: Specifies the subscripts for summation of Numpy `einsum` function + :param pf: Fourier transform of images :param r_max: Maximum index for common line detection :param h: common lines filter diff --git a/src/aspire/abinitio/commonline_c2.py b/src/aspire/abinitio/commonline_c2.py index 382078cc52..6a762a59c9 100644 --- a/src/aspire/abinitio/commonline_c2.py +++ b/src/aspire/abinitio/commonline_c2.py @@ -43,6 +43,7 @@ def __init__( degree_res=1, min_dist_cls=25, seed=None, + mask=True, ): """ Initialize object for estimating 3D orientations for molecules with C2 symmetry. @@ -57,6 +58,8 @@ def __init__( :param degree_res: Degree resolution for estimating in-plane rotations. :param min_dist_cls: Minimum distance between mutual common-lines. Default = 25 degrees. :param seed: Optional seed for RNG. + :param mask: Option to mask `src.images` with a fuzzy mask (boolean). + Default, `True`, applies a mask. """ super().__init__( src, @@ -69,6 +72,7 @@ def __init__( max_iters=max_iters, degree_res=degree_res, seed=seed, + mask=mask, ) self.min_dist_cls = min_dist_cls diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index 7ddfbde7cd..63ac41ac88 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -51,6 +51,7 @@ def __init__( max_iters=1000, degree_res=1, seed=None, + mask=True, ): """ Initialize object for estimating 3D orientations for molecules with C3 and C4 symmetry. @@ -65,6 +66,8 @@ def __init__( :param max_iter: Maximum iterations for the power method. :param degree_res: Degree resolution for estimating in-plane rotations. :param seed: Optional seed for RNG. + :param mask: Option to mask `src.images` with a fuzzy mask (boolean). + Default, `True`, applies a mask. """ super().__init__( @@ -73,6 +76,7 @@ def __init__( n_theta=n_theta, max_shift=max_shift, shift_step=shift_step, + mask=mask, ) self._check_symmetry(symmetry) @@ -549,6 +553,7 @@ def _syncmatrix_ij_vote_3n(self, clmatrix, i, j, k_list, n_theta): Given the common lines matrix `clmatrix`, a list of images specified in k_list and the number of common lines n_theta, find the (i, j) rotation block Rij. + :param clmatrix: The common lines matrix :param i: The i image :param j: The j image diff --git a/src/aspire/abinitio/commonline_cn.py b/src/aspire/abinitio/commonline_cn.py index d267a367c1..94c9fa3ca3 100644 --- a/src/aspire/abinitio/commonline_cn.py +++ b/src/aspire/abinitio/commonline_cn.py @@ -40,6 +40,7 @@ def __init__( n_points_sphere=500, equator_threshold=10, seed=None, + mask=True, ): """ Initialize object for estimating 3D orientations for molecules with Cn symmetry, n>4. @@ -57,6 +58,8 @@ def __init__( :param equator_threshold: Threshold for removing candidate rotations within `equator_threshold` degrees of being an equator image. Default is 10 degrees. :param seed: Optional seed for RNG. + :param mask: Option to mask `src.images` with a fuzzy mask (boolean). + Default, `True`, applies a mask. """ super().__init__( @@ -70,6 +73,7 @@ def __init__( max_iters=max_iters, degree_res=degree_res, seed=seed, + mask=mask, ) self.n_points_sphere = n_points_sphere diff --git a/src/aspire/abinitio/commonline_sdp.py b/src/aspire/abinitio/commonline_sdp.py index ceb9666267..0fdae3c8b6 100644 --- a/src/aspire/abinitio/commonline_sdp.py +++ b/src/aspire/abinitio/commonline_sdp.py @@ -1,30 +1,250 @@ import logging +import cvxpy as cp +import numpy as np +from scipy.sparse import csr_array + from aspire.abinitio import CLOrient3D +from aspire.utils import nearest_rotations +from aspire.utils.matlab_compat import stable_eigsh logger = logging.getLogger(__name__) -class CommLineSDP(CLOrient3D): +class CommonlineSDP(CLOrient3D): """ - Class to estimate 3D orientations using Semi-Definite Programming - :cite:`DBLP:journals/siamis/SingerS11` + Class to estimate 3D orientations using semi-definite programming. + + See the following publication for more details: + + A. Singer and Y. Shkolnisky, + "Three-Dimensional Structure Determination from Common Lines in Cryo-EM + by Eigenvectors and Semidefinite Programming" + SIAM J. Imaging Sciences, Vol. 4, No. 2, (2011): 543-572. doi:10.1137/090767777 """ - def __init__(self, src): + def estimate_rotations(self): + """ + Estimate rotation matrices using the common lines method with semi-definite programming. + """ + logger.info("Computing the common lines matrix.") + self.build_clmatrix() + + S = self._construct_S(self.clmatrix) + A, b = self._sdp_prep() + gram = self._compute_gram_matrix(S, A, b) + rotations = self._deterministic_rounding(gram) + self.rotations = rotations + + def _construct_S(self, clmatrix): + """ + Construct the 2*n_img x 2*n_img quadratic form matrix S corresponding to the common-lines + matrix 'clmatrix'. + + :param clmatrix: n_img x n_img common-lines matrix. + + :return: 2*n_img x 2*n_img quadratic form matrix S. + """ + logger.info("Constructing the common line quadratic form matrix S.") + + S11 = np.zeros((self.n_img, self.n_img), dtype=self.dtype) + S12 = np.zeros((self.n_img, self.n_img), dtype=self.dtype) + S21 = np.zeros((self.n_img, self.n_img), dtype=self.dtype) + S22 = np.zeros((self.n_img, self.n_img), dtype=self.dtype) + + for i in range(self.n_img): + for j in range(i + 1, self.n_img): + cij = clmatrix[i, j] + cji = clmatrix[j, i] + + xij = np.cos(2 * np.pi * cij / self.n_theta) + yij = np.sin(2 * np.pi * cij / self.n_theta) + xji = np.cos(2 * np.pi * cji / self.n_theta) + yji = np.sin(2 * np.pi * cji / self.n_theta) + + S11[i, j] = xij * xji + S11[j, i] = xji * xij + + S12[i, j] = xij * yji + S12[j, i] = xji * yij + + S21[i, j] = yij * xji + S21[j, i] = yji * xij + + S22[i, j] = yij * yji + S22[j, i] = yji * yij + + S = np.block([[S11, S12], [S21, S22]]) + + return S + + def _sdp_prep(self): + """ + Prepare optimization problem constraints. + + The constraints for the SDP optimization, max tr(SG), performed in `_compute_gram_matrix()` + as min tr(-SG), are that the Gram matrix, G, is semidefinite positive and G11_ii = G22_ii = 1, + G12_ii = G21_ii = 0, i=1,2,...,N, for the block representation of G = [[G11, G12], [G21, G22]]. + + We build a corresponding constraint for CVXPY in the form of tr(A_j @ G) = b_j, j = 1,...,p. + For the constraint G11_ii = G22_ii = 1, we have A_j[i, i] = 1 (zeros elsewhere) and b_j = 1. + For the constraint G12_ii = G21_ii = 0, we have A_j[i, i] = 1 (zeros elsewhere) and b_j = 0. + + :returns: Constraint data A, b. + """ + logger.info("Preparing SDP optimization constraints.") + + n = 2 * self.n_img + A = [] + b = [] + data = np.ones(1, dtype=self.dtype) + for i in range(n): + row_ind = np.array([i]) + col_ind = np.array([i]) + A_i = csr_array((data, (row_ind, col_ind)), shape=(n, n), dtype=self.dtype) + A.append(A_i) + b.append(1) + + for i in range(self.n_img): + row_ind = np.array([i]) + col_ind = np.array([self.n_img + i]) + A_i = csr_array((data, (row_ind, col_ind)), shape=(n, n), dtype=self.dtype) + A.append(A_i) + b.append(0) + + b = np.array(b, dtype=self.dtype) + + return A, b + + def _compute_gram_matrix(self, S, A, b): """ - constructor of an object for estimating 3D orientations + Compute the Gram matrix by solving an SDP optimization. + + The Gram matrix will be of the form G = R.T @ R, where R = [R1 R2] or the concatenation + of the first columns of every rotation, R1, and the second columns of every rotation, R2. + From this Gram matrix, the rotations can be recovered using the deterministic rounding + procedure below. + + Here we optimize over G, max tr(SG), written as min tr(-SG), subject to the constraints + described in `_spd_prep()`. It should be noted that tr(SG) = sum(dot(R_i @ c_ij, R_j @ c_ji)), + and that maximizing this objective function is equivalently to minimizing the L2 norm + of R_i @ c_ij - R_j @ c_ji, ie. finding the best approximation for the rotations R_i. + + :param S: The common-line quadratic form matrix of shape 2 * n_img x 2 * n_img. + :param A: 3 * n_img sparse arrays of constraint data. + :param b: 3 * n_img scalars such that tr(A_i G) = b_i. + + :return: Gram matrix. """ - pass + logger.info("Solving SDP to approximate Gram matrix.") + + n = 2 * self.n_img + # Define and solve the CVXPY problem. + # Create a symmetric matrix variable. + G = cp.Variable((n, n), symmetric=True) + # The operator >> denotes matrix inequality. + constraints = [G >> 0] + constraints += [cp.trace(A[i] @ G) == b[i] for i in range(3 * self.n_img)] + prob = cp.Problem(cp.Minimize(cp.trace(-S @ G)), constraints) + prob.solve() - def estimate(self): + return G.value + + def _deterministic_rounding(self, gram): """ - perform estimation of orientations + Deterministic rounding procedure to recover the rotations from the Gram matrix. + + The Gram matrix contains information about the first two columns of every rotation + matrix. These columns are extracted and used to form the remaining column of every + rotation matrix. + + :param gram: A 2n_img x 2n_img Gram matrix. + + :return: An n_img x 3 x 3 stack of rotation matrices. """ - pass + logger.info("Recovering rotations from Gram matrix.") + + # Obtain top eigenvectors from Gram matrix. + d, v = stable_eigsh(gram, 5) + sort_idx = np.argsort(-d) + logger.info(f"Top 5 eigenvalues from (rank-3) Gram matrix: {d[sort_idx]}") - def output(self): + # Only need the top 3 eigen-vectors. + v = v[:, sort_idx[:3]] + + # According to the structure of the Gram matrix, the first `n_img` rows, denoted v1, + # correspond to the linear combination of the vectors R_{i}^{1}, i=1,...,K, that is of + # column 1 of all rotation matrices. Similarly, the second `n_img` rows of v, + # denoted v2, are linear combinations of R_{i}^{2}, i=1,...,K, that is, the second + # column of all rotation matrices. + v1 = v[: self.n_img].T + v2 = v[self.n_img : 2 * self.n_img].T + + # Use a least-squares method to get A.T*A and a Cholesky decomposition to find A. + A = self._ATA_solver(v1, v2) + + # Recover the rotations. The first two columns of all rotation + # matrices are given by unmixing V1 and V2 using A. The third + # column is the cross product of the first two. + r1 = np.dot(A.T, v1) + r2 = np.dot(A.T, v2) + r3 = np.cross(r1, r2, axis=0) + rotations = np.stack((r1.T, r2.T, r3.T), axis=-1) + + # Make sure that we got rotations by enforcing R to be + # a rotation (in case the error is large) + rotations = nearest_rotations(rotations) + + return rotations + + @staticmethod + def _ATA_solver(v1, v2): """ - Output the 3D orientations + Uses a least squares method to solve for the linear transformation A + such that A*v1=R1 and A*v2=R2 correspond to the first and second columns + of a sequence of rotation matrices. + + :param v1: 3 x n_img array corresponding to linear combinations of the first + columns of all rotation matrices. + :param v2: 3 x n_img array corresponding to linear combinations of the second + columns of all rotation matrices. + + :return: 3x3 linear transformation mapping v1, v2 to first two columns of rotations. """ - pass + # We look for a linear transformation (3 x 3 matrix) A such that + # A*v1'=R1 and A*v2=R2 are the columns of the rotations matrices. + # Therefore: + # v1 * A'*A v1' = 1 + # v2 * A'*A v2' = 1 + # v1 * A'*A v2' = 0 + # These are 3*K linear equations for 9 matrix entries of A'*A + # Actually, there are only 6 unknown variables, because A'*A is symmetric. + # So we will truncate from 9 variables to 6 variables corresponding + # to the upper half of the matrix A'*A + n_img = v1.shape[-1] + truncated_equations = np.zeros((3 * n_img, 9), dtype=v1.dtype) + k = 0 + for i in range(3): + for j in range(3): + truncated_equations[0::3, k] = v1[i] * v1[j] + truncated_equations[1::3, k] = v2[i] * v2[j] + truncated_equations[2::3, k] = v1[i] * v2[j] + k += 1 + + # b = [1 1 0 1 1 0 ...]' is the right hand side vector + b = np.ones(3 * n_img) + b[2::3] = 0 + + # Find the least squares approximation of A'*A in vector form + ATA_vec = np.linalg.lstsq(truncated_equations, b, rcond=None)[0] + + # Construct the matrix A'*A from the vectorized matrix. + # Note, this is only the lower triangle of A'*A. + ATA = ATA_vec.reshape(3, 3) + + # The Cholesky decomposition of A'*A gives A (lower triangle). + # Note, that `np.linalg.cholesky()` only uses the lower-triangular + # and diagonal elements of ATA. + A = np.linalg.cholesky(ATA) + + return A diff --git a/src/aspire/abinitio/commonline_sync.py b/src/aspire/abinitio/commonline_sync.py index c19640bb35..5e07181f4a 100644 --- a/src/aspire/abinitio/commonline_sync.py +++ b/src/aspire/abinitio/commonline_sync.py @@ -3,6 +3,7 @@ import numpy as np from aspire.abinitio import CLOrient3D, SyncVotingMixin +from aspire.utils import nearest_rotations from aspire.utils.matlab_compat import stable_eigsh logger = logging.getLogger(__name__) @@ -22,7 +23,9 @@ class CLSyncVoting(CLOrient3D, SyncVotingMixin): Journal of Structural Biology, 169, 312-322 (2010). """ - def __init__(self, src, n_rad=None, n_theta=360, max_shift=0.15, shift_step=1): + def __init__( + self, src, n_rad=None, n_theta=360, max_shift=0.15, shift_step=1, mask=True + ): """ Initialize an object for estimating 3D orientations using synchronization matrix @@ -33,6 +36,8 @@ def __init__(self, src, n_rad=None, n_theta=360, max_shift=0.15, shift_step=1): :param max_shift: Determines maximum range for shifts as a proportion of the resolution. Default is 0.15. :param shift_step: Resolution for shift estimation in pixels. Default is 1 pixel. + :param mask: Option to mask `src.images` with a fuzzy mask (boolean). + Default, `True`, applies a mask. """ super().__init__( src, @@ -40,6 +45,7 @@ def __init__(self, src, n_rad=None, n_theta=360, max_shift=0.15, shift_step=1): n_theta=n_theta, max_shift=max_shift, shift_step=shift_step, + mask=mask, ) self.syncmatrix = None @@ -133,8 +139,7 @@ def estimate_rotations(self): rotations[:, :, 2] = r3.T # Make sure that we got rotations by enforcing R to be # a rotation (in case the error is large) - u, _, v = np.linalg.svd(rotations) - np.einsum("ijk, ikl -> ijl", u, v, out=rotations) + rotations = nearest_rotations(rotations) self.rotations = rotations @@ -175,6 +180,7 @@ def _syncmatrix_ij_vote(self, clmatrix, i, j, k_list, n_theta): Given the common lines matrix `clmatrix`, a list of images specified in k_list and the number of common lines n_theta, find the (i, j) rotation block (in X and Y) of the synchronization matrix. + :param clmatrix: The common lines matrix :param i: The i image :param j: The j image diff --git a/src/aspire/abinitio/sync_voting.py b/src/aspire/abinitio/sync_voting.py index 0902c57059..abc11ef6e1 100644 --- a/src/aspire/abinitio/sync_voting.py +++ b/src/aspire/abinitio/sync_voting.py @@ -19,6 +19,7 @@ def _rotratio_eulerangle_vec(self, clmatrix, i, j, good_k, n_theta): Given a common lines matrix, where the index of each common line is in the range of n_theta and a list of good image k from voting results. + :param clmatrix: The common lines matrix :param i: The i image :param j: The j image @@ -61,6 +62,7 @@ def _vote_ij(self, clmatrix, n_theta, i, j, k_list): clmatrix is the common lines matrix, constructed using angular resolution, n_theta. k_list are the images to be used for voting of the pair of images (i ,j). + :param clmatrix: The common lines matrix :param n_theta: The number of points in the theta direction (common lines) :param i: The i image diff --git a/src/aspire/apple/helper.py b/src/aspire/apple/helper.py index 410c3a9b76..25b59a13a3 100644 --- a/src/aspire/apple/helper.py +++ b/src/aspire/apple/helper.py @@ -6,11 +6,11 @@ class PickerHelper: @classmethod def gaussian_filter(cls, size_filter, std): - """Computes low-pass filter. + """ + Computes low-pass filter. - Args: - size_filter: Size of filter (size_filter x size_filter). - std: sigma value in filter. + :param size_filter: Size of filter (size_filter x size_filter). + :param std: sigma value in filter. """ y, x = xp.mgrid[ @@ -27,15 +27,14 @@ def gaussian_filter(cls, size_filter, std): @classmethod def extract_windows(cls, img, block_size): - """Extracts blocks of size (block_size x block_size) from the micrograph. Blocks are + """ + Extracts blocks of size (block_size x block_size) from the micrograph. Blocks are extracted with steps of size (block_size) - Args: - img: Micrograph image. - block_size: required block size. + :param img: Micrograph image. + :param block_size: required block size. - Returns: - 3D Matrix of blocks. For example, img[0] is the first block. + :return: 3D Matrix of blocks. For example, img[0] is the first block. """ # Compute x,y boundary using block_size @@ -57,15 +56,14 @@ def extract_windows(cls, img, block_size): @classmethod def extract_query(cls, img, block_size): - """Extract all query images from the micrograph. windows are + """ + Extract all query images from the micrograph. windows are extracted with steps of size (block_size/2) - Args: - img: Micrograph image. - block_size: Query images must be of size (block_size x block_size). + :param img: Micrograph image. + :param block_size: Query images must be of size (block_size x block_size). - Returns: - 4D Matrix of query images. + :return: 4D Matrix of query images. """ # keep only the portion of the image that can be split into blocks with no remainder @@ -135,16 +133,15 @@ def reference_size(cls, img, container_size): @classmethod def extract_references(cls, img, query_size, container_size): - """Chooses and extracts reference images from the micrograph. + """ + Chooses and extracts reference images from the micrograph. - Args: - img: Micrograph image. - query_size: Reference images must be of the same size of query images, i.e. (query_size x query_size). - container_size: Containers are large regions used to select reference images. The size of each - region is (container_size x container_size) + :param img: Micrograph image. + :param query_size: Reference images must be of the same size of query images, i.e. (query_size x query_size). + :param container_size: Containers are large regions used to select reference images. The size of each + region is (container_size x container_size) - Returns: - 3D Matrix of reference images. windows[0] is the first reference window. + :return: 3D Matrix of reference images. windows[0] is the first reference window. """ img = xp.asarray(img) @@ -220,16 +217,15 @@ def extract_references(cls, img, query_size, container_size): @classmethod def get_training_set(cls, micro_img, bw_mask_p, bw_mask_n, n): - """Gets training set for the SVM classifier. + """ + Gets training set for the SVM classifier. - Args: - micro_img: Micrograph image. - bw_mask_p: Binary image indicating regions from which to extract examples of particles. - bw_mask_n: Binary image indicating regions from which to extract examples of noise. - n: Size of training windows. + :param micro_img: Micrograph image. + :param bw_mask_p: Binary image indicating regions from which to extract examples of particles. + :param bw_mask_n: Binary image indicating regions from which to extract examples of noise. + :param n: Size of training windows. - Returns: - A matrix of features and a vector of labels for the SVM training. + :return: A matrix of features and a vector of labels for the SVM training. """ non_overlap = cls.extract_windows(micro_img, n) @@ -260,15 +256,14 @@ def get_training_set(cls, micro_img, bw_mask_p, bw_mask_n, n): @classmethod def moments(cls, img, query_size): - """Calculates the mean and standard deviation for each window of size (query_size x query_size) + """ + Calculates the mean and standard deviation for each window of size (query_size x query_size) in the micrograph. - Args: - img: Micrograph image. - query_size: Size of windows for which to compute mean and std. + :param img: Micrograph image. + :param query_size: Size of windows for which to compute mean and std. - Returns: - A matrix of mean intensity and a matrix of variance, each containing a single + :return: A matrix of mean intensity and a matrix of variance, each containing a single entry for each possible (query_size x query_size) window in the micrograph. """ diff --git a/src/aspire/basis/__init__.py b/src/aspire/basis/__init__.py index 7d57b57295..8f292a31bc 100644 --- a/src/aspire/basis/__init__.py +++ b/src/aspire/basis/__init__.py @@ -1,7 +1,7 @@ # We'll tell isort not to sort these base classes # isort: off -from .basis import Basis +from .basis import Basis, Coef, ComplexCoef from .steerable import SteerableBasis2D from .fb import FBBasisMixin diff --git a/src/aspire/basis/basis.py b/src/aspire/basis/basis.py index 9a78b55b9d..3dd768a788 100644 --- a/src/aspire/basis/basis.py +++ b/src/aspire/basis/basis.py @@ -10,6 +10,355 @@ logger = logging.getLogger(__name__) +class Coef: + """ + Numpy interoperable container for stacks of real coefficient vectors. + Each `Coef` instance has an associated `Basis`. + """ + + _allowed_dtypes = (np.float32, np.float64) + + def __init__(self, basis, data, dtype=None): + """ + A stack of one or more coefficient arrays. + + The stack can be multidimensional with `stack_size` equal + to the product of the stack dimensions. Singletons will be + expanded into a 1D stack of length one. + + The last axes always represents the coefficient `count`. + + :param basis: `Basis` associated with `data` coefficients. + :param data: Numpy array containing image data with shape + `(..., count)`. + :param dtype: Optionally cast `data` to this dtype. + Defaults to `data.dtype`. + + :return: `Coef` instance holding `data`. + """ + + if not isinstance(data, np.ndarray): + raise ValueError("Coef should be instantiated with an ndarray") + + if data.ndim < 1: + raise ValueError( + "Coef data should be ndarray with shape (N1...) x count or (count)." + ) + elif data.ndim == 1: + data = np.expand_dims(data, axis=0) + + if dtype is None: + self.dtype = data.dtype + else: + self.dtype = np.dtype(dtype) + + # Check real/complex dtype based on class. + self._check_dtype() + + if not isinstance(basis, Basis): + raise TypeError( + f"`basis` is required to be a `Basis` instance, received {type(basis)}" + ) + self.basis = basis + + self._data = data.astype(self.dtype, copy=False) + self.ndim = self._data.ndim + self.shape = self._data.shape + self.stack_ndim = self._data.ndim - 1 + self.stack_shape = self._data.shape[:-1] + self.stack_size = np.prod(self.stack_shape) + self.count = self._data.shape[-1] + + # Derive count from basis. + basis_count = self._get_basis_count() + + if self.count != basis_count: + raise RuntimeError( + f"Provided data count of {self.count} does not match basis count of {basis_count}." + ) + + # Numpy interop + # https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol + self.__array_interface__ = self.asnumpy().__array_interface__ + self.__array__ = self.asnumpy() + + def _check_dtype(self): + """ + Private helper method to check real/complex dtype based on class `_allowed_dtypes`. + + Raises on mismatch. + """ + + if self.dtype not in self._allowed_dtypes: + raise TypeError( + f"{self.__class__.__name__} requires {self._allowed_dtypes} coefficients, attempted {self.dtype}." + ) + + def _get_basis_count(self): + """ + Private helper method to return coefficient count from basis. + + :return: Basis count (integer). + """ + return int(self.basis.count) + + def __len__(self): + """ + Return length of slowest stack axis. + """ + return self.stack_shape[0] + + def asnumpy(self): + """ + Return image data as a (, count) + read-only array view. + + :return: read-only ndarray view + """ + + view = self._data.view() + view.flags.writeable = False + return view + + def _check_key_dims(self, key): + if isinstance(key, tuple) and (len(key) > self._data.ndim): + raise ValueError( + f"Coef stack_dim is {self.stack_ndim}, slice length must be =< {self.ndim}" + ) + + def __getitem__(self, key): + self._check_key_dims(key) + return self.__class__(self.basis, self._data[key]) + + def __setitem__(self, key, value): + self._check_key_dims(key) + self._data[key] = value + + def stack_reshape(self, *args): + """ + Reshape the stack axis. + + :*args: Integer(s) or tuple describing the intended shape. + + :returns: Coef instance + """ + + # If we're passed a tuple, use that + if len(args) == 1 and isinstance(args[0], tuple): + shape = args[0] + else: + # Otherwise use the variadic args + shape = args + + # Sanity check the size + if shape != (-1,) and np.prod(shape) != self.stack_size: + raise ValueError( + f"Number of coefficient vectors {self.stack_size} cannot be reshaped to {shape}." + ) + + return self.__class__( + self.basis, self._data.reshape(*shape, self._data.shape[-1]) + ) + + def copy(self): + """ + Return a new `Coef` instance with a deep copy of the data. + """ + return self.__class__(self.basis, self._data.copy()) + + def evaluate(self): + """ + Return the evaluation of coefficients in the associated `basis`. + """ + return self.basis.evaluate(self) + + def rotate(self, radians, refl=None): + """ + Returns coefs rotated counter-clockwise by `radians`. + + Raises error if underlying coef basis does not support rotations. + + :param radians: Rotation in radians. + :param refl: Optional reflect image (about y=0) (bool) + :return: rotated coefs. + """ + + if not callable(getattr(self.basis, "rotate", None)): + raise RuntimeError( + f"self.basis={self.basis} does not provide `rotate` method." + ) + + return self.basis.rotate(self, radians, refl) + + def shift(self, shifts): + """ + Returns coefs shifted by `shifts`. + + This will transform to real cartesian space, shift, + and transform back to basis space. + + :param coef: Basis coefs. + :param shifts: Shifts in pixels (x,y). Shape (1,2) or (len(coef), 2). + :return: coefs of shifted images. + """ + + if not callable(getattr(self.basis, "shift", None)): + raise RuntimeError( + f"self.basis={self.basis} does not provide `shift` method." + ) + + return self.basis.shift(self, shifts) + + def __mul__(self, other): + """ + Overload operator for multiplication. + + :param other: `Coef` instance to multiply with. + Also allows for multiplication by Numpy arrays and scalars. + :return: `Coef` instance. + """ + + if isinstance(other, Coef): + other = other._data + + return self.__class__(self.basis, self._data * other) + + def __add__(self, other): + """ + Overload operator for addition. + + :param other: `Coef` instance to add. + Also allows for addition by Numpy arrays and scalars. + :return: `Coef` instance. + """ + + if isinstance(other, Coef): + other = other._data + + return self.__class__(self.basis, self._data + other) + + def __sub__(self, other): + """ + Overload operator for subtraction. + + :param other: `Coef` instance to subtract. + Also allows for subtraction by Numpy arrays and scalars. + :return: `Coef` instance. + """ + + if isinstance(other, Coef): + other = other._data + + return self.__class__(self.basis, self._data - other) + + def __neg__(self): + """ + Overload operator for negation. + + :return: `Coef` instance. + """ + + return self.__class__(self.basis, -self._data) + + @property + def size(self): + """ + Return np.size of underlying data. + + This should be `stack_size * count`, + or `len(self) * count`. + """ + return np.size(self._data) + + # This is included for completion, but is not being adopted yet. + def by_indices(self, **kwargs): + """ + Select coefficients by indices (`radial`, `angular`). + + See `SteerableBasis.indices_mask` for argument details. + + :return: Numpy array. + """ + + mask = self.basis.indices_mask(**kwargs) + return self._data[..., mask] + + def to_complex(self): + """ + Convert and return real coefficients as `ComplexCoef`. + """ + return self.basis.to_complex(self) + + def to_real(self): + """ + Not implemented for real Coef. + """ + raise TypeError("Coef already real.") + + +class ComplexCoef(Coef): + """ + Numpy interoperable container for stacks of complex coefficient vectors. + Each `ComplexCoef` instance has an associated `Basis`. + """ + + _allowed_dtypes = (np.complex64, np.complex128) + + def _get_basis_count(self): + """ + Private helper method to return coefficient complex count from basis. + + :return: Basis complex count (integer). + """ + + return int(self.basis.complex_count) + + def evaluate(self): + """ + Return the evaluation of coefficients in the associated `basis`. + """ + return self.to_real().evaluate() + + def rotate(self, radians, refl=None): + """ + Returns coefs rotated counter-clockwise by `radians`. + + Raises error if underlying coef basis does not support rotations. + + :param radians: Rotation in radians. + :param refl: Optional reflect image (about y=0) (bool) + :return: Rotated ComplexCoefs. + """ + + return self.to_real().rotate(radians, refl).to_complex() + + def shift(self, shifts): + """ + Returns complex coefs shifted by `shifts`. + + This will transform to real cartesian space, shift, + and transform back to basis space. + + :param coef: Basis coefs. + :param shifts: Shifts in pixels (x,y). Shape (1,2) or (len(coef), 2). + :return: Complex coefs of shifted images. + """ + + return self.to_real().shift(shifts).to_complex() + + def to_real(self): + """ + Convert and return complex coefficients as `Coef`. + """ + return self.basis.to_real(self) + + def to_complex(self): + """ + Not implemented for ComplexCoef. + """ + raise TypeError("ComplexCoef already complex.") + + class Basis: """ Define a base class for expanding 2D particle images and 3D structure volumes @@ -60,12 +409,6 @@ def _build(self): """ raise NotImplementedError("subclasses must implement this") - def indices(self): - """ - Create the indices for each basis function - """ - raise NotImplementedError("subclasses must implement this") - def _precomp(self): """ Precompute the basis functions at defined sample points @@ -82,22 +425,28 @@ def evaluate(self, v): """ Evaluate coefficient vector in basis - :param v: A coefficient vector (or an array of coefficient vectors) - to be evaluated. The first dimension must correspond to the number of - coefficient vectors, while the second must correspond to `self.count` + :param v: `Coef` instance containing the coefficients to be + evaluated. The first dimension must correspond to the + number of coefficient vectors, while the second must + correspond to `self.count`. :return: The evaluation of the coefficient vector(s) `v` for this basis. This is an Image or a Volume object containing one image/volume for each coefficient vector, and of size `self.sz`. """ + if v.dtype != self.coefficient_dtype: logger.warning( f"{self.__class__.__name__}::evaluate" f" Inconsistent dtypes v: {v.dtype} self coefficient dtype: {self.coefficient_dtype}" ) - # Flatten stack, ndim is wrt Basis (2 or 3) - stack_shape = v.shape[:-1] - v = v.reshape(-1, self.count) + if not isinstance(v, Coef): + raise TypeError(f"`evaluate` should be passed a `Coef`, received {type(v)}") + + # Flatten stack + stack_shape = v.stack_shape + v = v.stack_reshape(-1).asnumpy() + # Compute the transform x = self._evaluate(v) # Restore stack shape @@ -141,8 +490,7 @@ def evaluate_t(self, v): # Restore stack shape x = x.reshape(*stack_shape, self.count) - # Return an ndarray - return x + return Coef(self, x) def _evaluate_t(self, v): raise NotImplementedError("Subclasses should implement this") @@ -175,7 +523,7 @@ def mat_evaluate_t(self, X): """ return mdim_mat_fun_conj(X, len(self.sz), 1, self._evaluate_t) - def expand(self, x): + def expand(self, x, tol=None, atol=0): """ Obtain coefficients in the basis from those in standard coordinate basis @@ -184,12 +532,16 @@ def expand(self, x): :param x: An array whose last two or three dimensions are to be expanded the desired basis. These dimensions must equal `self.sz`. + :param tol: Relative tolerance for convergence, `norm(residual) <= max(tol*norm(b), atol)`. + Deafult `None` sets to dtype's `eps`*10. + :param atol: Absolute tolerance for convergence, `norm(residual) <= max(tol*norm(b), atol)`. :return: The coefficients of `v` expanded in the desired basis. The last dimension of `v` is with size of `count` and the first dimensions of the return value correspond to those first dimensions of `x`. """ + if isinstance(x, Image) or isinstance(x, Volume): x = x.asnumpy() @@ -199,6 +551,7 @@ def expand(self, x): f" Inconsistent dtypes x: {x.dtype} self: {self.dtype}" ) + # TODO: We should only need to do this block when we are not passed Image/Volume. # check that last ndim values of input shape match # the shape of this basis assert ( @@ -211,25 +564,27 @@ def expand(self, x): operator = LinearOperator( shape=(self.count, self.count), - matvec=lambda v: self.evaluate_t(self.evaluate(v)), + matvec=lambda v: self.evaluate_t(self.evaluate(Coef(self, v))), dtype=self.dtype, ) - # TODO: (from MATLAB implementation) - Check that this tolerance make sense for multiple columns in v - tol = 10 * np.finfo(x.dtype).eps - logger.info("Expanding array in basis") + if tol is None: + # TODO: (from MATLAB implementation) - Check that this tolerance make sense for multiple columns in v + tol = 10 * np.finfo(x.dtype).eps + logger.info(f"Expanding array in basis with tol={tol} atol={atol}") # number of image samples n_data = x.shape[0] v = np.zeros((n_data, self.count), dtype=self.coefficient_dtype) for isample in range(0, n_data): - b = self.evaluate_t(self._cls(x[isample])).T + b = self.evaluate_t(self._cls(x[isample])).asnumpy().T # TODO: need check the initial condition x0 can improve the results or not. - v[isample], info = cg(operator, b, tol=tol, atol=0) + v[isample], info = cg(operator, b, tol=tol, atol=atol) if info != 0: - raise RuntimeError("Unable to converge!") + raise RuntimeError(f"Unable to converge! cg info={info}") # return v coefficients with the last dimension of self.count v = v.reshape((*sz_roll, self.count)) - return v + + return Coef(self, v) diff --git a/src/aspire/basis/basis_utils.py b/src/aspire/basis/basis_utils.py index 05366a58e3..fe599e9fdc 100644 --- a/src/aspire/basis/basis_utils.py +++ b/src/aspire/basis/basis_utils.py @@ -4,10 +4,12 @@ """ import logging +import warnings import numpy as np from numpy import diff, exp, log, pi from numpy.polynomial.legendre import leggauss +from pyshtools.expand import spharm_lm from scipy.special import jn, jv, sph_harm from aspire.utils import grid_2d, grid_3d @@ -170,7 +172,29 @@ def real_sph_harmonic(j, m, theta, phi): """ abs_m = abs(m) - y = sph_harm(abs_m, j, phi, theta) + # The `scipy` sph_harm implementation is much faster, + # but incorrectly returns NaN for high orders. + # For higher order use `pyshtools`. + if j < 86: + y = sph_harm(abs_m, j, phi, theta) + else: + warnings.warn( + "Computing higher order spherical harmonics is slow." + " Consider using `FFBBasis3D` or decreasing volume size.", + stacklevel=1, + ) + + y = spharm_lm( + j, + abs_m, + theta, + phi, + kind="complex", + degrees=False, + csphase=-1, + normalization="ortho", + ) + if m < 0: y = np.sqrt(2) * np.imag(y) elif m > 0: diff --git a/src/aspire/basis/fb_2d.py b/src/aspire/basis/fb_2d.py index b2bcc68f69..2698e41ed6 100644 --- a/src/aspire/basis/fb_2d.py +++ b/src/aspire/basis/fb_2d.py @@ -5,7 +5,7 @@ from aspire.basis import FBBasisMixin, SteerableBasis2D from aspire.basis.basis_utils import unique_coords_nd -from aspire.utils import complex_type, real_type, roll_dim, unroll_dim +from aspire.utils import roll_dim, unroll_dim from aspire.utils.matlab_compat import m_flatten, m_reshape logger = logging.getLogger(__name__) @@ -63,7 +63,6 @@ def _build(self): # generate 1D indices for basis functions self._compute_indices() - self._indices = self.indices() # get normalized factors self.radial_norms, self.angular_norms = self.norms() @@ -109,19 +108,6 @@ def _compute_indices(self): self.angular_indices = indices_ells self.radial_indices = indices_ks self.signs_indices = indices_sgns - # Relating to paper: a[i] = a_ell_ks = a_angularindices[i]_radialindices[i] - self.complex_angular_indices = indices_ells[self._pos] # k - self.complex_radial_indices = indices_ks[self._pos] # q - - def indices(self): - """ - Return the precomputed indices for each basis function. - """ - return { - "ells": self.angular_indices, - "ks": self.radial_indices, - "sgns": self.signs_indices, - } def _precomp(self): """ @@ -282,99 +268,6 @@ def _evaluate_t(self, v): v = roll_dim(v, sz_roll) return v.T # RCOPT - def to_complex(self, coef): - """ - Return complex valued representation of coefficients. - This can be useful when comparing or implementing methods - from literature. - - There is a corresponding method, to_real. - - :param coef: Coefficients from this basis. - :return: Complex coefficent representation from this basis. - """ - - if coef.ndim == 1: - coef = coef.reshape(1, -1) - - if coef.dtype not in (np.float64, np.float32): - raise TypeError("coef provided to to_complex should be real.") - - # Pass through dtype precions, but check and warn if mismatched. - dtype = complex_type(coef.dtype) - if coef.dtype != self.dtype: - logger.warning( - f"coef dtype {coef.dtype} does not match precision of basis.dtype {self.dtype}, returning {dtype}." - ) - - # Return the same precision as coef - imaginary = dtype(1j) - - ccoef = np.zeros((coef.shape[0], self.complex_count), dtype=dtype) - - ind = 0 - idx = np.arange(self.k_max[0], dtype=int) - ind += np.size(idx) - - ccoef[:, idx] = coef[:, idx] - - for ell in range(1, self.ell_max + 1): - idx = ind + np.arange(self.k_max[ell], dtype=int) - ccoef[:, idx] = ( - coef[:, self._pos[idx]] - imaginary * coef[:, self._neg[idx]] - ) / 2.0 - - ind += np.size(idx) - - return ccoef - - def to_real(self, complex_coef): - """ - Return real valued representation of complex coefficients. - This can be useful when comparing or implementing methods - from literature. - - There is a corresponding method, to_complex. - - :param complex_coef: Complex coefficients from this basis. - :return: Real coefficent representation from this basis. - """ - if complex_coef.ndim == 1: - complex_coef = complex_coef.reshape(1, -1) - - if complex_coef.dtype not in (np.complex128, np.complex64): - raise TypeError("coef provided to to_real should be complex.") - - # Pass through dtype precions, but check and warn if mismatched. - dtype = real_type(complex_coef.dtype) - if dtype != self.dtype: - logger.warning( - f"Complex coef dtype {complex_coef.dtype} does not match precision of basis.dtype {self.dtype}, returning {dtype}." - ) - - coef = np.zeros((complex_coef.shape[0], self.count), dtype=dtype) - - ind = 0 - idx = np.arange(self.k_max[0], dtype=int) - ind += np.size(idx) - ind_pos = ind - - coef[:, idx] = complex_coef[:, idx].real - - for ell in range(1, self.ell_max + 1): - idx = ind + np.arange(self.k_max[ell], dtype=int) - idx_pos = ind_pos + np.arange(self.k_max[ell], dtype=int) - idx_neg = idx_pos + self.k_max[ell] - - c = complex_coef[:, idx] - coef[:, idx_pos] = 2.0 * np.real(c) - coef[:, idx_neg] = -2.0 * np.imag(c) - - ind += np.size(idx) - ind_pos += 2 * self.k_max[ell] - - return coef - def calculate_bispectrum( self, coef, flatten=False, filter_nonzero_freqs=False, freq_cutoff=None ): @@ -401,3 +294,9 @@ def calculate_bispectrum( filter_nonzero_freqs=filter_nonzero_freqs, freq_cutoff=freq_cutoff, ) + + def filter_to_basis_mat(self, *args, **kwargs): + """ + See `SteerableBasis2D.filter_to_basis_mat`. + """ + return super().filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 74d46cc20c..5a5c7c3f27 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -8,6 +8,7 @@ from aspire.basis.basis_utils import lgwt from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp +from aspire.operators import BlkDiagMatrix from aspire.utils import complex_type from aspire.utils.matlab_compat import m_reshape @@ -50,7 +51,6 @@ def _build(self): # generate 1D indices for basis functions self._compute_indices() - self._indices = self.indices() # get normalized factors self.radial_norms, self.angular_norms = self.norms() @@ -95,12 +95,6 @@ def _precomp(self): return {"gl_nodes": r, "gl_weights": w, "radial": radial, "freqs": freqs} - def get_radial(self): - """ - Return precomputed radial part - """ - return self._precomp["radial"] - def _evaluate(self, v): """ Evaluate coefficients in standard 2D coordinate basis from those in FB basis @@ -123,7 +117,6 @@ def _evaluate(self, v): # go through each basis function and find corresponding coefficient pf = np.zeros((n_data, 2 * n_theta, n_r), dtype=complex_type(self.dtype)) - mask = self._indices["ells"] == 0 ind = 0 @@ -131,7 +124,7 @@ def _evaluate(self, v): # include the normalization factor of angular part into radial part radial_norm = self._precomp["radial"] / np.expand_dims(self.angular_norms, 1) - pf[:, 0, :] = v[:, mask] @ radial_norm[idx] + pf[:, 0, :] = v[:, self._zero_angular_inds] @ radial_norm[idx] ind = ind + np.size(idx) ind_pos = ind @@ -221,11 +214,10 @@ def _evaluate_t(self, x): # go through each basis function and find the corresponding coefficient ind = 0 idx = ind + np.arange(self.k_max[0]) - mask = self._indices["ells"] == 0 # include the normalization factor of angular part into radial part radial_norm = self._precomp["radial"] / np.expand_dims(self.angular_norms, 1) - v[:, mask] = pf[:, :, 0].real @ radial_norm[idx].T + v[:, self._zero_angular_inds] = pf[:, :, 0].real @ radial_norm[idx].T ind = ind + np.size(idx) ind_pos = ind @@ -251,3 +243,60 @@ def _evaluate_t(self, x): ind_pos = ind_pos + 2 * self.k_max[ell] return v + + def filter_to_basis_mat(self, f, **kwargs): + """ + See `SteerableBasis2D.filter_to_basis_mat`. + """ + # Note 'method' and 'truncate' not relevant for this optimized FFB code. + if kwargs.get("method", None) is not None: + raise NotImplementedError( + "`FFBBasis2D.filter_to_basis_mat` method {method} not supported." + " Use `method=None`." + ) + + # These form a circular dependence, import locally until time to clean up. + from aspire.basis.basis_utils import lgwt + + # Get the filter's evaluate function. + h_fun = f.evaluate + + # Set same dimensions as basis object + n_k = self.n_r + n_theta = self.n_theta + radial = self._precomp["radial"] + + # get 2D grid in polar coordinate + k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) + k, theta = np.meshgrid( + k_vals, np.arange(n_theta) * 2 * np.pi / (2 * n_theta), indexing="ij" + ) + + # Get function values in polar 2D grid and average out angle contribution + omegax = k * np.cos(theta) + omegay = k * np.sin(theta) + omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C"))) + + h_vals2d = h_fun(omega).reshape(n_k, n_theta).astype(self.dtype) + h_vals = np.sum(h_vals2d, axis=1) / n_theta + + # Represent 1D function values in basis + h_basis = BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) + ind_ell = 0 + for ell in range(0, self.ell_max + 1): + k_max = self.k_max[ell] + rmat = 2 * k_vals.reshape(n_k, 1) * self.r0[ell][0:k_max].T + basis_vals = np.zeros_like(rmat) + ind_radial = np.sum(self.k_max[0:ell]) + basis_vals[:, 0:k_max] = radial[ind_radial : ind_radial + k_max].T + h_basis_vals = basis_vals * h_vals.reshape(n_k, 1) + h_basis_ell = basis_vals.T @ ( + h_basis_vals * k_vals.reshape(n_k, 1) * wts.reshape(n_k, 1) + ) + h_basis[ind_ell] = h_basis_ell + ind_ell += 1 + if ell > 0: + h_basis[ind_ell] = h_basis[ind_ell - 1] + ind_ell += 1 + + return h_basis diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index e9a595d727..423d37c093 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -5,8 +5,8 @@ from scipy.fft import dct, idct from scipy.special import jv -from aspire.basis import FBBasisMixin, SteerableBasis2D -from aspire.basis.basis_utils import besselj_zeros +from aspire.basis import Coef, FBBasisMixin, SteerableBasis2D +from aspire.basis.basis_utils import besselj_zeros, lgwt from aspire.basis.fle_2d_utils import ( barycentric_interp_sparse, precomp_transform_complex_to_real, @@ -14,6 +14,7 @@ ) from aspire.nufft import anufft, nufft from aspire.numeric import fft +from aspire.operators import DiagMatrix from aspire.utils import complex_type, grid_2d logger = logging.getLogger(__name__) @@ -30,6 +31,9 @@ class FLEBasis2D(SteerableBasis2D, FBBasisMixin): https://arxiv.org/pdf/2207.13674.pdf """ + # Default matrix type for basis representation. + matrix_type = DiagMatrix + def __init__( self, size, bandlimit=None, epsilon=1e-10, dtype=np.float32, match_fb=True ): @@ -118,16 +122,26 @@ def _build_indices(self): self.radial_indices = self._fle_radial_indices[self._fle_to_fb_indices] # Note we negate the FLE signs? self.signs_indices = self._fle_signs_indices[self._fle_to_fb_indices] - - def indices(self): - """ - Return the precomputed indices for each basis function. - """ - return { - "ells": self.angular_indices, - "ks": self.radial_indices, - "sgns": self.signs_indices, - } + # These map indices in complex array to pair of indices in real array + self.complex_count = sum(self.k_max) + self._pos = np.zeros(self.complex_count, dtype=int) + self._neg = np.zeros(self.complex_count, dtype=int) + i = 0 + ci = 0 + for ell in range(self.ell_max + 1): + sgns = (1,) if ell == 0 else (1, -1) + ks = np.arange(0, self.k_max[ell]) + + for sgn in sgns: + rng = np.arange(i, i + len(ks)) + if sgn == 1: + self._pos[ci + ks] = rng + elif sgn == -1: + self._neg[ci + ks] = rng + + i += len(ks) + + ci += len(ks) def _precomp(self): """ @@ -392,6 +406,7 @@ def _threshold_basis_functions(self): """ Implements the bandlimit threshold which caps the number of basis functions that are actually required. + :return: The final overall number of basis functions to be used. """ # Maximum bandlimit @@ -457,7 +472,7 @@ def _create_basis_functions(self): self.norm_constants = norm_constants self.basis_functions = basis_functions - def _evaluate(self, coeffs): + def _evaluate(self, coefs): """ Evaluates FLE coefficients and return in standard 2D Cartesian coordinates. @@ -466,10 +481,10 @@ def _evaluate(self, coeffs): :return: An Image object containing the corresponding images. """ # convert from FB order - coeffs = coeffs[..., self._fb_to_fle_indices] + coefs = coefs[..., self._fb_to_fle_indices] # See Remark 3.3 and Section 3.4 - betas = self._step3(coeffs) + betas = self._step3(coefs) z = self._step2(betas) im = self._step1(z) return im.astype(self.dtype) @@ -487,11 +502,11 @@ def _evaluate_t(self, imgs): imgs[:, self.radial_mask] = 0 z = self._step1_t(imgs) b = self._step2_t(z) - coeffs = self._step3_t(b) + coefs = self._step3_t(b) # return in FB order - coeffs = coeffs[..., self._fle_to_fb_indices] - return coeffs.astype(self.coefficient_dtype, copy=False) + coefs = coefs[..., self._fle_to_fb_indices] + return coefs.astype(self.coefficient_dtype, copy=False) def _step1_t(self, im): """ @@ -545,31 +560,31 @@ def _step3_t(self, betas): betas = idct(betas, axis=1, type=2) * 2 * betas.shape[1] betas = np.moveaxis(betas, 0, -1) - coeffs = np.zeros((self.count, num_img), dtype=np.float64) + coefs = np.zeros((self.count, num_img), dtype=np.float64) for i in range(self.ell_p_max + 1): - coeffs[self.idx_list[i]] = self.A3[i] @ betas[:, i, :] - coeffs = coeffs.T + coefs[self.idx_list[i]] = self.A3[i] @ betas[:, i, :] + coefs = coefs.T - return coeffs * self.norm_constants / self.h + return coefs * self.norm_constants / self.h - def _step3(self, coeffs): + def _step3(self, coefs): """ Adjoint of _step3_t and Step 1 of the forward transformation (coefficients to images). Uses barycenteric interpolation in reverse to compute values of Betas at Chebyshev nodes, given an array of FLE coefficients. """ - coeffs = coeffs.copy().reshape(-1, self.count) - num_img = coeffs.shape[0] - coeffs *= self.h * self.norm_constants - coeffs = coeffs.T + coefs = coefs.copy().reshape(-1, self.count) + num_img = coefs.shape[0] + coefs *= self.h * self.norm_constants + coefs = coefs.T out = np.zeros( (self.num_interp, 2 * self.max_ell + 1, num_img), dtype=np.float64, ) for i in range(self.ell_p_max + 1): - out[:, i, :] = self.A3_T[i] @ coeffs[self.idx_list[i]] + out[:, i, :] = self.A3_T[i] @ coefs[self.idx_list[i]] out = np.moveaxis(out, -1, 0) if self.num_interp > self.num_radial_nodes: out = dct(out, axis=1, type=2) @@ -626,6 +641,7 @@ def _create_dense_matrix(self): """ Directly computes the transformation matrix from Cartesian coordinates to FLE coordinates without any shortcuts. + :return: A NumPy array of size `(self.nres**2, self.count)` containing the matrix entries. """ @@ -642,56 +658,71 @@ def _create_dense_matrix(self): return B - def lowpass(self, coeffs, bandlimit): + def lowpass(self, coefs, bandlimit): """ - Apply a low-pass filter to FLE coefficients `coeffs` with threshold `bandlimit`. - :param coeffs: A NumPy array of FLE coefficients, of shape (num_images, self.count) + Apply a low-pass filter to FLE coefficients `coefs` with threshold `bandlimit`. + + :param coefs: A `Coef` instance containing FLE coefficients. :param bandlimit: Integer bandlimit (max frequency). :return: Band-limited coefficient array. """ - if len(coeffs.shape) == 1: - coeffs = coeffs.reshape((1, coeffs.shape[0])) - assert ( - len(coeffs.shape) == 2 - ), "Input a stack of coefficients of dimension (num_images, self.count)." - assert ( - coeffs.shape[1] == self.count - ), "Number of coefficients must match self.count." + + if not isinstance(coefs, Coef): + raise TypeError( + f"`coefs` should be a `Coef` instance, received {type(coefs)}." + ) + + # Copy to mutate the coefs. + coefs = coefs.asnumpy().copy() k = self.count - 1 for _ in range(self.count): if self.bessel_zeros[k] / (np.pi) > (bandlimit - 1) // 2: k = k - 1 - coeffs[:, k + 1 :] = 0 + coefs[:, k + 1 :] = 0 - return coeffs + return Coef(self, coefs) - def radial_convolve(self, coeffs, radial_img): + def radial_convolve(self, coefs, radial_img): """ Convolve a stack of FLE coefficients with a 2D radial function. - :param coeffs: A NumPy array of FLE coefficients of size (num_images, self.count). + + :param coefs: A `Coef` instance containing FLE coefficients. :param radial_img: A 2D NumPy array of size (self.nres, self.nres). :return: Convolved FLE coefficients. """ - num_img = coeffs.shape[0] - coeffs_conv = np.zeros(coeffs.shape) + + if not isinstance(coefs, Coef): + raise TypeError( + f"`coefs` should be a `Coef` instance, received {type(coefs)}." + ) + + if len(coefs.stack_shape) > 1: + raise NotImplementedError( + "`radial_convolve` currently only implemented for 1D stacks." + ) + + coefs = coefs.asnumpy() + + num_img = coefs.shape[0] + coefs_conv = np.zeros(coefs.shape) # Convert to internal FLE indices ordering - coeffs = coeffs[..., self._fb_to_fle_indices] + coefs = coefs[..., self._fb_to_fle_indices] for k in range(num_img): - _coeffs = coeffs[k, :] + _coefs = coefs[k, :] z = self._step1_t(radial_img) b = self._step2_t(z) weights = self._radial_convolve_weights(b) b = weights / (self.h**2) b = b.reshape(self.count) - coeffs_conv[k, :] = np.real(self.c2r @ (b * (self.r2c @ _coeffs).flatten())) + coefs_conv[k, :] = np.real(self.c2r @ (b * (self.r2c @ _coefs).flatten())) # Convert from internal FLE ordering to FB convention - coeffs_conv = coeffs_conv[..., self._fle_to_fb_indices] + coefs_conv = coefs_conv[..., self._fle_to_fb_indices] - return coeffs_conv + return Coef(self, coefs_conv) def _radial_convolve_weights(self, b): """ @@ -712,3 +743,46 @@ def _radial_convolve_weights(self, b): a[self.idx_list[i]] = y[i] return a.flatten() + + def filter_to_basis_mat(self, f, **kwargs): + """ + See `SteerableBasis2D.filter_to_basis_mat`. + """ + # Note 'method' and 'truncate' not relevant for this optimized FLE code. + if kwargs.get("method", None) is not None: + raise NotImplementedError( + "`FLEBasis2D.filter_to_basis_mat` method {method} not supported." + " Use `method=None`." + ) + + # Get the filter's evaluate function. + h_fun = f.evaluate + + # Set same dimensions as basis object + n_k = 2 * self.num_radial_nodes # self.n_r + n_theta = self.num_angular_nodes # self.n_theta + + # get 2D grid in polar coordinate + k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) + k, theta = np.meshgrid( + k_vals, np.arange(n_theta) * 2 * np.pi / (2 * n_theta), indexing="ij" + ) + + # Get function values in polar 2D grid and average out angle contribution + # NOTE: should probably just let the ctf objects handle this... + omegax = k * np.cos(theta) + omegay = k * np.sin(theta) + omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C"))) + + h_vals2d = h_fun(omega).reshape(n_k, n_theta).astype(self.dtype) + h_vals = np.sum(h_vals2d, axis=1) / n_theta + + h_basis = np.zeros(self.count, dtype=self.dtype) + # For now we just need to handle 1D (stack of one ctf) + for j in range(self.ell_p_max + 1): + h_basis[self.idx_list[j]] = self.A3[j] @ h_vals + + # Convert from internal FLE ordering to FB convention + h_basis = h_basis[self._fle_to_fb_indices] + + return DiagMatrix(h_basis) diff --git a/src/aspire/basis/fpswf_2d.py b/src/aspire/basis/fpswf_2d.py index f030d12e20..5476eed61d 100644 --- a/src/aspire/basis/fpswf_2d.py +++ b/src/aspire/basis/fpswf_2d.py @@ -5,6 +5,7 @@ from scipy.optimize import least_squares from scipy.special import jn +from aspire.basis import ComplexCoef from aspire.basis.basis_utils import lgwt, t_x_mat, t_x_mat_dot from aspire.basis.pswf_2d import PSWFBasis2D from aspire.nufft import nufft @@ -97,7 +98,9 @@ def _precomp(self): n_max, ) = self._pswf_integration_sub_routine() - self.us_fft_pts = us_fft_pts + self.us_fft_pts = us_fft_pts.astype( + self.dtype, copy=False + ) # TODO, debug where this is incorrect dtype self.blk_r = blk_r self.num_angular_pts = num_angular_pts self.r_quad_indices = r_quad_indices @@ -121,9 +124,9 @@ def _evaluate_t(self, images): nfft_res = nufft(images_disk, self.us_fft_pts) # Accumulate coefficients - coefficients = self._pswf_integration(nfft_res) + coefficients = ComplexCoef(self, self._pswf_integration(nfft_res)) - return coefficients + return coefficients.to_real().asnumpy() def _generate_pswf_quad( self, n, bandlimit, phi_approximate_error, lambda_max, epsilon @@ -178,8 +181,8 @@ def _generate_pswf_quad( pts_y = quad_rule_pts_r * np.sin(quad_rule_pts_theta) return ( - pts_x, pts_y, + pts_x, quad_rule_weights, radial_quad_points, quad_rule_radial_weights, @@ -214,7 +217,9 @@ def _generate_pswf_radial_quad( if k % 2 == 0: k = k + 1 - range_array = np.arange(approx_length).reshape((1, approx_length)) + range_array = np.arange(approx_length, dtype=self.dtype).reshape( + (1, approx_length) + ) idx_for_quad_nodes = int((k + 1) / 2) num_quad_pts = idx_for_quad_nodes - 1 @@ -277,14 +282,13 @@ def phi_for_quad_nodes(t): fun_vec = phi_for_quad_nodes(x) sign_flipping_vec = np.where(np.sign(fun_vec[:-1]) != np.sign(fun_vec[1:]))[0] - phi_zeros = np.zeros(idx_for_quad_nodes - 1) + phi_zeros = np.zeros(idx_for_quad_nodes - 1, dtype=self.dtype) tmp = phi_for_quad_nodes(x) for i, j in enumerate(sign_flipping_vec[: idx_for_quad_nodes - 1]): new_zero = x[j] - tmp[j] * (x[j + 1] - x[j]) / (tmp[j + 1] - tmp[j]) phi_zeros[i] = new_zero - phi_zeros = np.array(phi_zeros) return phi_zeros def _sum_minus_cumsum_smaller_eps(self, x, eps): @@ -299,17 +303,17 @@ def _pswf_integration_sub_routine(self): r_quad_indices.extend(num_angular_pts) r_quad_indices = np.cumsum(r_quad_indices, dtype="int") - n_max = int(max(self.ang_freqs) + 1) + n_max = int(max(self.complex_angular_indices) + 1) numel_for_n = np.zeros(n_max, dtype="int") for i in range(n_max): - numel_for_n[i] = np.count_nonzero(self.ang_freqs == i) + numel_for_n[i] = np.count_nonzero(self.complex_angular_indices == i) indices_for_n = [0] indices_for_n.extend(numel_for_n) indices_for_n = np.cumsum(indices_for_n, dtype="int") - blk_r = [0] * n_max + blk_r = [0] * n_max # TODO, consider array here temp_const = self.bandlimit / (2 * np.pi * self.rcut) for i in range(n_max): blk_r[i] = ( @@ -351,13 +355,20 @@ def _pswf_integration(self, images_nufft): r_n_eval_mat = r_n_eval_mat.reshape( (len(self.radial_quad_pts) * self.n_max, num_images), order="F" ) - coeff_vec_quad = np.zeros( - (num_images, len(self.ang_freqs)), dtype=complex_type(self.dtype) + coef_vec_quad = np.zeros( + (num_images, len(self.complex_angular_indices)), + dtype=complex_type(self.dtype), ) m = self.pswf_radial_quad.shape[1] for i in range(self.n_max): - coeff_vec_quad[ + coef_vec_quad[ :, self.indices_for_n[i] + np.arange(self.numel_for_n[i]) ] = np.dot(self.blk_r[i], r_n_eval_mat[i * m : (i + 1) * m, :]).T - return coeff_vec_quad + return coef_vec_quad + + def filter_to_basis_mat(self, *args, **kwargs): + """ + See `SteerableBasis2D.filter_to_basis_mat`. + """ + return super().filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/fspca.py b/src/aspire/basis/fspca.py index 31a926472a..16500918af 100644 --- a/src/aspire/basis/fspca.py +++ b/src/aspire/basis/fspca.py @@ -3,7 +3,7 @@ import numpy as np -from aspire.basis import FFBBasis2D, SteerableBasis2D +from aspire.basis import Coef, ComplexCoef, FFBBasis2D, SteerableBasis2D from aspire.covariance import BatchedRotCov2D from aspire.operators import BlkDiagMatrix from aspire.utils import complex_type, fix_signs, real_type @@ -42,7 +42,7 @@ def __init__( Default value of `None` will use `self.basis.count`. :param noise_var: Optionally assign noise variance. Default value of `None` will estimate noise with WhiteNoiseEstimator. - Use 0 when using clean images so cov2d skips applying noisy covar coeffs.. + Use 0 when using clean images so cov2d skips applying noisy covar coefs.. :param batch_size: Batch size for computing basis coefficients. `batch_size` is also passed to BatchedRotCov2D. """ @@ -71,7 +71,7 @@ def __init__( self.complex_count = self.basis.complex_count self.angular_indices = self.basis.angular_indices self.radial_indices = self.basis.radial_indices - self.signs_indices = self.basis._indices["sgns"] + self.signs_indices = self.basis.signs_indices self.complex_angular_indices = self.basis.complex_angular_indices self.complex_radial_indices = self.basis.complex_radial_indices @@ -162,13 +162,13 @@ def build(self): } self.mean_coef_est = cov2d.get_mean() self.covar_coef_est = cov2d.get_covar( - mean_coeff=self.mean_coef_est, + mean_coef=self.mean_coef_est, noise_var=self.noise_var, covar_est_opt=covar_opt, ) # Create the arrays to be packed by _compute_spca - self.eigvals = np.zeros(self.basis.count, dtype=self.dtype) + self._eigvals = np.zeros(self.basis.count, dtype=self.dtype) self.eigvecs = BlkDiagMatrix.empty(2 * self.basis.ell_max + 1, dtype=self.dtype) @@ -211,7 +211,7 @@ def _compute_spca(self): basis_inds.append(_basis_inds) # Store the eigvals for this block, note this is a flat array. - self.eigvals[_basis_inds] = eigvals_k + self._eigvals[_basis_inds] = eigvals_k # Store the eigvecs, note this is a BlkDiagMatrix and is assigned incrementally. self.eigvecs[angular_index] = eigvecs_k @@ -231,7 +231,7 @@ def _compute_spca(self): # # We can pass a full or truncated slice of sorted_indices to any array indexed by # the coefs. This is used later for compression and index re-generation. - self.sorted_indices = np.argsort(-np.abs(self.eigvals)) + self.sorted_indices = np.argsort(-np.abs(self._eigvals)) compressed_indices = self._get_compressed_indices() @@ -240,10 +240,10 @@ def _compute_spca(self): ) # Compute coefficient vector of mean image at zeroth component - self.mean_coef_zero = self.mean_coef_est[self.angular_indices == 0] + self.mean_coef_zero = self.mean_coef_est.asnumpy()[0][self.angular_indices == 0] # Define mask for zero angular mode, used in loop below - zero_ell_mask = self.basis._indices["ells"] == 0 + zero_ell_mask = self.basis.angular_indices == 0 # Apply Data matrix batchwise num_batches = (self.src.n + self.batch_size - 1) // self.batch_size @@ -252,6 +252,7 @@ def _compute_spca(self): start = i * self.batch_size finish = min((i + 1) * self.batch_size, self.src.n) batch_coef = self.basis.evaluate_t(self.src.images[start:finish]) + batch_coef = batch_coef.asnumpy() # Make the Data matrix (A_k) # # Construct A_k, matrix of expansion coefficients a^i_k_q @@ -274,9 +275,9 @@ def _compute_spca(self): for ell in range( 1, self.basis.ell_max + 1 ): # `ell` in this code is `k` from paper - mask_ell = self.basis._indices["ells"] == ell - mask_pos = mask_ell & (self.basis._indices["sgns"] == +1) - mask_neg = mask_ell & (self.basis._indices["sgns"] == -1) + mask_ell = self.basis.angular_indices == ell + mask_pos = mask_ell & (self.basis.signs_indices == +1) + mask_neg = mask_ell & (self.basis.signs_indices == -1) A.append(batch_coef[:, mask_pos]) A.append(batch_coef[:, mask_neg]) @@ -330,13 +331,15 @@ def expand(self, x): Fourier Bessel basis. :return: Stack of coefs in the FSPCABasis. """ + if not isinstance(x, Coef): + raise TypeError(f"'x' should be `Coef` instance, received {type(x)}.") # Apply linear combination defined by FSPCA (eigvecs) - c_fspca = x @ self.eigvecs + c_fspca = x.asnumpy() @ self.eigvecs assert c_fspca.shape == (x.shape[0], self.count) - return c_fspca + return Coef(self, c_fspca) def evaluate_to_image_basis(self, c): """ @@ -346,6 +349,9 @@ def evaluate_to_image_basis(self, c): :return: The Image instance representing a stack of images in the standard 2D coordinate basis.. """ + if not isinstance(c, Coef): + raise TypeError(f"'c' should be `Coef` instance, received {type(c)}.") + c_fb = self.evaluate(c) return self.basis.evaluate(c_fb) @@ -358,6 +364,9 @@ def evaluate(self, c): :return: The (real) coefs representing a stack of images in self.basis """ + if not isinstance(c, Coef): + raise TypeError(f"'c' should be `Coef` instance, received {type(c)}.") + # apply FSPCA eigenvector to coefs c, yields coefs in self.basis eigvecs = self.eigvecs if isinstance(eigvecs, BlkDiagMatrix): @@ -368,7 +377,7 @@ def evaluate(self, c): # corrected_c[:, self.angular_indices!=0] *= 2 # return corrected_c @ eigvecs.T - return c @ eigvecs.T + return Coef(self.basis, c @ eigvecs.T) # TODO: Python>=3.8 @cached_property def _get_compressed_indices(self): @@ -396,8 +405,8 @@ def _get_compressed_indices(self): top_components = list(ordered_components)[: self.components] # Now we need to find the locations of both the + and - sgns. - pos_mask = self.basis._indices["sgns"] == 1 - neg_mask = self.basis._indices["sgns"] == -1 + pos_mask = self.basis.signs_indices == 1 + neg_mask = self.basis.signs_indices == -1 compressed_indices = [] for k, q in top_components: # Compute the locations of coefs we're interested in. @@ -435,7 +444,7 @@ def _compress(self): compressed_indices = self._get_compressed_indices() self.count = len(compressed_indices) - self.eigvals = self.eigvals[compressed_indices] + self._eigvals = self._eigvals[compressed_indices] if isinstance(self.eigvecs, BlkDiagMatrix): self.eigvecs = self.eigvecs.dense() self.eigvecs = self.eigvecs[:, compressed_indices] @@ -462,11 +471,11 @@ def to_complex(self, coef): There is a corresponding method, to_real. :param coef: Coefficients from this basis. - :return: Complex coefficent representation from this basis. + :return: Complex coeficent representation from this basis. """ - - if coef.ndim == 1: - coef = coef.reshape(1, -1) + if not isinstance(coef, Coef): + raise TypeError(f"'coef' should be `Coef` instance, received {type(coef)}.") + coef = coef.asnumpy() if coef.dtype not in (np.float64, np.float32): raise TypeError("coef provided to to_complex should be real.") @@ -503,7 +512,7 @@ def to_complex(self, coef): for i, k in enumerate(ccoef_d.keys()): ccoef[:, i] = ccoef_d[k] - return ccoef + return ComplexCoef(self, ccoef) def to_real(self, complex_coef): """ @@ -514,7 +523,7 @@ def to_real(self, complex_coef): There is a corresponding method, to_complex. :param complex_coef: Complex coefficients from this basis. - :return: Real coefficent representation from this basis. + :return: Real coefficient representation from this basis. """ if complex_coef.ndim == 1: @@ -543,7 +552,7 @@ def to_real(self, complex_coef): coef[:, pos_i] = 2.0 * complex_coef[:, i].real coef[:, neg_i] = -2.0 * complex_coef[:, i].imag - return coef + return Coef(self, coef) def calculate_bispectrum( self, coef, flatten=False, filter_nonzero_freqs=False, freq_cutoff=None @@ -557,6 +566,13 @@ def calculate_bispectrum( freq_cutoff=freq_cutoff, ) + @property + def eigvals(self): + """ + Return the eigenvals of FSPCABasis as Numpy array. + """ + return self._eigvals + def eigen_images(self): """ Return the eigen images of the FSPCA basis, evaluated to image space. @@ -570,7 +586,7 @@ def eigen_images(self): if isinstance(eigvecs, BlkDiagMatrix): eigvecs = eigvecs.dense() - return self.basis.evaluate(eigvecs.T) + return Coef(self.basis, eigvecs.T).evaluate() def shift(self, coef, shifts): """ @@ -597,3 +613,15 @@ def shift(self, coef, shifts): return self.expand_from_image_basis( self.evaluate_to_image_basis(coef).shift(shifts) ) + + def filter_to_basis_mat(self, f): + """ + Convert a filter into a basis representation. + + :param f: `Filter` object, usually a `CTFFilter`. + + :return: Representation of filter in `basis`. + Return type will be based on the class's `matrix_type`. + """ + # This is possible to implement, but there are no current use cases. + raise NotImplementedError("Not currently implemented for compressed basis.") diff --git a/src/aspire/basis/pswf_2d.py b/src/aspire/basis/pswf_2d.py index 6fda8661f5..c9795ec1bc 100644 --- a/src/aspire/basis/pswf_2d.py +++ b/src/aspire/basis/pswf_2d.py @@ -2,7 +2,7 @@ import numpy as np -from aspire.basis import Basis +from aspire.basis import Coef, ComplexCoef, SteerableBasis2D from aspire.basis.basis_utils import ( d_decay_approx_fun, k_operator, @@ -12,12 +12,13 @@ t_x_mat, ) from aspire.basis.pswf_utils import BNMatrix -from aspire.utils import complex_type +from aspire.operators import BlkDiagMatrix +from aspire.utils import complex_type, grid_2d logger = logging.getLogger(__name__) -class PSWFBasis2D(Basis): +class PSWFBasis2D(SteerableBasis2D): """ Define a derived class for direct Prolate Spheroidal Wave Function (PSWF) expanding 2D images @@ -33,6 +34,8 @@ class PSWFBasis2D(Basis): Comput. Harmon. Anal. 22, 235-256 (2007). """ + matrix_type = BlkDiagMatrix + def __init__(self, size, gamma_trunc=1.0, beta=1.0, dtype=np.float32): """ Initialize an object for 2D PSWF basis expansion using direct method @@ -58,9 +61,6 @@ def __init__(self, size, gamma_trunc=1.0, beta=1.0, dtype=np.float32): self.beta = beta super().__init__(size, dtype=dtype) - # this basis has complex coefficients - self.coefficient_dtype = complex_type(self.dtype) - def _build(self): """ Build internal data structures for the direct 2D PSWF method @@ -82,22 +82,11 @@ def _build(self): def _generate_grid(self): """ Generate the 2D sampling grid - - TODO: need to re-implement to use the similar grid function as FB methods. """ - if self.nres % 2 == 0: - x_1d_grid = range(-self.rcut, self.rcut) - else: - x_1d_grid = range(-self.rcut, self.rcut + 1) - x_2d_grid, y_2d_grid = np.meshgrid(x_1d_grid, x_1d_grid) - r_2d_grid = np.sqrt(np.square(x_2d_grid) + np.square(y_2d_grid)) - points_in_disk = r_2d_grid <= self.rcut - x = y_2d_grid[points_in_disk] - y = x_2d_grid[points_in_disk] - self._r_disk = np.sqrt(np.square(x) + np.square(y)) / self.rcut - self._theta_disk = np.angle(x + 1j * y) - self._image_height = len(x_1d_grid) - self._disk_mask = points_in_disk + grid = grid_2d(self.nres, normalized=False, indexing="yx") + self._disk_mask = grid["r"] <= self.rcut + self._r_disk = grid["r"][self._disk_mask] / self.rcut + self._theta_disk = grid["phi"][self._disk_mask] def _precomp(self): """ @@ -105,14 +94,6 @@ def _precomp(self): """ self._generate_samples() - self.non_neg_freq_inds = slice(0, len(self.ang_freqs)) - - tmp = np.nonzero(self.ang_freqs == 0)[0] - self.zero_freq_inds = slice(tmp[0], tmp[-1] + 1) - - tmp = np.nonzero(self.ang_freqs > 0)[0] - self.pos_freq_inds = slice(tmp[0], tmp[-1] + 1) - def _generate_samples(self): """ Generate sample points for PSWF functions @@ -138,18 +119,67 @@ def _generate_samples(self): alpha_all.extend(alpha[:n_end]) m += 1 - self.alpha_nn = np.array(alpha_all).reshape(-1, 1) + self.alpha_nn = np.array(alpha_all, dtype=complex_type(self.dtype)).reshape( + -1, 1 + ) self.max_ns = max_ns self.samples = self._evaluate_pswf2d_all(self._r_disk, self._theta_disk, max_ns) - self.ang_freqs = np.repeat(np.arange(len(max_ns)), max_ns).astype("float") - self.rad_freqs = np.concatenate([range(1, i + 1) for i in max_ns]).astype( - "float" + self.complex_angular_indices = np.repeat( + np.arange(len(max_ns), dtype=int), max_ns ) + self.complex_radial_indices = np.concatenate( + [np.arange(1, i + 1, dtype=int) for i in max_ns] + ) + + # Added to support subclassing SteerableBasis + self.complex_signs_indices = np.sign(self.complex_angular_indices) + self.samples = (self.beta / 2.0) * self.samples * self.alpha_nn self.samples_conj_transpose = self.samples.conj().transpose() # the column dimension of samples_conj_transpose is the number of basis coefficients - self.count = self.samples_conj_transpose.shape[1] + self.complex_count = self.samples_conj_transpose.shape[1] + + # Add required real indices attributes and maps + # TODO, this block of code can probably be consolidated with + # FB basis. For now, just get everything working together. + nz = np.sum(self.complex_signs_indices == 0) + nnz = self.complex_count - nz + + self.real_count = nz + 2 * nnz + self.count = self.real_count + + self.radial_indices = np.empty(self.real_count, dtype=int) + self.angular_indices = np.empty(self.real_count, dtype=int) + self.signs_indices = np.empty(self.real_count, dtype=int) + + self._pos = np.zeros(self.complex_count, dtype=int) + self._neg = np.zeros(self.complex_count, dtype=int) + + i = 0 + ci = 0 + self.k_max = [] + self.ell_max = np.max(self.complex_angular_indices) + for ell in range(self.ell_max + 1): + sgns = (1,) if ell == 0 else (1, -1) + k_max = np.sum(self.complex_angular_indices == ell) + self.k_max.append(k_max) + ks = np.arange(0, k_max) + + for sgn in sgns: + rng = np.arange(i, i + len(ks)) + self.angular_indices[rng] = ell + self.radial_indices[rng] = ks + self.signs_indices[rng] = sgn + + if sgn == 1: + self._pos[ci + ks] = rng + elif sgn == -1: + self._neg[ci + ks] = rng + + i += len(ks) + + ci += len(ks) def _evaluate_t(self, images): """ @@ -160,24 +190,27 @@ def _evaluate_t(self, images): :return: The evaluation of the coefficient array in the PSWF basis. """ flattened_images = images[:, self._disk_mask] - - return flattened_images @ self.samples_conj_transpose + complex_coef = ComplexCoef(self, flattened_images @ self.samples_conj_transpose) + return complex_coef.to_real().asnumpy() def _evaluate(self, coefficients): """ Evaluate coefficients in standard 2D coordinate basis from those in PSWF basis - :param coeffcients: A coefficient vector (or an array of coefficient + :param coefficients: A coefficient vector (or an array of coefficient vectors) in PSWF basis to be evaluated. (n_image, count) :return : Image in standard 2D coordinate basis. """ + # Convert real coefficient to complex. + coefficients = Coef(self, coefficients).to_complex() + # Handle a single coefficient vector or stack of vectors. coefficients = np.atleast_2d(coefficients) n_images = coefficients.shape[0] - angular_is_zero = np.absolute(self.ang_freqs) == 0 + angular_is_zero = np.absolute(self.complex_angular_indices) == 0 flatten_images = coefficients[:, angular_is_zero] @ self.samples[ angular_is_zero @@ -185,9 +218,7 @@ def _evaluate(self, coefficients): coefficients[:, ~angular_is_zero] @ self.samples[~angular_is_zero] ) - images = np.zeros( - (n_images, self._image_height, self._image_height), dtype=self.dtype - ) + images = np.zeros((n_images, self.nres, self.nres), dtype=self.dtype) images[:, self._disk_mask] = np.real(flatten_images) return images @@ -259,7 +290,7 @@ def _evaluate_pswf2d_all(self, r, theta, max_ns): d_vec = self.d_vec_all[i] phase_part = np.exp(1j * i * theta) / np.sqrt(2 * np.pi) - range_array = np.arange(len(d_vec)) + range_array = np.arange(len(d_vec), dtype=self.dtype) r_radial_part_mat = t_radial_part_mat(r, i, range_array, len(d_vec)).dot( d_vec[:, :max_n] ) @@ -364,5 +395,11 @@ def _pswf_2d_minor_computations(self, big_n, n, bandlimit, phi_approximate_error d_vec, _ = BNMatrix(big_n, bandlimit, approx_length).get_eig_vectors() - range_array = np.array(range(approx_length)) + range_array = np.arange(approx_length, dtype=self.dtype) return d_vec, approx_length, range_array + + def filter_to_basis_mat(self, *args, **kwargs): + """ + See `SteerableBasis2D.filter_to_basis_mat`. + """ + return super().filter_to_basis_mat(*args, **kwargs) diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 955c92dc4e..a2b9872886 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -1,20 +1,25 @@ +import abc import logging from collections.abc import Iterable import numpy as np -from aspire.basis import Basis -from aspire.utils import complex_type +from aspire.basis import Basis, Coef, ComplexCoef +from aspire.operators import BlkDiagMatrix +from aspire.utils import LogFilterByCount, complex_type, real_type, trange logger = logging.getLogger(__name__) -class SteerableBasis2D(Basis): +class SteerableBasis2D(Basis, abc.ABC): """ - SteerableBasis2D is an extension of Basis that is expected to have + `SteerableBasis2D` is an extension of Basis that is expected to have `rotation` (steerable) and `calculate_bispectrum` methods. """ + # Default matrix type for basis representation. + matrix_type = BlkDiagMatrix + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -24,6 +29,16 @@ def __init__(self, *args, **kwargs): self._zero_angular_inds = self.angular_indices == 0 self._pos_angular_inds = (self.signs_indices == 1) & (self.angular_indices != 0) self._neg_angular_inds = self.signs_indices == -1 + self._non_neg_angular_inds = self.signs_indices >= 0 + self._blk_diag_cov_shape = None + + # Centralize indices attributes between FB/PSWF/FLE in SteerableBasis2D + self.complex_count = self.count - sum(self._neg_angular_inds) + self.complex_angular_indices = self.angular_indices[self._non_neg_angular_inds] + self.complex_radial_indices = self.radial_indices[self._non_neg_angular_inds] + + # Attribute for caching the blk_diag shape once known. + self._blk_diag_cov_shape = None def calculate_bispectrum( self, complex_coef, flatten=False, filter_nonzero_freqs=False, freq_cutoff=None @@ -43,6 +58,10 @@ def calculate_bispectrum( :return: Bispectum matrix (complex valued). """ + if not isinstance(complex_coef, Coef): + raise TypeError(f"Expect `Coef` received {type(complex_coef)}.") + complex_coef = complex_coef.asnumpy() + # Check shape if complex_coef.shape[0] != 1: raise ValueError( @@ -151,16 +170,23 @@ def rotate(self, coef, radians, refl=None): :return: rotated coefs. """ - # Enforce a stack axis to support sanity checks - coef = np.atleast_2d(coef) + if not isinstance(coef, Coef): + raise TypeError(f"`coef` must be `Coef` instance, received {type(coef)}.") + + coef = coef.asnumpy() # Covert radians to a broadcastable shape if isinstance(radians, Iterable): - radians = np.fromiter(radians, dtype=self.dtype).reshape(-1, 1) - if len(radians) != len(coef): + radians = np.array(radians, dtype=self.dtype) + if radians.ndim < 2: + radians = radians.reshape(-1, 1) + else: + radians = np.expand_dims(radians, axis=-1) + + if radians.size != np.prod(coef.shape[:-1]): raise RuntimeError( - "`rotate` call `radians` length cannot broadcast with" - f" `coef` {len(coef)} != {len(radians)}" + f"`rotate` call `radians` {radians.shape} does not match" + f" `coef` {coef.shape[:-1]}." ) # else: radians can be a constant @@ -172,17 +198,17 @@ def rotate(self, coef, radians, refl=None): # For all coef in stack, # compute the ks * radian used in the trig functions ks_rad = np.atleast_2d(self.angular_indices * radians) - ks_pos = ks_rad[:, self._pos_angular_inds] - ks_neg = ks_rad[:, self._neg_angular_inds] + ks_pos = ks_rad[..., self._pos_angular_inds] + ks_neg = ks_rad[..., self._neg_angular_inds] # Slice the coef on postive and negative ells - coef_zer = coef[:, self._zero_angular_inds] - coef_pos = coef[:, self._pos_angular_inds] - coef_neg = coef[:, self._neg_angular_inds] + coef_zer = coef[..., self._zero_angular_inds] + coef_pos = coef[..., self._pos_angular_inds] + coef_neg = coef[..., self._neg_angular_inds] # Handle zero case and avoid mutating the original array coef = np.empty_like(coef) - coef[:, self._zero_angular_inds] = coef_zer + coef[..., self._zero_angular_inds] = coef_zer # refl if refl is not None: @@ -193,14 +219,14 @@ def rotate(self, coef, radians, refl=None): coef_neg[refl] = coef_neg[refl] * -1 # Apply formula - coef[:, self._pos_angular_inds] = coef_pos * np.cos(ks_pos) + coef_neg * np.sin( - ks_neg - ) - coef[:, self._neg_angular_inds] = coef_neg * np.cos(ks_neg) - coef_pos * np.sin( + coef[..., self._pos_angular_inds] = coef_pos * np.cos( ks_pos - ) + ) + coef_neg * np.sin(ks_neg) + coef[..., self._neg_angular_inds] = coef_neg * np.cos( + ks_neg + ) - coef_pos * np.sin(ks_pos) - return coef + return Coef(self, coef) def complex_rotate(self, complex_coef, radians, refl=None): """ @@ -274,3 +300,230 @@ def shift(self, coef, shifts): ) return self.evaluate_t(self.evaluate(coef).shift(shifts)) + + @property + def blk_diag_cov_shape(self): + """ + Return the `BlkDiagMatrix` partition shapes. + + If the shape has already been cached, + returns cached value. Otherwise, will + compute the shape and cache in this instance. + """ + # Compute the _blk_diag_cov_shape as needed. + if self._blk_diag_cov_shape is None: + blks = [] + for ell in range(self.ell_max + 1): + sgns = (1,) if ell == 0 else (1, -1) + for _ in sgns: + blks.append( + [ + self.k_max[ell], + ] + * 2 + ) + self._blk_diag_cov_shape = np.array(blks) + + # Return the cached shape + return self._blk_diag_cov_shape + + # This is included for completion, but is not being adopted yet. + def indices_mask(self, **kwargs): + """ + Given `radial=` or `angular=` expressions, return (`count`,) + shaped mask where values satisfying the expression are `True`. + + Examples: + No args yield all indices. + `angular=0 creates a mask for selecting coefficients with zero angular indices. + `angular=1, radial=2` selects coefficients satisfying angular index of 1 _and_ radial index of 2. + More advanced operations can combine indices attributes. + `angular=self.angular_indices>=0, radial=r` selects coefficients with non negative angular indices and some radial index `r`. + + :return: Boolen mask of shape (`count`,). + Intended to be broadcast with `Coef` containers. + """ + + radial = kwargs.get("radial", None) + angular = kwargs.get("angular", None) + signs = kwargs.get("signs", None) + + # slowly construct the map + signs_mask = np.zeros(self.count, dtype=bool) + radial_mask = signs_mask.copy() + angular_mask = signs_mask.copy() + + if radial is None: + radial_mask[:] = True + else: + for k in np.atleast_1d(radial): + radial_mask[self.radial_indices == k] = True + + if angular is None: + angular_mask[:] = True + else: + for el in np.atleast_1d(angular): + angular_mask[self.angular_indices == el] = True + + if signs is None: + signs_mask[:] = True + else: + for s in np.atleast_1d(signs): + signs_mask[self.signs_indices == s] = True + + mask = radial_mask & angular_mask & signs_mask + + return mask + + def to_real(self, complex_coef): + """ + Return real valued representation of complex coefficients. + This can be useful when comparing, prototyping, or + implementing methods from literature. + + There is a corresponding method, `to_complex`. + + :param complex_coef: Complex `Coef` from this basis. + :return: Real `Ceof` representation from this basis. + """ + + if not isinstance(complex_coef, ComplexCoef): + raise TypeError( + f"complex_coef should be instance of `Coef`, received {type(complex_coef)}." + ) + + if complex_coef.dtype not in (np.complex128, np.complex64): + raise TypeError("coef provided to to_real should be complex.") + + # Pass through dtype precisions, but check and warn if mismatched. + dtype = real_type(complex_coef.dtype) + if dtype != self.dtype: + logger.warning( + f"Complex coef dtype {complex_coef.dtype} does not match precision of basis.dtype {self.dtype}, returning {dtype}." + ) + + coef = np.zeros((*complex_coef.stack_shape, self.count), dtype=dtype) + complex_coef = complex_coef.asnumpy() + + ind = 0 + idx = np.arange(self.k_max[0], dtype=int) + ind += np.size(idx) + ind_pos = ind + + coef[..., idx] = complex_coef[..., idx].real + + for ell in range(1, self.ell_max + 1): + idx = ind + np.arange(self.k_max[ell], dtype=int) + idx_pos = ind_pos + np.arange(self.k_max[ell], dtype=int) + idx_neg = idx_pos + self.k_max[ell] + + c = complex_coef[..., idx] + coef[..., idx_pos] = 2.0 * np.real(c) + coef[..., idx_neg] = -2.0 * np.imag(c) + + ind += np.size(idx) + ind_pos += 2 * self.k_max[ell] + + return Coef(self, coef) + + def to_complex(self, coef): + """ + Return complex valued representation of complex coefficients. + This can be useful when comparing, prototyping, or + implementing methods from literature. + + There is a corresponding method, `to_real`. + + :param coef: Real `Coef` from this basis. + :return: `ComplexCoef` representation from this basis. + """ + + if not isinstance(coef, Coef): + raise TypeError( + f"coef should be instance of `Coef`, received {type(coef)}." + ) + + if coef.dtype not in (np.float64, np.float32): + raise TypeError("coef provided to to_complex should be real.") + + # Pass through dtype precions, but check and warn if mismatched. + dtype = complex_type(coef.dtype) + if coef.dtype != self.dtype: + logger.warning( + f"coef dtype {coef.dtype} does not match precision of basis.dtype {self.dtype}, returning {dtype}." + ) + + # Return the same precision as coef + imaginary = dtype(1j) + + complex_coef = np.zeros((*coef.stack_shape, self.complex_count), dtype=dtype) + coef = coef.asnumpy() + + ind = 0 + idx = np.arange(self.k_max[0], dtype=int) + ind += np.size(idx) + + complex_coef[..., idx] = coef[..., idx] + + for ell in range(1, self.ell_max + 1): + idx = ind + np.arange(self.k_max[ell], dtype=int) + complex_coef[..., idx] = ( + coef[..., self._pos[idx]] - imaginary * coef[..., self._neg[idx]] + ) / 2.0 + + ind += np.size(idx) + + return ComplexCoef(self, complex_coef) + + # `abstractmethod` enforces when a new subclass of + # `SteerableBasis2D` is created that this method is explicitly + # implemented. This is intended to encourage future basis authors + # to consider this method for their application. + @abc.abstractmethod + def filter_to_basis_mat(self, f, method="evaluate_t", truncate=True): + """ + Convert a filter into a basis operator representation. + + :param f: `Filter` object, usually a `CTFFilter`. + :param method: `evaluate_t` or `expand`. + :param truncate: Optionally, truncate dense matrix to BlkDiagMatrix. + Defaults to True. + + :return: Representation of filter as `basis` operator. + Return type will be based on the class's `matrix_type`. + """ + # evaluate_t is not as accurate, but much much faster... + if method == "evaluate_t": + expand_method = self.evaluate_t + elif method == "expand": + expand_method = self.expand + else: + raise NotImplementedError( + "`filter_to_basis_mat` method {method} not supported." + " Try `evaluate_t` or `expand`." + ) + + coef = Coef(self, np.eye(self.count, dtype=self.dtype)) + img = coef.evaluate() + + # Expansion can fail for some filters on specific basis vectors. + # Loop over the expanding the filtered basis vectors one by one, + # zero-ing failed vectors. + filt = np.zeros((self.count, self.count), self.dtype) + with LogFilterByCount(logger, 1): + for i in trange(self.count): + try: + filt[i] = expand_method(img[i].filter(f)).asnumpy()[0] + except Exception: + logger.warning( + f"Failed to expand basis vector {i} after filter {f}." + ) + + # Optionally truncate off block elements to zero. + if truncate: + filt = BlkDiagMatrix.from_dense( + filt, + self.blk_diag_cov_shape, + ) + + return filt diff --git a/src/aspire/classification/averager2d.py b/src/aspire/classification/averager2d.py index 4700a7a230..3c487070ab 100644 --- a/src/aspire/classification/averager2d.py +++ b/src/aspire/classification/averager2d.py @@ -11,6 +11,7 @@ from ray.util.multiprocessing import Pool from aspire import config +from aspire.basis import Coef from aspire.classification.reddy_chatterji import reddy_chatterji_register from aspire.image import Image, ImageStacker, MeanImageStacker from aspire.utils import trange @@ -230,7 +231,7 @@ def _innerloop(i): ) # Averaging in composite_basis - return self.image_stacker(neighbors_coefs) + return self.image_stacker(neighbors_coefs.asnumpy()) if self.num_procs <= 1: for i in trange(n_classes): @@ -253,7 +254,7 @@ def _innerloop(i): b_avgs[i] = result # Now we convert the averaged images from Basis to Cartesian. - return self.composite_basis.evaluate(b_avgs) + return Coef(self.composite_basis, b_avgs).evaluate() def _shift_search_grid(self, L, radius, roll_zero=False): """ @@ -362,12 +363,12 @@ def _innerloop(k): ) # then store dot between class base image (0) and each nbor - for j, nbor in enumerate(rotated_nbrs): + for j, nbor in enumerate(rotated_nbrs.asnumpy()): # Skip the base image. if j == 0: continue norm_nbor = np.linalg.norm(nbor) - _correlations[j, i] = np.dot(nbr_coef[0], nbor) / ( + _correlations[j, i] = np.dot(nbr_coef.asnumpy()[0], nbor) / ( norm_nbor * norm_0 ) @@ -681,7 +682,7 @@ def _innerloop(i): ) # Averaging in composite_basis - return self.image_stacker(neighbors_coefs) + return self.image_stacker(neighbors_coefs.asnumpy()) if self.num_procs <= 1: for i in trange(n_classes): @@ -704,7 +705,7 @@ def _innerloop(i): b_avgs[i] = result # Now we convert the averaged images from Basis to Cartesian. - return self.composite_basis.evaluate(b_avgs) + return Coef(self.composite_basis, b_avgs).evaluate() class BFSReddyChatterjiAverager2D(ReddyChatterjiAverager2D): diff --git a/src/aspire/classification/legacy_implementations.py b/src/aspire/classification/legacy_implementations.py index e627c19742..2809a175ce 100644 --- a/src/aspire/classification/legacy_implementations.py +++ b/src/aspire/classification/legacy_implementations.py @@ -125,18 +125,18 @@ def bispec_operator_1(freqs): return o1, o2 -def bispec_2drot_large(coeff, freqs, eigval, alpha, sample_n, seed=None): +def bispec_2drot_large(coef, freqs, eigval, alpha, sample_n, seed=None): """ alpha 1/3 sample_n 4000 """ freqs_not_zero = freqs != 0 - coeff_norm = np.log(np.power(np.absolute(coeff[freqs_not_zero]), alpha)) - if np.any(coeff_norm == float("-inf")): - raise ValueError("coeff_norm should not be -inf") + coef_norm = np.log(np.power(np.absolute(coef[freqs_not_zero]), alpha)) + if np.any(coef_norm == float("-inf")): + raise ValueError("coef_norm should not be -inf") - phase = coeff[freqs_not_zero] / np.absolute(coeff[freqs_not_zero]) + phase = coef[freqs_not_zero] / np.absolute(coef[freqs_not_zero]) phase = np.arctan2(np.imag(phase), np.real(phase)) eigval = eigval[freqs_not_zero] o1, o2 = bispec_operator_1(freqs[freqs_not_zero]) @@ -151,15 +151,15 @@ def bispec_2drot_large(coeff, freqs, eigval, alpha, sample_n, seed=None): m_id = np.where(x < sample_n * p_m)[0] o1 = o1[m_id] o2 = o2[m_id] - m = np.exp(o1 * coeff_norm + 1j * o2 * phase) + m = np.exp(o1 * coef_norm + 1j * o2 * phase) # svd of the reduced bispectrum - u, s, v = pca_y(m, 300, seed=seed) + u, s, v = pca_y(m, min(300, len(m)), seed=seed) - coeff_b = np.einsum("i, ij -> ij", s, np.conjugate(v)) - coeff_b_r = np.conjugate(u.T).dot(np.conjugate(m)) + coef_b = np.einsum("i, ij -> ij", s, np.conjugate(v)) + coef_b_r = np.conjugate(u.T).dot(np.conjugate(m)) - coeff_b = coeff_b / np.linalg.norm(coeff_b, axis=0) - coeff_b_r = coeff_b_r / np.linalg.norm(coeff_b_r, axis=0) + coef_b = coef_b / np.linalg.norm(coef_b, axis=0) + coef_b_r = coef_b_r / np.linalg.norm(coef_b_r, axis=0) - return coeff_b, coeff_b_r + return coef_b, coef_b_r diff --git a/src/aspire/classification/rir_class2d.py b/src/aspire/classification/rir_class2d.py index a20ce3c0d2..6927d19bc5 100644 --- a/src/aspire/classification/rir_class2d.py +++ b/src/aspire/classification/rir_class2d.py @@ -4,7 +4,7 @@ import numpy as np from sklearn.neighbors import NearestNeighbors -from aspire.basis import FSPCABasis +from aspire.basis import Coef, ComplexCoef, FSPCABasis from aspire.classification import Class2D from aspire.classification.legacy_implementations import bispec_2drot_large, pca_y from aspire.numeric import ComplexPCA @@ -174,14 +174,11 @@ def classify(self, diagnostics=False): self.src, components=self.fspca_components, batch_size=self.batch_size ) - # For convenience, assign the fb_basis used in the pca_basis. - self.fb_basis = self.pca_basis.basis - # Get the expanded coefs in the compressed FSPCA space. self.fspca_coef = self.pca_basis.spca_coef # Compute Bispectrum - coef_b, coef_b_r = self.bispectrum(self.fspca_coef) + coef_b, coef_b_r = self.bispectrum(Coef(self.pca_basis, self.fspca_coef)) # # Stage 2: Compute Nearest Neighbors logger.info(f"Calculate Nearest Neighbors using {self._nn_implementation}.") @@ -251,10 +248,16 @@ def bispectrum(self, coef): :param coef: complex steerable coefficients (eg. from FSPCABasis). :returns: tuple of arrays (coef_b, coef_b_r) """ + + if not isinstance(coef, Coef): + raise TypeError( + f"`coef` should be a `Coef` instance, received {type(coef)}" + ) + # _bispectrum is assigned during initialization. return self._bispectrum(coef) - def _sk_nn_classification(self, coeff_b, coeff_b_r): + def _sk_nn_classification(self, coef_b, coef_b_r): """ Perform nearest neighbor classification using scikit learn. @@ -269,10 +272,10 @@ def _sk_nn_classification(self, coeff_b, coeff_b_r): # so we'll pretend we have 2*n_features of real values. # Don't worry about the copy because NearestNeighbors wants # C-contiguous anyway... (it would copy internally otherwise) - X = np.column_stack((coeff_b.real, coeff_b.imag)) + X = np.column_stack((coef_b.real, coef_b.imag)) # We'll also want to consider the neighbors under reflection. - # These coefficients should be provided by coeff_b_r - X_r = np.column_stack((coeff_b_r.real, coeff_b_r.imag)) + # These coefficients should be provided by coef_b_r + X_r = np.column_stack((coef_b_r.real, coef_b_r.imag)) # We can compare both non-reflected and reflected representations as one large set by # taking care later that we store refl=True where indices>=n_img @@ -291,7 +294,7 @@ def _sk_nn_classification(self, coeff_b, coeff_b_r): return classes, refl, distances - def _legacy_nn_classification(self, coeff_b, coeff_b_r): + def _legacy_nn_classification(self, coef_b, coef_b_r): """ Perform nearest neighbor classification using port of ASPIRE legacy MATLAB code. @@ -299,8 +302,8 @@ def _legacy_nn_classification(self, coeff_b, coeff_b_r): """ # Note kept ordering from legacy code (n_features, n_img) - coeff_b = coeff_b.T - coeff_b_r = coeff_b_r.T + coef_b = coef_b.T + coef_b_r = coef_b_r.T n_im = self.src.n # Shouldn't have more neighbors than images @@ -311,7 +314,7 @@ def _legacy_nn_classification(self, coeff_b, coeff_b_r): ) n_nbor = n_im - 1 - concat_coeff = np.concatenate((coeff_b, coeff_b_r), axis=1) + concat_coef = np.concatenate((coef_b, coef_b_r), axis=1) num_batches = (n_im + self.batch_size - 1) // self.batch_size @@ -321,8 +324,8 @@ def _legacy_nn_classification(self, coeff_b, coeff_b_r): for i in trange(num_batches): start = i * self.batch_size finish = min((i + 1) * self.batch_size, n_im) - batch = np.conjugate(coeff_b[:, start:finish]) - corr = np.real(np.dot(batch.T, concat_coeff)) + batch = np.conjugate(coef_b[:, start:finish]) + corr = np.real(np.dot(batch.T, concat_coef)) assert np.all( np.abs(corr) <= 1.01 # Allow some numerical wiggle @@ -432,7 +435,7 @@ def _devel_bispectrum(self, coef): for i in trange(self.src.n): B = self.pca_basis.calculate_bispectrum( - coef_normed[i, np.newaxis], + ComplexCoef(self.pca_basis, coef_normed[i]), filter_nonzero_freqs=True, freq_cutoff=self.bispectrum_freq_cutoff, ) @@ -480,10 +483,18 @@ def _legacy_bispectrum(self, coef, retry_attempts=3): :return: Compressed feature and reflected feature vectors. """ + if not isinstance(coef, Coef): + raise TypeError( + f"`coef` should be a `Coef` instance, received {type(coef)}" + ) + # The legacy code expects the complex representation - coef = self.pca_basis.to_complex(coef) - complex_eigvals = self.pca_basis.to_complex(self.pca_basis.eigvals).reshape( - self.pca_basis.complex_count + coef = self.pca_basis.to_complex(coef).asnumpy() + complex_eigvals = ( + Coef(self.pca_basis, self.pca_basis.eigvals) + .to_complex() + .asnumpy() + .reshape(self.pca_basis.complex_count) ) # flatten # bispec_2drot_large has a random selection component. @@ -496,7 +507,7 @@ def _legacy_bispectrum(self, coef, retry_attempts=3): _seed = self.seed or 0 while attempt < retry_attempts: coef_b, coef_b_r = bispec_2drot_large( - coeff=coef.T, # Note F style transpose here and in return + coef=coef.T, # Note F style transpose here and in return freqs=self.pca_basis.complex_angular_indices, eigval=complex_eigvals, alpha=self.alpha, diff --git a/src/aspire/commands/denoise.py b/src/aspire/commands/denoise.py index e0900e69c9..8a56a18f8e 100644 --- a/src/aspire/commands/denoise.py +++ b/src/aspire/commands/denoise.py @@ -4,7 +4,7 @@ from aspire.basis import FFBBasis2D from aspire.commands import log_level_option -from aspire.denoising.denoiser_cov2d import DenoiserCov2D +from aspire.denoising import DenoisedSource, DenoiserCov2D from aspire.noise import AnisotropicNoiseEstimator, WhiteNoiseEstimator from aspire.source.relion import RelionSource from aspire.utils.logging import setConsoleLoggingLevel @@ -101,7 +101,7 @@ def denoise( if denoise_method == "CWF": logger.info("Denoise the images using CWF cov2D method.") denoiser = DenoiserCov2D(source, basis) - denoised_src = denoiser.denoise(batch_size=512) + denoised_src = DenoisedSource(source, denoiser) denoised_src.save( starfile_out, batch_size=512, save_mode="single", overwrite=False ) diff --git a/src/aspire/config_default.yaml b/src/aspire/config_default.yaml index 1768d18f08..cb82637a9b 100644 --- a/src/aspire/config_default.yaml +++ b/src/aspire/config_default.yaml @@ -1,4 +1,4 @@ -version: 0.12.0 +version: 0.12.1 common: # numeric module to use - one of numpy/cupy numeric: numpy diff --git a/src/aspire/covariance/covar.py b/src/aspire/covariance/covar.py index ad72849221..3e9a64f545 100644 --- a/src/aspire/covariance/covar.py +++ b/src/aspire/covariance/covar.py @@ -87,15 +87,15 @@ def compute_kernel(self): def estimate(self, mean_vol, noise_variance, tol=1e-5, regularizer=0): logger.info("Running Covariance Estimator") - b_coeff = self.src_backward(mean_vol, noise_variance) - est_coeff = self.conj_grad(b_coeff, tol=tol, regularizer=regularizer) - covar_est = self.basis.mat_evaluate(est_coeff) + b_coef = self.src_backward(mean_vol, noise_variance) + est_coef = self.conj_grad(b_coef, tol=tol, regularizer=regularizer) + covar_est = self.basis.mat_evaluate(est_coef) covar_est = vecmat_to_volmat(make_symmat(volmat_to_vecmat(covar_est))) return covar_est - def conj_grad(self, b_coeff, tol=1e-5, regularizer=0): - b_coeff = symmat_to_vec_iso(b_coeff) - N = b_coeff.shape[0] + def conj_grad(self, b_coef, tol=1e-5, regularizer=0): + b_coef = symmat_to_vec_iso(b_coef) + N = b_coef.shape[0] kernel = self.kernel if regularizer > 0: @@ -118,41 +118,41 @@ def conj_grad(self, b_coeff, tol=1e-5, regularizer=0): dtype=self.dtype, ) - target_residual = tol * norm(b_coeff) + target_residual = tol * norm(b_coef) def cb(xk): logger.info( - f"Delta {norm(b_coeff - self.apply_kernel(xk, packed=True))} (target {target_residual})" + f"Delta {norm(b_coef - self.apply_kernel(xk, packed=True))} (target {target_residual})" ) x, info = scipy.sparse.linalg.cg( - operator, b_coeff, M=M, callback=cb, tol=tol, atol=0 + operator, b_coef, M=M, callback=cb, tol=tol, atol=0 ) if info != 0: raise RuntimeError("Unable to converge!") return vec_to_symmat_iso(x) - def apply_kernel(self, coeff, kernel=None, packed=False): + def apply_kernel(self, coef, kernel=None, packed=False): """ Applies the kernel represented by convolution - :param coeff: The volume matrix (6 dimensions) to be convolved (but see the `packed` argument below). + :param coef: The volume matrix (6 dimensions) to be convolved (but see the `packed` argument below). :param kernel: a Kernel object. If None, the kernel for this Estimator is used. - :param packed: whether the `coeff` matrix represents an isometrically mapped packed vector, - through the `symmat_to_vec_iso` function. In this case, the function expands `coeff` into a symmetric + :param packed: whether the `coef` matrix represents an isometrically mapped packed vector, + through the `symmat_to_vec_iso` function. In this case, the function expands `coef` into a symmetric matrix internally, and returns a packed vector in return. - :return: The result of evaluating `coeff` in the given basis, convolving with the kernel given by + :return: The result of evaluating `coef` in the given basis, convolving with the kernel given by kernel, and backprojecting into the basis. If `packed` is True, then the isometrically mapped packed vector is returned instead. """ if kernel is None: kernel = self.kernel if packed: - coeff = vec_to_symmat_iso(coeff) + coef = vec_to_symmat_iso(coef) result = self.basis.mat_evaluate_t( - kernel.convolve_volume_matrix(self.basis.mat_evaluate(coeff)) + kernel.convolve_volume_matrix(self.basis.mat_evaluate(coef)) ) return symmat_to_vec_iso(result) if packed else result @@ -182,14 +182,14 @@ def src_backward(self, mean_vol, noise_variance, shrink_method=None): covar_b += vecmat_to_volmat(im_centered_b.T @ im_centered_b) / self.src.n - covar_b_coeff = self.basis.mat_evaluate_t(covar_b) - return self._shrink(covar_b_coeff, noise_variance, shrink_method) + covar_b_coef = self.basis.mat_evaluate_t(covar_b) + return self._shrink(covar_b_coef, noise_variance, shrink_method) - def _shrink(self, covar_b_coeff, noise_variance, method=None): + def _shrink(self, covar_b_coef, noise_variance, method=None): """ Shrink covariance matrix - :param covar_b_coeff: Outer products of the mean-subtracted images + :param covar_b_coef: Outer products of the mean-subtracted images :param noise_variance: Noise variance :param method: One of None/'frobenius_norm'/'operator_norm'/'soft_threshold' :return: Shrunk covariance matrix @@ -203,8 +203,8 @@ def _shrink(self, covar_b_coeff, noise_variance, method=None): An = self.basis.mat_evaluate_t(self.mean_kernel.toeplitz()) if method is None: - covar_b_coeff -= noise_variance * An + covar_b_coef -= noise_variance * An else: raise NotImplementedError("Only default shrink method supported.") - return covar_b_coeff + return covar_b_coef diff --git a/src/aspire/covariance/covar2d.py b/src/aspire/covariance/covar2d.py index d0b1507009..a4f1971b78 100644 --- a/src/aspire/covariance/covar2d.py +++ b/src/aspire/covariance/covar2d.py @@ -4,7 +4,8 @@ from numpy.linalg import eig, inv from scipy.linalg import solve, sqrtm -from aspire.operators import BlkDiagMatrix, RadialCTFFilter +from aspire.basis import Coef, FFBBasis2D +from aspire.operators import BlkDiagMatrix, DiagMatrix from aspire.optimization import conj_grad, fill_struct from aspire.utils import make_symmat from aspire.utils.matlab_compat import m_reshape @@ -100,121 +101,141 @@ def __init__(self, basis): self.dtype = self.basis.dtype assert basis.ndim == 2, "Only two-dimensional basis functions are needed." - def _get_mean(self, coeffs): + def _ctf_identity_mat(self): + """ + Returns CTF identity corresponding to the `matrix_type` of `self.basis`. + + :return: Identity BlkDiagMatrix or DiagMatrix + """ + if self.basis.matrix_type == DiagMatrix: + return DiagMatrix.eye(self.basis.count, dtype=self.dtype) + else: + return BlkDiagMatrix.eye(self.basis.blk_diag_cov_shape, dtype=self.dtype) + + def _get_mean(self, coefs): """ Calculate the mean vector from the expansion coefficients of 2D images without CTF information. - :param coeffs: A coefficient vector (or an array of coefficient vectors) to be averaged. + :param coefs: A coefficient vector (or an array of coefficient vectors) to be averaged. :return: The mean value vector for all images. """ - if coeffs.size == 0: + + if coefs.size == 0: raise RuntimeError("The coefficients need to be calculated first!") - mask = self.basis._indices["ells"] == 0 - mean_coeff = np.zeros(self.basis.count, dtype=coeffs.dtype) - mean_coeff[mask] = np.mean(coeffs[..., mask], axis=0) - return mean_coeff + mean_coef = np.zeros(self.basis.count, dtype=coefs.dtype) + mean_coef[self.basis._zero_angular_inds] = np.mean( + coefs[..., self.basis._zero_angular_inds], axis=0 + ) + + return mean_coef - def _get_covar(self, coeffs, mean_coeff=None, do_refl=True): + def _get_covar(self, coefs, mean_coef=None, do_refl=True): """ Calculate the covariance matrix from the expansion coefficients without CTF information. - :param coeffs: A coefficient vector (or an array of coefficient vectors) calculated from 2D images. - :param mean_coeff: The mean vector calculated from the `coeffs`. + :param coefs: A coefficient vector (an array of coefficient vectors) calculated from 2D images. + :param mean_coef: The mean vector calculated from the `coefs`. :param do_refl: If true, enforce invariance to reflection (default false). :return: The covariance matrix of coefficients for all images. """ - if coeffs.size == 0: + if coefs.size == 0: raise RuntimeError("The coefficients need to be calculated first!") - if mean_coeff is None: - mean_coeff = self._get_mean(coeffs) + if mean_coef is None: + mean_coef = self._get_mean(coefs) # Initialize a totally empty BlkDiagMatrix, build incrementally. - covar_coeff = BlkDiagMatrix.empty(0, dtype=coeffs.dtype) + covar_coef = BlkDiagMatrix.empty(0, dtype=coefs.dtype) ell = 0 - mask = self.basis._indices["ells"] == ell - coeff_ell = coeffs[..., mask] - mean_coeff[mask] - covar_ell = np.array(coeff_ell.T @ coeff_ell / coeffs.shape[0]) - covar_coeff.append(covar_ell) + + mask = self.basis.angular_indices == ell + + coef_ell = coefs[..., mask] - mean_coef[mask] + covar_ell = np.array(coef_ell.T @ coef_ell / coefs.shape[0]) + covar_coef.append(covar_ell) for ell in range(1, self.basis.ell_max + 1): - mask_ell = self.basis._indices["ells"] == ell - mask_pos = mask_ell & (self.basis._indices["sgns"] == +1) - mask_neg = mask_ell & (self.basis._indices["sgns"] == -1) + mask_ell = self.basis.angular_indices == ell + mask_pos = mask_ell & (self.basis.signs_indices == +1) + mask_neg = mask_ell & (self.basis.signs_indices == -1) covar_ell_diag = np.array( - coeffs[:, mask_pos].T @ coeffs[:, mask_pos] - + coeffs[:, mask_neg].T @ coeffs[:, mask_neg] - ) / (2 * coeffs.shape[0]) + coefs[:, mask_pos].T @ coefs[:, mask_pos] + + coefs[:, mask_neg].T @ coefs[:, mask_neg] + ) / (2 * coefs.shape[0]) if do_refl: - covar_coeff.append(covar_ell_diag) - covar_coeff.append(covar_ell_diag) + covar_coef.append(covar_ell_diag) + covar_coef.append(covar_ell_diag) else: covar_ell_off = np.array( ( - coeffs[:, mask_pos] @ coeffs[:, mask_neg].T / coeffs.shape[0] - - coeffs[:, mask_pos].T @ coeffs[:, mask_neg] + coefs[:, mask_pos] @ coefs[:, mask_neg].T / coefs.shape[0] + - coefs[:, mask_pos].T @ coefs[:, mask_neg] ) - / (2 * coeffs.shape[0]) + / (2 * coefs.shape[0]) ) hsize = covar_ell_diag.shape[0] - covar_coeff_blk = np.zeros((2, hsize, 2, hsize)) + covar_coef_blk = np.zeros((2, hsize, 2, hsize)) - covar_coeff_blk[0:2, :, 0:2, :] = covar_ell_diag[:hsize, :hsize] - covar_coeff_blk[0, :, 1, :] = covar_ell_off[:hsize, :hsize] - covar_coeff_blk[1, :, 0, :] = covar_ell_off.T[:hsize, :hsize] + covar_coef_blk[0:2, :, 0:2, :] = covar_ell_diag[:hsize, :hsize] + covar_coef_blk[0, :, 1, :] = covar_ell_off[:hsize, :hsize] + covar_coef_blk[1, :, 0, :] = covar_ell_off.T[:hsize, :hsize] - covar_coeff.append(covar_coeff_blk.reshape(2 * hsize, 2 * hsize)) + covar_coef.append(covar_coef_blk.reshape(2 * hsize, 2 * hsize)) - return covar_coeff + return covar_coef - def get_mean(self, coeffs, ctf_fb=None, ctf_idx=None): + def get_mean(self, coefs, ctf_basis=None, ctf_idx=None): """ Calculate the mean vector from the expansion coefficients with CTF information. - :param coeffs: A coefficient vector (or an array of coefficient vectors) to be averaged. - :param ctf_fb: The CFT functions in the FB expansion. - :param ctf_idx: An array of the CFT function indices for all 2D images. - If ctf_fb or ctf_idx is None, the identity filter will be applied. + :param coefs: A coefficient vector (or an array of coefficient vectors) to be averaged. + :param ctf_basis: The CTF functions in the Basis expansion. + :param ctf_idx: An array of the CTF function indices for all 2D images. + If ctf_basis or ctf_idx is None, the identity filter will be applied. :return: The mean value vector for all images. """ - if coeffs.size == 0: + if not isinstance(coefs, Coef): + raise TypeError( + f"`coefs` should be instance of `Coef`, received {type(Coef)}." + ) + + coefs = coefs.asnumpy() + + # TODO: Redundant, remove? + if coefs.size == 0: raise RuntimeError("The coefficients need to be calculated!") # should assert we require none or both... - if (ctf_fb is None) or (ctf_idx is None): - ctf_idx = np.zeros(coeffs.shape[0], dtype=int) - ctf_fb = [ - BlkDiagMatrix.eye_like( - RadialCTFFilter().fb_mat(self.basis), dtype=self.dtype - ) - ] + if (ctf_basis is None) or (ctf_idx is None): + ctf_idx = np.zeros(coefs.shape[0], dtype=int) + ctf_basis = [self._ctf_identity_mat()] - b = np.zeros(self.basis.count, dtype=coeffs.dtype) + b = np.zeros(self.basis.count, dtype=coefs.dtype) - A = BlkDiagMatrix.zeros_like(ctf_fb[0]) + A = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape) for k in np.unique(ctf_idx[:]).T: - coeff_k = coeffs[ctf_idx == k] - weight = coeff_k.shape[0] / coeffs.shape[0] - mean_coeff_k = self._get_mean(coeff_k) + coef_k = coefs[ctf_idx == k] + weight = coef_k.shape[0] / coefs.shape[0] + mean_coef_k = self._get_mean(coef_k) - ctf_fb_k = ctf_fb[k] - ctf_fb_k_t = ctf_fb_k.T - b += weight * ctf_fb_k_t.apply(mean_coeff_k) - A += weight * (ctf_fb_k_t @ ctf_fb_k) + ctf_basis_k = ctf_basis[k] + ctf_basis_k_t = ctf_basis_k.T + b += weight * ctf_basis_k_t.apply(mean_coef_k) + A += weight * (ctf_basis_k_t @ ctf_basis_k) - mean_coeff = A.solve(b) - return mean_coeff + mean_coef = A.solve(b) + return Coef(self.basis, mean_coef) def get_covar( self, - coeffs, - ctf_fb=None, + coefs, + ctf_basis=None, ctf_idx=None, - mean_coeff=None, + mean_coef=None, do_refl=True, noise_var=0, covar_est_opt=None, @@ -223,34 +244,36 @@ def get_covar( """ Calculate the covariance matrix from the expansion coefficients and CTF information. - :param coeffs: A coefficient vector (or an array of coefficient vectors) to be calculated. - :param ctf_fb: The CFT functions in the FB expansion. - :param ctf_idx: An array of the CFT function indices for all 2D images. - If ctf_fb or ctf_idx is None, the identity filter will be applied. - :param mean_coeff: The mean value vector from all images. - :param noise_var: The estimated variance of noise. The value should be zero for `coeffs` + :param coefs: A coefficient vector (or an array of coefficient vectors) to be calculated. + :param ctf_basis: The CTF functions in the Basis expansion. + :param ctf_idx: An array of the CTF function indices for all 2D images. + If ctf_basis or ctf_idx is None, the identity filter will be applied. + :param mean_coef: The mean value vector from all images. + :param noise_var: The estimated variance of noise. The value should be zero for `coefs` from clean images of simulation data. :param covar_est_opt: The optimization parameter list for obtaining the Cov2D matrix. :param make_psd: If True, make the covariance matrix positive semidefinite :return: The basis coefficients of the covariance matrix in the form of cell array representing a block diagonal matrix. These block diagonal matrices are implemented as BlkDiagMatrix instances. - The covariance is calculated from the images represented by the coeffs array, + The covariance is calculated from the images represented by the coefs array, along with all possible rotations and reflections. As a result, the computed covariance - matrix is invariant to both reflection and rotation. The effect of the filters in ctf_fb + matrix is invariant to both reflection and rotation. The effect of the filters in ctf_basis are accounted for and inverted to yield a covariance estimate of the unfiltered images. """ - if coeffs.size == 0: + if not isinstance(coefs, Coef): + raise TypeError( + f"`coefs` should be instance of `Coef`, received {type(Coef)}." + ) + coefs = coefs.asnumpy() + + if coefs.size == 0: raise RuntimeError("The coefficients need to be calculated!") - if (ctf_fb is None) or (ctf_idx is None): - ctf_idx = np.zeros(coeffs.shape[0], dtype=int) - ctf_fb = [ - BlkDiagMatrix.eye_like( - RadialCTFFilter().fb_mat(self.basis), dtype=self.dtype - ) - ] + if (ctf_basis is None) or (ctf_idx is None): + ctf_idx = np.zeros(coefs.shape[0], dtype=int) + ctf_basis = [self._ctf_identity_mat()] def identity(x): return x @@ -268,44 +291,48 @@ def identity(x): covar_est_opt = fill_struct(covar_est_opt, default_est_opt) - if mean_coeff is None: - mean_coeff = self.get_mean(coeffs, ctf_fb, ctf_idx) + if mean_coef is None: + mean_coef = self.get_mean(Coef(self.basis, coefs), ctf_basis, ctf_idx) - b_coeff = BlkDiagMatrix.zeros_like(ctf_fb[0]) - b_noise = BlkDiagMatrix.zeros_like(ctf_fb[0]) + b_coef = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape) + b_noise = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape) A = [] - for _ in range(len(ctf_fb)): - A.append(BlkDiagMatrix.zeros_like(ctf_fb[0])) + for _ in range(len(ctf_basis)): + A.append(BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape)) - M = BlkDiagMatrix.zeros_like(ctf_fb[0]) + M = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape) for k in np.unique(ctf_idx[:]): - coeff_k = coeffs[ctf_idx == k].astype(self.dtype) - weight = coeff_k.shape[0] / coeffs.shape[0] + coef_k = coefs[ctf_idx == k].astype(self.dtype) + weight = coef_k.shape[0] / coefs.shape[0] + + ctf_basis_k = ctf_basis[k] + ctf_basis_k_t = ctf_basis_k.T + mean_coef_k = ctf_basis_k.apply(mean_coef.asnumpy()[0]) + covar_coef_k = self._get_covar(coef_k, mean_coef_k) - ctf_fb_k = ctf_fb[k] - ctf_fb_k_t = ctf_fb_k.T - mean_coeff_k = ctf_fb_k.apply(mean_coeff) - covar_coeff_k = self._get_covar(coeff_k, mean_coeff_k) + b_coef += weight * (ctf_basis_k_t @ covar_coef_k @ ctf_basis_k) - b_coeff += weight * (ctf_fb_k_t @ covar_coeff_k @ ctf_fb_k) + ctf_basis_k_sq = ctf_basis_k_t @ ctf_basis_k + b_noise += weight * ctf_basis_k_sq - ctf_fb_k_sq = ctf_fb_k_t @ ctf_fb_k - b_noise += weight * ctf_fb_k_sq + A_k = np.sqrt(weight) * ctf_basis_k_sq + if not isinstance(A_k, BlkDiagMatrix): + A_k = DiagMatrix(A_k).as_blk_diag(self.basis.blk_diag_cov_shape) - A[k] = np.sqrt(weight) * ctf_fb_k_sq + A[k] = A_k M += A[k] - if not b_coeff.check_psd(): + if not b_coef.check_psd(): logger.warning("Left side b in Cov2D is not positive semidefinite.") if covar_est_opt["shrinker"] is None: - b = b_coeff - noise_var * b_noise + b = b_coef - noise_var * b_noise else: b = self.shrink_covar_backward( - b_coeff, + b_coef, b_noise, - np.size(coeffs, 0), + np.size(coefs, 0), noise_var, covar_est_opt["shrinker"], ) @@ -319,7 +346,7 @@ def identity(x): cg_opt = covar_est_opt - covar_coeff = BlkDiagMatrix.zeros_like(ctf_fb[0]) + covar_coef = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape) def precond_fun(S, x): p = np.size(S, 0) @@ -346,18 +373,18 @@ def apply(A, x): b_ell = m_reshape(b[ell], (p**2,)) S = inv(M[ell]) cg_opt["preconditioner"] = lambda x, S=S: precond_fun(S, x) - covar_coeff_ell, _, _ = conj_grad( + covar_coef_ell, _, _ = conj_grad( lambda x, A_ell=A_ell: apply(A_ell, x), b_ell, cg_opt ) - covar_coeff[ell] = m_reshape(covar_coeff_ell, (p, p)) + covar_coef[ell] = m_reshape(covar_coef_ell, (p, p)) - if not covar_coeff.check_psd(): + if not covar_coef.check_psd(): logger.warning("Covariance matrix in Cov2D is not positive semidefinite.") if make_psd: logger.info("Convert matrices to positive semidefinite.") - covar_coeff = covar_coeff.make_psd() + covar_coef = covar_coef.make_psd() - return covar_coeff + return covar_coef def shrink_covar_backward(self, b, b_noise, n, noise_var, shrinker): """ @@ -383,76 +410,83 @@ def shrink_covar_backward(self, b, b_noise, n, noise_var, shrinker): b_out[ell] = b_ell return b_out - def get_cwf_coeffs( + def get_cwf_coefs( self, - coeffs, - ctf_fb=None, + coefs, + ctf_basis=None, ctf_idx=None, - mean_coeff=None, - covar_coeff=None, + mean_coef=None, + covar_coef=None, noise_var=0, ): """ Estimate the expansion coefficients using the Covariance Wiener Filtering (CWF) method. - :param coeffs: A coefficient vector (or an array of coefficient vectors) to be calculated. - :param ctf_fb: The CFT functions in the FB expansion. - :param ctf_idx: An array of the CFT function indices for all 2D images. - If ctf_fb or ctf_idx is None, the identity filter will be applied. - :param mean_coeff: The mean value vector from all images. - :param covar_coeff: The block diagonal covariance matrix of the clean coefficients represented by a cell array. - :param noise_var: The estimated variance of noise. The value should be zero for `coeffs` + :param coefs: A coefficient vector (or an array of coefficient vectors) to be calculated. + :param ctf_basis: The CTF functions in the Basis expansion. + :param ctf_idx: An array of the CTF function indices for all 2D images. + If ctf_basis or ctf_idx is None, the identity filter will be applied. + :param mean_coef: The mean value vector from all images. + :param covar_coef: The block diagonal covariance matrix of the clean coefficients represented by a cell array. + :param noise_var: The estimated variance of noise. The value should be zero for `coefs` from clean images of simulation data. :return: The estimated coefficients of the unfiltered images in certain math basis. These are obtained using a Wiener filter with the specified covariance for the clean images and white noise of variance `noise_var` for the noise. """ - if mean_coeff is None: - mean_coeff = self.get_mean(coeffs, ctf_fb, ctf_idx) + if not isinstance(coefs, Coef): + raise TypeError( + f"`coefs` should be instance of `Coef`, received {type(Coef)}." + ) + + if mean_coef is None: + mean_coef = self.get_mean(coefs, ctf_basis, ctf_idx) - if covar_coeff is None: - covar_coeff = self.get_covar( - coeffs, ctf_fb, ctf_idx, mean_coeff, noise_var=noise_var + if covar_coef is None: + covar_coef = self.get_covar( + coefs, ctf_basis, ctf_idx, mean_coef, noise_var=noise_var ) + coefs = coefs.asnumpy() + # Handle CTF arguments. - if (ctf_fb is None) ^ (ctf_idx is None): + if (ctf_basis is None) ^ (ctf_idx is None): raise RuntimeError( - "Both `ctf_fb` and `ctf_idx` should be provided," + "Both `ctf_basis` and `ctf_idx` should be provided," " or both should be `None`." - f' Given {"ctf_fb" if ctf_idx is None else "ctf_idx"}' + f' Given {"ctf_basis" if ctf_idx is None else "ctf_idx"}' ) - elif ctf_fb is None: + elif ctf_basis is None: # Setup defaults for CTF - ctf_idx = np.zeros(coeffs.shape[0], dtype=int) - ctf_fb = [BlkDiagMatrix.eye_like(covar_coeff)] + ctf_idx = np.zeros(coefs.shape[0], dtype=int) + ctf_basis = [BlkDiagMatrix.eye_like(covar_coef)] - noise_covar_coeff = noise_var * BlkDiagMatrix.eye_like(covar_coeff) + noise_covar_coef = noise_var * BlkDiagMatrix.eye_like(covar_coef) - coeffs_est = np.zeros_like(coeffs) + coefs_est = np.zeros_like(coefs) for k in np.unique(ctf_idx[:]): - coeff_k = coeffs[ctf_idx == k] - ctf_fb_k = ctf_fb[k] - ctf_fb_k_t = ctf_fb_k.T + coef_k = coefs[ctf_idx == k] + ctf_basis_k = ctf_basis[k] + ctf_basis_k_t = ctf_basis_k.T - mean_coeff_k = ctf_fb_k.apply(mean_coeff) - coeff_est_k = coeff_k - mean_coeff_k + mean_coef_k = ctf_basis_k.apply(mean_coef.asnumpy()[0]) + coef_est_k = coef_k - mean_coef_k if noise_var == 0: - coeff_est_k = ctf_fb_k.solve(coeff_est_k.T).T + coef_est_k = ctf_basis_k.solve(coef_est_k.T).T else: - sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t - sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff + sig_covar_coef = ctf_basis_k @ covar_coef @ ctf_basis_k_t + sig_noise_covar_coef = sig_covar_coef + noise_covar_coef - coeff_est_k = sig_noise_covar_coeff.solve(coeff_est_k.T).T - coeff_est_k = (covar_coeff @ ctf_fb_k_t).apply(coeff_est_k.T).T + coef_est_k = sig_noise_covar_coef.solve(coef_est_k.T).T + coef_est_k = (covar_coef @ ctf_basis_k_t).apply(coef_est_k.T).T - coeff_est_k = coeff_est_k + mean_coeff - coeffs_est[ctf_idx == k] = coeff_est_k + coef_est_k = coef_est_k + mean_coef + coefs_est[ctf_idx == k] = coef_est_k - return coeffs_est + return Coef(self.basis, coefs_est) class BatchedRotCov2D(RotCov2D): @@ -492,57 +526,61 @@ def _build(self): src = self.src if self.basis is None: - from aspire.basis import FFBBasis2D - self.basis = FFBBasis2D((src.L, src.L), dtype=self.dtype) if not src.unique_filters: logger.info("CTF filters are not included in Cov2D denoising") # set all CTF filters to an identity filter self.ctf_idx = np.zeros(src.n, dtype=int) - self.ctf_fb = [BlkDiagMatrix.eye_like(RadialCTFFilter().fb_mat(self.basis))] + self.ctf_basis = [self._ctf_identity_mat()] + else: - logger.info("Represent CTF filters in FB basis") + logger.info("Represent CTF filters in basis") unique_filters = src.unique_filters self.ctf_idx = src.filter_indices - self.ctf_fb = [f.fb_mat(self.basis) for f in unique_filters] + self.ctf_basis = [self.basis.filter_to_basis_mat(f) for f in unique_filters] def _calc_rhs(self): src = self.src basis = self.basis - ctf_fb = self.ctf_fb + ctf_basis = self.ctf_basis ctf_idx = self.ctf_idx - zero_coeff = np.zeros((basis.count,), dtype=self.dtype) + zero_coef = np.zeros((basis.count,), dtype=self.dtype) - b_mean = [np.zeros(basis.count, dtype=self.dtype) for _ in ctf_fb] + b_mean = [np.zeros(basis.count, dtype=self.dtype) for _ in ctf_basis] - b_covar = BlkDiagMatrix.zeros_like(ctf_fb[0]) + b_covar = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, dtype=self.dtype) for start in range(0, src.n, self.batch_size): batch = np.arange(start, min(start + self.batch_size, src.n)) im = src.images[batch[0] : batch[0] + len(batch)] - coeff = basis.evaluate_t(im) + coef = basis.evaluate_t(im).asnumpy() for k in np.unique(ctf_idx[batch]): - coeff_k = coeff[ctf_idx[batch] == k] - weight = np.size(coeff_k, 0) / src.n + coef_k = coef[ctf_idx[batch] == k] + weight = np.size(coef_k, 0) / src.n - mean_coeff_k = self._get_mean(coeff_k) + mean_coef_k = self._get_mean(coef_k) - ctf_fb_k = ctf_fb[k] - ctf_fb_k_t = ctf_fb_k.T + ctf_basis_k = ctf_basis[k] + ctf_basis_k_t = ctf_basis_k.T - b_mean_k = weight * ctf_fb_k_t.apply(mean_coeff_k) + b_mean_k = weight * ctf_basis_k_t.apply(mean_coef_k) + + if isinstance(b_mean_k, DiagMatrix): + # Convert to a column vector + b_mean_k = b_mean_k.asnumpy().T b_mean[k] += b_mean_k - covar_coeff_k = self._get_covar(coeff_k, zero_coeff) + covar_coef_k = self._get_covar(coef_k, zero_coef) + + b_covar_k = ctf_basis_k_t @ covar_coef_k - b_covar_k = ctf_fb_k_t @ covar_coeff_k - b_covar_k = b_covar_k @ ctf_fb_k + b_covar_k = b_covar_k @ ctf_basis_k b_covar_k *= weight b_covar += b_covar_k @@ -553,24 +591,23 @@ def _calc_rhs(self): def _calc_op(self): src = self.src - ctf_fb = self.ctf_fb + ctf_basis = self.ctf_basis ctf_idx = self.ctf_idx - A_mean = BlkDiagMatrix.zeros_like(ctf_fb[0]) - A_covar = [None for _ in ctf_fb] - M_covar = BlkDiagMatrix.zeros_like(ctf_fb[0]) + A_mean = BlkDiagMatrix.zeros(self.basis.blk_diag_cov_shape, self.dtype) + A_covar = [None for _ in ctf_basis] + M_covar = BlkDiagMatrix.zeros_like(A_mean) for k in np.unique(ctf_idx): weight = np.count_nonzero(ctf_idx == k) / src.n - ctf_fb_k = ctf_fb[k] - ctf_fb_k_t = ctf_fb_k.T + ctf_basis_k = ctf_basis[k] + ctf_basis_k_t = ctf_basis_k.T - ctf_fb_k_sq = ctf_fb_k_t @ ctf_fb_k - A_mean_k = weight * ctf_fb_k_sq + ctf_basis_k_sq = ctf_basis_k_t @ ctf_basis_k + A_mean_k = weight * ctf_basis_k_sq A_mean += A_mean_k - - A_covar_k = np.sqrt(weight) * ctf_fb_k_sq + A_covar_k = np.sqrt(weight) * ctf_basis_k_sq A_covar[k] = A_covar_k M_covar += A_covar_k @@ -579,13 +616,13 @@ def _calc_op(self): self.A_covar = A_covar self.M_covar = M_covar - def _mean_correct_covar_rhs(self, b_covar, b_mean, mean_coeff): + def _mean_correct_covar_rhs(self, b_covar, b_mean, mean_coef): src = self.src - ctf_fb = self.ctf_fb + ctf_basis = self.ctf_basis ctf_idx = self.ctf_idx - partition = ctf_fb[0].partition + partition = self.basis.blk_diag_cov_shape # Note: If we don't do this, we'll be modifying the stored `b_covar` # since the operations below are in-place. @@ -594,19 +631,19 @@ def _mean_correct_covar_rhs(self, b_covar, b_mean, mean_coeff): for k in np.unique(ctf_idx): weight = np.count_nonzero(ctf_idx == k) / src.n - ctf_fb_k = ctf_fb[k] - ctf_fb_k_t = ctf_fb_k.T + ctf_basis_k = ctf_basis[k] + ctf_basis_k_t = ctf_basis_k.T - mean_coeff_k = ctf_fb_k.apply(mean_coeff) - mean_coeff_k = ctf_fb_k_t.apply(mean_coeff_k) + mean_coef_k = ctf_basis_k.apply(mean_coef.asnumpy()[0]) + mean_coef_k = ctf_basis_k_t.apply(mean_coef_k) - mean_coeff_k = mean_coeff_k[: partition[0][0]] + mean_coef_k = mean_coef_k[: partition[0][0]] b_mean_k = b_mean[k][: partition[0][0]] correction = ( - np.outer(mean_coeff_k, b_mean_k) - + np.outer(b_mean_k, mean_coeff_k) - - weight * np.outer(mean_coeff_k, mean_coeff_k) + np.outer(mean_coef_k, b_mean_k) + + np.outer(b_mean_k, mean_coef_k) + - weight * np.outer(mean_coef_k, mean_coef_k) ) b_covar[0] -= correction @@ -625,8 +662,26 @@ def _noise_correct_covar_rhs(self, b_covar, b_noise, noise_var, shrinker): return b_covar def _solve_covar(self, A_covar, b_covar, M, covar_est_opt): - ctf_fb = self.ctf_fb + method = self._solve_covar_cg + if self.basis.matrix_type == DiagMatrix: + method = self._solve_covar_direct + + return method(A_covar, b_covar, M, covar_est_opt) + def _solve_covar_direct(self, A_covar, b_covar, M, covar_est_opt): + # A_covar is a list of DiagMatrix, representing each ctf in self.basis. + # b_covar is a BlkDiagMatrix + # M is sum of weighted A squared, only used for cg, ignore here. + A_covar = DiagMatrix(np.concatenate([x.asnumpy() for x in A_covar])) + A2i = A_covar * A_covar + + res = BlkDiagMatrix.empty(b_covar.nblocks, self.dtype) + for b in range(b_covar.nblocks): + res.data[b] = b_covar[b] / A2i[b] + + return res + + def _solve_covar_cg(self, A_covar, b_covar, M, covar_est_opt): def precond_fun(S, x): p = np.size(S, 0) assert np.size(x) == p * p, "The sizes of S and x are not consistent." @@ -645,7 +700,9 @@ def apply(A, x): return y cg_opt = covar_est_opt - covar_coeff = BlkDiagMatrix.zeros_like(ctf_fb[0]) + covar_coef = BlkDiagMatrix.zeros( + self.basis.blk_diag_cov_shape, dtype=self.dtype + ) for ell in range(0, len(b_covar)): A_ell = [] @@ -655,12 +712,12 @@ def apply(A, x): b_ell = m_reshape(b_covar[ell], (p**2,)) S = inv(M[ell]) cg_opt["preconditioner"] = lambda x, S=S: precond_fun(S, x) - covar_coeff_ell, _, _ = conj_grad( + covar_coef_ell, _, _ = conj_grad( lambda x, A_ell=A_ell: apply(A_ell, x), b_ell, cg_opt ) - covar_coeff[ell] = m_reshape(covar_coeff_ell, (p, p)) + covar_coef[ell] = m_reshape(covar_coef_ell, (p, p)) - return covar_coeff + return covar_coef def get_mean(self): """ @@ -677,19 +734,17 @@ def get_mean(self): self._calc_op() b_mean_all = np.stack(self.b_mean).sum(axis=0) - mean_coeff = self.A_mean.solve(b_mean_all) + mean_coef = self.A_mean.solve(b_mean_all) - return mean_coeff + return Coef(self.basis, mean_coef) - def get_covar( - self, noise_var=0, mean_coeff=None, covar_est_opt=None, make_psd=True - ): + def get_covar(self, noise_var=0, mean_coef=None, covar_est_opt=None, make_psd=True): """ Calculate the block diagonal covariance matrix in the basis coefficients. :param noise_var: The variance of the noise in the images (default 1) - :param mean_coeff: If specified, overrides the mean coefficient vector + :param mean_coef: If specified, overrides the mean coefficient vector used to calculate the covariance (default `self.get_mean()`). :param :covar_est_opt: The estimation parameters for obtaining the covariance matrix in the form of a dictionary. Keys include: @@ -738,12 +793,12 @@ def identity(x): if not self.A_covar or self.M_covar: self._calc_op() - if mean_coeff is None: - mean_coeff = self.get_mean() + if mean_coef is None: + mean_coef = self.get_mean() b_covar = self.b_covar - b_covar = self._mean_correct_covar_rhs(b_covar, self.b_mean, mean_coeff) + b_covar = self._mean_correct_covar_rhs(b_covar, self.b_mean, mean_coef) if not b_covar.check_psd(): logger.warning("Left side b in Batched Cov2D is not positive semidefinite.") @@ -756,78 +811,84 @@ def identity(x): "in Batched Cov2D is not positive semidefinite." ) - covar_coeff = self._solve_covar( + covar_coef = self._solve_covar( self.A_covar, b_covar, self.M_covar, covar_est_opt ) - if not covar_coeff.check_psd(): + if not covar_coef.check_psd(): logger.warning( "Covariance matrix in Batched Cov2D is not positive semidefinite." ) if make_psd: logger.info("Convert matrices to positive semidefinite.") - covar_coeff = covar_coeff.make_psd() + covar_coef = covar_coef.make_psd() - return covar_coeff + return covar_coef - def get_cwf_coeffs( - self, coeffs, ctf_fb, ctf_idx, mean_coeff, covar_coeff, noise_var=0 + def get_cwf_coefs( + self, coefs, ctf_basis, ctf_idx, mean_coef, covar_coef, noise_var=0 ): """ Estimate the expansion coefficients using the Covariance Wiener Filtering (CWF) method. - :param coeffs: A coefficient vector (or an array of coefficient vectors) to be calculated. - :param ctf_fb: The CFT functions in the FB expansion. - :param ctf_idx: An array of the CFT function indices for all 2D images. - If ctf_fb or ctf_idx is None, the identity filter will be applied. - :param mean_coeff: The mean value vector from all images. - :param covar_coeff: The block diagonal covariance matrix of the clean coefficients represented by a cell array. - :param noise_var: The estimated variance of noise. The value should be zero for `coeffs` + :param coefs: A coefficient vector (or an array of coefficient vectors) to be calculated. + :param ctf_basis: The CTF functions in the Basis expansion. + :param ctf_idx: An array of the CTF function indices for all 2D images. + If ctf_basis or ctf_idx is None, the identity filter will be applied. + :param mean_coef: The mean value vector from all images. + :param covar_coef: The block diagonal covariance matrix of the clean coefficients represented by a cell array. + :param noise_var: The estimated variance of noise. The value should be zero for `coefs` from clean images of simulation data. :return: The estimated coefficients of the unfiltered images in certain math basis. These are obtained using a Wiener filter with the specified covariance for the clean images and white noise of variance `noise_var` for the noise. """ - if mean_coeff is None: - mean_coeff = self.get_mean() + if not isinstance(coefs, Coef): + raise TypeError( + f"`coefs` should be instance of `Coef`, received {type(Coef)}." + ) + coefs = coefs.asnumpy() + + if mean_coef is None: + mean_coef = self.get_mean() - if covar_coeff is None: - covar_coeff = self.get_covar(noise_var=noise_var, mean_coeff=mean_coeff) + if covar_coef is None: + covar_coef = self.get_covar(noise_var=noise_var, mean_coef=mean_coef) # Handle CTF arguments. - if (ctf_fb is None) ^ (ctf_idx is None): + if (ctf_basis is None) ^ (ctf_idx is None): raise RuntimeError( - "Both `ctf_fb` and `ctf_idx` should be provided," + "Both `ctf_basis` and `ctf_idx` should be provided," " or both should be `None`." - f' Given {"ctf_fb" if ctf_idx is None else "ctf_idx"}' + f' Given {"ctf_basis" if ctf_idx is None else "ctf_idx"}' ) - elif ctf_fb is None: + elif ctf_basis is None: # Setup defaults for CTF - ctf_idx = np.zeros(coeffs.shape[0], dtype=int) - ctf_fb = [BlkDiagMatrix.eye_like(covar_coeff)] + ctf_idx = np.zeros(coefs.shape[0], dtype=int) + ctf_basis = [BlkDiagMatrix.eye_like(covar_coef)] - noise_covar_coeff = noise_var * BlkDiagMatrix.eye_like(covar_coeff) + noise_covar_coef = noise_var * BlkDiagMatrix.eye_like(covar_coef) - coeffs_est = np.zeros_like(coeffs) + coefs_est = np.zeros_like(coefs) for k in np.unique(ctf_idx[:]): - coeff_k = coeffs[ctf_idx == k] - ctf_fb_k = ctf_fb[k] - ctf_fb_k_t = ctf_fb_k.T + coef_k = coefs[ctf_idx == k] + ctf_basis_k = ctf_basis[k] + ctf_basis_k_t = ctf_basis_k.T - mean_coeff_k = ctf_fb_k.apply(mean_coeff) - coeff_est_k = coeff_k - mean_coeff_k + mean_coef_k = ctf_basis_k.apply(mean_coef.asnumpy()[0]) + coef_est_k = coef_k - mean_coef_k if noise_var == 0: - coeff_est_k = ctf_fb_k.solve(coeff_est_k.T).T + coef_est_k = ctf_basis_k.solve(coef_est_k.T).T else: - sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t - sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff + sig_covar_coef = ctf_basis_k @ covar_coef @ ctf_basis_k_t + sig_noise_covar_coef = sig_covar_coef + noise_covar_coef - coeff_est_k = sig_noise_covar_coeff.solve(coeff_est_k.T).T - coeff_est_k = (covar_coeff @ ctf_fb_k_t).apply(coeff_est_k.T).T + coef_est_k = sig_noise_covar_coef.solve(coef_est_k.T).T + coef_est_k = (covar_coef @ ctf_basis_k_t).apply(coef_est_k.T).T - coeff_est_k = coeff_est_k + mean_coeff - coeffs_est[ctf_idx == k] = coeff_est_k + coef_est_k = coef_est_k + mean_coef + coefs_est[ctf_idx == k] = coef_est_k - return coeffs_est + return Coef(self.basis, coefs_est) diff --git a/src/aspire/ctf/ctf_estimator.py b/src/aspire/ctf/ctf_estimator.py index 7f9fab3571..b6983cb071 100644 --- a/src/aspire/ctf/ctf_estimator.py +++ b/src/aspire/ctf/ctf_estimator.py @@ -1,8 +1,26 @@ """ +Contains code supporting CTF parameter estimation. +Generally, this is a port of ASPIRE-CTF from MATLAB. + +See paper: + + | "Reducing bias and variance for CTF estimation in single particle cryo-EM" + | Ayelet Heimowitz, Joakim Andén, Amit Singer + | Ultramicroscopy, Volume 212, 2020 + | https://doi.org/10.1016/j.ultramic.2020.112950. + +Note: +``CtfEstimator`` computes the background as a monotonically decreasing +function of spatial frequency. This practice may lead to an inaccurate +background estimation for experimental images produced using a K2 +camera in counting mode, as the background in this case is not +monotonically decreasing. Despite this, CTF parameters are captured +successfully in such situations. + Created on Sep 10, 2019 @author: Ayelet Heimowitz, Amit Moscovich -Integrated into ASPIRE by Garrett Wright Feb 2021. +Integrated into ASPIRE-Python by Garrett Wright Feb 2021. """ import logging @@ -15,7 +33,7 @@ from scipy.optimize import linprog from scipy.signal.windows import dpss -from aspire.basis.ffb_2d import FFBBasis2D +from aspire.basis import Coef, FFBBasis2D from aspire.image import Image from aspire.numeric import fft from aspire.storage import StarFile @@ -266,19 +284,19 @@ def elliptical_average(self, ffbbasis, amplitude_spectrum, circular): """ # RCOPT, come back and change the indices for this method - coeffs_s = ffbbasis.evaluate_t(amplitude_spectrum).T - coeffs_n = coeffs_s.copy() + coefs_s = ffbbasis.evaluate_t(amplitude_spectrum).asnumpy().copy().T + coefs_n = coefs_s.copy() - coeffs_s[np.argwhere(ffbbasis._indices["ells"] == 1)] = 0 + coefs_s[np.argwhere(ffbbasis.angular_indices == 1)] = 0 if circular: - coeffs_s[np.argwhere(ffbbasis._indices["ells"] == 2)] = 0 + coefs_s[np.argwhere(ffbbasis.angular_indices == 2)] = 0 noise = amplitude_spectrum else: - coeffs_n[np.argwhere(ffbbasis._indices["ells"] == 0)] = 0 - coeffs_n[np.argwhere(ffbbasis._indices["ells"] == 2)] = 0 - noise = ffbbasis.evaluate(coeffs_n.T) + coefs_n[np.argwhere(ffbbasis.angular_indices == 0)] = 0 + coefs_n[np.argwhere(ffbbasis.angular_indices == 2)] = 0 + noise = Coef(ffbbasis, coefs_n.T).evaluate() - psd = ffbbasis.evaluate(coeffs_s.T) + psd = Coef(ffbbasis, coefs_s.T).evaluate() return psd, noise diff --git a/src/aspire/denoising/__init__.py b/src/aspire/denoising/__init__.py index 96bd4a8213..920638eb44 100644 --- a/src/aspire/denoising/__init__.py +++ b/src/aspire/denoising/__init__.py @@ -1,5 +1,9 @@ from .adaptive_support import adaptive_support from .class_avg import ClassAvgSource, DebugClassAvgSource, DefaultClassAvgSource -from .denoised_src import DenoisedImageSource + +# isort: off from .denoiser import Denoiser from .denoiser_cov2d import DenoiserCov2D, src_wiener_coords +from .denoised_src import DenoisedSource + +# isort: on diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index d297841156..c9d3f7dade 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -348,6 +348,28 @@ def _images(self, indices): # Finally, apply transforms to resulting Images return self.generation_pipeline.forward(im, indices) + def _get_classifier_basis(self, classifier): + """ + Returns underlying basis of a classifier. + + For classifiers using compressed basis, + returns the underlying uncompressed basis. + + Defaults to `FFBBasis2D` when `pca_basis` is not found. + + :param classifier: Class2D subclass to query. + :return: `classifier` basis + """ + + if hasattr(classifier, "pca_basis") and classifier.pca_basis is not None: + basis = classifier.pca_basis.basis + else: + # In the cases where a basis is not defined yet, + # construct a FFBBasis2D default. + basis = FFBBasis2D(classifier.src.L, dtype=classifier.dtype) + + return basis + # The following sub classes attempt to pack sensible defaults # into ClassAvgSource so that users don't need to @@ -407,7 +429,7 @@ def __init__( if averager is None: averager = BFRAverager2D( - FFBBasis2D(src.L, dtype=src.dtype), + self._get_classifier_basis(classifier), src, num_procs=num_procs, dtype=dtype, @@ -544,7 +566,7 @@ def __init__( if averager_src is None: averager_src = src - basis_2d = FFBBasis2D(averager_src.L, dtype=dtype) + basis_2d = self._get_classifier_basis(classifier) averager = BFSRAverager2D( composite_basis=basis_2d, diff --git a/src/aspire/denoising/denoised_src.py b/src/aspire/denoising/denoised_src.py index 94d4c1e4e9..ac891f20d4 100644 --- a/src/aspire/denoising/denoised_src.py +++ b/src/aspire/denoising/denoised_src.py @@ -1,31 +1,36 @@ import logging -import numpy as np - -from aspire.image import Image +from aspire.denoising import Denoiser from aspire.source import ImageSource logger = logging.getLogger(__name__) -class DenoisedImageSource(ImageSource): +class DenoisedSource(ImageSource): """ - Define a derived ImageSource class to perform operations for denoised 2D images + `ImageSource` class serving denoised 2D images. """ - def __init__(self, src, denoiser, batch_size=512): + def __init__(self, src, denoiser): """ - Initialize a denoised ImageSource object from original ImageSource of noisy images + Initialize a denoised `ImageSource` object from an `ImageSource`. - :param src: Original ImageSource object storing noisy images - :param denoiser: A Denoiser object for specifying a method for denoising - :param batch_size: Batch size for loading denoised images. + :param src: Original `ImageSource` object storing noisy images + :param denoiser: A `Denoiser` object for specifying a method for denoising """ super().__init__(src.L, src.n, dtype=src.dtype, metadata=src._metadata.copy()) - self._im = None + # TODO, we can probably setup a reasonable default here. self.denoiser = denoiser - self.batch_size = batch_size + if not isinstance(denoiser, Denoiser): + raise TypeError("`denoiser` must be subclass of `Denoiser`") + + # Safety check src and self.denoiser.src are the same. + # See #1020 + if src != self.denoiser.src: + raise NotImplementedError( + "Denoiser `src` and noisy image `src` must match." + ) # Any further operations should not mutate this instance. self._mutable = False @@ -34,9 +39,11 @@ def _images(self, indices): """ Internal function to return a set of images after denoising, when accessed via the `ImageSource.images` property. - :param indices: The indices of images to return as a 1-D NumPy array. - :return: an `Image` object after denoisng. + + :param indices: The indices of images to return as a 1-D Numpy array. + :return: an `Image` object after denoising. """ + # check for cached images first if self._cached_im is not None: logger.info("Loading images from cache") @@ -44,23 +51,7 @@ def _images(self, indices): self._cached_im[indices, :, :], indices ) - # start and end (and indices) refer to the indices in the DenoisedImageSource - # that are being denoised and returned in batches - start = indices.min() - end = indices.max() - - nimgs = len(indices) - im = np.empty((nimgs, self.L, self.L), self.dtype) - - # If we request less than a whole batch, don't crash - batch_size = min(nimgs, self.batch_size) - - logger.info(f"Loading {nimgs} images complete") - for batch_start in range(start, end + 1, batch_size): - imgs_denoised = self.denoiser.images(batch_start, batch_size) - batch_end = min(batch_start + batch_size, end + 1) - # we subtract start here to correct for any offset in the indices - im[batch_start - start : batch_end - start] = imgs_denoised.asnumpy() + imgs_denoised = self.denoiser.denoise[indices] # Finally, apply transforms to resulting Image - return self.generation_pipeline.forward(Image(im), indices) + return self.generation_pipeline.forward(imgs_denoised, indices) diff --git a/src/aspire/denoising/denoiser.py b/src/aspire/denoising/denoiser.py index 5a80f0da39..997cf26fd1 100644 --- a/src/aspire/denoising/denoiser.py +++ b/src/aspire/denoising/denoiser.py @@ -1,31 +1,42 @@ import logging +from abc import ABC, abstractmethod + +from aspire.source.image import _ImageAccessor logger = logging.getLogger(__name__) -class Denoiser: +class Denoiser(ABC): """ - Define a base class for denoising 2D images + Base class for 2D image denoisers. """ def __init__(self, src): """ - Initialize an object for denoising 2D images from the image source + Initialize an object for denoising 2D images from `src`. - :param src: The source object of 2D images with metadata + :param src: `ImageSource` providing noisy images. """ + self.src = src self.dtype = src.dtype - self.nimg = src.n + self.n = src.n + self._img_accessor = _ImageAccessor(self._denoise, self.n) + @property def denoise(self): """ - Precompute for Denoiser and DenoisedImageSource for 2D images + Subscriptable property returning 2D images after denoising. + + See `_ImageAccessor`. """ - raise NotImplementedError("subclasses must implement this") + return self._img_accessor - def image(self, istart=0, batch_size=512): + @abstractmethod + def _denoise(self, indices): """ - Obtain a batch size of 2D images after denosing by a specified method + Subclasses must implement a private `_denoise` method accepting `indices`. + Subclasses handle any caching as well as denoising. + + See `_ImageAccessor`. """ - raise NotImplementedError("subclasses must implement this") diff --git a/src/aspire/denoising/denoiser_cov2d.py b/src/aspire/denoising/denoiser_cov2d.py index b5581013e4..545f6ce642 100644 --- a/src/aspire/denoising/denoiser_cov2d.py +++ b/src/aspire/denoising/denoiser_cov2d.py @@ -1,4 +1,5 @@ import logging +from copy import deepcopy import numpy as np from numpy.linalg import solve @@ -6,7 +7,6 @@ from aspire.basis import FFBBasis2D from aspire.covariance import BatchedRotCov2D from aspire.denoising import Denoiser -from aspire.denoising.denoised_src import DenoisedImageSource from aspire.noise import WhiteNoiseEstimator from aspire.optimization import fill_struct from aspire.utils import mat_to_vec @@ -104,15 +104,31 @@ class DenoiserCov2D(Denoiser): Define a derived class for denoising 2D images using Cov2D method """ - def __init__(self, src, basis=None, var_noise=None): + # Default options for cov2d configuration. + default_opt = { + "shrinker": "frobenius_norm", + "verbose": 0, + "max_iter": 250, + "iter_callback": [], + "store_iterates": False, + "rel_tolerance": 1e-12, + } + + def __init__(self, src, basis=None, var_noise=None, batch_size=512, covar_opt=None): """ Initialize an object for denoising 2D images using Cov2D method :param src: The source object of 2D images with metadata :param basis: The basis method to expand 2D images :param var_noise: The estimated variance of noise + :param batch_size: Integer batch size for processing images. + Defaults to 512. + :param covar_opt: Optional dictionary of option overides for Cov2D. + Provided options will supersede defaults in `DenoiserCov2D.default_opt`. """ + super().__init__(src) + self.batch_size = int(batch_size) # When var_noise is not specfically over-ridden, # recompute it now. See #496. @@ -126,76 +142,69 @@ def __init__(self, src, basis=None, var_noise=None): if basis is None: basis = FFBBasis2D((self.src.L, self.src.L), dtype=src.dtype) - if not isinstance(basis, FFBBasis2D): - raise NotImplementedError("Currently only fast FB method is supported") - self.basis = basis self.cov2d = None self.mean_est = None self.covar_est = None - def denoise(self, covar_opt=None, batch_size=512): + # Create a local copy of the default options. + default_opt = deepcopy(self.default_opt) + # Assign the dtype corresponding to this instance. + default_opt["precision"] = self.dtype + # Apply any overrides provided by the user. + self.covar_opt = fill_struct(covar_opt, default_opt) + + # Initialize the rotationally invariant covariance matrix of 2D images + # A fixed batch_size is used to loop through image stack. + self.cov2d = BatchedRotCov2D(self.src, self.basis, batch_size=batch_size) + + def build_denoiser(self): """ - Build covariance matrix of 2D images and return a new ImageSource object + Build estimated mean and covariance matrix of 2D images. - :param covar_opt: The option list for building Cov2D matrix - :param batch_size: The batch size for processing images - :return: A `DenoisedImageSource` object with the specified denoising object + This method should be computed once, on first `images` access. """ - # Initialize the rotationally invariant covariance matrix of 2D images - # A fixed batch size is used to go through each image - self.cov2d = BatchedRotCov2D(self.src, self.basis, batch_size=batch_size) + if self.covar_est is not None: + return - default_opt = { - "shrinker": "frobenius_norm", - "verbose": 0, - "max_iter": 250, - "iter_callback": [], - "store_iterates": False, - "rel_tolerance": 1e-12, - "precision": self.dtype, - } - - covar_opt = fill_struct(covar_opt, default_opt) - # Calculate the mean and covariance for the rotationally invariant covariance matrix of 2D images + logger.info(f"Building mean estimate for {len(self.src)} images.") self.mean_est = self.cov2d.get_mean() + logger.info(f"Building covariance estimates for {len(self.src)} images.") self.covar_est = self.cov2d.get_covar( - noise_var=self.var_noise, mean_coeff=self.mean_est, covar_est_opt=covar_opt + noise_var=self.var_noise, + mean_coef=self.mean_est, + covar_est_opt=self.covar_opt, ) - return DenoisedImageSource(self.src, self) - - def images(self, istart=0, batch_size=512): + def _denoise(self, indices): """ - Obtain a batch size of 2D images after denosing by Cov2D method + Compute denoised 2D images corresponding to `indices`. - :param istart: the index of starting image - :param batch_size: The batch size for processing images - :return: an `Image` object with denoised images + :return: `Image` object containing denoised images. """ - src = self.src - - # Denoise one batch size of 2D images using the SPCAs from the rotationally invariant covariance matrix - img_start = istart - img_end = min(istart + batch_size, src.n) - imgs_noise = src.images[img_start : img_start + batch_size] - coeffs_noise = self.basis.evaluate_t(imgs_noise) - logger.info( - f"Estimating Cov2D coefficients for images from {img_start} to {img_end-1}" - ) - coeffs_estim = self.cov2d.get_cwf_coeffs( - coeffs_noise, - self.cov2d.ctf_fb, - self.cov2d.ctf_idx[img_start:img_end], - mean_coeff=self.mean_est, - covar_coeff=self.covar_est, + + # Lazy evaluate estimates on access. + # `build_denoiser` internally guards to compute once. + self.build_denoiser() + + # Denoise requested `indices` selection of 2D images. + imgs_noise = self.src.images[indices] + + coefs_noise = self.basis.evaluate_t(imgs_noise) + logger.debug(f"Estimating Cov2D coefficients for {imgs_noise.n_images} images.") + coefs_estim = self.cov2d.get_cwf_coefs( + coefs_noise, + self.cov2d.ctf_basis, + self.cov2d.ctf_idx[indices], + mean_coef=self.mean_est, + covar_coef=self.covar_est, noise_var=self.var_noise, ) # Convert Fourier-Bessel coefficients back into 2D images logger.info("Converting Cov2D coefficients back to 2D images") - imgs_denoised = self.basis.evaluate(coeffs_estim) + imgs_denoised = self.basis.evaluate(coefs_estim) return imgs_denoised diff --git a/src/aspire/downloader/__init__.py b/src/aspire/downloader/__init__.py index e29b9a0983..be0d375878 100644 --- a/src/aspire/downloader/__init__.py +++ b/src/aspire/downloader/__init__.py @@ -18,4 +18,5 @@ emdb_10835, emdb_14621, remove_downloads, + simulated_channelspin, ) diff --git a/src/aspire/downloader/data_fetcher.py b/src/aspire/downloader/data_fetcher.py index 44b11dbd9d..3fe163521b 100644 --- a/src/aspire/downloader/data_fetcher.py +++ b/src/aspire/downloader/data_fetcher.py @@ -1,9 +1,13 @@ import shutil +import numpy as np import pooch from aspire import config from aspire.downloader import file_to_method_map, registry, registry_urls +from aspire.image import Image +from aspire.source import _LegacySimulation +from aspire.utils import Rotation from aspire.volume import Volume # Initialize pooch data fetcher instance. @@ -263,3 +267,27 @@ def emdb_6458(): vol = Volume.load(file_path, symmetry_group="C11") return vol + + +def simulated_channelspin(): + """ + Downloads the Simulated ChannelSpin dataset and returns the file path. + + This dataset includes a stack of 54 volumes sized (54,54,54) + and a corresponding stack of 10000 projection images (54,54). + + :return: Dictionary containing Volume and Image instances, + along with associated metadata fields in Numpy arrays. + """ + file_path = fetch_data("simulated_channelspin.npz") + # Use context manager so the file handle closes. + with np.load(file_path) as data: + # Convert to dict so that the entries can be modified + data = dict(data) + + # Instantiate ASPIRE objects where appropriate + data["vols"] = Volume(data["vols"]) + data["images"] = Image(data["images"]) + data["rots"] = Rotation(_LegacySimulation.rots_zyx_to_legacy_aspire(data["rots"])) + + return data diff --git a/src/aspire/downloader/registry.py b/src/aspire/downloader/registry.py index 8c82b4701e..467ad3c772 100644 --- a/src/aspire/downloader/registry.py +++ b/src/aspire/downloader/registry.py @@ -13,6 +13,7 @@ "emdb_14621.map": "b45774245c2bd5e1a44e801b8fb1705a44d5850631838d060294be42e34a6900", "emdb_2484.map": "6a324e23352bea101c191d5e854026162a5a9b0b8fc73ac5a085cc22038e1999", "emdb_6458.map": "645208af6d36bbd3d172c549e58d387b81142fd320e064bc66105be0eae540d1", + "simulated_channelspin.npz": "c0752674acb85417f6a77a28ac55280c1926c73fda9e25ce0a9940728b1dfcc8", } registry_urls = { @@ -29,6 +30,7 @@ "emdb_14621.map": "https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-14621/map/emd_14621.map.gz", "emdb_2484.map": "https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-2484/map/emd_2484.map.gz", "emdb_6458.map": "https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-6458/map/emd_6458.map.gz", + "simulated_channelspin.npz": "https://zenodo.org/records/8186548/files/example_FakeKV_dataset.npz", } file_to_method_map = { @@ -45,4 +47,5 @@ "emdb_14621.map": "emdb_14621", "emdb_2484.map": "emdb_2484", "emdb_6458.map": "emdb_6458", + "simulated_channelspin.npz": "simulated_channelspin", } diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 8d8b7652b0..11ee7f412a 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -58,8 +58,8 @@ def normalize_bg(imgs, bg_radius=1.0, do_ramp=True): imgs = imgs.reshape((-1, L * L)) # Fit a ramping background and apply to images - coeff = lstsq(ramp_mask, imgs[:, mask_reshape].T)[0] # RCOPT - imgs = imgs - (ramp_all @ coeff).T # RCOPT + coef = lstsq(ramp_mask, imgs[:, mask_reshape].T)[0] # RCOPT + imgs = imgs - (ramp_all @ coef).T # RCOPT imgs = imgs.reshape((-1, L, L)) # Apply mask images and calculate mean and std values of background @@ -437,6 +437,7 @@ def load(filepath, dtype=None): def _im_translate(self, shifts): """ Translate image by shifts + :param im: An array of size n-by-L-by-L containing images to be translated. :param shifts: An array of size n-by-2 specifying the shifts in pixels. Alternatively, it can be a row vector of length 2, in which case the same shifts is applied to each image. @@ -493,6 +494,7 @@ def size(self): def backproject(self, rot_matrices): """ Backproject images along rotation + :param im: An Image (stack) to backproject. :param rot_matrices: An n-by-3-by-3 array of rotation matrices \ corresponding to viewing directions. @@ -569,7 +571,7 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() - def frc(self, other, cutoff, pixel_size=None, method="fft", plot=False): + def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): r""" Compute the Fourier ring correlation between two images. @@ -584,6 +586,9 @@ def frc(self, other, cutoff, pixel_size=None, method="fft", plot=False): :param other: `Image` instance to compare. :param cutoff: Cutoff value, traditionally `.143`. + Default `None` implies `cutoff=1` and excludes + plotting cutoff line. + :param pixel_size: Pixel size in angstrom. Default `None` implies unit in pixels, equivalent to pixel_size=1. :param method: Selects either 'fft' (on cartesian grid), diff --git a/src/aspire/image/xform.py b/src/aspire/image/xform.py index 189bc33ce4..0759c819e8 100644 --- a/src/aspire/image/xform.py +++ b/src/aspire/image/xform.py @@ -76,6 +76,7 @@ def _forward(self, im, indices): def enabled(self): """ Enable this Xform in a context manager, regardless of its `active` attribute value. + :return: A context manager in which this Xform is enabled. """ return Xform.XformActiveContextManager(self, active=True) @@ -83,6 +84,7 @@ def enabled(self): def disabled(self): """ Disable this Xform in a context manager, regardless of its `active` attribute value. + :return: A context manager in which this Xform is disabled. """ return Xform.XformActiveContextManager(self, active=False) diff --git a/src/aspire/noise/noise.py b/src/aspire/noise/noise.py index 154e77f624..642ef0f5c5 100644 --- a/src/aspire/noise/noise.py +++ b/src/aspire/noise/noise.py @@ -277,6 +277,7 @@ def _create_filter(self, noise_variance=None): def _estimate_noise_variance(self): """ Any additional arguments/keyword-arguments are passed on to the Source's 'images' method + :return: The estimated noise variance of the images in the Source used to create this estimator. TODO: How's this initial estimate of variance different from the 'estimate' method? """ @@ -292,7 +293,7 @@ def _estimate_noise_variance(self): _denominator = self.src.n * np.sum(mask) first_moment += np.sum(images_masked) / _denominator - second_moment += np.sum(np.abs(images_masked**2)) / _denominator + second_moment += np.sum(np.abs(images_masked) ** 2) / _denominator return second_moment - first_moment**2 @@ -338,7 +339,7 @@ def estimate_noise_psd(self): _denominator = self.src.n * np.sum(mask) mean_est += np.sum(images_masked) / _denominator im_masked_f = xp.asnumpy(fft.centered_fft2(xp.asarray(images_masked))) - noise_psd_est += np.sum(np.abs(im_masked_f**2), axis=0) / _denominator + noise_psd_est += np.sum(np.abs(im_masked_f) ** 2, axis=0) / _denominator mid = self.src.L // 2 noise_psd_est[mid, mid] -= mean_est**2 diff --git a/src/aspire/nufft/__init__.py b/src/aspire/nufft/__init__.py index 6452956ba5..aa7c3a4adf 100644 --- a/src/aspire/nufft/__init__.py +++ b/src/aspire/nufft/__init__.py @@ -18,6 +18,7 @@ def check_backends(raise_errors=True): """ Check all NFFT backends in package configuration + :param raise_errors: Whether to raise a RuntimeError if no backends detected. :return: On return, the global names 'backends'/'default_plan_class' have been populated @@ -28,6 +29,7 @@ def check_backends(raise_errors=True): def _try_backend(backend): """ This function tries out a particular NFFT backend by name. + :param backend: A string representing the NFFT backend we want to try. Currently one of: 'cufinufft' The Python wrapper for the CUDA variant of FINUFFT library diff --git a/src/aspire/nufft/pynfft.py b/src/aspire/nufft/pynfft.py index fdeb831956..fec853098f 100644 --- a/src/aspire/nufft/pynfft.py +++ b/src/aspire/nufft/pynfft.py @@ -17,6 +17,7 @@ def epsilon_to_nfft_cutoff(epsilon): def __init__(self, sz, fourier_pts, epsilon=1e-15, **kwargs): """ A plan for non-uniform FFT (3D) + :param sz: A tuple indicating the geometry of the signal :param fourier_pts: The points in Fourier space where the Fourier transform is to be calculated, arranged as a 3-by-K array. These need to be in the range [-pi, pi] in each dimension. diff --git a/src/aspire/operators/blk_diag_matrix.py b/src/aspire/operators/blk_diag_matrix.py index a92dd078b0..3e493f9525 100644 --- a/src/aspire/operators/blk_diag_matrix.py +++ b/src/aspire/operators/blk_diag_matrix.py @@ -3,6 +3,9 @@ block diagonal matrices as used by ASPIRE. """ +import logging +import warnings + import numpy as np from numpy.linalg import norm, solve from scipy.linalg import block_diag @@ -10,6 +13,8 @@ from aspire.utils import make_psd from aspire.utils.cell import Cell2D +logger = logging.getLogger(__name__) + def is_scalar_type(x): """ @@ -86,7 +91,7 @@ def append(self, blk): :param blk: Block to append (ndarray). """ - self.data.append(blk) + self.data.append(blk.astype(self.dtype, copy=False)) self.nblocks += 1 self.reset_cache() @@ -125,7 +130,7 @@ def __setitem__(self, key, value): Convenience wrapper, setter on self.data. """ - self.data[key] = value + self.data[key] = value.astype(self.dtype, copy=False) self.reset_cache() def __len__(self): @@ -706,14 +711,14 @@ def apply(self, X): cols = self.partition[:, 1] - if np.sum(cols) != np.size(X, 0): - raise RuntimeError("Sizes of matrix `self` and `X` are not compatible.") - vector = False if np.ndim(X) == 1: X = X[:, np.newaxis] vector = True + if np.sum(cols) != np.size(X, 0): + raise RuntimeError("Sizes of matrix `self` and `X` are not compatible.") + rows = np.array( [ np.size(X, 1), @@ -754,6 +759,7 @@ def rapply(self, X): def eigvals(self): """ Compute the eigenvalues of a BlkDiagMatrix. + :return: Array of eigvals, with length equal to the fully expanded matrix diagonal. """ @@ -765,7 +771,7 @@ def check_psd(self): :return: True if all blocks have non-negative eigenvalues. """ - return np.alltrue(self.eigvals() > 0.0) + return np.all(self.eigvals() > 0.0) def make_psd(self): """ @@ -940,3 +946,45 @@ def diag(self): diag.extend(list(np.diag(blk))) return DiagMatrix(np.array(diag, dtype=self.dtype)) + + @staticmethod + def from_dense(A, blk_partition, warn_eps=1e-3): + """ + Create BlkDiagMatrix with `blk_partition` from dense matrix `A`. + + :param A: Dense `Numpy` array. + :param blk_partition: List of block partition shapes. + :param warn_eps: Optionally warn if off block values from `A` + exceed `warn_eps`. `None` disables warnings. + :return: `BlkDiagMatrix` with values from `A`. + """ + + # Instantiate an empty BlkDiagMatrix with `blk_partition` + B = BlkDiagMatrix.zeros(blk_partition, dtype=A.dtype) + + # Set the data + inds = np.array([0, 0]) + for i in range(B.nblocks): + ends = inds + B.partition[i] + B[i][:, :] = A[inds[0] : ends[0], inds[1] : ends[1]] + inds = ends + + # We should reach exactly the end of A when partition was correct + if not np.all(inds == A.shape): + raise RuntimeError( + "Block partition appears to mismatch shape of dense matrix A." + ) + + if warn_eps is not None: + max_diff = np.max(np.abs((A - B.dense()))) + if max_diff > warn_eps: + # Warn (once) + warnings.warn( + f"BlkDiagMatrix.from_dense truncating values exceeding {warn_eps}", + UserWarning, + stacklevel=2, + ) + # Log the specifics for debugging + logger.debug(f"BlkDiagMatrix.from_dense truncated max value {max_diff}") + + return B diff --git a/src/aspire/operators/diag_matrix.py b/src/aspire/operators/diag_matrix.py index 751d748fbf..4c94ab83d1 100644 --- a/src/aspire/operators/diag_matrix.py +++ b/src/aspire/operators/diag_matrix.py @@ -468,13 +468,17 @@ def dense(self): def apply(self, X): """ Define the apply option of a diagonal matrix with a matrix of - coefficient vectors. + coefficient column vectors. - :param X: Coefficient matrix, each column is a coefficient vector. - :return: A matrix with new coefficient vectors. + :param X: Coefficient matrix (ndarray), each column is a coefficient vector. + :return: A matrix with new coefficient column vectors. """ - return self * DiagMatrix(X) + # Transpose X to become row major because, + # X is a coefficient matrix (ndarray), each column is a coefficient vector. + # Transpose the row major multiplication result back to column major, to + # return a matrix with new coefficient column vectors. + return (self * DiagMatrix(X.T)).asnumpy().T def rapply(self, X): """ diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 839a6dffe5..696482d649 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -5,7 +5,6 @@ from scipy.interpolate import RegularGridInterpolator from aspire.utils import grid_2d, voltage_to_wavelength -from aspire.utils.filter_to_fb_mat import filter_to_fb_mat logger = logging.getLogger(__name__) @@ -53,6 +52,7 @@ def __mul__(self, other): def __str__(self): """ Show class name of Filter + :return: A string of class name """ return self.__class__.__name__ @@ -89,11 +89,15 @@ def evaluate(self, omega): def _evaluate(self, omega): raise NotImplementedError("Subclasses should implement this method") - def fb_mat(self, fbasis): + def basis_mat(self, basis): """ - Represent the filter in FB basis matrix + Represent the filter in `basis`. + + :param basis: 2D Basis. + :return: `basis` representation of this filter. + Return type will depend on `basis`. """ - return filter_to_fb_mat(self.evaluate, fbasis) + return basis.filter_to_basis_mat(self) def scale(self, c=1): """ @@ -105,7 +109,7 @@ def scale(self, c=1): """ return ScaledFilter(self, c) - def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs): + def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): """ Generates a two dimensional grid with prescribed dtype, yielding the values (omega) which are then evaluated by @@ -187,7 +191,7 @@ def __init__(self, filter, power=1): def _evaluate(self, omega): return self._filter.evaluate(omega) ** self._power - def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs): + def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): """ Calls the provided filter's evaluate_grid method in case there is an optimization. @@ -195,10 +199,23 @@ def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs): See `Filter.evaluate_grid` for usage. """ + filter_vals = self._filter.evaluate_grid(L, *args, dtype=dtype, **kwargs) + + # Place safeguard on values below machine epsilon for negative powers. + if self._power < 0: + eps = np.finfo(filter_vals.dtype).eps + condition = abs(filter_vals) < eps + num_less_eps = np.count_nonzero(condition) + if num_less_eps > 0: + logger.warning( + f"{self} setting {num_less_eps} extremal filter value(s) to zero." + ) - return ( - self._filter.evaluate_grid(L, dtype=dtype, *args, **kwargs) ** self._power - ) + filter_vals = np.where(condition, 0, filter_vals**self._power) + + return filter_vals + + return filter_vals**self._power class LambdaFilter(Filter): @@ -321,7 +338,7 @@ def _evaluate(self, omega): return result - def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs): + def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): """ Optimized evaluate_grid method for ArrayFilter. @@ -336,7 +353,6 @@ def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs): See Filter.evaluate_grid for usage. """ - if all(dim == L for dim in self.xfer_fn_array.shape): logger.debug( "Size of transfer function matches evaluate_grid size L exactly," @@ -345,7 +361,7 @@ def evaluate_grid(self, L, dtype=np.float32, *args, **kwargs): res = self.xfer_fn_array else: # Otherwise call parent code to generate a grid then evaluate. - res = super().evaluate_grid(L, dtype=dtype, *args, **kwargs) + res = super().evaluate_grid(L, *args, dtype=dtype, **kwargs) return res diff --git a/src/aspire/operators/polar_ft.py b/src/aspire/operators/polar_ft.py index 48b08c4c0c..2b90d5cad6 100644 --- a/src/aspire/operators/polar_ft.py +++ b/src/aspire/operators/polar_ft.py @@ -77,8 +77,8 @@ def _precomp(self): # only need half size of ntheta freqs = np.zeros((2, self.ntheta // 2, self.nrad), dtype=self.dtype) for i in range(self.ntheta // 2): - freqs[0, i] = np.cos(i * dtheta) - freqs[1, i] = np.sin(i * dtheta) + freqs[0, i] = np.sin(i * dtheta) + freqs[1, i] = np.cos(i * dtheta) freqs *= omega0 * np.arange(self.nrad) diff --git a/src/aspire/reconstruction/estimator.py b/src/aspire/reconstruction/estimator.py index c2a07e7a78..dec6a901da 100644 --- a/src/aspire/reconstruction/estimator.py +++ b/src/aspire/reconstruction/estimator.py @@ -1,5 +1,6 @@ import logging +from aspire.basis import Coef from aspire.reconstruction.kernel import FourierKernel logger = logging.getLogger(__name__) @@ -58,26 +59,29 @@ def __getattr__(self, name): def compute_kernel(self): raise NotImplementedError("Subclasses must implement the compute_kernel method") - def estimate(self, b_coeff=None, tol=1e-5, regularizer=0): + def estimate(self, b_coef=None, tol=1e-5, regularizer=0): """Return an estimate as a Volume instance.""" - if b_coeff is None: - b_coeff = self.src_backward() - est_coeff = self.conj_grad(b_coeff, tol=tol, regularizer=regularizer) - est = self.basis.evaluate(est_coeff).T + if b_coef is None: + b_coef = self.src_backward() + est_coef = self.conj_grad(b_coef, tol=tol, regularizer=regularizer) + est = Coef(self.basis, est_coef).evaluate().T return est - def apply_kernel(self, vol_coeff, kernel=None): + def apply_kernel(self, vol_coef, kernel=None): """ Applies the kernel represented by convolution - :param vol_coeff: The volume to be convolved, stored in the basis coefficients. + + :param vol_coef: The volume to be convolved, stored in the basis coefficients. :param kernel: a Kernel object. If None, the kernel for this Estimator is used. - :return: The result of evaluating `vol_coeff` in the given basis, convolving with the kernel given by + :return: The result of evaluating `vol_coef` in the given basis, convolving with the kernel given by kernel, and backprojecting into the basis. """ + if kernel is None: kernel = self.kernel - vol = self.basis.evaluate(vol_coeff) # returns a Volume + + vol = Coef(self.basis, vol_coef).evaluate() # returns a Volume vol = kernel.convolve_volume(vol) # returns a Volume vol_coef = self.basis.evaluate_t(vol) return vol_coef diff --git a/src/aspire/reconstruction/mean.py b/src/aspire/reconstruction/mean.py index 96dc811080..fda258cecd 100644 --- a/src/aspire/reconstruction/mean.py +++ b/src/aspire/reconstruction/mean.py @@ -6,6 +6,7 @@ from scipy.sparse.linalg import LinearOperator, cg from aspire import config +from aspire.basis import Coef from aspire.nufft import anufft from aspire.numeric import fft from aspire.operators import evaluate_src_filters_on_grid @@ -172,8 +173,8 @@ def src_backward(self): return res - def conj_grad(self, b_coeff, tol=1e-5, regularizer=0): - count = b_coeff.shape[-1] # b_coef should be (r, basis.count) + def conj_grad(self, b_coef, tol=1e-5, regularizer=0): + count = b_coef.shape[-1] # b_coef should be (r, basis.count) kernel = self.kernel if regularizer > 0: @@ -197,50 +198,50 @@ def conj_grad(self, b_coeff, tol=1e-5, regularizer=0): ) tol = tol or config.mean.cg_tol - target_residual = tol * norm(b_coeff) + target_residual = tol * norm(b_coef) def cb(xk): logger.info( - f"Delta {norm(b_coeff - self.apply_kernel(xk))} (target {target_residual})" + f"Delta {norm(b_coef - self.apply_kernel(xk))} (target {target_residual})" ) - x, info = cg(operator, b_coeff.flatten(), M=M, callback=cb, tol=tol, atol=0) + x, info = cg(operator, b_coef.flatten(), M=M, callback=cb, tol=tol, atol=0) if info != 0: raise RuntimeError("Unable to converge!") return x.reshape(self.r, self.basis.count) - def apply_kernel(self, vol_coeff, kernel=None): + def apply_kernel(self, vol_coef, kernel=None): """ Applies the kernel represented by convolution - :param vol_coeff: The volume to be convolved, stored in the basis coefficients. + :param vol_coef: The volume to be convolved, stored in the basis coefficients. :param kernel: a Kernel object. If None, the kernel for this Estimator is used. - :return: The result of evaluating `vol_coeff` in the given basis, convolving with the kernel given by + :return: The result of evaluating `vol_coef` in the given basis, convolving with the kernel given by kernel, and backprojecting into the basis. """ if kernel is None: kernel = self.kernel - assert np.size(vol_coeff) == self.r * self.basis.count - if vol_coeff.ndim == 1: - vol_coeff = vol_coeff.reshape(self.r, self.basis.count) + assert np.size(vol_coef) == self.r * self.basis.count + if vol_coef.ndim == 1: + vol_coef = vol_coef.reshape(self.r, self.basis.count) vols_out = Volume( np.zeros((self.r, self.src.L, self.src.L, self.src.L), dtype=self.dtype) ) - vol = self.basis.evaluate(vol_coeff) + vol = Coef(self.basis, vol_coef).evaluate() for k in range(self.r): for j in range(self.r): vols_out[k] = vols_out[k] + kernel.convolve_volume(vol[j], j, k) # Note this is where we would add mask_gamma - vol_coeff = self.basis.evaluate_t(vols_out) + vol_coef = self.basis.evaluate_t(vols_out) - return vol_coeff + return vol_coef class MeanEstimator(WeightedVolumesEstimator): diff --git a/src/aspire/source/__init__.py b/src/aspire/source/__init__.py index 7fd7691a3b..c3703db8bc 100644 --- a/src/aspire/source/__init__.py +++ b/src/aspire/source/__init__.py @@ -8,7 +8,7 @@ OrientedSource, ) from aspire.source.relion import RelionSource -from aspire.source.simulation import Simulation +from aspire.source.simulation import Simulation, _LegacySimulation # isort: off from aspire.source.micrograph import ( diff --git a/src/aspire/source/coordinates.py b/src/aspire/source/coordinates.py index 027490711b..dca7aaf873 100644 --- a/src/aspire/source/coordinates.py +++ b/src/aspire/source/coordinates.py @@ -197,6 +197,7 @@ def _box_coord_from_center(center, particle_size): to a list `[lower left x, lower left y, particle_size, particle_size]` representing the box around the particle in box format. + :param center: a list of length two representing a center :param particle_size: the size of the box around the particle """ @@ -219,6 +220,7 @@ def _center_from_box_coord(box_coord): `[lower left x, lower left y, particle_size, particle_size]` representing a particle in the box format to a list `[x, y]` representing the particle center. + :param box_coord: a list of length 4 representing the particle box """ # Get lower left corner x and y coordinates @@ -232,6 +234,7 @@ def _coords_list_from_star(self, star_file): """ Given a Relion STAR coordinate file (generally containing particle centers) return a list of coordinates in box format. + :param star_file: A path to a STAR file containing particle centers """ data_block = StarFile(star_file).get_block_by_index(0) @@ -245,6 +248,7 @@ def _populate_local_metadata(self): """ Called during ImageSource.save(), populates metadata columns specific to `CoordinateSource` when saving to STAR file. + :return: A list of the names of the columns added. """ # Insert stored particle coordinates (centers) into metadata @@ -268,6 +272,7 @@ def _exclude_boundary_particles(self): """ Remove particles boxes which do not fit in the micrograph with the given `particle_size`. + :return: Number of particles removed """ out_of_range = [] @@ -414,6 +419,7 @@ def _crop_micrograph(data, coord): Crops a particle box defined by `coord` out of `data`. According to MRC 2014 convention, the origin represents the bottom-left corner of the image. + :param data: A 2D numpy array representing a micrograph :param coord: A list of integers: (lower left X, lower left Y, X, Y) """ @@ -430,6 +436,7 @@ def _images(self, indices): particles were excluded due to their box not fitting into the mrc dimensions. Thus, the exact particles returned are a function of the `particle_size`. + :param indices: A 1-D NumPy array of integer indices. :return: An `Image` object. """ diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index e5fb7a4350..aa127bfb66 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -53,6 +53,7 @@ def __getitem__(self, indices): ImageAccessor can be indexed via Python slice object, 1-D NumPy array, list, or a single integer, corresponding to the indices of the requested images. By default, slices default to a start of 0, an end of self.num_imgs, and a step of 1. + :return: An Image object containing the requested images. """ if isinstance(indices, Iterable) and not isinstance(indices, np.ndarray): @@ -304,6 +305,7 @@ def __len__(self): def _metadata_as_dict(self, metadata_fields, indices, default_value=None): """ Return a dictionary of selected metadata fields at selected indices. + :param metadata_fields: An iterable of strings specifying metadata fields. :param indices: An ndarray of 0-indexed locations we're interested in. :param default_value: A scalar default value to use if a metadata_field is not found. @@ -324,6 +326,7 @@ def _metadata_as_dict(self, metadata_fields, indices, default_value=None): def _metadata_as_ndarray(self, metadata_fields, indices, default_value=None): """ Return a numpy array of selected metadata fields at selected indices. + :param metadata_fields: An iterable of strings specifying metadata fields. :param indices: An ndarray of 0-indexed locations we're interested in. :param default_value: A scalar default value to use if a metadata_field is not found. @@ -1078,6 +1081,7 @@ def _populate_local_metadata(self): """ Populate metadata columns specific to the `ImageSource` subclass being saved. Subclasses optionally override, but must return a list of strings. + :return: A list of the names of the columns added. """ return [] @@ -1687,6 +1691,7 @@ def _images(self, indices): """ Returns images corresponding to `indices` after being accessed via the `ImageSource.images` property + :param indices: A 1-D NumPy array of indices. :return: An `Image` object. """ diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index 6106bf4389..02b8ab9363 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -203,6 +203,7 @@ def _images(self, indices): """ Returns particle images when accessed via the `ImageSource.images` property. Loads particle images corresponding to `indices` from StarFile and .mrcs stacks. + :param indices: A 1-D NumPy array of integer indices. :return: An `Image` object. """ diff --git a/src/aspire/source/simulation.py b/src/aspire/source/simulation.py index 2aa544abcf..e2ef10da12 100644 --- a/src/aspire/source/simulation.py +++ b/src/aspire/source/simulation.py @@ -10,6 +10,7 @@ from aspire.source import ImageSource from aspire.source.image import _ImageAccessor from aspire.utils import ( + Rotation, acorr, ainner, anorm, @@ -147,9 +148,7 @@ def __init__( states = randi(self.C, n, seed=seed) self.states = states - if angles is None: - angles = uniform_random_angles(n, seed=seed, dtype=self.dtype) - self.angles = angles + self.angles = self._init_angles(angles) if unique_filters is None: unique_filters = [] @@ -192,6 +191,11 @@ def __init__( # Any further operations should not mutate this instance. self._mutable = False + def _init_angles(self, angles): + if angles is None: + angles = uniform_random_angles(self.n, seed=self.seed, dtype=self.dtype) + return angles + def _populate_ctf_metadata(self, filter_indices): # Since we are not reading from a starfile, we must construct # metadata based on the CTF filters by hand and set the values @@ -272,6 +276,7 @@ def _clean_images(self, indices): def _images(self, indices, clean_images=False): """ Returns particle images when accessed via the `ImageSource.images` property. + :param indices: A 1-D NumPy array of integer indices. :param clean_images: Only used internally, toggled on when `clean_images` requested. Will skip accessing cache, skip noise, while applying CTF to projections. @@ -366,6 +371,7 @@ def covar_true(self): def eigs(self): """ Eigendecomposition of volume covariance matrix of simulation + :return: A 2-tuple: eigs_true: The eigenvectors of the volume covariance matrix in the form of Volume instance. lambdas_true: The eigenvalues of the covariance matrix in the form of a (C-1)-by-(C-1) diagonal matrix. @@ -471,9 +477,10 @@ def eval_coords(self, mean_vol, eig_vols, coords_est): :param mean_vol: A mean volume in the form of a Volume instance. :param eig_vols: A set of eigenvolumes in an Volume instance. - :param coords_est: The estimated coordinates in the affine space defined centered at `mean_vol` and spanned - by `eig_vols`. - :return: + :param coords_est: The estimated coordinates in the affine space defined centered + at `mean_vol` and spanned by `eig_vols`. + :return: Dictionary containing error, relative error, and correlation for each set + of estimated coordinates. """ assert isinstance(mean_vol, Volume) assert isinstance(eig_vols, Volume) @@ -481,36 +488,49 @@ def eval_coords(self, mean_vol, eig_vols, coords_est): # 0-indexed states vector states = self.states - 1 - coords_true = coords_true[states] + + coords_true = coords_true.T[states] res_norms = res_norms[states] res_inners = res_inners[:, states] - mean_eigs_inners = (mean_vol.to_vec() @ eig_vols.to_vec().T).item() + if coords_est.ndim == 1: + coords_est = coords_est[:, None] + coords_true = coords_true[:, None] + + mean_eigs_inners = mean_vol.to_vec() @ eig_vols.to_vec().T coords_err = coords_true - coords_est - err = np.hypot(res_norms, coords_err) + K = coords_true.shape[-1] + err = np.zeros((K, len(coords_true))) + rel_err = np.zeros((K, len(coords_true))) + corr = np.zeros((K, len(coords_true))) - mean_vol_norm2 = anorm(mean_vol) ** 2 - norm_true = np.sqrt( - coords_true**2 - + mean_vol_norm2 - + 2 * res_inners - + 2 * mean_eigs_inners * coords_true - ) - norm_true = np.hypot(res_norms, norm_true) + for k in range(K): + err[k] = np.hypot(res_norms, coords_err[:, k]) - rel_err = err / norm_true - inner = ( - mean_vol_norm2 - + mean_eigs_inners * (coords_true + coords_est) - + coords_true * coords_est - + res_inners - ) - norm_est = np.sqrt( - coords_est**2 + mean_vol_norm2 + 2 * mean_eigs_inners * coords_est - ) + mean_vol_norm2 = anorm(mean_vol) ** 2 + norm_true = np.sqrt( + coords_true[:, k] ** 2 + + mean_vol_norm2 + + 2 * res_inners + + 2 * mean_eigs_inners[:, k] * coords_true[:, k] + ) + norm_true = np.hypot(res_norms, norm_true) + + rel_err[k] = err[k] / norm_true + inner = ( + mean_vol_norm2 + + mean_eigs_inners[:, k] * (coords_true[:, k] + coords_est[:, k]) + + coords_true[:, k] * coords_est[:, k] + + res_inners + ) + norm_est = np.sqrt( + coords_est[:, k] ** 2 + + mean_vol_norm2 + + 2 * mean_eigs_inners[:, k] * coords_est[:, k] + ) - corr = inner / (norm_true * norm_est) + corr[k] = inner / (norm_true * norm_est) return {"err": err, "rel_err": rel_err, "corr": corr} @@ -530,3 +550,53 @@ def true_snr(self, *args, **kwargs): noise_power = self.noise_adder.noise_var signal_power = self.true_signal_power(*args, **kwargs) return signal_power / noise_power + + +class _LegacySimulation(Simulation): + """ + Legacy Simulation enforces the legacy grid convention for generating projection + images. + + Note, that `angles`, and thus `rotations`, are altered upon initialization. + To recover the rotations associated with the input angles use the staticmethod + `rots_zyx_to_legacy_aspire()`. + """ + + def _init_angles(self, angles): + angles = super()._init_angles(angles) + + # Convert to rotations. + rots = Rotation.from_euler(angles).matrices + + # Transform rotations to replicate legacy grid convention. + legacy_rots = Rotation(self.rots_zyx_to_legacy_aspire(rots)) + + # Convert back to angles. + return legacy_rots.angles.astype(self.dtype) + + @staticmethod + def rots_zyx_to_legacy_aspire(rots): + """ + Helper function to transform rotations to mimic original aspire python + grid indexing. Now that we are enforcing "zyx" grid indexing across the + code base, in particular for the rotated_grids used for volume projection, + we must transform rotation matrices to allow for existing hardcoded tests + to remain valid. + + Note, this transformation is it's own inverse. + + :param rots: n_rot x 3 x 3 array of rotation matrices. + :return: Transformed rotations. + """ + dtype = rots.dtype + + # Handle singletons + og_shape = rots.shape + if len(og_shape) == 2: + rots = np.expand_dims(rots, axis=0) + + # Transform rots + flip_xy = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]], dtype=dtype) + new_rots = rots[:, ::-1] @ flip_xy + + return new_rots.reshape(og_shape) diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index 6aa71fb8f5..e691e0ba5e 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -1,6 +1,7 @@ from .types import complex_type, real_type, utest_tolerance # isort:skip from .coor_trans import ( # isort:skip common_line_from_rots, + mean_aligned_angular_distance, crop_pad_2d, crop_pad_3d, get_aligned_rotations, @@ -45,6 +46,7 @@ make_symmat, mat_to_vec, mdim_mat_fun_conj, + nearest_rotations, roll_dim, symmat_to_vec, symmat_to_vec_iso, diff --git a/src/aspire/utils/bot_align.py b/src/aspire/utils/bot_align.py index 68f6f9047c..fc8fbe4fdc 100644 --- a/src/aspire/utils/bot_align.py +++ b/src/aspire/utils/bot_align.py @@ -11,7 +11,6 @@ from numpy.linalg import norm from scipy.optimize import minimize -from aspire.operators import wemd_embed from aspire.utils.rotation import Rotation # Store parameters specific to each loss_type. @@ -64,6 +63,9 @@ def align_BO( Default `None` infers dtype from `vol_ref`. :return: Rotation matrix R_init (without refinement) or (R_init, R_est) (with refinement). """ + # Avoid utils/operators/utils circular import + from aspire.operators import wemd_embed + # Infer dtype dtype = np.dtype(dtype or vol_ref.dtype) diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index a29e041966..e909e2f394 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -9,6 +9,7 @@ from scipy.linalg import svd from aspire.utils.random import Random +from aspire.utils.rotation import Rotation def cart2pol(x, y): @@ -237,7 +238,6 @@ def get_aligned_rotations(rots, Q_mat, flag): Calculated aligned rotation matrices from the orthogonal transformation that best aligns the estimated rotations to the reference rotations. - :param rots: The reference rotations to which we would like to align in the form of a n-by-3-by-3 array. :param Q_mat: optimal orthogonal 3x3 transformation matrix @@ -283,6 +283,30 @@ def get_rots_mse(rots_reg, rots_ref): return mse +def mean_aligned_angular_distance(rots_est, rots_gt, degree_tol=None): + """ + Register estimates to ground truth rotations and compute the + mean angular distance between them (in degrees). + + :param rots_est: A set of estimated rotations of size nx3x3. + :param rots_gt: A set of ground truth rotations of size nx3x3. + :param degree_tol: Option to assert if the mean angular distance is + less than `degree_tol` degrees. If `None`, returns the mean + aligned angular distance. + + :return: The mean angular distance between registered estimates + and the ground truth (in degrees). + """ + Q_mat, flag = register_rotations(rots_est, rots_gt) + regrot = get_aligned_rotations(rots_est, Q_mat, flag) + mean_ang_dist = Rotation.mean_angular_distance(regrot, rots_gt) * 180 / np.pi + + if degree_tol is not None: + np.testing.assert_array_less(mean_ang_dist, degree_tol) + + return mean_ang_dist + + def common_line_from_rots(r1, r2, ell): """ Compute the common line induced by rotation matrices r1 and r2. diff --git a/src/aspire/utils/filter_to_fb_mat.py b/src/aspire/utils/filter_to_fb_mat.py deleted file mode 100644 index 9933ffe3ba..0000000000 --- a/src/aspire/utils/filter_to_fb_mat.py +++ /dev/null @@ -1,60 +0,0 @@ -import numpy as np - -from aspire.operators import BlkDiagMatrix - - -def filter_to_fb_mat(h_fun, fbasis): - """ - Convert a nonradial function in k space into a basis representation. - - :param h_fun: The function form in k space. - :param fbasis: The basis object for expanding. - - :return: a BlkDiagMatrix instance representation using the - `fbasis` expansion. - """ - - # These form a circular dependence, import locally until time to clean up. - from aspire.basis import FFBBasis2D - from aspire.basis.basis_utils import lgwt - - if not isinstance(fbasis, FFBBasis2D): - raise NotImplementedError("Currently only fast FB method is supported") - # Set same dimensions as basis object - n_k = fbasis.n_r - n_theta = fbasis.n_theta - radial = fbasis.get_radial() - - # get 2D grid in polar coordinate - k_vals, wts = lgwt(n_k, 0, 0.5, dtype=fbasis.dtype) - k, theta = np.meshgrid( - k_vals, np.arange(n_theta) * 2 * np.pi / (2 * n_theta), indexing="ij" - ) - - # Get function values in polar 2D grid and average out angle contribution - omegax = k * np.cos(theta) - omegay = k * np.sin(theta) - omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C"))) - h_vals2d = h_fun(omega).reshape(n_k, n_theta).astype(fbasis.dtype) - h_vals = np.sum(h_vals2d, axis=1) / n_theta - - # Represent 1D function values in fbasis - h_fb = BlkDiagMatrix.empty(2 * fbasis.ell_max + 1, dtype=fbasis.dtype) - ind_ell = 0 - for ell in range(0, fbasis.ell_max + 1): - k_max = fbasis.k_max[ell] - rmat = 2 * k_vals.reshape(n_k, 1) * fbasis.r0[ell][0:k_max].T - fb_vals = np.zeros_like(rmat) - ind_radial = np.sum(fbasis.k_max[0:ell]) - fb_vals[:, 0:k_max] = radial[ind_radial : ind_radial + k_max].T - h_fb_vals = fb_vals * h_vals.reshape(n_k, 1) - h_fb_ell = fb_vals.T @ ( - h_fb_vals * k_vals.reshape(n_k, 1) * wts.reshape(n_k, 1) - ) - h_fb[ind_ell] = h_fb_ell - ind_ell += 1 - if ell > 0: - h_fb[ind_ell] = h_fb[ind_ell - 1] - ind_ell += 1 - - return h_fb diff --git a/src/aspire/utils/matrix.py b/src/aspire/utils/matrix.py index 20d54ee2dd..5e56d2e65e 100644 --- a/src/aspire/utils/matrix.py +++ b/src/aspire/utils/matrix.py @@ -434,6 +434,34 @@ def best_rank1_approximation(A): return (U @ S_rank1 @ V).reshape(og_shape) +def nearest_rotations(A): + """ + Uses the SVD method to compute the set of nearest rotations to the set A of noisy rotations. + + :param A: A 2D array or a 3D array where the first axis is the stack axis. + :return: ndarray of rotations of equal size to A. + """ + og_shape = A.shape + dtype = A.dtype + + if A.ndim == 2: + A = A[np.newaxis] + if A.ndim != 3 or not A.shape[1] == A.shape[2] == 3: + raise ValueError( + f"Array must be of shape (3, 3) or (n, 3, 3). Found shape {A.shape}." + ) + + # For the singular value decomposition A = U @ S @ V, we compute the nearest rotation + # matrices R = U @ V. If det(U)*det(V) = -1, we negate the third singular value to ensure + # we have a rotation. + U, _, V = np.linalg.svd(A) + neg_det_idx = np.linalg.det(U) * np.linalg.det(V) < 0 + U[neg_det_idx] = U[neg_det_idx] @ np.diag((1, 1, -1)).astype(dtype, copy=False) + rots = U @ V + + return rots.reshape(og_shape) + + def fix_signs(u): """ Negates columns so the sign of the largest element in the column is positive. diff --git a/src/aspire/utils/misc.py b/src/aspire/utils/misc.py index 7a91f66168..b113bcf85b 100644 --- a/src/aspire/utils/misc.py +++ b/src/aspire/utils/misc.py @@ -206,6 +206,7 @@ def gaussian_3d(size, mu=(0, 0, 0), sigma=(1, 1, 1), indexing="zyx", dtype=np.fl def bump_3d(size, spread=1, dtype=np.float64): """ Returns a centered 3D bump function in a (size)x(size)x(size) numpy array. + :param size: The length of the dimensions of the array (pixels. :param spread: A factor controling the spread of the bump function. :param dtype: dtype of returned array @@ -274,26 +275,52 @@ def inverse_r(size, x0=0, y0=0, peak=1, dtype=np.float64): return (peak / vals).astype(dtype) -def fuzzy_mask(L, r0, risetime, origin=None): +def fuzzy_mask(L, dtype, r0=None, risetime=None): """ - Create a centered 1D to 3D fuzzy mask of radius r0 + Create a centered 1D to 3D fuzzy mask of radius r0. Made with an error function with effective rise time. - :param L: The sizes of image in tuple structure - :param r0: The specified radius - :param risetime: The rise time for `erf` function - :param origin: The coordinates of origin + :param L: The sizes of image in tuple structure. Must be 1D, 2D square, + or 3D cube. + :param dtype: dtype for fuzzy mask. + :param r0: The specified radius. Defaults to floor(0.45 * L) + :param risetime: The rise time for `erf` function. Defaults to floor(0.05 * L) + :return: The desired fuzzy mask """ + # Note: default values for r0 and risetime are from Matlab common-lines code. + if r0 is None: + r0 = np.floor(0.45 * L[0]) + if risetime is None: + risetime = np.floor(0.05 * L[0]) + + dim = len(L) + axes = ["x"] + grid_kwargs = {"n": L[0], "shifted": False, "normalized": False, "dtype": dtype} + + if dim == 1: + grid = grid_1d(**grid_kwargs) - center = [sz // 2 + 1 for sz in L] - if origin is None: - origin = center + elif dim == 2: + if not (L[0] == L[1]): + raise ValueError(f"A 2D fuzzy_mask must be square, found L={L}.") + grid = grid_2d(**grid_kwargs) + axes.insert(0, "y") + + elif dim == 3: + if not (L[0] == L[1] == L[2]): + raise ValueError(f"A 3D fuzzy_mask must be cubic, found L={L}.") + grid = grid_3d(**grid_kwargs) + axes.insert(0, "y") + axes.insert(0, "z") + + else: + raise RuntimeError( + f"Only 1D, 2D, or 3D fuzzy_mask supported. Found {dim}-dimensional `L`." + ) - grids = [np.arange(1 - org, ell - org + 1) for ell, org in zip(L, origin)] - XYZ = np.meshgrid(*grids, indexing="ij") - XYZ_sq = [X**2 for X in XYZ] + XYZ_sq = [grid[axis] ** 2 for axis in axes] R = np.sqrt(np.sum(XYZ_sq, axis=0)) k = 1.782 / risetime m = 0.5 * (1 - erf(k * (R - r0))) diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 73780e36be..8e2fee99c9 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -2,6 +2,7 @@ This module contains code for estimating resolution achieved by reconstructions. """ import logging +import warnings import matplotlib.pyplot as plt import numpy as np @@ -38,13 +39,13 @@ class FourierCorrelation: def __init__(self, a, b, pixel_size=None, method="fft"): """ - :param a: Input array a, shape(..., *dim). - :param b: Input array b, shape(..., *dim). - :param pixel_size: Pixel size in angstrom. - Default `None` implies "pixel" units. - :param method: Selects either 'fft' (on Cartesian grid), - or 'nufft' (on polar grid). Defaults to 'fft'. - 7""" + :param a: Input array a, shape(..., *dim). + :param b: Input array b, shape(..., *dim). + :param pixel_size: Pixel size in angstrom. + Default `None` implies "pixel" units. + :param method: Selects either 'fft' (on Cartesian grid), + or 'nufft' (on polar grid). Defaults to 'fft'. + """ # Sanity checks if not hasattr(self, "dim"): @@ -248,8 +249,13 @@ def _nufft_correlations(self): def analyze_correlations(self, cutoff): """ Convert from the Fourier correlations to frequencies and resolution. + :param cutoff: Cutoff value, traditionally `.143`. + Note `cutoff=None` evaluates as `cutoff=1`. """ + # Handle optional cutoff plotting. + if cutoff is None: + cutoff = 1 cutoff = float(cutoff) if not (0 <= cutoff <= 1): @@ -270,8 +276,12 @@ def analyze_correlations(self, cutoff): # Convert indices to frequency (as 1/angstrom) frequencies = self._freq(c_ind) - # Convert to resolution in angstrom, smaller is higher frequency. - self._resolutions = 1 / frequencies + with warnings.catch_warnings(): + # When using high cutoff (eg. 1) it is possible `frequencies` + # contains 0; capture and ignore that division warning. + warnings.filterwarnings("ignore", r".*divide by zero.*") + # Convert to resolution in angstrom, smaller is higher frequency. + self._resolutions = 1 / frequencies return self._resolutions @@ -288,17 +298,25 @@ def _freq(self, k): # Similar to wavenumbers. Larger is higher frequency. return k / (self.L * self.pixel_size) - def plot(self, cutoff, save_to_file=False, labels=None): + def plot(self, cutoff=None, save_to_file=False, labels=None): """ Generates a Fourier correlation plot. :param cutoff: Cutoff value, traditionally `.143`. + Default `None` implies `cutoff=1` and excludes + plotting cutoff line. :param save_to_file: Optionally, save plot to file. Defaults False, enabled by providing a string filename. User is responsible for providing reasonable filename. See `https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html`. """ - cutoff = float(cutoff) + + # Handle optional cutoff plotting. + _plot_cutoff = True + if cutoff is None: + cutoff = 1 + _plot_cutoff = False + if not (0 <= cutoff <= 1): raise ValueError("Supplied correlation `cutoff` not in [0,1], {cutoff}") @@ -343,17 +361,20 @@ def plot(self, cutoff, save_to_file=False, labels=None): _label = labels[i] plt.plot(freqs_units, line, label=_label) - # Display cutoff - plt.axhline(y=cutoff, color="r", linestyle="--", label=f"cutoff={cutoff}") estimated_resolution = self.analyze_correlations(cutoff)[0] - # Display resolution - plt.axvline( - x=estimated_resolution, - color="b", - linestyle=":", - label=f"Resolution={estimated_resolution:.3f}", - ) + # Display cutoff + if _plot_cutoff: + plt.axhline(y=cutoff, color="r", linestyle="--", label=f"cutoff={cutoff}") + + # Display resolution + plt.axvline( + x=estimated_resolution, + color="b", + linestyle=":", + label=f"Resolution={estimated_resolution:.3f}", + ) + # x-axis decreasing plt.gca().invert_xaxis() plt.legend(title=f"Method: {self.method}") diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index c7738db9f6..cc29710045 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -434,10 +434,9 @@ def shift(self): def rotate(self, rot_matrices, zero_nyquist=True): """ - Rotate volumes using a `Rotation` object. If the `Rotation` object - is a single rotation, each volume will be rotated by that rotation. - If the `Rotation` object is a stack of rotations of length n_vols, - the ith volume is rotated by the ith rotation. + Rotate volumes, within a fixed grid, by `rot_matrices`. If `rot_matrices` is a single + rotation, each volume will be rotated by that rotation. If `rot_matrices` is a stack of + rotations of length n_vols, the ith volume will be rotated by the ith rotation. :param rot_matrices: `Rotation` object of length 1 or n_vols. :param zero_nyquist: Option to keep or remove Nyquist frequency for even resolution. @@ -454,13 +453,13 @@ def rotate(self, rot_matrices, zero_nyquist=True): rot_matrices, Rotation ), f"Argument must be an instance of the Rotation class. {type(rot_matrices)} was supplied." - # Get numpy representation of Rotation object. - rot_matrices = rot_matrices.matrices + # Invert the rotations passed to `rotated_grids_3d` and get numpy representation of Rotation object. + rots_inverted = rot_matrices.invert().matrices - K = len(rot_matrices) # Rotation stack size + K = len(rots_inverted) # Rotation stack size assert K == self.n_vols or K == 1, "Rotation object must be length 1 or n_vols." - if rot_matrices.dtype != self.dtype: + if rots_inverted.dtype != self.dtype: logger.warning( f"{self.__class__.__name__}" f" rot_matrices.dtype {rot_matrices.dtype}" @@ -470,19 +469,19 @@ def rotate(self, rot_matrices, zero_nyquist=True): # If K = 1 we broadcast the single Rotation object across each volume. if K == 1: - pts_rot = rotated_grids_3d(self.resolution, rot_matrices) + pts_rot = rotated_grids_3d(self.resolution, rots_inverted) vol_f = nufft(self.asnumpy(), pts_rot) vol_f = vol_f.reshape(-1, self.resolution, self.resolution, self.resolution) # If K = n_vols, we apply the ith rotation to ith volume. else: - rot_matrices = rot_matrices.reshape((K, 1, 3, 3)) + rots_inverted = rots_inverted.reshape((K, 1, 3, 3)) pts_rot = np.zeros((K, 3, self.resolution**3), dtype=self.dtype) vol_f = np.empty( (self.n_vols, self.resolution**3), dtype=complex_type(self.dtype) ) for i in range(K): - pts_rot[i] = rotated_grids_3d(self.resolution, rot_matrices[i]) + pts_rot[i] = rotated_grids_3d(self.resolution, rots_inverted[i]) vol_f[i] = nufft(self[i].asnumpy(), pts_rot[i]) @@ -550,7 +549,7 @@ def load(cls, filename, permissive=True, dtype=None, symmetry_group=None): return cls(loaded_data, symmetry_group=symmetry_group, dtype=dtype) - def fsc(self, other, cutoff, pixel_size=None, method="fft", plot=False): + def fsc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): r""" Compute the Fourier shell correlation between two volumes. @@ -565,6 +564,8 @@ def fsc(self, other, cutoff, pixel_size=None, method="fft", plot=False): :param other: `Volume` instance to compare. :param cutoff: Cutoff value, traditionally `.143`. + Default `None` implies `cutoff=1` and excludes + plotting cutoff line. :param pixel_size: Pixel size in angstrom. Default `None` implies unit in pixels, equivalent to pixel_size=1. :param method: Selects either 'fft' (on cartesian grid), @@ -658,9 +659,11 @@ def rotated_grids(L, rot_matrices): Frequencies are in the range [-pi, pi]. """ - grid2d = grid_2d(L, indexing="xy", dtype=rot_matrices.dtype) + grid2d = grid_2d(L, indexing="yx", dtype=rot_matrices.dtype) num_pts = L**2 num_rots = rot_matrices.shape[0] + + # Frequency points flattened and placed in xyz order to apply rotations. pts = np.pi * np.vstack( [ grid2d["x"].flatten(), @@ -672,7 +675,8 @@ def rotated_grids(L, rot_matrices): for i in range(num_rots): pts_rot[:, i, :] = rot_matrices[i, :, :] @ pts - pts_rot = pts_rot.reshape((3, num_rots, L, L)) + # Reshape rotated frequency points and convert back into zyx convention. + pts_rot = pts_rot.reshape((3, num_rots, L, L))[::-1] return pts_rot @@ -687,9 +691,11 @@ def rotated_grids_3d(L, rot_matrices): Frequencies are in the range [-pi, pi]. """ - grid3d = grid_3d(L, indexing="xyz", dtype=rot_matrices.dtype) + grid3d = grid_3d(L, indexing="zyx", dtype=rot_matrices.dtype) num_pts = L**3 num_rots = rot_matrices.shape[0] + + # Frequency points flattened and placed in xyz order to apply rotations. pts = np.pi * np.vstack( [ grid3d["x"].flatten(), @@ -702,4 +708,4 @@ def rotated_grids_3d(L, rot_matrices): pts_rot[:, i, :] = rot_matrices[i, :, :] @ pts # Note we return grids as (Z,Y,X) - return pts_rot.reshape(3, -1) + return pts_rot.reshape(3, -1)[::-1] diff --git a/src/aspire/volume/volume_synthesis.py b/src/aspire/volume/volume_synthesis.py index 579eb05919..c10c466081 100644 --- a/src/aspire/volume/volume_synthesis.py +++ b/src/aspire/volume/volume_synthesis.py @@ -78,9 +78,9 @@ def generate(self): def _gaussian_blob_vols(self): """ - Generates a Volume object composed of Gaussian blobs. + Generates a 4D array representing a stack of volumes composed of Gaussian blobs. - :return: A Volume instance containing C Gaussian blob volumes. + :return: An ndarray containing C Gaussian blob volumes. """ vols = np.zeros(shape=((self.C,) + (self.L,) * 3)).astype(self.dtype) with Random(self.seed): @@ -88,7 +88,7 @@ def _gaussian_blob_vols(self): Q, D, mu = self._gen_gaussians() Q_rot, D_sym, mu_rot = self._symmetrize_gaussians(Q, D, mu) vols[c] = self._eval_gaussians(Q_rot, D_sym, mu_rot) - return Volume(vols) + return vols def _gen_gaussians(self): """ @@ -141,7 +141,7 @@ def _eval_gaussians(self, Q, D, mu): :return: An L x L x L array. """ - g = grid_3d(self.L, indexing="xyz", dtype=self.dtype) + g = grid_3d(self.L, indexing="zyx", dtype=self.dtype) coords = np.array( [g["x"].flatten(), g["y"].flatten(), g["z"].flatten()], dtype=self.dtype ) @@ -263,4 +263,9 @@ def generate(self): """ Generates an asymmetric volume composed of random 3D Gaussian blobs. """ - return self._gaussian_blob_vols() + vols = self._gaussian_blob_vols() + + # Swap axes to retain Legacy xyz-indexing. + vols = np.swapaxes(vols, 1, 3) + + return Volume(vols) diff --git a/tests/_basis_util.py b/tests/_basis_util.py index fc4cea74a2..5138887496 100644 --- a/tests/_basis_util.py +++ b/tests/_basis_util.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from aspire.basis import Coef from aspire.image import Image from aspire.utils import gaussian_2d, utest_tolerance from aspire.utils.coor_trans import grid_2d @@ -44,8 +45,6 @@ def testIndices(self, basis): ell_max = basis.ell_max k_max = basis.k_max - indices = basis.indices() - i = 0 for ell in range(ell_max + 1): @@ -56,9 +55,9 @@ def testIndices(self, basis): for sgn in sgns: for k in range(k_max[ell]): - assert indices["ells"][i] == ell - assert indices["sgns"][i] == sgn - assert indices["ks"][i] == k + assert basis.angular_indices[i] == ell + assert basis.signs_indices[i] == sgn + assert basis.radial_indices[i] == k i += 1 @@ -96,12 +95,12 @@ def testIsotropic(self, basis): sigma = L / 8 im = gaussian_2d(L, sigma=sigma, dtype=basis.dtype) - coef = basis.expand(im) + coef_np = basis.expand(im).asnumpy() - ells = basis.indices()["ells"] + ells = basis.angular_indices - energy_outside = np.sum(np.abs(coef[ells != 0]) ** 2) - energy_total = np.sum(np.abs(coef) ** 2) + energy_outside = np.sum(np.abs(coef_np[..., ells != 0]) ** 2) + energy_total = np.sum(np.abs(coef_np) ** 2) energy_ratio = energy_outside / energy_total @@ -122,12 +121,12 @@ def testModulated(self, basis): for trig_fun in (np.sin, np.cos): im1 = im * trig_fun(ell * g2d["phi"]) - coef = basis.expand(im1) + coef_np = basis.expand(im1).asnumpy() - ells = basis.indices()["ells"] + ells = basis.angular_indices - energy_outside = np.sum(np.abs(coef[ells != ell]) ** 2) - energy_total = np.sum(np.abs(coef) ** 2) + energy_outside = np.sum(np.abs(coef_np[..., ells != ell]) ** 2) + energy_total = np.sum(np.abs(coef_np) ** 2) energy_ratio = energy_outside / energy_total @@ -135,19 +134,21 @@ def testModulated(self, basis): def testEvaluateExpand(self, basis): coef1 = randn(basis.count, seed=self.seed) - coef1 = coef1.astype(basis.dtype) + coef1 = Coef(basis, coef1.astype(basis.dtype)) im = basis.evaluate(coef1) if isinstance(im, Image): im = im.asnumpy() - coef2 = basis.expand(im)[0] + coef2 = basis.expand(im) - assert coef1.shape == coef2.shape + assert ( + coef1.shape == coef2.shape + ), f"shape mismatch {coef1.shape} != {coef2.shape}" assert np.allclose(coef1, coef2, atol=utest_tolerance(basis.dtype)) def testAdjoint(self, basis): u = randn(basis.count, seed=self.seed) - u = u.astype(basis.dtype) + u = Coef(basis, u, dtype=basis.dtype) Au = basis.evaluate(u) if isinstance(Au, Image): @@ -180,7 +181,8 @@ def testEvaluate(self, basis): # evaluate should take a NumPy array of type basis.coefficient_dtype # and return an Image/Volume _class = self.getClass(basis) - result = basis.evaluate(np.zeros((basis.count), dtype=basis.coefficient_dtype)) + coef = Coef(basis, np.zeros((basis.count)), dtype=basis.coefficient_dtype) + result = basis.evaluate(coef) assert isinstance(result, _class) def testEvaluate_t(self, basis): @@ -190,7 +192,7 @@ def testEvaluate_t(self, basis): result = basis.evaluate_t( _class(np.zeros((basis.nres,) * basis.ndim, dtype=basis.dtype)) ) - assert isinstance(result, np.ndarray) + assert isinstance(result, Coef) assert result.dtype == basis.coefficient_dtype def testExpand(self, basis): @@ -200,7 +202,7 @@ def testExpand(self, basis): result = basis.expand( _class(np.zeros((basis.nres,) * basis.ndim, dtype=basis.dtype)) ) - assert isinstance(result, np.ndarray) + assert isinstance(result, Coef) assert result.dtype == basis.coefficient_dtype def testInitWithIntSize(self, basis): diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_cwf_coeff.npy b/tests/saved_test_data/clean70SRibosome_cov2d_cwf_coef.npy similarity index 100% rename from tests/saved_test_data/clean70SRibosome_cov2d_cwf_coeff.npy rename to tests/saved_test_data/clean70SRibosome_cov2d_cwf_coef.npy diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_cwf_coeff_clean.npy b/tests/saved_test_data/clean70SRibosome_cov2d_cwf_coef_clean.npy similarity index 100% rename from tests/saved_test_data/clean70SRibosome_cov2d_cwf_coeff_clean.npy rename to tests/saved_test_data/clean70SRibosome_cov2d_cwf_coef_clean.npy diff --git a/tests/saved_test_data/clean70SRibosome_cov2d_cwf_coeff_noCTF.npy b/tests/saved_test_data/clean70SRibosome_cov2d_cwf_coef_noCTF.npy similarity index 100% rename from tests/saved_test_data/clean70SRibosome_cov2d_cwf_coeff_noCTF.npy rename to tests/saved_test_data/clean70SRibosome_cov2d_cwf_coef_noCTF.npy diff --git a/tests/saved_test_data/fbbasis_coefficients_8_8.npy b/tests/saved_test_data/fbbasis_coeficients_8_8.npy similarity index 100% rename from tests/saved_test_data/fbbasis_coefficients_8_8.npy rename to tests/saved_test_data/fbbasis_coeficients_8_8.npy diff --git a/tests/saved_test_data/ffbbasis2d_vcoeff_out_8_8.npy b/tests/saved_test_data/ffbbasis2d_vcoef_out_8_8.npy similarity index 100% rename from tests/saved_test_data/ffbbasis2d_vcoeff_out_8_8.npy rename to tests/saved_test_data/ffbbasis2d_vcoef_out_8_8.npy diff --git a/tests/saved_test_data/ffbbasis2d_vcoeff_out_exp_8_8.npy b/tests/saved_test_data/ffbbasis2d_vcoef_out_exp_8_8.npy similarity index 100% rename from tests/saved_test_data/ffbbasis2d_vcoeff_out_exp_8_8.npy rename to tests/saved_test_data/ffbbasis2d_vcoef_out_exp_8_8.npy diff --git a/tests/saved_test_data/ffbbasis2d_xcoeff_in_8_8.npy b/tests/saved_test_data/ffbbasis2d_xcoef_in_8_8.npy similarity index 100% rename from tests/saved_test_data/ffbbasis2d_xcoeff_in_8_8.npy rename to tests/saved_test_data/ffbbasis2d_xcoef_in_8_8.npy diff --git a/tests/saved_test_data/ffbbasis2d_xcoeff_out_8_8.npy b/tests/saved_test_data/ffbbasis2d_xcoef_out_8_8.npy similarity index 100% rename from tests/saved_test_data/ffbbasis2d_xcoeff_out_8_8.npy rename to tests/saved_test_data/ffbbasis2d_xcoef_out_8_8.npy diff --git a/tests/saved_test_data/ffbbasis3d_vcoeff_out_8_8_8.npy b/tests/saved_test_data/ffbbasis3d_vcoef_out_8_8_8.npy similarity index 100% rename from tests/saved_test_data/ffbbasis3d_vcoeff_out_8_8_8.npy rename to tests/saved_test_data/ffbbasis3d_vcoef_out_8_8_8.npy diff --git a/tests/saved_test_data/ffbbasis3d_vcoeff_out_exp_8_8_8.npy b/tests/saved_test_data/ffbbasis3d_vcoef_out_exp_8_8_8.npy similarity index 100% rename from tests/saved_test_data/ffbbasis3d_vcoeff_out_exp_8_8_8.npy rename to tests/saved_test_data/ffbbasis3d_vcoef_out_exp_8_8_8.npy diff --git a/tests/saved_test_data/ffbbasis3d_xcoeff_in_8_8_8.npy b/tests/saved_test_data/ffbbasis3d_xcoef_in_8_8_8.npy similarity index 100% rename from tests/saved_test_data/ffbbasis3d_xcoeff_in_8_8_8.npy rename to tests/saved_test_data/ffbbasis3d_xcoef_in_8_8_8.npy diff --git a/tests/saved_test_data/ffbbasis3d_xcoeff_out_8_8_8.npy b/tests/saved_test_data/ffbbasis3d_xcoef_out_8_8_8.npy similarity index 100% rename from tests/saved_test_data/ffbbasis3d_xcoeff_out_8_8_8.npy rename to tests/saved_test_data/ffbbasis3d_xcoef_out_8_8_8.npy diff --git a/tests/saved_test_data/pfbasis_coefficients_8_4_32.npy b/tests/saved_test_data/pfbasis_coeficients_8_4_32.npy similarity index 100% rename from tests/saved_test_data/pfbasis_coefficients_8_4_32.npy rename to tests/saved_test_data/pfbasis_coeficients_8_4_32.npy diff --git a/tests/saved_test_data/pswf2d_vcoeffs_out_8_8.npy b/tests/saved_test_data/pswf2d_vcoefs_out_8_8.npy similarity index 100% rename from tests/saved_test_data/pswf2d_vcoeffs_out_8_8.npy rename to tests/saved_test_data/pswf2d_vcoefs_out_8_8.npy diff --git a/tests/saved_test_data/pswf2d_xcoeff_out_8_8.npy b/tests/saved_test_data/pswf2d_xcoef_out_8_8.npy similarity index 100% rename from tests/saved_test_data/pswf2d_xcoeff_out_8_8.npy rename to tests/saved_test_data/pswf2d_xcoef_out_8_8.npy diff --git a/tests/saved_test_data/rln_proj_64.mrcs b/tests/saved_test_data/rln_proj_64.mrcs new file mode 100644 index 0000000000..bdc2f9f8b2 Binary files /dev/null and b/tests/saved_test_data/rln_proj_64.mrcs differ diff --git a/tests/saved_test_data/rln_proj_64.star b/tests/saved_test_data/rln_proj_64.star new file mode 100644 index 0000000000..b7d88e788a --- /dev/null +++ b/tests/saved_test_data/rln_proj_64.star @@ -0,0 +1,18 @@ + +# version 30001 + +data_particles + +loop_ +_rlnAngleRot #1 +_rlnAngleTilt #2 +_rlnAnglePsi #3 +_rlnOriginXAngst #4 +_rlnOriginYAngst #5 +_rlnOpticsGroup #6 +_rlnImageName #7 + 275.167758 67.327105 47.970504 0.000000 0.000000 1 000001@rln_proj_64.mrcs + 19.551395 77.303721 168.404682 0.000000 0.000000 1 000002@rln_proj_64.mrcs + 76.607773 91.268867 326.522598 0.000000 0.000000 1 000003@rln_proj_64.mrcs + 116.208186 115.859671 140.116990 0.000000 0.000000 1 000004@rln_proj_64.mrcs + 250.742147 120.229633 124.614004 0.000000 0.000000 1 000005@rln_proj_64.mrcs diff --git a/tests/saved_test_data/rln_proj_65.mrcs b/tests/saved_test_data/rln_proj_65.mrcs new file mode 100644 index 0000000000..58ae2769e4 Binary files /dev/null and b/tests/saved_test_data/rln_proj_65.mrcs differ diff --git a/tests/saved_test_data/rln_proj_65.star b/tests/saved_test_data/rln_proj_65.star new file mode 100644 index 0000000000..d70dc8b683 --- /dev/null +++ b/tests/saved_test_data/rln_proj_65.star @@ -0,0 +1,18 @@ + +# version 30001 + +data_particles + +loop_ +_rlnAngleRot #1 +_rlnAngleTilt #2 +_rlnAnglePsi #3 +_rlnOriginXAngst #4 +_rlnOriginYAngst #5 +_rlnOpticsGroup #6 +_rlnImageName #7 + 355.858841 72.200518 240.688133 0.000000 0.000000 1 000001@rln_proj_65.mrcs + 63.759542 92.697219 303.796113 0.000000 0.000000 1 000002@rln_proj_65.mrcs + 178.751045 42.736465 137.511331 0.000000 0.000000 1 000003@rln_proj_65.mrcs + 248.292861 83.402259 197.749400 0.000000 0.000000 1 000004@rln_proj_65.mrcs + 99.680307 72.793409 270.791295 0.000000 0.000000 1 000005@rln_proj_65.mrcs diff --git a/tests/saved_test_data/sample_docstrings.py b/tests/saved_test_data/sample_docstrings.py new file mode 100644 index 0000000000..dc81b54f43 --- /dev/null +++ b/tests/saved_test_data/sample_docstrings.py @@ -0,0 +1,70 @@ +def good_fun1(frog, dog): + """ + This docstring is properly formatted. + + It has a multi-line, multi-section body + followed by exactly one blank line. + + :param frog: This param description is + multiline. + :param dog: Single line description + :return: A frog on a dog + """ + + +def good_fun2(): + """ + This function has only a return. + + :return: Just a return. + """ + + +def good_fun3(): + def nested_fun(bip): + """ + This is a properly formatted docstring + in a nested function. + + :param bip: A small bip + :return: A large bop + """ + + +def good_fun4(bing, bong): + """ + :param bing: This docstring has no body. + :param bong: Should not error. + :return: Boom. + """ + + +def bad_fun1(cat, hat): + """ + This docstring is missing a blank line + between the body and parameter sections. + :param cat: A cat. + :param hat: A hat. + :return: A cat in a hat. + """ + + +def bad_fun2(foo): + """ + This docstring has too many blank lines between + the body and parameter sections. + + + :param foo: foo description. + :return: bar + """ + + +def bad_fun3(): + def nested_fun(bip): + """ + This is an improperly formatted docstring + in a nested function. + :param bip: A small bip + :return: A large bop + """ diff --git a/tests/test_BlkDiagMatrix.py b/tests/test_BlkDiagMatrix.py index 9f606b9731..1b92b3b13f 100644 --- a/tests/test_BlkDiagMatrix.py +++ b/tests/test_BlkDiagMatrix.py @@ -125,22 +125,22 @@ def testBlkDiagMatrixSub(self): def testBlkDiagMatrixApply(self): m = np.sum(self.blk_a.partition[:, 1]) k = 3 - coeffm = np.arange(k * m).reshape(m, k).astype(self.blk_a.dtype) + coefm = np.arange(k * m).reshape(m, k).astype(self.blk_a.dtype) # Manually compute ind = 0 - res = np.empty_like(coeffm) + res = np.empty_like(coefm) for b, blk in enumerate(self.blk_a): col = self.blk_a.partition[b, 1] - res[ind : ind + col, :] = blk @ coeffm[ind : ind + col, :] + res[ind : ind + col, :] = blk @ coefm[ind : ind + col, :] ind += col # Check ndim 1 case - c = self.blk_a.apply(coeffm[:, 0]) + c = self.blk_a.apply(coefm[:, 0]) self.allallfunc(c, res[:, 0]) # Check ndim 2 case - d = self.blk_a.apply(coeffm) + d = self.blk_a.apply(coefm) self.allallfunc(res, d) # Here we are checking that the ndim 2 case distributes as described. @@ -148,33 +148,33 @@ def testBlkDiagMatrixApply(self): # should be equivalent to e = [A.apply(r0), ... A.apply(ri)]. e = np.empty((m, k)) for i in range(k): - e[:, i] = self.blk_a.apply(coeffm[:, i]) + e[:, i] = self.blk_a.apply(coefm[:, i]) self.allallfunc(e, d) # We can use syntactic sugar @ for apply as well - f = self.blk_a @ coeffm + f = self.blk_a @ coefm self.allallfunc(f, d) # Test the rapply is also functional - coeffm = coeffm.T # matmul dimensions - res = coeffm @ self.blk_a.dense() - d = self.blk_a.rapply(coeffm) + coefm = coefm.T # matmul dimensions + res = coefm @ self.blk_a.dense() + d = self.blk_a.rapply(coefm) self.allallfunc(res, d) # And the syntactic sugar @ - d = coeffm @ self.blk_a + d = coefm @ self.blk_a self.allallfunc(res, d) # And test some incorrrect invocations: # inplace not supported for matmul of mixed classes. with pytest.raises(RuntimeError, match=r".*method not supported.*"): - self.blk_a @= coeffm + self.blk_a @= coefm # Test left operand of an __rmatmul__ must be an ndarray with pytest.raises( RuntimeError, match=r".*only defined for np.ndarray @ BlkDiagMatrix.*" ): - _ = list(coeffm) @ self.blk_a + _ = list(coefm) @ self.blk_a def testBlkDiagMatrixMatMult(self): result = [np.matmul(*tup) for tup in zip(self.blk_a, self.blk_b)] @@ -285,18 +285,18 @@ def testBlkDiagMatrixSolve(self): m = np.sum(self.blk_a.partition[:, 1]) k = 3 - coeffm = np.arange(k * m).reshape(m, k).astype(self.blk_a.dtype) + coefm = np.arange(k * m).reshape(m, k).astype(self.blk_a.dtype) # Manually compute ind = 0 - res = np.empty_like(coeffm) + res = np.empty_like(coefm) for b, blk in enumerate(B): col = self.blk_a.partition[b, 1] - res[ind : ind + col, :] = solve(blk, coeffm[ind : ind + col, :]) + res[ind : ind + col, :] = solve(blk, coefm[ind : ind + col, :]) ind += col - coeff_est = B.solve(coeffm) - self.allallfunc(res, coeff_est) + coef_est = B.solve(coefm) + self.allallfunc(res, coef_est) def testBlkDiagMatrixTranspose(self): blk_c = [blk.T for blk in self.blk_a] @@ -378,6 +378,35 @@ def test_blk_diag_to_diag(self): """ self.assertTrue(np.allclose(np.diag(self.blk_a.dense()), self.blk_a.diag())) + def test_from_dense(self): + """ + Test truncating dense array returns correct block diagonal entries. + """ + B = BlkDiagMatrix.from_dense(self.dense, self.blk_partition) + + self.allallfunc(B, self.blk_a) + + def test_from_dense_warns(self): + """ + Test that a warning is emitted when values outside the blocks + are larger than some `eps`. + """ + # Add ones to the entire dense matrix, to exceed `warn_eps` below. + dense = self.dense + 1 + + with pytest.warns(UserWarning, match=r".*truncating values.*"): + _ = BlkDiagMatrix.from_dense(dense, self.blk_partition, warn_eps=1e-6) + + def test_from_dense_incorrect_shape(self): + """ + Test truncating dense array returns raises warning on incorrect shape. + """ + # Pad the dense array so there will be a leftover row and column. + dense = np.pad(self.dense, (0, 1)) + + with pytest.raises(RuntimeError, match=r".*mismatch shape.*"): + _ = BlkDiagMatrix.from_dense(dense, self.blk_partition) + class IrrBlkDiagMatrixTestCase(TestCase): """ @@ -446,42 +475,42 @@ def testApply(self): n = np.sum(self.blk_x.partition[:, 0]) m = np.sum(self.blk_x.partition[:, 1]) k = 3 - coeffm = np.arange(k * m).reshape(m, k).astype(self.blk_x.dtype) + coefm = np.arange(k * m).reshape(m, k).astype(self.blk_x.dtype) # Manually compute indc = 0 indr = 0 - res = np.empty(shape=(n, k), dtype=coeffm.dtype) + res = np.empty(shape=(n, k), dtype=coefm.dtype) for b, blk in enumerate(self.blk_x): row, col = self.blk_x.partition[b] - res[indr : indr + row, :] = blk @ coeffm[indc : indc + col, :] + res[indr : indr + row, :] = blk @ coefm[indc : indc + col, :] indc += col indr += row # Check ndim 1 case - c = self.blk_x.apply(coeffm[:, 0]) + c = self.blk_x.apply(coefm[:, 0]) self.allallfunc(c, res[:, 0]) # Check ndim 2 case - d = self.blk_x.apply(coeffm) + d = self.blk_x.apply(coefm) self.allallfunc(res, d) # Check against dense numpy matmul - self.allallfunc(d, self.blk_x.dense() @ coeffm) + self.allallfunc(d, self.blk_x.dense() @ coefm) def testSolve(self): """ Test attempts to solve non square BlkDiagMatrix raise error. """ - # Setup a dummy coeff matrix + # Setup a dummy coef matrix n = np.sum(self.blk_x.partition[:, 0]) k = 3 - coeffm = np.arange(n * k).reshape(n, k).astype(self.blk_x.dtype) + coefm = np.arange(n * k).reshape(n, k).astype(self.blk_x.dtype) with pytest.raises( NotImplementedError, match=r"BlkDiagMatrix.solve is only defined for square arrays.*", ): # Attemplt solve using the Block Diagonal implementation - _ = self.blk_x.solve(coeffm) + _ = self.blk_x.solve(coefm) diff --git a/tests/test_FBbasis2D.py b/tests/test_FBbasis2D.py index c5e9c555ce..2682747ba1 100644 --- a/tests/test_FBbasis2D.py +++ b/tests/test_FBbasis2D.py @@ -5,7 +5,7 @@ from pytest import raises from scipy.special import jv -from aspire.basis import FBBasis2D +from aspire.basis import Coef, ComplexCoef, FBBasis2D from aspire.image import Image from aspire.source import Simulation from aspire.utils import complex_type, real_type @@ -33,10 +33,9 @@ def _testElement(self, basis, ell, k, sgn): # This is covered by the isotropic test. assert ell > 0 - indices = basis.indices() - ells = indices["ells"] - sgns = indices["sgns"] - ks = indices["ks"] + ells = basis.angular_indices + sgns = basis.signs_indices + ks = basis.radial_indices g2d = grid_2d(basis.nres, dtype=basis.dtype) mask = g2d["r"] < 1 @@ -56,7 +55,7 @@ def _testElement(self, basis, ell, k, sgn): coef_ref = np.zeros(basis.count, dtype=basis.dtype) coef_ref[(ells == ell) & (sgns == sgn) & (ks == k)] = 1 - im_ref = basis.evaluate(coef_ref) + im_ref = basis.evaluate(Coef(basis, coef_ref)) coef = basis.expand(im) @@ -86,16 +85,35 @@ def testComplexCoversion(self, basis): # The round trip should be equivalent up to machine precision assert np.allclose(v1, v2) + # Convert real FB coef to complex coef using Coef class + cv = v1.to_complex() + # then convert back to real coef representation. + v2 = cv.to_real() + + # The round trip should be equivalent up to machine precision + assert np.allclose(v1, v2) + def testComplexCoversionErrorsToComplex(self, basis): - x = randn(*basis.sz, seed=self.seed) + x = randn(*basis.sz, seed=self.seed).astype(basis.dtype) - # Express in an FB basis - v1 = basis.expand(x.astype(basis.dtype)) + # Express in an FB basis, cast to array. + v1 = basis.expand(x).asnumpy() + + # Test catching Errors + with raises(TypeError): + # Pass complex into `to_complex` + v1_cpx = Coef(basis, v1, dtype=np.complex64) + _ = basis.to_complex(v1_cpx) # Test catching Errors with raises(TypeError): # Pass complex into `to_complex` - _ = basis.to_complex(v1.astype(np.complex64)) + v1_cpx = Coef(basis, v1, dtype=np.complex64) + + with raises(TypeError): + # Pass complex into `to_complex` + v1_cpx = Coef(basis, v1).to_complex() + _ = v1_cpx.to_complex() # Test casting case, where basis and coef don't match if basis.dtype == np.float32: @@ -105,22 +123,25 @@ def testComplexCoversionErrorsToComplex(self, basis): # Result should be same precision as coef input, just complex. result_dtype = complex_type(test_dtype) - v3 = basis.to_complex(v1.astype(test_dtype)) + v3 = basis.to_complex(Coef(basis, v1, dtype=test_dtype)) assert v3.dtype == result_dtype - # Try 0d vector, should not crash. - _ = basis.to_complex(v1.reshape(-1)) - def testComplexCoversionErrorsToReal(self, basis): x = randn(*basis.sz, seed=self.seed) # Express in an FB basis - cv1 = basis.to_complex(basis.expand(x.astype(basis.dtype))) + cv = basis.expand(x.astype(basis.dtype)) + ccv = cv.to_complex() # Test catching Errors with raises(TypeError): # Pass real into `to_real` - _ = basis.to_real(cv1.real.astype(np.float32)) + _ = basis.to_real(cv) + + # Test catching Errors + with raises(TypeError): + # Pass real into `to_real` + _ = cv.to_real() # Test casting case, where basis and coef precision don't match if basis.dtype == np.float32: @@ -130,12 +151,10 @@ def testComplexCoversionErrorsToReal(self, basis): # Result should be same precision as coef input, just real. result_dtype = real_type(test_dtype) - v3 = basis.to_real(cv1.astype(test_dtype)) + x = ComplexCoef(basis, ccv.asnumpy().astype(test_dtype)) + v3 = x.to_real() assert v3.dtype == result_dtype - # Try a 0d vector, should not crash. - _ = basis.to_real(cv1.reshape(-1)) - params = [pytest.param(256, np.float32, marks=pytest.mark.expensive)] @@ -158,8 +177,8 @@ def testHighResFBBasis2D(L, dtype): im = sim.images[0] # Round trip - coeff = basis.expand(im) - im_fb = basis.evaluate(coeff) + coef = basis.expand(im) + im_fb = basis.evaluate(coef) # Mask to compare inside disk of radius 1. mask = grid_2d(L, normalized=True)["r"] < 1 diff --git a/tests/test_FBbasis3D.py b/tests/test_FBbasis3D.py index 02ba88109a..7f9a581fbd 100644 --- a/tests/test_FBbasis3D.py +++ b/tests/test_FBbasis3D.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from aspire.basis import FBBasis3D +from aspire.basis import Coef, FBBasis3D from aspire.utils import grid_3d, utest_tolerance from aspire.volume import AsymmetricVolume, Volume @@ -355,7 +355,7 @@ def testFBBasis3DNorms(self, basis): ) def testFBBasis3DEvaluate(self, basis): - coeffs = np.array( + coefs = np.array( [ 1.07338590e-01, 1.23690941e-01, @@ -458,7 +458,7 @@ def testFBBasis3DEvaluate(self, basis): ], dtype=basis.dtype, ) - result = basis.evaluate(coeffs) + result = Coef(basis, coefs).evaluate() assert np.allclose( result.asnumpy(), @@ -687,7 +687,7 @@ def testFBBasis3DExpand(self, basis): ) -# NOTE: This test is failing for L=64. `coeff_0` has a few NANs which propogate into `vol_1`. See GH issue #923 +# NOTE: This test is failing for L=64. `coef_0` has a few NANs which propogate into `vol_1`. See GH issue #923 params = [pytest.param(64, np.float32, marks=pytest.mark.expensive)] @@ -702,8 +702,8 @@ def testHighResFBbasis3D(L, dtype): vol = AsymmetricVolume(L=L, C=1, K=64, dtype=dtype, seed=seed).generate() # Round trip - coeff = basis.expand(vol) - vol_fb = basis.evaluate(coeff) + coef = basis.expand(vol) + vol_fb = basis.evaluate(coef) # Mask to compare inside sphere of radius 1. mask = grid_3d(L, normalized=True)["r"] < 1 diff --git a/tests/test_FFBbasis2D.py b/tests/test_FFBbasis2D.py index 1796bbf0e7..8acf7201d1 100644 --- a/tests/test_FFBbasis2D.py +++ b/tests/test_FFBbasis2D.py @@ -5,8 +5,7 @@ import pytest from scipy.special import jv -from aspire.basis import FFBBasis2D -from aspire.image import Image +from aspire.basis import Coef, FFBBasis2D from aspire.source import Simulation from aspire.utils.misc import grid_2d from aspire.volume import Volume @@ -31,10 +30,9 @@ class TestFFBBasis2D(Steerable2DMixin, UniversalBasisMixin): seed = 9161341 def _testElement(self, basis, ell, k, sgn): - indices = basis.indices() - ells = indices["ells"] - sgns = indices["sgns"] - ks = indices["ks"] + ells = basis.angular_indices + sgns = basis.signs_indices + ks = basis.radial_indices g2d = grid_2d(basis.nres, dtype=basis.dtype) mask = g2d["r"] < 1 @@ -64,7 +62,7 @@ def _testElement(self, basis, ell, k, sgn): coef_ref = np.zeros(basis.count, dtype=basis.dtype) coef_ref[(ells == ell) & (sgns == sgn) & (ks == k)] = 1 - im_ref = basis.evaluate(coef_ref).asnumpy()[0] + im_ref = Coef(basis, coef_ref).evaluate().asnumpy()[0] coef = basis.expand(im) @@ -82,96 +80,6 @@ def testElements(self, basis): for ell, k, sgn in zip(ells, ks, sgns): self._testElement(basis, ell, k, sgn) - def testRotate(self, basis): - # Now low res (8x8) had problems; - # better with odd (7x7), but still not good. - # We'll use a higher res test image. - # fh = np.load(os.path.join(DATA_DIR, 'ffbbasis2d_xcoeff_in_8_8.npy'))[:7,:7] - # Use a real data volume to generate a clean test image. - v = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype( - basis.dtype - ) - ) - src = Simulation(L=v.resolution, n=1, vols=v, dtype=v.dtype) - # Extract, this is the original image to transform. - x1 = src.images[0] - - # Rotate 90 degrees CCW in cartesian coordinates. - x2 = Image(np.rot90(x1.asnumpy(), axes=(1, 2))) - - # Express in an FB basis - basis = FFBBasis2D(x1.resolution, dtype=x1.dtype) - v1 = basis.evaluate_t(x1) - v2 = basis.evaluate_t(x2) - - # Reflect in the FB basis space - v4 = basis.rotate(v1, 0, refl=[True]) - - # Rotate in the FB basis space - v3 = basis.rotate(v1, 2 * np.pi) - v1 = basis.rotate(v1, np.pi / 2) - - # Evaluate back into cartesian - y1 = basis.evaluate(v1).asnumpy() - y2 = basis.evaluate(v2).asnumpy() - y3 = basis.evaluate(v3).asnumpy() - y4 = basis.evaluate(v4).asnumpy() - - # Rotate 90 - assert np.allclose(y1[0], y2[0], atol=1e-5) - - # 2*pi Identity - assert np.allclose(x1[0], y3[0], atol=1e-5) - - # Refl (flipped using flipud) - assert np.allclose(np.flipud(x1.asnumpy()[0]), y4[0], atol=1e-5) - - def testRotateComplex(self, basis): - # Now low res (8x8) had problems; - # better with odd (7x7), but still not good. - # We'll use a higher res test image. - # fh = np.load(os.path.join(DATA_DIR, 'ffbbasis2d_xcoeff_in_8_8.npy'))[:7,:7] - # Use a real data volume to generate a clean test image. - v = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype( - basis.dtype - ) - ) - src = Simulation(L=v.resolution, n=1, vols=v, dtype=v.dtype) - # Extract, this is the original image to transform. - x1 = src.images[0] - - # Rotate 90 degrees CCW in cartesian coordinates. - x2 = Image(np.rot90(x1.asnumpy(), axes=(1, 2))) - - # Express in an FB basis - basis = FFBBasis2D(x1.resolution, dtype=x1.dtype) - v1 = basis.evaluate_t(x1) - v2 = basis.evaluate_t(x2) - - # Reflect in the FB basis space - v4 = basis.to_real(basis.complex_rotate(basis.to_complex(v1), 0, refl=[True])) - - # Complex Rotate in the FB basis space - v3 = basis.to_real(basis.complex_rotate(basis.to_complex(v1), 2 * np.pi)) - v1 = basis.to_real(basis.complex_rotate(basis.to_complex(v1), np.pi / 2)) - - # Evaluate back into cartesian - y1 = basis.evaluate(v1).asnumpy() - y2 = basis.evaluate(v2).asnumpy() - y3 = basis.evaluate(v3).asnumpy() - y4 = basis.evaluate(v4).asnumpy() - - # Rotate 90 - assert np.allclose(y1[0], y2[0], atol=1e-5) - - # 2*pi Identity - assert np.allclose(x1[0].asnumpy(), y3[0], atol=1e-5) - - # Refl (flipped using flipud) - assert np.allclose(np.flipud(x1.asnumpy()[0]), y4[0], atol=1e-5) - def testShift(self, basis): """ Compare shifting using Image with shifting provided by the Basis. @@ -236,8 +144,8 @@ def testHighResFFBBasis2D(L, dtype): im = sim.images[0] # Round trip - coeff = basis.evaluate_t(im) - im_ffb = basis.evaluate(coeff) + coef = basis.evaluate_t(im) + im_ffb = basis.evaluate(coef) # Mask to compare inside disk of radius 1. mask = grid_2d(L, normalized=True)["r"] < 1 diff --git a/tests/test_FFBbasis3D.py b/tests/test_FFBbasis3D.py index c8879ab2a8..03bdfdd21b 100644 --- a/tests/test_FFBbasis3D.py +++ b/tests/test_FFBbasis3D.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from aspire.basis import FFBBasis3D +from aspire.basis import Coef, FFBBasis3D from aspire.utils import grid_3d from aspire.volume import AsymmetricVolume, Volume @@ -356,7 +356,7 @@ def testFFBBasis3DNorms(self, basis): ) def testFFBBasis3DEvaluate(self, basis): - coeffs = np.array( + coefs = np.array( [ 1.07338590e-01, 1.23690941e-01, @@ -460,29 +460,29 @@ def testFFBBasis3DEvaluate(self, basis): dtype=basis.dtype, ) - result = basis.evaluate(coeffs) + result = Coef(basis, coefs).evaluate() ref = np.load( - os.path.join(DATA_DIR, "ffbbasis3d_xcoeff_out_8_8_8.npy") + os.path.join(DATA_DIR, "ffbbasis3d_xcoef_out_8_8_8.npy") ).T # RCOPT assert np.allclose(result, ref, atol=1e-2) def testFFBBasis3DEvaluate_t(self, basis): - x = np.load(os.path.join(DATA_DIR, "ffbbasis3d_xcoeff_in_8_8_8.npy")).T # RCOPT + x = np.load(os.path.join(DATA_DIR, "ffbbasis3d_xcoef_in_8_8_8.npy")).T # RCOPT x = x.astype(basis.dtype, copy=False) result = basis.evaluate_t(Volume(x)) - ref = np.load(os.path.join(DATA_DIR, "ffbbasis3d_vcoeff_out_8_8_8.npy"))[..., 0] + ref = np.load(os.path.join(DATA_DIR, "ffbbasis3d_vcoef_out_8_8_8.npy"))[..., 0] assert np.allclose(result, ref, atol=1e-2) def testFFBBasis3DExpand(self, basis): - x = np.load(os.path.join(DATA_DIR, "ffbbasis3d_xcoeff_in_8_8_8.npy")).T # RCOPT + x = np.load(os.path.join(DATA_DIR, "ffbbasis3d_xcoef_in_8_8_8.npy")).T # RCOPT x = x.astype(basis.dtype, copy=False) result = basis.expand(x) - ref = np.load(os.path.join(DATA_DIR, "ffbbasis3d_vcoeff_out_exp_8_8_8.npy"))[ + ref = np.load(os.path.join(DATA_DIR, "ffbbasis3d_vcoef_out_exp_8_8_8.npy"))[ ..., 0 ] @@ -502,8 +502,8 @@ def testHighResFFBbasis3D(L, dtype): vol = AsymmetricVolume(L=L, C=1, K=64, dtype=dtype, seed=seed).generate() # Round trip - coeff = basis.evaluate_t(vol) - vol_ffb = basis.evaluate(coeff) + coef = basis.evaluate_t(vol) + vol_ffb = basis.evaluate(coef) # Mask to compare inside sphere of radius 1. mask = grid_3d(L, normalized=True)["r"] < 1 diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py index 6af33580cb..7d6b3f5c47 100644 --- a/tests/test_FLEbasis2D.py +++ b/tests/test_FLEbasis2D.py @@ -4,12 +4,11 @@ import numpy as np import pytest -from aspire.basis import FBBasis2D, FFBBasis2D, FLEBasis2D +from aspire.basis import Coef, FBBasis2D, FLEBasis2D from aspire.image import Image from aspire.nufft import backend_available from aspire.numeric import fft from aspire.source import Simulation -from aspire.utils import utest_tolerance from aspire.volume import Volume from ._basis_util import UniversalBasisMixin @@ -55,7 +54,9 @@ def create_images(L, n): np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype(np.float64) ) v = v.downsample(L) - sim = Simulation(L=L, n=n, vols=v, dtype=v.dtype, seed=1103) + sim = Simulation( + L=L, n=n, vols=v, dtype=v.dtype, offsets=0, amplitudes=1, seed=1103 + ) img = sim.clean_images[:] return img @@ -99,12 +100,12 @@ def testFastVDense(self, basis): # get sample coefficients x = create_images(basis.nres, 1) # hold input test data constant (would depend on epsilon parameter) - coeffs = FLEBasis2D( + coefs = FLEBasis2D( basis.nres, epsilon=1e-4, dtype=np.float64, match_fb=False ).evaluate_t(x) - result_dense = dense_b @ coeffs.T - result_fast = basis.evaluate(coeffs).asnumpy() + result_dense = dense_b @ coefs.asnumpy().T + result_fast = basis.evaluate(coefs).asnumpy() assert relerr(result_dense, result_fast) < (self.test_eps * basis.epsilon) @@ -136,10 +137,10 @@ def testMatchFBEvaluate(basis): fb_basis = FBBasis2D(basis.nres, dtype=np.float64) # in match_fb, count is the same for both bases - coeffs = np.eye(basis.count) + coefs = Coef(basis, np.eye(basis.count)) - fb_images = fb_basis.evaluate(coeffs) - fle_images = basis.evaluate(coeffs) + fb_images = fb_basis.evaluate(coefs) + fle_images = basis.evaluate(coefs) assert np.allclose(fb_images._data, fle_images._data, atol=1e-4) @@ -151,10 +152,10 @@ def testMatchFBDenseEvaluate(basis): fb_basis = FBBasis2D(basis.nres, dtype=np.float64) - coeffs = np.eye(basis.count) + coefs = Coef(basis, np.eye(basis.count)) - fb_images = fb_basis.evaluate(coeffs).asnumpy() - fle_out = basis._create_dense_matrix() @ coeffs + fb_images = fb_basis.evaluate(coefs).asnumpy() + fle_out = basis._create_dense_matrix() @ coefs fle_images = Image(fle_out.T.reshape(-1, basis.nres, basis.nres)).asnumpy() # Matrix column reording in match_fb mode flips signs of some of the basis functions @@ -171,12 +172,12 @@ def testMatchFBEvaluate_t(basis): fb_basis = FBBasis2D(basis.nres, dtype=np.float64) # test images to evaluate - images = fb_basis.evaluate(np.eye(basis.count)) + images = fb_basis.evaluate(Coef(basis, np.eye(basis.count))) - fb_coeffs = fb_basis.evaluate_t(images) - fle_coeffs = basis.evaluate_t(images) + fb_coefs = fb_basis.evaluate_t(images) + fle_coefs = basis.evaluate_t(images) - assert np.allclose(fb_coeffs, fle_coeffs, atol=1e-4) + assert np.allclose(fb_coefs, fle_coefs, atol=1e-4) @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) @@ -188,15 +189,15 @@ def testMatchFBDenseEvaluate_t(basis): # test images to evaluate # gets a stack of shape (basis.count, L, L) - images = fb_basis.evaluate(np.eye(basis.count)) + images = fb_basis.evaluate(Coef(basis, np.eye(basis.count))) # reshape to a stack of basis.count vectors of length L**2 vec = images.asnumpy().reshape((-1, basis.nres**2)) - fb_coeffs = fb_basis.evaluate_t(images) - fle_coeffs = basis._create_dense_matrix().T @ vec.T + fb_coefs = fb_basis.evaluate_t(images) + fle_coefs = basis._create_dense_matrix().T @ vec.T # Matrix column reording in match_fb mode flips signs of some of the basis coefficients - assert np.allclose(np.abs(fb_coeffs), np.abs(fle_coeffs), atol=1e-4) + assert np.allclose(np.abs(fb_coefs), np.abs(fle_coefs), atol=1e-4) def testLowPass(): @@ -208,120 +209,22 @@ def testLowPass(): # sample coefficients ims = create_images(L, 1) - coeffs = basis.evaluate_t(ims) + coefs = basis.evaluate_t(ims) - nonzero_coeffs = [] + nonzero_coefs = [] for i in range(4): bandlimit = L // (2**i) - coeffs_lowpassed = basis.lowpass(coeffs, bandlimit) - nonzero_coeffs.append(np.sum(coeffs_lowpassed != 0)) + coefs_lowpassed = basis.lowpass(coefs, bandlimit).asnumpy() + nonzero_coefs.append(np.sum(coefs_lowpassed != 0)) # for bandlimit == L, no frequencies should be removed - assert nonzero_coeffs[0] == basis.count - - # for lower bandlimits, there should be fewer and fewer nonzero coeffs - assert nonzero_coeffs[0] > nonzero_coeffs[1] > nonzero_coeffs[2] > nonzero_coeffs[3] - - # make sure you can pass in a 1-D array if you want - _ = basis.lowpass(coeffs[0, :], L) - - # cannot pass in the wrong number of coefficients - with pytest.raises( - AssertionError, match="Number of coefficients must match self.count." - ): - _ = basis.lowpass(coeffs[:, :1000], L) - - # cannot pass in wrong shape - with pytest.raises( - AssertionError, - match="Input a stack of coefficients of dimension", - ): - _ = basis.lowpass(np.zeros((3, 3, 3)), L) - - -def testRotate(): - # test ability to accurately rotate images via - # FLE coefficients - - L = 128 - basis = FLEBasis2D(L, match_fb=False) - - # sample image - ims = create_images(L, 1) - # rotate 90 degrees in cartesian coordinates - ims_90 = Image(np.rot90(ims.asnumpy(), axes=(1, 2))) - - # get FLE coefficients - coeffs = basis.evaluate_t(ims) - coeffs_cart_rot = basis.evaluate_t(ims_90) - - # rotate original image in FLE space using Steerable rotate method - coeffs_fle_rot = basis.rotate(coeffs, np.pi / 2) - - # back to cartesian - ims_cart_rot = basis.evaluate(coeffs_cart_rot) - ims_fle_rot = basis.evaluate(coeffs_fle_rot) + assert nonzero_coefs[0] == basis.count - # test rot90 close - assert np.allclose(ims_cart_rot[0], ims_fle_rot[0], atol=1e-4) - - # 2Pi identity in FLE space (rotate by 2Pi) - coeffs_fle_2pi = basis.rotate(coeffs, 2 * np.pi) - ims_fle_2pi = basis.evaluate(coeffs_fle_2pi) - - # test 2Pi identity - assert np.allclose(ims[0], ims_fle_2pi[0], atol=utest_tolerance(basis.dtype)) - - # Reflect in FLE space (rotate by Pi) - coeffs_fle_pi = basis.rotate(coeffs, np.pi) - ims_fle_pi = basis.evaluate(coeffs_fle_pi) - - # test reflection - assert np.allclose(np.flipud(ims.asnumpy()[0]), ims_fle_pi[0], atol=1e-4) + # for lower bandlimits, there should be fewer and fewer nonzero coefs + assert nonzero_coefs[0] > nonzero_coefs[1] > nonzero_coefs[2] > nonzero_coefs[3] # make sure you can pass in a 1-D array if you want - _ = basis.lowpass(np.zeros((basis.count,)), np.pi) - - # cannot pass in the wrong number of coefficients - with pytest.raises( - AssertionError, match="Number of coefficients must match self.count." - ): - _ = basis.rotate(np.zeros((1, 10)), np.pi) - - # cannot pass in wrong shape - with pytest.raises( - AssertionError, - match="Input a stack of coefficients of dimension", - ): - _ = basis.lowpass(np.zeros((3, 3, 3)), np.pi) - - -def testRotate45(): - # test ability to accurately rotate images via - # FLE coefficients - dtype = np.float64 - - L = 128 - fb_basis = FFBBasis2D(L, dtype=dtype) - basis = FLEBasis2D(L, match_fb=True, dtype=dtype) - - # sample image - ims = create_images(L, 1) - - # get FLE coefficients - fb_coeffs = fb_basis.evaluate_t(ims) - coeffs = basis.evaluate_t(ims) - - # rotate original image in FLE space using Steerable rotate method - fb_coeffs_rot = fb_basis.rotate(fb_coeffs, np.pi / 4) - coeffs_rot = basis.rotate(coeffs, np.pi / 4) - - # back to cartesian - fb_ims_rot = fb_basis.evaluate(fb_coeffs_rot) - ims_rot = basis.evaluate(coeffs_rot) - - # test close - assert np.allclose(ims_rot[0], fb_ims_rot[0], atol=1e-4) + _ = basis.lowpass(coefs[0, :], L) def testRadialConvolution(): @@ -337,13 +240,13 @@ def testRadialConvolution(): # get sample images ims = create_images(L, 10) # convolve using coefficients - coeffs = basis.evaluate_t(ims) - coeffs_convolved = basis.radial_convolve(coeffs, x) - imgs_convolved_fle = basis.evaluate(coeffs_convolved).asnumpy() + coefs = basis.evaluate_t(ims) + coefs_convolved = basis.radial_convolve(coefs, x) + imgs_convolved_fle = basis.evaluate(coefs_convolved).asnumpy() # convolve using FFT x = basis.evaluate(basis.evaluate_t(x)).asnumpy() - ims = basis.evaluate(coeffs).asnumpy() + ims = basis.evaluate(coefs).asnumpy() imgs_convolved_slow = np.zeros((10, L, L)) for i in range(10): diff --git a/tests/test_FPSWFbasis2D.py b/tests/test_FPSWFbasis2D.py index 2ca6e3e213..b53c043a2c 100644 --- a/tests/test_FPSWFbasis2D.py +++ b/tests/test_FPSWFbasis2D.py @@ -3,8 +3,9 @@ import numpy as np import pytest -from aspire.basis import FPSWFBasis2D +from aspire.basis import ComplexCoef, FPSWFBasis2D from aspire.image import Image +from aspire.utils import utest_tolerance from ._basis_util import UniversalBasisMixin, pswf_params_2d, show_basis_params @@ -16,24 +17,28 @@ @pytest.mark.parametrize("basis", test_bases, ids=show_basis_params) class TestFPSWFBasis2D(UniversalBasisMixin): def testFPSWFBasis2DEvaluate_t(self, basis): - img_ary = np.load( - os.path.join(DATA_DIR, "ffbbasis2d_xcoeff_in_8_8.npy") - ).T # RCOPT + img_ary = np.load(os.path.join(DATA_DIR, "ffbbasis2d_xcoef_in_8_8.npy")) images = Image(img_ary) + result = basis.evaluate_t(images) - coeffs = np.load( - os.path.join(DATA_DIR, "pswf2d_vcoeffs_out_8_8.npy") - ).T # RCOPT - # make sure both real and imaginary parts are consistent. - assert np.allclose(np.real(result), np.real(coeffs)) and np.allclose( - np.imag(result) * 1j, np.imag(coeffs) * 1j - ) + # Historically, FPSWF returned complex values. + # Load and convert them for this hard coded test. + ccoefs = np.load(os.path.join(DATA_DIR, "pswf2d_vcoefs_out_8_8.npy")).T # RCOPT + coefs = ComplexCoef(basis, ccoefs).to_real() + + np.testing.assert_allclose(result, coefs, atol=utest_tolerance(basis.dtype)) def testFPSWFBasis2DEvaluate(self, basis): - coeffs = np.load( - os.path.join(DATA_DIR, "pswf2d_vcoeffs_out_8_8.npy") - ).T # RCOPT - result = basis.evaluate(coeffs) - images = np.load(os.path.join(DATA_DIR, "pswf2d_xcoeff_out_8_8.npy")).T # RCOPT - assert np.allclose(result.asnumpy(), images) + # Historically, FPSWF returned complex values. + # Load and convert them for this hard coded test. + ccoefs = np.load(os.path.join(DATA_DIR, "pswf2d_vcoefs_out_8_8.npy")).T # RCOPT + coefs = ComplexCoef(basis, ccoefs).to_real() + result = coefs.evaluate() + + # This hardcoded reference result requires transposing the stack axis. + images = np.transpose( + np.load(os.path.join(DATA_DIR, "pswf2d_xcoef_out_8_8.npy")), (2, 0, 1) + ) + + np.testing.assert_allclose(result.asnumpy(), images, rtol=1e-05, atol=1e-08) diff --git a/tests/test_PSWFbasis2D.py b/tests/test_PSWFbasis2D.py index 8908d3173b..3660843ccd 100644 --- a/tests/test_PSWFbasis2D.py +++ b/tests/test_PSWFbasis2D.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from aspire.basis import PSWFBasis2D +from aspire.basis import ComplexCoef, PSWFBasis2D from aspire.image import Image from ._basis_util import UniversalBasisMixin, pswf_params_2d, show_basis_params @@ -16,26 +16,29 @@ @pytest.mark.parametrize("basis", test_bases, ids=show_basis_params) class TestPSWFBasis2D(UniversalBasisMixin): def testPSWFBasis2DEvaluate_t(self, basis): - img_ary = np.load( - os.path.join(DATA_DIR, "ffbbasis2d_xcoeff_in_8_8.npy") - ).T # RCOPT + img_ary = np.load(os.path.join(DATA_DIR, "ffbbasis2d_xcoef_in_8_8.npy")) images = Image(img_ary) result = basis.evaluate_t(images) - coeffs = np.load( - os.path.join(DATA_DIR, "pswf2d_vcoeffs_out_8_8.npy") - ).T # RCOPT + # Historically, PSWF returned complex values. + # Load and convert them for this hard coded test. + ccoefs = np.load(os.path.join(DATA_DIR, "pswf2d_vcoefs_out_8_8.npy")).T # RCOPT + coefs = ComplexCoef(basis, ccoefs).to_real() - # make sure both real and imaginary parts are consistent. - assert np.allclose(np.real(result), np.real(coeffs)) and np.allclose( - np.imag(result) * 1j, np.imag(coeffs) * 1j - ) + np.testing.assert_allclose(result, coefs, rtol=1e-05, atol=1e-08) def testPSWFBasis2DEvaluate(self, basis): - coeffs = np.load( - os.path.join(DATA_DIR, "pswf2d_vcoeffs_out_8_8.npy") - ).T # RCOPT - result = basis.evaluate(coeffs) - images = np.load(os.path.join(DATA_DIR, "pswf2d_xcoeff_out_8_8.npy")).T # RCOPT - assert np.allclose(result.asnumpy(), images) + # Historically, PSWF returned complex values. + # Load and convert them for this hard coded test. + ccoefs = np.load(os.path.join(DATA_DIR, "pswf2d_vcoefs_out_8_8.npy")).T # RCOPT + coefs = ComplexCoef(basis, ccoefs).to_real() + + result = coefs.evaluate() + + # This hardcoded reference result requires transposing the stack axis. + images = np.transpose( + np.load(os.path.join(DATA_DIR, "pswf2d_xcoef_out_8_8.npy")), (2, 0, 1) + ) + + np.testing.assert_allclose(result.asnumpy(), images, rtol=1e-05, atol=1e-08) diff --git a/tests/test_anisotropic_noise.py b/tests/test_anisotropic_noise.py index dc4732732e..12b52064ea 100644 --- a/tests/test_anisotropic_noise.py +++ b/tests/test_anisotropic_noise.py @@ -5,8 +5,8 @@ from aspire.noise import AnisotropicNoiseEstimator, WhiteNoiseEstimator from aspire.operators import RadialCTFFilter -from aspire.source import ArrayImageSource, Simulation -from aspire.utils.types import utest_tolerance +from aspire.source import ArrayImageSource, _LegacySimulation +from aspire.utils import utest_tolerance from aspire.volume import LegacyVolume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") @@ -16,7 +16,7 @@ class SimTestCase(TestCase): def setUp(self): self.dtype = np.float32 self.vol = LegacyVolume(L=8, dtype=self.dtype).generate() - self.sim = Simulation( + self.sim = _LegacySimulation( n=1024, vols=self.vol, unique_filters=[ diff --git a/tests/test_averager2d.py b/tests/test_averager2d.py index c09da2c42e..efa3a683ac 100644 --- a/tests/test_averager2d.py +++ b/tests/test_averager2d.py @@ -127,9 +127,11 @@ def _construct_rotations(self): 0, 2 * np.pi, num=self.n_img, endpoint=False, retstep=True, dtype=self.dtype ) - # Generate rotations to be used by `Simulation` + # Generate rotations to be used by `Simulation`. Since `Simulation` rotates + # the coordinate grid and the averager aligns by rotating the projection images, + # we negate the angles fed into `Simulation` for direct comparison later. self.rotations = Rotation.about_axis( - "z", self.thetas, dtype=self.dtype, gimble_lock_warnings=False + "z", -self.thetas, dtype=self.dtype, gimble_lock_warnings=False ) diff --git a/tests/test_batched_covar2d.py b/tests/test_batched_covar2d.py index 71c3366870..fbc6d3d7dc 100644 --- a/tests/test_batched_covar2d.py +++ b/tests/test_batched_covar2d.py @@ -2,7 +2,7 @@ import numpy as np -from aspire.basis import FFBBasis2D +from aspire.basis import Coef, FFBBasis2D from aspire.covariance import BatchedRotCov2D, RotCov2D from aspire.noise import WhiteNoiseAdder from aspire.operators import RadialCTFFilter @@ -17,7 +17,7 @@ class BatchedRotCov2DTestCase(TestCase): filters = None ctf_idx = None - ctf_fb = None + ctf_basis = None def setUp(self): n = 32 @@ -39,7 +39,7 @@ def setUp(self): noise_adder=noise_adder, ) self.basis = FFBBasis2D((L, L), dtype=self.dtype) - self.coeff = self.basis.evaluate_t(self.src.images[:]) + self.coef = self.basis.evaluate_t(self.src.images[:]) self.cov2d = RotCov2D(self.basis) self.bcov2d = BatchedRotCov2D(self.src, self.basis, batch_size=7) @@ -60,12 +60,12 @@ def testMeanCovar(self): # Test basic functionality against RotCov2D. mean_cov2d = self.cov2d.get_mean( - self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx + self.coef, ctf_basis=self.ctf_basis, ctf_idx=self.ctf_idx ) covar_cov2d = self.cov2d.get_covar( - self.coeff, - mean_coeff=mean_cov2d, - ctf_fb=self.ctf_fb, + self.coef, + mean_coef=mean_cov2d, + ctf_basis=self.ctf_basis, ctf_idx=self.ctf_idx, noise_var=self.noise_var, ) @@ -85,13 +85,16 @@ def testMeanCovar(self): def testZeroMean(self): # Make sure it works with zero mean (pure second moment). - zero_coeff = np.zeros((self.basis.count,), dtype=self.dtype) + zero_coef = Coef(self.basis, np.zeros((self.basis.count,), dtype=self.dtype)) covar_cov2d = self.cov2d.get_covar( - self.coeff, mean_coeff=zero_coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx + self.coef, + mean_coef=zero_coef, + ctf_basis=self.ctf_basis, + ctf_idx=self.ctf_idx, ) - covar_bcov2d = self.bcov2d.get_covar(mean_coeff=zero_coeff) + covar_bcov2d = self.bcov2d.get_covar(mean_coef=zero_coef) self.assertTrue( self.blk_diag_allclose( @@ -102,7 +105,7 @@ def testZeroMean(self): def testAutoMean(self): # Make sure it automatically calls get_mean if needed. covar_cov2d = self.cov2d.get_covar( - self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx + self.coef, ctf_basis=self.ctf_basis, ctf_idx=self.ctf_idx ) covar_bcov2d = self.bcov2d.get_covar() @@ -126,8 +129,8 @@ def testShrink(self): } covar_cov2d = self.cov2d.get_covar( - self.coeff, - ctf_fb=self.ctf_fb, + self.coef, + ctf_basis=self.ctf_basis, ctf_idx=self.ctf_idx, covar_est_opt=covar_est_opt, ) @@ -152,22 +155,22 @@ def testAutoBasis(self): def testCWFCoeff(self): # Calculate CWF coefficients using Cov2D base class mean_cov2d = self.cov2d.get_mean( - self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx + self.coef, ctf_basis=self.ctf_basis, ctf_idx=self.ctf_idx ) covar_cov2d = self.cov2d.get_covar( - self.coeff, - ctf_fb=self.ctf_fb, + self.coef, + ctf_basis=self.ctf_basis, ctf_idx=self.ctf_idx, noise_var=self.noise_var, make_psd=True, ) - coeff_cov2d = self.cov2d.get_cwf_coeffs( - self.coeff, - self.ctf_fb, + coef_cov2d = self.cov2d.get_cwf_coefs( + self.coef, + self.ctf_basis, self.ctf_idx, - mean_coeff=mean_cov2d, - covar_coeff=covar_cov2d, + mean_coef=mean_cov2d, + covar_coef=covar_cov2d, noise_var=self.noise_var, ) @@ -175,9 +178,9 @@ def testCWFCoeff(self): mean_bcov2d = self.bcov2d.get_mean() covar_bcov2d = self.bcov2d.get_covar(noise_var=self.noise_var, make_psd=True) - coeff_bcov2d = self.bcov2d.get_cwf_coeffs( - self.coeff, - self.ctf_fb, + coef_bcov2d = self.bcov2d.get_cwf_coefs( + self.coef, + self.ctf_basis, self.ctf_idx, mean_bcov2d, covar_bcov2d, @@ -185,15 +188,15 @@ def testCWFCoeff(self): ) self.assertTrue( self.blk_diag_allclose( - coeff_cov2d, - coeff_bcov2d, + coef_cov2d, + coef_bcov2d, atol=utest_tolerance(self.dtype), ) ) def testCWFCoeffCleanCTF(self): """ - Test case of clean images (coeff_clean and noise_var=0) + Test case of clean images (coef_clean and noise_var=0) while using a non Identity CTF. This case may come up when a developer switches between @@ -202,22 +205,22 @@ def testCWFCoeffCleanCTF(self): # Calculate CWF coefficients using Cov2D base class mean_cov2d = self.cov2d.get_mean( - self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx + self.coef, ctf_basis=self.ctf_basis, ctf_idx=self.ctf_idx ) covar_cov2d = self.cov2d.get_covar( - self.coeff, - ctf_fb=self.ctf_fb, + self.coef, + ctf_basis=self.ctf_basis, ctf_idx=self.ctf_idx, noise_var=self.noise_var, make_psd=True, ) - coeff_cov2d = self.cov2d.get_cwf_coeffs( - self.coeff, - self.ctf_fb, + coef_cov2d = self.cov2d.get_cwf_coefs( + self.coef, + self.ctf_basis, self.ctf_idx, - mean_coeff=mean_cov2d, - covar_coeff=covar_cov2d, + mean_coef=mean_cov2d, + covar_coef=covar_cov2d, noise_var=0, ) @@ -225,9 +228,9 @@ def testCWFCoeffCleanCTF(self): mean_bcov2d = self.bcov2d.get_mean() covar_bcov2d = self.bcov2d.get_covar(noise_var=self.noise_var, make_psd=True) - coeff_bcov2d = self.bcov2d.get_cwf_coeffs( - self.coeff, - self.ctf_fb, + coef_bcov2d = self.bcov2d.get_cwf_coefs( + self.coef, + self.ctf_basis, self.ctf_idx, mean_bcov2d, covar_bcov2d, @@ -235,8 +238,8 @@ def testCWFCoeffCleanCTF(self): ) self.assertTrue( self.blk_diag_allclose( - coeff_cov2d, - coeff_bcov2d, + coef_cov2d, + coef_bcov2d, atol=utest_tolerance(self.dtype), ) ) @@ -259,5 +262,5 @@ def ctf_idx(self): return self.src.filter_indices @property - def ctf_fb(self): - return [f.fb_mat(self.basis) for f in self.src.unique_filters] + def ctf_basis(self): + return [self.basis.filter_to_basis_mat(f) for f in self.src.unique_filters] diff --git a/tests/test_class2D.py b/tests/test_class2D.py index bec7348442..450c90ba9c 100644 --- a/tests/test_class2D.py +++ b/tests/test_class2D.py @@ -1,12 +1,19 @@ import logging import os -from unittest import TestCase import numpy as np import pytest from sklearn import datasets -from aspire.basis import FFBBasis2D, FSPCABasis +from aspire.basis import ( + Coef, + FBBasis2D, + FFBBasis2D, + FLEBasis2D, + FPSWFBasis2D, + FSPCABasis, + PSWFBasis2D, +) from aspire.classification import RIRClass2D from aspire.classification.legacy_implementations import bispec_2drot_large, pca_y from aspire.noise import WhiteNoiseAdder @@ -23,369 +30,421 @@ SEED = 42 -class FSPCATestCase(TestCase): - def setUp(self): - self.resolution = 16 - self.dtype = np.float32 +IMG_SIZES = [16] +DTYPES = [np.float32] +# Basis used in FSPCA for class averaging. +BASIS = [ + FFBBasis2D, + pytest.param(FBBasis2D, marks=pytest.mark.expensive), + pytest.param(FLEBasis2D, marks=pytest.mark.expensive), + pytest.param(PSWFBasis2D, marks=pytest.mark.expensive), + pytest.param(FPSWFBasis2D, marks=pytest.mark.expensive), +] - # Get a volume - v = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype( - self.dtype - ) - ) - v = v.downsample(self.resolution) - # Create a src from the volume - self.src = Simulation( - L=self.resolution, n=321, vols=v, dtype=self.dtype, seed=SEED - ) - self.src = self.src.cache() # Precompute image stack +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param - # Calculate some projection images - self.imgs = self.src.images[:] - # Configure an FSPCA basis - self.fspca_basis = FSPCABasis(self.src, noise_var=0) +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}", scope="module") +def img_size(request): + return request.param - def testExpandEval(self): - coef = self.fspca_basis.expand_from_image_basis(self.imgs) - recon = self.fspca_basis.evaluate_to_image_basis(coef) - # Check recon is close to imgs - rmse = np.sqrt(np.mean(np.square(self.imgs.asnumpy() - recon.asnumpy()))) - logger.info(f"FSPCA Expand Eval Image Round True RMSE: {rmse}") - self.assertTrue(rmse < utest_tolerance(self.dtype)) +@pytest.fixture(scope="module") +def volume(dtype, img_size): + # Get a volume + v = Volume( + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype(dtype) + ) - def testComplexConversionErrors(self): - """ - Test we raise when passed incorrect dtypes. + return v.downsample(img_size) - Also checks we can handle 0d vector in `to_real`. - Most other cases covered by classification unit tests. - """ +@pytest.fixture(scope="module") +def sim_fixture(volume, img_size, dtype): + """ + Provides a clean simulation parameterized by `img_size` and `dtype`. + """ - with pytest.raises( - TypeError, match="coef provided to to_complex should be real." - ): - _ = self.fspca_basis.to_complex( - np.arange(self.fspca_basis.count, dtype=np.complex64) - ) + # Create a src from the volume + src = Simulation(L=img_size, n=321, vols=volume, dtype=dtype, seed=SEED) + src = src.cache() # Precompute image stack - with pytest.raises( - TypeError, match="coef provided to to_real should be complex." - ): - _ = self.fspca_basis.to_real( - np.arange(self.fspca_basis.count, dtype=np.float32).flatten() - ) + # Calculate some projection images + imgs = src.images[:] - def testRotate(self): - """ - Trivial test of rotation in FSPCA Basis. - - Also covers to_real and to_complex conversions in FSPCA Basis. - """ - coef = self.fspca_basis.expand_from_image_basis(self.imgs) - # rotate by pi - rot_coef = self.fspca_basis.rotate(coef, radians=np.pi) - rot_imgs = self.fspca_basis.evaluate_to_image_basis(rot_coef) - - for i, img in enumerate(self.imgs): - rmse = np.sqrt(np.mean(np.square(np.flip(img) - rot_imgs[i]))) - self.assertTrue(rmse < 10 * utest_tolerance(self.dtype)) - - def testBasisTooSmall(self): - """ - When number of components is more than basis functions raise with descriptive error. - """ - fb_basis = FFBBasis2D((self.resolution, self.resolution), dtype=self.dtype) - - with pytest.raises(ValueError, match=r".*Reduce components.*"): - # Configure an FSPCA basis - _ = FSPCABasis( - self.src, basis=fb_basis, components=fb_basis.count * 2, noise_var=0 - ) + # Configure an FSPCA basis + fspca_basis = FSPCABasis(src, noise_var=0) + + return imgs, src, fspca_basis + + +@pytest.fixture(params=BASIS, ids=lambda x: f"basis={x}", scope="module") +def basis(request, img_size, dtype): + cls = request.param + # Setup a Basis + basis = cls(img_size, dtype=dtype) + return basis + + +def test_expand_eval(sim_fixture): + imgs, _, fspca_basis = sim_fixture + coef = fspca_basis.expand_from_image_basis(imgs) + recon = fspca_basis.evaluate_to_image_basis(coef) + # Check recon is close to imgs + rmse = np.sqrt(np.mean(np.square(imgs.asnumpy() - recon.asnumpy()))) + logger.info(f"FSPCA Expand Eval Image Round True RMSE: {rmse}") + assert rmse < utest_tolerance(fspca_basis.dtype) -class RIRClass2DTestCase(TestCase): - def setUp(self): - self.n_classes = 5 - self.resolution = 16 - self.dtype = np.float64 - self.n_img = 150 - # Create some projections - v = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype( - self.dtype +def test_complex_conversions_errors(sim_fixture): + """ + Test we raise when passed incorrect dtypes. + + Also checks we can handle 0d vector in `to_real`. + + Most other cases covered by classification unit tests. + """ + imgs, _, fspca_basis = sim_fixture + + with pytest.raises(TypeError): + _ = fspca_basis.to_complex( + Coef( + fspca_basis, + np.arange(fspca_basis.count), + dtype=np.complex64, ) ) - v = v.downsample(self.resolution) - # Clean - self.clean_src = Simulation( - L=self.resolution, n=self.n_img, vols=v, dtype=self.dtype, seed=SEED + with pytest.raises(TypeError): + _ = fspca_basis.to_real( + Coef(fspca_basis, np.arange(fspca_basis.count), dtype=np.float32) ) - # With Noise - noise_var = 0.01 * np.var(np.sum(v[0], axis=0)) - noise_adder = WhiteNoiseAdder(var=noise_var) - self.noisy_src = Simulation( - L=self.resolution, - n=self.n_img, - vols=v, - dtype=self.dtype, - noise_adder=noise_adder, - seed=SEED, - ) - # Set up FFB - # Setup a Basis - self.basis = FFBBasis2D((self.resolution, self.resolution), dtype=self.dtype) - - # Create Basis, use precomputed Basis - self.clean_fspca_basis = FSPCABasis( - self.clean_src, self.basis, noise_var=0 - ) # Note noise_var assigned zero, skips eigval filtering. - - self.clean_fspca_basis_compressed = FSPCABasis( - self.clean_src, self.basis, components=101, noise_var=0 - ) # Note noise_var assigned zero, skips eigval filtering. - - # Ceate another fspca_basis, use autogeneration FFB2D Basis - self.noisy_fspca_basis = FSPCABasis(self.noisy_src) - - def testSourceTooSmall(self): - """ - When number of images in source is less than requested bispectrum components, - raise with descriptive error. - """ - - with pytest.raises( - RuntimeError, match=r".*Increase number of images or reduce components.*" - ): - _ = RIRClass2D( - self.clean_src, - fspca_components=self.clean_src.n * 4, - bispectrum_components=self.clean_src.n * 2, - ) +def test_rotate(sim_fixture): + """ + Trivial test of rotation in FSPCA Basis. - def testIncorrectComponents(self): - """ - Check we raise with inconsistent configuration of FSPCA components. - """ - - with pytest.raises( - RuntimeError, match=r"`pca_basis` components.*provided by user." - ): - _ = RIRClass2D( - self.clean_src, - self.clean_fspca_basis, # 400 components - fspca_components=100, - large_pca_implementation="legacy", - nn_implementation="legacy", - bispectrum_implementation="legacy", - ) + Also covers to_real and to_complex conversions in FSPCA Basis. + """ + imgs, _, fspca_basis = sim_fixture + + coef = fspca_basis.expand_from_image_basis(imgs) + # rotate by pi + rot_coef = fspca_basis.rotate(coef, radians=np.pi) + rot_imgs = fspca_basis.evaluate_to_image_basis(rot_coef) - # Explicitly providing the same number should be okay. + for i, img in enumerate(imgs): + rmse = np.sqrt(np.mean(np.square(np.flip(img) - rot_imgs[i]))) + assert rmse < 10 * utest_tolerance(fspca_basis.dtype) + + +def test_basis_too_small(sim_fixture, basis): + """ + When number of components is more than basis functions raise with descriptive error. + """ + src = sim_fixture[1] + + with pytest.raises(ValueError, match=r".*Reduce components.*"): + # Configure an FSPCA basis + _ = FSPCABasis(src, basis=basis, components=basis.count * 2, noise_var=0) + + +@pytest.fixture(scope="module") +def sim_fixture2(volume, basis, img_size, dtype): + """ + Provides clean/noisy pair of smaller parameterized simulations, + along with corresponding clean/noisy basis and an additional + compressed basis. + + These are slightly smaller than `sim_fixture` and support covering + additional code and corner cases. + """ + + n_img = 150 + + # Clean + clean_src = Simulation(L=img_size, n=n_img, vols=volume, dtype=dtype, seed=SEED) + clean_src = clean_src.cache() + + # With Noise + noise_var = 0.01 * np.var(np.sum(volume[0], axis=0)) + noise_adder = WhiteNoiseAdder(var=noise_var) + noisy_src = Simulation( + L=img_size, + n=n_img, + vols=volume, + dtype=dtype, + noise_adder=noise_adder, + seed=SEED, + ) + noisy_src = noisy_src.cache() + + # Create Basis, use precomputed Basis + clean_fspca_basis = FSPCABasis( + clean_src, basis, noise_var=0 + ) # Note noise_var assigned zero, skips eigval filtering. + + clean_fspca_basis_compressed = FSPCABasis( + clean_src, basis, components=101, noise_var=0 + ) # Note noise_var assigned zero, skips eigval filtering. + + # Ceate another fspca_basis, use autogeneration Basis + noisy_fspca_basis = FSPCABasis(noisy_src) + + return ( + clean_src, + noisy_src, + clean_fspca_basis, + clean_fspca_basis_compressed, + noisy_fspca_basis, + ) + + +def test_source_too_small(sim_fixture2): + """ + When number of images in source is less than requested bispectrum components, + raise with descriptive error. + """ + clean_src = sim_fixture2[0] + + with pytest.raises( + RuntimeError, match=r".*Increase number of images or reduce components.*" + ): _ = RIRClass2D( - self.clean_src, - self.clean_fspca_basis, # 400 components - fspca_components=self.clean_fspca_basis.components, - bispectrum_components=100, - large_pca_implementation="legacy", - nn_implementation="legacy", - bispectrum_implementation="legacy", - seed=SEED, + clean_src, + fspca_components=clean_src.n * 4, + bispectrum_components=clean_src.n * 2, ) - def testRIRLegacy(self): - """ - Currently just tests for runtime errors. - """ - clean_fspca_basis = FSPCABasis( - self.clean_src, self.basis, noise_var=0, components=100 - ) # Note noise_var assigned zero, skips eigval filtering. +def test_incorrect_components(sim_fixture2): + """ + Check we raise with inconsistent configuration of FSPCA components. + """ + clean_src, clean_fspca_basis = sim_fixture2[0], sim_fixture2[2] - rir = RIRClass2D( - self.clean_src, - clean_fspca_basis, - bispectrum_components=42, + with pytest.raises( + RuntimeError, match=r"`pca_basis` components.*provided by user." + ): + _ = RIRClass2D( + clean_src, + clean_fspca_basis, # 400 components + fspca_components=100, large_pca_implementation="legacy", nn_implementation="legacy", bispectrum_implementation="legacy", - seed=SEED, ) - _ = rir.classify() + # Explicitly providing the same number should be okay. + _ = RIRClass2D( + clean_src, + clean_fspca_basis, # 400 components + fspca_components=clean_fspca_basis.components, + bispectrum_components=100, + large_pca_implementation="legacy", + nn_implementation="legacy", + bispectrum_implementation="legacy", + seed=SEED, + ) - def testRIRDevelBisp(self): - """ - Currently just tests for runtime errors. - """ - # Use the basis class setup, only requires a Source. - rir = RIRClass2D( - self.clean_src, - fspca_components=self.clean_fspca_basis.components, - bispectrum_components=self.clean_fspca_basis.components - 1, - large_pca_implementation="legacy", - nn_implementation="legacy", - bispectrum_implementation="devel", - ) +def test_RIR_legacy(basis, sim_fixture2): + """ + Currently just tests for runtime errors. + """ + clean_src = sim_fixture2[0] - _ = rir.classify() + clean_fspca_basis = FSPCABasis( + clean_src, basis, noise_var=0, components=100 + ) # Note noise_var assigned zero, skips eigval filtering. - def testRIRsk(self): - """ - Excercises the eigenvalue based filtering, - along with other swappable components. + rir = RIRClass2D( + clean_src, + clean_fspca_basis, + bispectrum_components=42, + large_pca_implementation="legacy", + nn_implementation="legacy", + bispectrum_implementation="legacy", + seed=SEED, + ) - Currently just tests for runtime errors. - """ - rir = RIRClass2D( - self.noisy_src, - self.noisy_fspca_basis, - bispectrum_components=100, - sample_n=42, - large_pca_implementation="sklearn", - nn_implementation="sklearn", - bispectrum_implementation="devel", - seed=SEED, + _ = rir.classify() + + +def test_RIR_devel_disp(sim_fixture2): + """ + Currently just tests for runtime errors. + """ + clean_src, fspca_basis = sim_fixture2[0], sim_fixture2[3] + + # Use the basis class setup, only requires a Source. + rir = RIRClass2D( + clean_src, + fspca_components=fspca_basis.components, + bispectrum_components=fspca_basis.components - 1, + large_pca_implementation="legacy", + nn_implementation="legacy", + bispectrum_implementation="devel", + ) + + _ = rir.classify() + + +def test_RIR_sk(sim_fixture2): + """ + Excercises the eigenvalue based filtering, + along with other swappable components. + + Currently just tests for runtime errors. + """ + noisy_src, noisy_fspca_basis = sim_fixture2[1], sim_fixture2[4] + + rir = RIRClass2D( + noisy_src, + noisy_fspca_basis, + bispectrum_components=100, + sample_n=42, + large_pca_implementation="sklearn", + nn_implementation="sklearn", + bispectrum_implementation="devel", + seed=SEED, + ) + + _ = rir.classify() + + +def test_eigein_images(sim_fixture2): + """ + Test we can return eigenimages. + """ + clean_fspca_basis, clean_fspca_basis_compressed = sim_fixture2[2], sim_fixture2[3] + + # Get the eigenimages from an FSPCA basis for testing + eigimg_uncompressed = clean_fspca_basis.eigen_images() + + # Get the eigenimages from a compressed FSPCA basis for testing + eigimg_compressed = clean_fspca_basis_compressed.eigen_images() + + # Check they are close. + # Note it is expected the compression reorders the eigvecs, + # and thus the eigimages. + # We sum over all the eigimages to yield an "average" for comparison + assert np.allclose( + np.sum(eigimg_uncompressed.asnumpy(), axis=0), + np.sum(eigimg_compressed.asnumpy(), axis=0), + atol=utest_tolerance(clean_fspca_basis.dtype), + ) + + +def test_component_size(sim_fixture2): + """ + Tests we raise when number of components are too small. + + Also tests dtype mismatch behavior. + """ + clean_src, compressed_fspca_basis = sim_fixture2[0], sim_fixture2[3] + + with pytest.raises(RuntimeError, match=r".*Reduce bispectrum_components.*"): + _ = RIRClass2D( + clean_src, + compressed_fspca_basis, + bispectrum_components=clean_src.n + 1, ) - _ = rir.classify() - def testEigenImages(self): - """ - Test we can return eigenimages. - """ +def test_implementations(basis, sim_fixture2): + """ + Test optional implementations handle bad inputs with a descriptive error. + """ + clean_src, clean_fspca_basis = sim_fixture2[0], sim_fixture2[2] - # Get the eigenimages from an FSPCA basis for testing - eigimg_uncompressed = self.clean_fspca_basis.eigen_images() + # Nearest Neighbhor component + with pytest.raises(ValueError, match=r"Provided nn_implementation.*"): + _ = RIRClass2D( + clean_src, + clean_fspca_basis, + bispectrum_components=int(0.75 * clean_fspca_basis.basis.count), + nn_implementation="badinput", + ) - # Get the eigenimages from a compressed FSPCA basis for testing - eigimg_compressed = self.clean_fspca_basis_compressed.eigen_images() + # Large PCA component + with pytest.raises(ValueError, match=r"Provided large_pca_implementation.*"): + _ = RIRClass2D( + clean_src, + clean_fspca_basis, + large_pca_implementation="badinput", + ) - # Check they are close. - # Note it is expected the compression reorders the eigvecs, - # and thus the eigimages. - # We sum over all the eigimages to yield an "average" for comparison - self.assertTrue( - np.allclose( - np.sum(eigimg_uncompressed.asnumpy(), axis=0), - np.sum(eigimg_compressed.asnumpy(), axis=0), - ) + # Bispectrum component + with pytest.raises(ValueError, match=r"Provided bispectrum_implementation.*"): + _ = RIRClass2D( + clean_src, + clean_fspca_basis, + bispectrum_implementation="badinput", + ) + + # Legacy Bispectrum implies legacy bispectrum (they're integrated). + with pytest.raises( + ValueError, match=r'"legacy" bispectrum_implementation implies.*' + ): + _ = RIRClass2D( + clean_src, + clean_fspca_basis, + bispectrum_implementation="legacy", + large_pca_implementation="sklearn", ) - def testComponentSize(self): - """ - Tests we raise when number of components are too small. + # Currently we only FSPCA Basis in RIRClass2D + with pytest.raises( + RuntimeError, + match="RIRClass2D has currently only been developed for pca_basis as a FSPCABasis.", + ): + _ = RIRClass2D(clean_src, basis) - Also tests dtype mismatch behavior. - """ - with pytest.raises(RuntimeError, match=r".*Reduce bispectrum_components.*"): - _ = RIRClass2D( - self.clean_src, - self.clean_fspca_basis, - bispectrum_components=self.clean_src.n + 1, - dtype=np.float64, - ) +# Cover branches of Legacy code not taken by the classification unit tests. - def testImplementations(self): - """ - Test optional implementations handle bad inputs with a descriptive error. - """ - - # Nearest Neighbhor component - with pytest.raises(ValueError, match=r"Provided nn_implementation.*"): - _ = RIRClass2D( - self.clean_src, - self.clean_fspca_basis, - bispectrum_components=int(0.75 * self.clean_fspca_basis.basis.count), - nn_implementation="badinput", - ) - # Large PCA component - with pytest.raises(ValueError, match=r"Provided large_pca_implementation.*"): - _ = RIRClass2D( - self.clean_src, - self.clean_fspca_basis, - large_pca_implementation="badinput", - ) +def test_pca_y(): + """ + We want to check that real inputs and differing input matrix shapes work. - # Bispectrum component - with pytest.raises(ValueError, match=r"Provided bispectrum_implementation.*"): - _ = RIRClass2D( - self.clean_src, - self.clean_fspca_basis, - bispectrum_implementation="badinput", - ) + Most of pca_y is covered by the classificiation unit tests. + """ - # Legacy Bispectrum implies legacy bispectrum (they're integrated). - with pytest.raises( - ValueError, match=r'"legacy" bispectrum_implementation implies.*' - ): - _ = RIRClass2D( - self.clean_src, - self.clean_fspca_basis, - bispectrum_implementation="legacy", - large_pca_implementation="sklearn", - ) + # The iris dataset is a small 150 sample by 5 feature dataset in float64 + iris = datasets.load_iris() - # Currently we only FSPCA Basis in RIRClass2D - with pytest.raises( - RuntimeError, - match="RIRClass2D has currently only been developed for pca_basis as a FSPCABasis.", - ): - _ = RIRClass2D(self.clean_src, self.basis) + # Extract the data matrix, run once as is (150, 5), + # and once tranposed so shape[0] < shape[1] (5, 150) + for x in (iris.data, iris.data.T): + # Run pca_y and check reconstruction holds + lsvec, svals, rsvec = pca_y(x, 5) + # svd ~~> A = U S V = (U S) V + recon = np.dot(lsvec * svals, rsvec) -class LegacyImplementationTestCase(TestCase): + assert np.allclose(x, recon) + + +def test_bispect_overflow(): """ - Cover branches of Legacy code not taken by the classification unit tests. + A zero value coef will cause a div0 error in log call. + Check it is raised. """ - def setUp(self): - pass - - def test_pca_y(self): - """ - We want to check that real inputs and differing input matrix shapes work. - - Most of pca_y is covered by the classificiation unit tests. - """ - - # The iris dataset is a small 150 sample by 5 feature dataset in float64 - iris = datasets.load_iris() - - # Extract the data matrix, run once as is (150, 5), - # and once tranposed so shape[0] < shape[1] (5, 150) - for x in (iris.data, iris.data.T): - # Run pca_y and check reconstruction holds - lsvec, svals, rsvec = pca_y(x, 5) - - # svd ~~> A = U S V = (U S) V - recon = np.dot(lsvec * svals, rsvec) - - self.assertTrue(np.allclose(x, recon)) - - def testBispectOverflow(self): - """ - A zero value coeff will cause a div0 error in log call. - Check it is raised. - """ - - with pytest.raises(ValueError, match="coeff_norm should not be -inf"): - # This should emit a warning before raising - with self.assertWarns(RuntimeWarning): - bispec_2drot_large( - coeff=np.arange(10), - freqs=np.arange(1, 11), - eigval=np.arange(10), - alpha=1 / 3, - sample_n=4000, - ) + with pytest.raises(ValueError, match="coef_norm should not be -inf"): + # This should emit a warning before raising + with pytest.warns(RuntimeWarning): + bispec_2drot_large( + coef=np.arange(10), + freqs=np.arange(1, 11), + eigval=np.arange(10), + alpha=1 / 3, + sample_n=4000, + ) diff --git a/tests/test_class_src.py b/tests/test_class_src.py index 6e1b9ca903..0c169621cd 100644 --- a/tests/test_class_src.py +++ b/tests/test_class_src.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from aspire.basis import FFBBasis2D +from aspire.basis import FBBasis2D, FFBBasis2D, FLEBasis2D, FPSWFBasis2D, PSWFBasis2D from aspire.classification import ( BandedSNRImageQualityFunction, BFRAverager2D, @@ -35,10 +35,12 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") +# RNG SEED, should help small class average tests be deterministic. +SEED = 5552368 IMG_SIZES = [ - 32, - pytest.param(31, marks=pytest.mark.expensive), + 16, + pytest.param(15, marks=pytest.mark.expensive), ] DTYPES = [ np.float64, @@ -49,23 +51,40 @@ NUM_PROCS = 1 +BASIS = [ + FFBBasis2D, + pytest.param(FBBasis2D, marks=pytest.mark.expensive), + pytest.param(FLEBasis2D, marks=pytest.mark.expensive), + pytest.param(PSWFBasis2D, marks=pytest.mark.expensive), + pytest.param(FPSWFBasis2D, marks=pytest.mark.expensive), +] + + +@pytest.fixture(params=BASIS, ids=lambda x: f"basis={x}", scope="module") +def basis(request, img_size, dtype): + cls = request.param + # Setup a Basis + basis = cls(img_size, dtype=dtype) + return basis + + def sim_fixture_id(params): res = params[0] dtype = params[1] return f"res={res}, dtype={dtype}" -@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") def dtype(request): return request.param -@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}") +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}", scope="module") def img_size(request): return request.param -@pytest.fixture +@pytest.fixture(scope="module") def class_sim_fixture(dtype, img_size): """ Construct a Simulation with explicit viewing angles forming @@ -117,12 +136,31 @@ def class_sim_fixture(dtype, img_size): return src -@pytest.fixture(params=CLS_SRCS, ids=lambda param: f"ClassSource={param.__class__}") +@pytest.fixture( + params=CLS_SRCS, + ids=lambda param: f"ClassSource={param.__class__.__name__}", + scope="module", +) def test_src_cls(request): return request.param -def test_basic_averaging(class_sim_fixture, test_src_cls): +@pytest.fixture(scope="module") +def classifier(class_sim_fixture): + return RIRClass2D( + class_sim_fixture, + fspca_components=63, + bispectrum_components=51, # Compressed Features after last PCA stage. + n_nbor=10, + sample_n=50000, + large_pca_implementation="legacy", + nn_implementation="legacy", + bispectrum_implementation="legacy", + seed=SEED, + ) + + +def test_basic_averaging(class_sim_fixture, test_src_cls, basis, classifier): """ Test that the default `ClassAvgSource` implementations return class averages. @@ -130,8 +168,10 @@ class averages. cmp_n = 5 - # Classify, Select, and compute averaged images. - test_src = test_src_cls(src=class_sim_fixture, num_procs=NUM_PROCS) + test_src = test_src_cls( + src=class_sim_fixture, classifier=classifier, num_procs=NUM_PROCS + ) + test_imgs = test_src.images[:cmp_n] # Fetch reference images from the original source. @@ -177,13 +217,21 @@ def test_heap_helper(): assert popped == a, "Failed to pop min item" -@pytest.fixture() +@pytest.fixture(scope="module") def cls_fixture(class_sim_fixture): """ Classifier fixture. """ # Create the classifier - c2d = RIRClass2D(class_sim_fixture, nn_implementation="sklearn") + c2d = RIRClass2D( + class_sim_fixture, + fspca_components=63, + bispectrum_components=51, # Compressed Features after last PCA stage. + n_nbor=10, + sample_n=50000, + nn_implementation="sklearn", + seed=SEED, + ) # Compute the classification # (classes, reflections, distances) return c2d.classify() @@ -234,9 +282,9 @@ def test_online_selector(cls_fixture, selector): "quality_function", QUALITY_FUNCTIONS, ids=lambda param: f"Quality Function={param}" ) @pytest.mark.expensive -def test_global_selector(class_sim_fixture, cls_fixture, selector, quality_function): - basis = FFBBasis2D(class_sim_fixture.L, dtype=class_sim_fixture.dtype) - +def test_global_selector( + class_sim_fixture, cls_fixture, selector, quality_function, basis +): averager = BFRAverager2D(basis, class_sim_fixture, num_procs=NUM_PROCS) fun = quality_function() @@ -277,8 +325,10 @@ def test_contrast_selector(dtype): assert np.allclose(selector._quality_scores, ref_scores) -def test_avg_src_starfileio(class_sim_fixture, test_src_cls): - src = test_src_cls(src=class_sim_fixture, num_procs=NUM_PROCS) +def test_avg_src_starfileio(class_sim_fixture, test_src_cls, classifier): + src = test_src_cls( + src=class_sim_fixture, classifier=classifier, num_procs=NUM_PROCS + ) # Save and load the source as a STAR file. # Saving should force classification and selection to occur, diff --git a/tests/test_coef.py b/tests/test_coef.py new file mode 100644 index 0000000000..ab546d9bac --- /dev/null +++ b/tests/test_coef.py @@ -0,0 +1,446 @@ +import numpy as np +import pytest + +from aspire.basis import ( + Coef, + FBBasis2D, + FFBBasis2D, + FLEBasis2D, + FPSWFBasis2D, + PSWFBasis2D, +) +from aspire.utils import utest_tolerance + +IMG_SIZE = [ + 31, + 32, +] +DTYPES = [ + np.float32, + np.float64, +] +STACKS = [ + (), + (1,), + (2,), + (3, 4), +] + +ALLYOURBASES = [ + FBBasis2D, + FFBBasis2D, + PSWFBasis2D, + FPSWFBasis2D, + FLEBasis2D, +] + + +def sim_fixture_id(params): + stack, count, dtype = params + return f"stack={stack}, count={count}, dtype={dtype}" + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + """ + Dtypes for coef array + """ + return request.param + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def basis_dtype(request): + """ + Dtypes for basis + """ + return request.param + + +@pytest.fixture(params=IMG_SIZE, ids=lambda x: f"count={x}", scope="module") +def img_size(request): + """ + Image size for basis. + """ + return request.param + + +@pytest.fixture(params=STACKS, ids=lambda x: f"stack={x}", scope="module") +def stack(request): + """ + Stack dimensions. + """ + return request.param + + +@pytest.fixture(params=ALLYOURBASES, ids=lambda x: f"basis={x}", scope="module") +def basis(request, img_size, basis_dtype): + """ + Parameterized `Basis` instantiation. + """ + cls = request.param + return cls(img_size, dtype=basis_dtype) + + +@pytest.fixture(scope="module") +def coef_fixture(basis, stack, dtype): + """ + Construct parameterized testing coefficient array as `Coef`. + """ + # Combine the stack and coeficent counts into multidimensional + # shape. + size = stack + (basis.count,) + + coef_np = np.random.random(size=size).astype(dtype, copy=False) + + return Coef(basis, coef_np, dtype=dtype) + + +@pytest.fixture(scope="module") +def rots(coef_fixture, dtype): + # Rotations + return np.linspace(-np.pi, np.pi, coef_fixture.stack_size).reshape( + coef_fixture.stack_shape + ) + + +def test_mismatch_count(basis): + """ + Confirm raises when instantiated with incorrect coef vector len. + """ + # Derive an incorrect Coef + x = np.empty(basis.count + 1, basis.dtype) + with pytest.raises(RuntimeError, match=r".*does not match basis count.*"): + _ = Coef(basis, x) + + +def test_incorrect_coef_type(basis): + """ + Confirm raises when instantiated with incorrect coef type. + """ + # Construct incorrect Coef type (list) + x = list(range(basis.count + 1)) + with pytest.raises(ValueError, match=r".*should be instantiated with an.*"): + _ = Coef(basis, x) + + +def test_0dim(basis): + """ + Confirm raises when instantiated with 0dim scalar. + """ + # Construct 0dim scalar object + x = np.array(1) + with pytest.raises(ValueError, match=r".*with shape.*"): + _ = Coef(basis, x) + + +def test_not_a_basis(): + """ + Confirm raises when instantiated with something that is not a Basis. + """ + # Derive an incorrect Coef + x = np.empty(10) + with pytest.raises(TypeError, match=r".*required to be a `Basis`.*"): + _ = Coef(None, x) + + +def test_coef_key_dims(coef_fixture): + """ + Test key lookup out of bounds dimension raises. + """ + dim = coef_fixture.ndim + # Construct a key with too many dims + key = (0,) * (dim + 1) + with pytest.raises(ValueError, match=r".*stack_dim is.*"): + _ = coef_fixture[key] + + +def test_incorrect_reshape(basis): + """ + Confirm raises when attempting incorrect stack reshape. + """ + + # create a multi dim coef array. + x = np.empty((2, 3, 4, basis.count)) + c = Coef(basis, x) + + # Alter the stack shape, creating an incorrect shape. + shp = list(c.stack_shape) + shp[0] = shp[0] + 1 + + with pytest.raises(ValueError, match=r".*cannot be reshaped to.*"): + _ = c.stack_reshape(*shp) + + +def test_stack_reshape(basis): + """ + Test stack_reshape matches corresponding pure Numpy reshape. + """ + # create a multi dim coef array. + x = np.empty((2, 3, 4, basis.count)) + c = Coef(basis, x) + + # Test -1 flatten + ref = x.reshape(-1, basis.count) + np.testing.assert_allclose(c.stack_reshape(-1).asnumpy(), ref) + # Test 1d flatten + np.testing.assert_allclose(c.stack_reshape(np.prod(x.shape[:-1])).asnumpy(), ref) + # Test 2d reshape tuple (2,3,4) ~> ((6,4)) + ref = x.reshape(np.prod(x.shape[:-2]), x.shape[-2], basis.count) + np.testing.assert_allclose( + c.stack_reshape((np.prod(x.shape[:-2]), x.shape[-2])).asnumpy(), ref + ) + # Test 2d reshape args (2,3,4) ~> (6,4) + ref = x.reshape(np.prod(x.shape[:-2]), x.shape[-2], basis.count) + np.testing.assert_allclose( + c.stack_reshape(np.prod(x.shape[:-2]), x.shape[-2]).asnumpy(), ref + ) + + +def test_size(coef_fixture): + """ + Confirm size matches. + """ + np.testing.assert_equal(coef_fixture.size, coef_fixture.asnumpy().size) + np.testing.assert_equal(coef_fixture.size, coef_fixture._data.size) + + +# Test basic arithmetic functions + + +def test_add(basis, coef_fixture): + """ + Tests addition operation against pure Numpy. + """ + # Make array + x = np.random.random(size=coef_fixture.shape).astype(coef_fixture.dtype, copy=False) + # Construct Coef + c = Coef(basis, x) + + # Perform operation as array for reference + ref = coef_fixture.asnumpy() + x + + # Perform operation as `Coef` for result + res = coef_fixture + c + + # Compare result with reference + np.testing.assert_allclose(res, ref) + + +def test_sub(basis, coef_fixture): + """ + Tests subtraction operation against pure Numpy. + """ + # Make array + x = np.random.random(size=coef_fixture.shape).astype(coef_fixture.dtype, copy=False) + # Construct Coef + c = Coef(basis, x) + + # Perform operation as array for reference + ref = coef_fixture.asnumpy() - x + + # Perform operation as `Coef` for result + res = coef_fixture - c + + # Compare result with reference + np.testing.assert_allclose(res, ref) + + +def test_neg(basis, coef_fixture): + """ + Tests negation operation against pure Numpy. + """ + # Perform operation as array for reference + ref = -coef_fixture.asnumpy() + + # Perform operation as `Coef` for result + res = -coef_fixture + + # Compare result with reference + np.testing.assert_allclose(res, ref) + + +def test_mul(basis, coef_fixture): + """ + Tests multiplication operation against pure Numpy. + """ + # Make array + x = np.random.random(size=coef_fixture.shape).astype(coef_fixture.dtype, copy=False) + # Construct Coef + c = Coef(basis, x) + + # Perform operation as array for reference + ref = coef_fixture.asnumpy() * x + + # Perform operation as `Coef` for result + res = coef_fixture * c + + # Compare result with reference + np.testing.assert_allclose(res, ref) + + +# Test Passthrough Functions + + +def test_by_indices(coef_fixture, basis): + """ + Test indice passthrough. + """ + keys = [ + dict(), + dict(angular=1), + dict(radial=2), + dict(angular=1, radial=2), + dict(angular=basis.angular_indices > 0), + ] + + for key in keys: + np.testing.assert_allclose( + coef_fixture.by_indices(**key), + coef_fixture.asnumpy()[..., basis.indices_mask(**key)], + ) + + +def test_coef_evalute(coef_fixture, basis): + """ + Test evaluate pass through. + """ + np.testing.assert_allclose( + coef_fixture.evaluate(), + basis.evaluate(coef_fixture), + rtol=1e-05, + atol=utest_tolerance(basis.dtype), + ) + + +def test_coef_rotate(coef_fixture, basis, rots): + """ + Test rotation pass through. + """ + + # Refl + refl = ( + np.random.rand(coef_fixture.stack_size).reshape(coef_fixture.stack_shape) > 0.5 + ) # Random bool + + np.testing.assert_allclose( + coef_fixture.rotate(rots), basis.rotate(coef_fixture, rots) + ) + + np.testing.assert_allclose( + coef_fixture.rotate(rots, refl), basis.rotate(coef_fixture, rots, refl) + ) + + +# Test related Basis Coef checks. +# These are easier to test here via parameterization. +def test_evaluate_incorrect_type(coef_fixture, basis): + """ + Test that evaluate raises when passed non Coef type. + """ + with pytest.raises(TypeError, match=r".*should be passed a `Coef`.*"): + # Pass something that is not a Coef, eg Numpy array. + basis.evaluate(coef_fixture.asnumpy()) + + +def test_to_real_incorrect_type(coef_fixture, basis): + """ + Test to_real conversion raises on non `Coef` type. + """ + # Convert Coef to complex, then to Numpy. + x = basis.to_complex(coef_fixture).asnumpy() + + # Call to_real with Numpy array + with pytest.raises(TypeError, match=r".*should be instance of `Coef`.*"): + _ = basis.to_real(x) + + +def test_to_complex_incorrect_type(coef_fixture, basis): + """ + Test to_complex conversion raises on non `Coef` type. + """ + # Convert Coef to Numpy. + x = coef_fixture.asnumpy() + + # Call to_complex with Numpy array + with pytest.raises(TypeError, match=r".*should be instance of `Coef`.*"): + _ = basis.to_complex(x) + + +def test_real_complex_real_roundtrip(coef_fixture, basis): + rcoef = basis.to_real(basis.to_complex(coef_fixture)) + + np.testing.assert_allclose(rcoef, coef_fixture, rtol=1e-05, atol=1e-08) + + +def test_complex_evaluate(coef_fixture): + """ + Confirm using `ComplexCoef.evaluate` is equivalent to `Coef.evaluate`. + """ + + # Create a ComplexCoef + complex_coef = coef_fixture.to_complex() + + # Compare + np.testing.assert_allclose( + complex_coef.evaluate(), + coef_fixture.evaluate(), + rtol=1e-05, + atol=utest_tolerance(coef_fixture.basis.dtype), + ) + + +def test_complex_rotate(coef_fixture, rots): + """ + Confirm using `ComplexCoef.rotate` is equivalent to `Coef.rotate`. + """ + # Create a ComplexCoef + complex_coef = coef_fixture.to_complex() + + # Compare + np.testing.assert_allclose( + complex_coef.rotate(rots), + coef_fixture.rotate(rots).to_complex(), + rtol=1e-05, + atol=utest_tolerance(coef_fixture.basis.dtype), + ) + + +def test_shifts(coef_fixture, basis, rots): + """ + Confirm using `Coef.shift` is equivalent to `basis.shift`. + """ + if coef_fixture.stack_ndim > 1: + pytest.xfail(reason="Shifts currently only support 1d stack axis.") + + # Create some shifts, by reusing the `rots` array. + shifts = np.column_stack((rots, rots[::-1])) + + # Compare + np.testing.assert_allclose( + coef_fixture.shift(shifts), + basis.shift(coef_fixture, shifts), + rtol=1e-05, + atol=utest_tolerance(basis.dtype), + ) + + +def test_complex_shift(coef_fixture, rots): + """ + Confirm using `ComplexCoef.shift` is equivalent to `Coef.shift`. + """ + if coef_fixture.stack_ndim > 1: + pytest.xfail(reason="Shifts currently only support 1d stack axis.") + + # Create a ComplexCoef + complex_coef = coef_fixture.to_complex() + + # Create some shifts, by reusing the `rots` array. + shifts = np.column_stack((rots, rots[::-1])) + + # Compare + np.testing.assert_allclose( + complex_coef.shift(shifts), + coef_fixture.shift(shifts).to_complex(), + rtol=1e-05, + atol=utest_tolerance(coef_fixture.basis.dtype), + ) diff --git a/tests/test_coor_trans.py b/tests/test_coor_trans.py index 6da99cbe9b..70d4dc37a4 100644 --- a/tests/test_coor_trans.py +++ b/tests/test_coor_trans.py @@ -10,6 +10,7 @@ get_aligned_rotations, grid_2d, grid_3d, + mean_aligned_angular_distance, register_rotations, uniform_random_angles, ) @@ -335,3 +336,18 @@ def testCrop3DFillValue(self): a = np.ones((4, 4, 3)) b = crop_pad_3d(a, 4, fill_value=-1) self.assertTrue(np.array_equal(b[:, :, 0], -1 * np.ones((4, 4)))) + + +def test_mean_aligned_angular_distance(): + n_rots = 10 + dtype = np.float32 + rots_gt = Rotation.generate_random_rotations(n_rots, dtype=dtype).matrices + + # Create a set of rotations that can be exactly globally aligned to rots_gt. + rots_est = rots_gt[0] @ rots_gt + + # Check that the mean angular distance is zero degrees. + np.testing.assert_allclose(mean_aligned_angular_distance(rots_est, rots_gt), 0.0) + + # Test internal assert using the `degree_tol` argument. + mean_aligned_angular_distance(rots_est, rots_gt, degree_tol=0.1) diff --git a/tests/test_covar2d.py b/tests/test_covar2d.py index b8e223d92f..05e0eda509 100644 --- a/tests/test_covar2d.py +++ b/tests/test_covar2d.py @@ -1,259 +1,329 @@ import os import os.path -from unittest import TestCase import numpy as np -from parameterized import parameterized +import pytest from pytest import raises from aspire.basis import FFBBasis2D from aspire.covariance import RotCov2D from aspire.noise import WhiteNoiseAdder from aspire.operators import RadialCTFFilter -from aspire.source.simulation import Simulation -from aspire.utils import utest_tolerance +from aspire.source.simulation import _LegacySimulation +from aspire.utils import randi, utest_tolerance from aspire.volume import Volume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") -class Cov2DTestCase(TestCase): +IMG_SIZES = [8] +DTYPES = [np.float32] +# Basis used in FSPCA for class averaging. +BASIS = [ + FFBBasis2D, +] + +# Hard coded to match legacy files. +NOISE_VAR = 1.3957e-4 + +# Cover `test_shrinkage` +SHRINKERS = [None, "frobenius_norm", "operator_norm", "soft_threshold"] + +CTF_ENABLED = [True, False] + + +@pytest.fixture(params=CTF_ENABLED, ids=lambda x: f"ctf={x}") +def ctf_enabled(request): + return request.param + + +@pytest.fixture(params=SHRINKERS, ids=lambda x: f"shrinker={x}") +def shrinker(request): + return request.param + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +def dtype(request): + return request.param + + +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}") +def img_size(request): + return request.param + + +@pytest.fixture +def volume(dtype, img_size): + # Get a volume + v = Volume( + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype(dtype) + ) + # 1e3 is hardcoded to match legacy test files. + return v.downsample(img_size) * 1.0e3 + + +@pytest.fixture(params=BASIS, ids=lambda x: f"basis={x}") +def basis(request, img_size, dtype): + cls = request.param + # Setup a Basis + basis = cls(img_size, dtype=dtype) + return basis + + +@pytest.fixture +def cov2d_fixture(volume, basis, ctf_enabled): """ - Cov2D Test without CTFFilters populated. + Cov2D Test Fixture. """ + n = 32 + # Default CTF params unique_filters = None h_idx = None h_ctf_fb = None + # Popluate CTF + if ctf_enabled: + unique_filters = [ + RadialCTFFilter( + 5.0 * 65 / volume.resolution, 200, defocus=d, Cs=2.0, alpha=0.1 + ) + for d in np.linspace(1.5e4, 2.5e4, 7) + ] - # These class variables support parameterized arg checking in `testShrinkers` - shrinkers = [(None,), "frobenius_norm", "operator_norm", "soft_threshold"] - bad_shrinker_inputs = ["None", "notashrinker", ""] + # Copied from simulation defaults to match legacy test files. + h_idx = randi(len(unique_filters), n, seed=0) - 1 - def setUp(self): - self.dtype = np.float32 + h_ctf_fb = [basis.filter_to_basis_mat(f) for f in unique_filters] - self.L = L = 8 - n = 32 + noise_adder = WhiteNoiseAdder(var=NOISE_VAR) - self.noise_var = 1.3957e-4 - noise_adder = WhiteNoiseAdder(var=self.noise_var) + sim = _LegacySimulation( + n=n, + vols=volume, + unique_filters=unique_filters, + filter_indices=h_idx, + offsets=0.0, + amplitudes=1.0, + dtype=volume.dtype, + noise_adder=noise_adder, + ) + sim.cache() - vols = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype( - self.dtype - ) - ) # RCOPT - vols = vols.downsample(L) * 1.0e3 - # Since FFBBasis2D doesn't yet implement dtype, we'll set this to double to match its built in types. - self.sim = Simulation( - n=n, - L=L, - vols=vols, - unique_filters=self.unique_filters, - offsets=0.0, - amplitudes=1.0, - dtype=self.dtype, - noise_adder=noise_adder, - ) + cov2d = RotCov2D(basis) + coef_clean = basis.evaluate_t(sim.projections[:]) + coef = basis.evaluate_t(sim.images[:]) - self.basis = FFBBasis2D((L, L), dtype=self.dtype) + return sim, cov2d, coef_clean, coef, h_ctf_fb, h_idx - self.imgs_clean = self.sim.projections[:] - self.imgs_ctf_clean = self.sim.clean_images[:] - self.imgs_ctf_noise = self.sim.images[:n] - self.cov2d = RotCov2D(self.basis) - self.coeff_clean = self.basis.evaluate_t(self.imgs_clean) - self.coeff = self.basis.evaluate_t(self.imgs_ctf_noise) +def test_get_mean(cov2d_fixture): + results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_cov2d_mean.npy")) + cov2d, coef_clean = cov2d_fixture[1], cov2d_fixture[2] - def tearDown(self): - pass + mean_coef = cov2d._get_mean(coef_clean.asnumpy()) + np.testing.assert_allclose(results, mean_coef, atol=utest_tolerance(cov2d.dtype)) - def testGetMean(self): - results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_cov2d_mean.npy")) - mean_coeff = self.cov2d._get_mean(self.coeff_clean) - self.assertTrue(np.allclose(results, mean_coeff)) - def testGetCovar(self): - results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"), - allow_pickle=True, - ) - covar_coeff = self.cov2d._get_covar(self.coeff_clean) - - for im, mat in enumerate(results.tolist()): - self.assertTrue(np.allclose(mat, covar_coeff[im])) - - def testGetMeanCTF(self): - """ - Compare `get_mean` (no CTF args) with `_get_mean` (no CTF model). - """ - mean_coeff_ctf = self.cov2d.get_mean(self.coeff, self.h_ctf_fb, self.h_idx) - mean_coeff = self.cov2d._get_mean(self.coeff_clean) - self.assertTrue(np.allclose(mean_coeff_ctf, mean_coeff, atol=0.002)) - - def testGetCWFCoeffsClean(self): - results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff_clean.npy") - ) - coeff_cwf_clean = self.cov2d.get_cwf_coeffs(self.coeff_clean, noise_var=0) - self.assertTrue( - np.allclose(results, coeff_cwf_clean, atol=utest_tolerance(self.dtype)) - ) +def test_get_covar(cov2d_fixture): + results = np.load( + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"), + allow_pickle=True, + ) - def testGetCWFCoeffsCleanCTF(self): - """ - Test case of clean images (coeff_clean and noise_var=0) - while using a non Identity CTF. + cov2d, coef_clean = cov2d_fixture[1], cov2d_fixture[2] - This case may come up when a developer switches between - clean and dirty images. - """ + covar_coef = cov2d._get_covar(coef_clean.asnumpy()) - coeff_cwf = self.cov2d.get_cwf_coeffs( - self.coeff_clean, self.h_ctf_fb, self.h_idx, noise_var=0 - ) + for im, mat in enumerate(results.tolist()): + np.testing.assert_allclose(mat, covar_coef[im], rtol=1e-05) - # Reconstruct images from output of get_cwf_coeffs - img_est = self.basis.evaluate(coeff_cwf) - # Compare with clean images - delta = np.mean(np.square((self.imgs_clean - img_est).asnumpy())) - self.assertTrue(delta < 0.02) - - # Note, parameterized module can be removed at a later date - # and replaced with pytest if ASPIRE-Python moves away from - # the TestCase class style tests. - # Paramaterize over known shrinkers and some bad values - @parameterized.expand(shrinkers + bad_shrinker_inputs) - def testShrinkers(self, shrinker): - """Test all the shrinkers we know about run without crashing, - and check we raise with specific message for unsupporting shrinker arg.""" - - if shrinker in self.bad_shrinker_inputs: - with raises(AssertionError, match="Unsupported shrink method"): - _ = self.cov2d.get_covar( - self.coeff_clean, covar_est_opt={"shrinker": shrinker} - ) - return - - results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"), - allow_pickle=True, - ) - covar_coeff = self.cov2d.get_covar( - self.coeff_clean, covar_est_opt={"shrinker": shrinker} - ) +def test_get_mean_ctf(cov2d_fixture, ctf_enabled): + """ + Compare `get_mean` (no CTF args) with `_get_mean` (no CTF model). + """ + sim, cov2d, coef_clean, coef, h_ctf_fb, h_idx = cov2d_fixture + + mean_coef_ctf = cov2d.get_mean(coef, h_ctf_fb, h_idx) + + tol = utest_tolerance(sim.dtype) + if ctf_enabled: + result = np.load(os.path.join(DATA_DIR, "clean70SRibosome_cov2d_meanctf.npy")) + else: + result = cov2d._get_mean(coef_clean.asnumpy()) + tol = 0.002 + + np.testing.assert_allclose(mean_coef_ctf.asnumpy()[0], result, atol=tol) - for im, mat in enumerate(results.tolist()): - self.assertTrue( - np.allclose(mat, covar_coeff[im], atol=utest_tolerance(self.dtype)) - ) +def test_get_cwf_coefs_clean(cov2d_fixture): + results = np.load( + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coef_clean.npy") + ) -class Cov2DTestCaseCTF(Cov2DTestCase): + cov2d, coef_clean = cov2d_fixture[1], cov2d_fixture[2] + + coef_cwf_clean = cov2d.get_cwf_coefs(coef_clean, noise_var=0) + np.testing.assert_allclose( + results, coef_cwf_clean, atol=utest_tolerance(cov2d.dtype) + ) + + +def test_get_cwf_coefs_clean_ctf(cov2d_fixture): """ - Cov2D Test with CTFFilters populated. + Test case of clean images (coef_clean and noise_var=0) + while using a non Identity CTF. + + This case may come up when a developer switches between + clean and dirty images. """ + sim, cov2d, coef_clean, _, h_ctf_fb, h_idx = cov2d_fixture - @property - def unique_filters(self): - return [ - RadialCTFFilter(5.0 * 65 / self.L, 200, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(1.5e4, 2.5e4, 7) - ] + coef_cwf = cov2d.get_cwf_coefs(coef_clean, h_ctf_fb, h_idx, noise_var=0) - @property - def h_idx(self): - return self.sim.filter_indices + # Reconstruct images from output of get_cwf_coefs + img_est = cov2d.basis.evaluate(coef_cwf) + # Compare with clean images + delta = np.mean(np.square((sim.clean_images[:] - img_est).asnumpy())) + np.testing.assert_array_less(delta, 0.01) - @property - def h_ctf_fb(self): - return [filt.fb_mat(self.basis) for filt in self.unique_filters] - def testGetCWFCoeffsCTFargs(self): - """ - Test we raise when user supplies incorrect CTF arguments, - and that the error message matches. - """ +def test_shrinker_inputs(cov2d_fixture): + """ + Check we raise with specific message for unsupporting shrinker arg. + """ + cov2d, coef_clean = cov2d_fixture[1], cov2d_fixture[2] - with raises(RuntimeError, match=r".*Given ctf_fb.*"): - _ = self.cov2d.get_cwf_coeffs( - self.coeff, self.h_ctf_fb, None, noise_var=self.noise_var - ) + bad_shrinker_inputs = ["None", "notashrinker", ""] - def testGetMeanCTF(self): - """ - Compare `get_mean` with saved legacy cov2d results. - """ - results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_cov2d_meanctf.npy")) - mean_coeff_ctf = self.cov2d.get_mean(self.coeff, self.h_ctf_fb, self.h_idx) - self.assertTrue(np.allclose(results, mean_coeff_ctf)) - - def testGetCWFCoeffs(self): - """ - Tests `get_cwf_coeffs` with poulated CTF. - """ - results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff.npy") - ) - coeff_cwf = self.cov2d.get_cwf_coeffs( - self.coeff, self.h_ctf_fb, self.h_idx, noise_var=self.noise_var - ) - self.assertTrue( - np.allclose(results, coeff_cwf, atol=utest_tolerance(self.dtype)) + for shrinker in bad_shrinker_inputs: + with raises(AssertionError, match="Unsupported shrink method"): + _ = cov2d.get_covar(coef_clean, covar_est_opt={"shrinker": shrinker}) + + +def test_shrinkage(cov2d_fixture, shrinker): + """ + Test all the shrinkers we know about run without crashing, + """ + cov2d, coef_clean = cov2d_fixture[1], cov2d_fixture[2] + + results = np.load( + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"), + allow_pickle=True, + ) + + covar_coef = cov2d.get_covar(coef_clean, covar_est_opt={"shrinker": shrinker}) + + for im, mat in enumerate(results.tolist()): + np.testing.assert_allclose( + mat, covar_coef[im], atol=utest_tolerance(cov2d.dtype) ) + +def test_get_cwf_coefs_ctf_args(cov2d_fixture): + """ + Test we raise when user supplies incorrect CTF arguments, + and that the error message matches. + """ + sim, cov2d, _, coef, h_ctf_fb, _ = cov2d_fixture + + # When half the ctf info (h_ctf_fb) is populated, + # set the other ctf param (h_idx) to none. + h_idx = None + if h_ctf_fb is None: + # And when h_ctf_fb is None, we'll populate the other half (h_idx) + h_idx = sim.filter_indices + + # Both the above situations should be an error. + with raises(RuntimeError, match=r".*Given ctf_.*"): + _ = cov2d.get_cwf_coefs(coef, h_ctf_fb, h_idx, noise_var=NOISE_VAR) + + +def test_get_cwf_coefs(cov2d_fixture, ctf_enabled): + """ + Tests `get_cwf_coefs` with poulated CTF. + """ + _, cov2d, coef_clean, coef, h_ctf_fb, h_idx = cov2d_fixture + + # Hard coded file expects sim with ctf. + if not ctf_enabled: + pytest.skip(reason="Reference file n/a.") + + results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coef.npy")) + + coef_cwf = cov2d.get_cwf_coefs(coef, h_ctf_fb, h_idx, noise_var=NOISE_VAR) + + np.testing.assert_allclose(results, coef_cwf, atol=utest_tolerance(cov2d.dtype)) + + +def test_get_cwf_coefs_without_ctf_args(cov2d_fixture, ctf_enabled): + """ + Tests `get_cwf_coefs` with poulated CTF. + """ + + _, cov2d, _, coef, _, _ = cov2d_fixture + + # Hard coded file expects sim with ctf. + if not ctf_enabled: + pytest.skip(reason="Reference file n/a.") + # Note, I think this file is incorrectly named... # It appears to have come from operations on images with ctf applied. - def testGetCWFCoeffsNoCTF(self): - """ - Tests `get_cwf_coeffs` without providing CTF. (Internally uses IdentityCTF). - """ - results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff_noCTF.npy") - ) - coeff_cwf_noCTF = self.cov2d.get_cwf_coeffs( - self.coeff, noise_var=self.noise_var - ) + results = np.load( + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coef_noCTF.npy") + ) - self.assertTrue( - np.allclose(results, coeff_cwf_noCTF, atol=utest_tolerance(self.dtype)) - ) + coef_cwf = cov2d.get_cwf_coefs(coef, noise_var=NOISE_VAR) - def testGetCovarCTF(self): - results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf.npy"), - allow_pickle=True, - ) - covar_coeff_ctf = self.cov2d.get_covar( - self.coeff, self.h_ctf_fb, self.h_idx, noise_var=self.noise_var - ) - for im, mat in enumerate(results.tolist()): - self.assertTrue(np.allclose(mat, covar_coeff_ctf[im])) + np.testing.assert_allclose(results, coef_cwf, atol=utest_tolerance(cov2d.dtype)) - def testGetCovarCTFShrink(self): - results = np.load( - os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf_shrink.npy"), - allow_pickle=True, - ) - covar_opt = { - "shrinker": "frobenius_norm", - "verbose": 0, - "max_iter": 250, - "iter_callback": [], - "store_iterates": False, - "rel_tolerance": 1e-12, - "precision": self.dtype, - } - covar_coeff_ctf_shrink = self.cov2d.get_covar( - self.coeff, - self.h_ctf_fb, - self.h_idx, - noise_var=self.noise_var, - covar_est_opt=covar_opt, - ) - for im, mat in enumerate(results.tolist()): - self.assertTrue(np.allclose(mat, covar_coeff_ctf_shrink[im])) +def test_get_covar_ctf(cov2d_fixture, ctf_enabled): + # Hard coded file expects sim with ctf. + if not ctf_enabled: + pytest.skip(reason="Reference file n/a.") + + sim, cov2d, _, coef, h_ctf_fb, h_idx = cov2d_fixture + + results = np.load( + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf.npy"), + allow_pickle=True, + ) + + covar_coef_ctf = cov2d.get_covar(coef, h_ctf_fb, h_idx, noise_var=NOISE_VAR) + for im, mat in enumerate(results.tolist()): + np.testing.assert_allclose(mat, covar_coef_ctf[im], rtol=1e-05, atol=1e-08) + + +def test_get_covar_ctf_shrink(cov2d_fixture, ctf_enabled): + sim, cov2d, _, coef, h_ctf_fb, h_idx = cov2d_fixture + + # Hard coded file expects sim with ctf. + if not ctf_enabled: + pytest.skip(reason="Reference file n/a.") + + results = np.load( + os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf_shrink.npy"), + allow_pickle=True, + ) + + covar_opt = { + "shrinker": "frobenius_norm", + "verbose": 0, + "max_iter": 250, + "iter_callback": [], + "store_iterates": False, + "rel_tolerance": 1e-12, + "precision": cov2d.dtype, + } + + covar_coef_ctf_shrink = cov2d.get_covar( + coef, + h_ctf_fb, + h_idx, + noise_var=NOISE_VAR, + covar_est_opt=covar_opt, + ) + + for im, mat in enumerate(results.tolist()): + np.testing.assert_allclose(mat, covar_coef_ctf_shrink[im]) diff --git a/tests/test_covar2d_denoiser.py b/tests/test_covar2d_denoiser.py index 77883db7ea..a403a72109 100644 --- a/tests/test_covar2d_denoiser.py +++ b/tests/test_covar2d_denoiser.py @@ -1,44 +1,226 @@ -from unittest import TestCase - import numpy as np +import pytest -from aspire.basis.ffb_2d import FFBBasis2D -from aspire.denoising.denoiser_cov2d import DenoiserCov2D +from aspire.basis import FBBasis2D, FFBBasis2D, FLEBasis2D, FPSWFBasis2D, PSWFBasis2D +from aspire.denoising import DenoisedSource, DenoiserCov2D from aspire.noise import WhiteNoiseAdder -from aspire.operators.filters import RadialCTFFilter -from aspire.source.simulation import Simulation - - -class BatchedRotCov2DTestCase(TestCase): - def testMSE(self): - # need larger numbers of images and higher resolution for good MSE - dtype = np.float32 - img_size = 64 - num_imgs = 1024 - noise_var = 0.1848 - noise_adder = WhiteNoiseAdder(var=noise_var) - filters = [ - RadialCTFFilter(5, 200, defocus=d, Cs=2.0, alpha=0.1) - for d in np.linspace(1.5e4, 2.5e4, 7) - ] - # set simulation object - sim = Simulation( - L=img_size, - n=num_imgs, - unique_filters=filters, - offsets=0.0, - amplitudes=1.0, - dtype=dtype, - noise_adder=noise_adder, - ) - imgs_clean = sim.projections[:] - - # Specify the fast FB basis method for expending the 2D images - ffbbasis = FFBBasis2D((img_size, img_size), dtype=dtype) - denoiser = DenoiserCov2D(sim, ffbbasis, noise_var) - denoised_src = denoiser.denoise(batch_size=64) - imgs_denoised = denoised_src.images[:] - # Calculate the normalized RMSE of the estimated images. - nrmse_ims = (imgs_denoised - imgs_clean).norm() / imgs_clean.norm() - - self.assertTrue(nrmse_ims < 0.25) +from aspire.operators import IdentityFilter, RadialCTFFilter +from aspire.source import Simulation + +# TODO, parameterize these further. +dtype = np.float32 +img_size = 32 +num_imgs = 1024 +noise_var = 0.1848 +noise_adder = WhiteNoiseAdder(var=noise_var) +filters = [ + RadialCTFFilter(5, 200, defocus=d, Cs=2.0, alpha=0.1) + for d in np.linspace(1.5e4, 2.5e4, 7) +] +BASIS = [ + pytest.param(FBBasis2D, marks=pytest.mark.expensive), + FFBBasis2D, + FLEBasis2D, + pytest.param(PSWFBasis2D, marks=pytest.mark.expensive), + FPSWFBasis2D, +] + + +@pytest.fixture(params=BASIS, scope="module", ids=lambda x: f"basis={x}") +def basis(request): + """ + Construct and return a 2D Basis. + """ + cls = request.param + return cls(img_size, dtype=dtype) + + +@pytest.fixture(scope="module") +def sim(): + """Create a reusable Simulation source.""" + sim = Simulation( + L=img_size, + n=num_imgs, + unique_filters=filters, + offsets=0.0, + amplitudes=1.0, + dtype=dtype, + noise_adder=noise_adder, + ) + sim = sim.cache() + return sim + + +@pytest.fixture(scope="module") +def coef(sim, basis): + """Generate small set of reference coefficients.""" + return basis.expand(sim.images[:3]) + + +def test_batched_rotcov2d_MSE(sim, basis): + """ + Check calling `DenoiserCov2D` via `DenoiserSource` framework yields acceptable error. + """ + # Smoke test reference values (chosen by experimentation). + refs = { + "FBBasis2D": 0.23, + "FFBBasis2D": 0.23, + "PSWFBasis2D": 0.76, + "FPSWFBasis2D": 0.76, + "FLEBasis2D": 0.52, + } + + # need larger numbers of images and higher resolution for good MSE + imgs_clean = sim.projections[:] + + # Specify the fast FB basis method for expending the 2D images + denoiser = DenoiserCov2D(sim, basis, noise_var) + imgs_denoised = denoiser.denoise[:] + + # Calculate the normalized RMSE of the estimated images. + nrmse_ims = (imgs_denoised - imgs_clean).norm() / imgs_clean.norm() + ref = refs[basis.__class__.__name__] + np.testing.assert_array_less( + nrmse_ims, + ref, + err_msg=f"Comparison failed for {basis}. Achieved: {nrmse_ims} expected: {ref}.", + ) + + # Additionally test the `DenoisedSource` and lazy-eval-cache + # of the cov2d estimator. + src = DenoisedSource(sim, denoiser) + np.testing.assert_allclose(imgs_denoised, src.images[:], rtol=1e-05, atol=1e-08) + + +def test_source_mismatch(sim, basis): + """ + Assert mismatched sources raises an error. + """ + # Create a denoiser. + denoiser = DenoiserCov2D(sim, basis, noise_var) + + # Create a different source. + src2 = sim[: sim.n - 1] + + # Raise because src2 not identical to denoiser.src (sim) + with pytest.raises(NotImplementedError, match=r".*must match.*"): + _ = DenoisedSource(src2, denoiser) + + +def test_filter_to_basis_mat_id(coef, basis): + """ + Test `basis.filter_to_basis_mat` operator performance against + manual sequence of evaluate->filter->expand for `IdentifyFilter`. + """ + + refs = { + "FBBasis2D": 0.025, + "FFBBasis2D": 3e-6, + "PSWFBasis2D": 0.14, + "FPSWFBasis2D": 0.14, + "FLEBasis2D": 4e-7, + } + + # IdentityFilter should produce id + filt = IdentityFilter() + + # Apply the basis filter operator. + # Note transpose because `apply` expects and returns column vectors. + coef_ftbm = (basis.filter_to_basis_mat(filt) @ coef.asnumpy().T).T + + # Apply evaluate->filter->expand manually + imgs = coef.evaluate() + imgs_manual = imgs.filter(filt) + coef_manual = basis.expand(imgs_manual) + + # Compare coefs from using ftbm operator with coef from eval->filter->exp + rms = np.sqrt(np.mean(np.square(coef_ftbm - coef_manual))) + ref = refs[basis.__class__.__name__] + np.testing.assert_array_less( + rms, + ref, + err_msg=f"Comparison failed for {basis}. Achieved: {rms} expected: {ref}", + ) + + +def test_filter_to_basis_mat_ctf(coef, basis): + """ + Test `basis.filter_to_basis_mat` operator performance against + manual sequence of evaluate->filter->expand for `RadialCTFFilter`. + """ + + refs = { + "FBBasis2D": 0.025, + "FFBBasis2D": 0.35, + "PSWFBasis2D": 0.11, + "FPSWFBasis2D": 0.11, + "FLEBasis2D": 0.4, + } + + # Create a RadialCTFFilter + filt = RadialCTFFilter(pixel_size=1) + + # Apply the basis filter operator. + # Note transpose because `apply` expects and returns column vectors. + coef_ftbm = (basis.filter_to_basis_mat(filt, truncate=False) @ coef.asnumpy().T).T + + # Apply evaluate->filter->expand manually + imgs = coef.evaluate() + imgs_manual = imgs.filter(filt) + coef_manual = basis.expand(imgs_manual) + + # Compare coefs from using ftbm operator with coef from eval->filter->exp + rms = np.sqrt(np.mean(np.square(coef_ftbm - coef_manual))) + ref = refs[basis.__class__.__name__] + np.testing.assert_array_less( + rms, + ref, + err_msg=f"Comparison failed for {basis}. Achieved: {rms} expected: {ref}", + ) + + +def test_filter_to_basis_mat_id_expand(coef, basis): + """ + Test `basis.filter_to_basis_mat` operator performance using slower + `expand` method against manual sequence of + evaluate->filter->expand for `IdentifyFilter`. + """ + + refs = { + "FBBasis2D": 4e-7, + "PSWFBasis2D": 5e-6, + "FPSWFBasis2D": 5e-6, + } + + # IdentityFilter should produce id + filt = IdentityFilter() + + # Some basis do not provide alternative `method`s + if isinstance(basis, FFBBasis2D) or isinstance(basis, FLEBasis2D): + with pytest.raises(NotImplementedError, match=r".*not supported.*"): + _ = basis.filter_to_basis_mat(filt, method="expand") + return + + # Apply the basis filter operator. + # Note transpose because `apply` expects and returns column vectors. + coef_ftbm = (basis.filter_to_basis_mat(filt, method="expand") @ coef.asnumpy().T).T + + # Apply evaluate->filter->expand manually + imgs = coef.evaluate() + imgs_manual = imgs.filter(filt) + coef_manual = basis.expand(imgs_manual) + + # Compare coefs from using ftbm operator with coef from eval->filter->exp + rms = np.sqrt(np.mean(np.square(coef_ftbm - coef_manual))) + ref = refs[basis.__class__.__name__] + np.testing.assert_array_less( + rms, + ref, + err_msg=f"Comparison failed for {basis}. Achieved: {rms} expected: {ref}", + ) + + +def test_filter_to_basis_mat_bad(coef, basis): + filt = IdentityFilter() + with pytest.raises(NotImplementedError, match=r".*not supported.*"): + _ = basis.filter_to_basis_mat(filt, method="bad_method") diff --git a/tests/test_covar3d.py b/tests/test_covar3d.py index 1d3dadbc71..9bf04b4594 100644 --- a/tests/test_covar3d.py +++ b/tests/test_covar3d.py @@ -12,7 +12,7 @@ from aspire.denoising import src_wiener_coords from aspire.operators import RadialCTFFilter from aspire.reconstruction import MeanEstimator -from aspire.source.simulation import Simulation +from aspire.source.simulation import _LegacySimulation from aspire.utils import eigs from aspire.utils.random import Random from aspire.volume import LegacyVolume, Volume @@ -25,7 +25,7 @@ class Covar3DTestCase(TestCase): def setUpClass(cls): cls.dtype = np.float32 cls.vols = LegacyVolume(L=8, dtype=cls.dtype).generate() - cls.sim = Simulation( + cls.sim = _LegacySimulation( n=1024, vols=cls.vols, unique_filters=[ diff --git a/tests/test_diag_matrix.py b/tests/test_diag_matrix.py index e883d7f55f..ecce899105 100644 --- a/tests/test_diag_matrix.py +++ b/tests/test_diag_matrix.py @@ -139,7 +139,6 @@ def test_dtype_mismatch(): _ = d1 + d2 -# Explicit Tests (non parameterized). def test_dtype_passthrough(): """ Test that the datatype is inferred correctly. @@ -611,7 +610,8 @@ def test_apply(diag_matrix_fixture): """ d1, _, d_np = diag_matrix_fixture - x = d1.apply(d_np) + # Apply is used on column vectors, transpose. + x = d1.apply(d_np.T).T np.testing.assert_allclose(x, d_np[0][None, :] * d_np) @@ -623,9 +623,10 @@ def test_rapply(diag_matrix_fixture): d1, _, d_np = diag_matrix_fixture - x = d1.rapply(d_np) + # Apply is used on column vectors, transpose. + x = d1.rapply(d_np.T) - np.testing.assert_allclose(x, (d_np * d_np[0])) + np.testing.assert_allclose(x.T, (d_np * d_np[0])) def test_solve(diag_matrix_fixture): diff --git a/tests/test_docstring_checker.py b/tests/test_docstring_checker.py new file mode 100644 index 0000000000..514d06fbbe --- /dev/null +++ b/tests/test_docstring_checker.py @@ -0,0 +1,31 @@ +import logging +import os + +from docs import check_docstrings + + +def test_check_blank_line(caplog): + test_string = os.path.join( + os.path.dirname(__file__), "saved_test_data", "sample_docstrings.py" + ) + + caplog.clear() + caplog.set_level(logging.ERROR) + error_count = check_docstrings.check_blank_line_above_param_section(test_string) + + # Line numbers of good and bad docstrings in sample_docstrings.py + good_doc_line_nums = [2, 16, 25, 35] + bad_doc_line_nums = [43, 53, 65] + + # Check that good docstrings do not log error + for line_num in good_doc_line_nums: + msg = f"sample_docstrings.py: {line_num}: Must have exactly 1 blank line" + assert msg not in caplog.text + + # Check that bad docstrings log error + for line_num in bad_doc_line_nums: + msg = f"sample_docstrings.py: {line_num}: Must have exactly 1 blank line" + assert msg in caplog.text + + # Check total error count log + assert error_count == len(bad_doc_line_nums) diff --git a/tests/test_filters.py b/tests/test_filters.py index 40b0fb2a9d..35d7955a9e 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,9 +1,12 @@ +import logging import os.path from unittest import TestCase import numpy as np +import pytest from aspire.operators import ( + ArrayFilter, CTFFilter, FunctionFilter, IdentityFilter, @@ -327,3 +330,34 @@ def testFilterSigns(self): signs = np.sign(ctf_filter.evaluate(self.omega)) sign_filter = ctf_filter.sign self.assertTrue(np.allclose(sign_filter.evaluate(self.omega), signs)) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_power_filter_safeguard(dtype, caplog): + L = 25 + arr = np.ones((L, L), dtype=dtype) + + # Set a few values below machine epsilon. + num_eps = 3 + eps = np.finfo(dtype).eps + arr[L // 2, L // 2 : L // 2 + num_eps] = eps / 2 + + # For negative powers, values below machine eps will be set to zero. + filt = PowerFilter( + filter=ArrayFilter(arr), + power=-0.5, + ) + + caplog.clear() + caplog.set_level(logging.WARN) + filt_vals = filt.evaluate_grid(L, dtype=dtype) + + # Check that extreme values are set to zero. + ref = np.ones((L, L), dtype=dtype) + ref[L // 2, L // 2 : L // 2 + num_eps] = 0 + + np.testing.assert_array_equal(filt_vals, ref) + + # Check caplog for warning. + msg = f"setting {num_eps} extremal filter value(s) to zero." + assert msg in caplog.text diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 476e2f5548..282089ce12 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -145,9 +145,10 @@ def test_frc_img_plot(image_fixture): _ = img_a.frc(img_n, pixel_size=1, cutoff=0.143, plot=True) # Plot to file + # Also tests `cutoff=None` with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "img_frc_curve.png") - img_a.frc(img_n, pixel_size=1, cutoff=0.143, plot=file_path) + img_a.frc(img_n, pixel_size=1, cutoff=None, plot=file_path) assert os.path.exists(file_path) @@ -204,9 +205,10 @@ def test_fsc_vol_plot(volume_fixture): _ = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, plot=True) # Plot to file + # Also tests `cutoff=None` with tempfile.TemporaryDirectory() as tmp_input_dir: - file_path = os.path.join(tmp_input_dir, "img_fsc_curve.png") - vol_a.fsc(vol_b, pixel_size=1, cutoff=0.143, plot=file_path) + file_path = os.path.join(tmp_input_dir, "vol_fsc_curve.png") + vol_a.fsc(vol_b, pixel_size=1, cutoff=None, plot=file_path) assert os.path.exists(file_path) diff --git a/tests/test_matrix.py b/tests/test_matrix.py index e0af0e3954..728200b9b3 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -2,15 +2,21 @@ from unittest import TestCase import numpy as np +import pytest from aspire.utils import ( + Rotation, best_rank1_approximation, fix_signs, im_to_vec, mat_to_vec, + mean_aligned_angular_distance, + nearest_rotations, + randn, roll_dim, symmat_to_vec_iso, unroll_dim, + utest_tolerance, vec_to_im, vec_to_symmat, vec_to_symmat_iso, @@ -342,3 +348,80 @@ def testFixSigns(self): x[:, 3] = 0 y[:, 3] = 0 self.assertTrue(np.allclose(fix_signs(x), y)) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_nearest_rotations(dtype): + n_rots = 5 + rots = Rotation.generate_random_rotations(n_rots, seed=0, dtype=dtype).matrices + + # Add some noise to the rotations. + noise = 1e-3 * randn(n_rots * 9, seed=0).astype(dtype, copy=False).reshape( + n_rots, 3, 3 + ) + noisy_rots = rots + noise + + # Find nearest rotations for stack. + nearest_rots = nearest_rotations(noisy_rots) + + # Check that estimates are rotation matrices. + _is_rotation(nearest_rots, dtype) + + # Check that estimates are close to original rotations. + mean_aligned_angular_distance(rots, nearest_rots, degree_tol=1) + + # Check dtype pass-through. + assert nearest_rots.dtype == dtype + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_nearest_rotations_reflection(dtype): + # Generate singleton rotation. + rot = Rotation.generate_random_rotations(1, seed=0, dtype=dtype).matrices[0] + + # Add a reflection and some noise to the rotation. + refl = rot @ np.diag((1, -1, 1)).astype(dtype) + noise = 1e-3 * randn(9, seed=0).astype(dtype, copy=False).reshape(3, 3) + noisy_refl = refl + noise + + # Find nearest rotation. + nearest_rot = nearest_rotations(noisy_refl) + + # Check that estimate is a rotation. + _is_rotation(nearest_rot, dtype) + + # Check that we retain singleton shape. + assert nearest_rot.shape == rot.shape + + +def test_nearest_rotations_error(): + # Check error for bad ndim. + A = np.empty((2, 5, 3, 3)) + with pytest.raises(ValueError, match="Array must be of shape"): + _ = nearest_rotations(A) + + # Check error for bad shape. + A = np.empty((5, 3, 2)) + with pytest.raises(ValueError, match="Array must be of shape"): + _ = nearest_rotations(A) + + +def _is_rotation(R, dtype): + """ + Helper function to check if a set of 3x3 matrices are rotations + by checking that R.T @ R = I and det(R) = 1. + + :param R: Singleton or stack of 3x3 arrays. + :param dtype: dtype to use for test tolerance. + :return: boolean indicating if all 3x3 arrays are rotations. + """ + if R.ndim == 2: + R = R[np.newaxis] + + n_rots = len(R) + RTR = np.transpose(R, axes=(0, 2, 1)) @ R + atol = utest_tolerance(dtype) + np.testing.assert_allclose( + RTR, np.broadcast_to(np.eye(3), (n_rots, 3, 3)), atol=atol + ) + np.testing.assert_allclose(np.linalg.det(R), 1, atol=atol) diff --git a/tests/test_mean_estimator.py b/tests/test_mean_estimator.py index ad73bb2716..2c7563dbb8 100644 --- a/tests/test_mean_estimator.py +++ b/tests/test_mean_estimator.py @@ -7,7 +7,7 @@ from aspire.basis import FBBasis3D from aspire.operators import RadialCTFFilter from aspire.reconstruction import MeanEstimator -from aspire.source.simulation import Simulation +from aspire.source.simulation import _LegacySimulation from aspire.volume import LegacyVolume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") @@ -18,7 +18,7 @@ def setUp(self): self.dtype = np.float32 self.resolution = 8 self.vols = LegacyVolume(L=self.resolution, dtype=self.dtype).generate() - self.sim = sim = Simulation( + self.sim = _LegacySimulation( n=1024, vols=self.vols, unique_filters=[ @@ -28,10 +28,10 @@ def setUp(self): ) basis = FBBasis3D((self.resolution,) * 3, dtype=self.dtype) - self.estimator = MeanEstimator(sim, basis, preconditioner="none") + self.estimator = MeanEstimator(self.sim, basis, preconditioner="none") self.estimator_with_preconditioner = MeanEstimator( - sim, basis, preconditioner="circulant" + self.sim, basis, preconditioner="circulant" ) def tearDown(self): @@ -140,10 +140,10 @@ def testEstimate(self): ) def testAdjoint(self): - mean_b_coeff = self.estimator.src_backward().squeeze() + mean_b_coef = self.estimator.src_backward().squeeze() self.assertTrue( np.allclose( - mean_b_coeff, + mean_b_coef, [ 1.07338590e-01, 1.23690941e-01, @@ -249,7 +249,7 @@ def testAdjoint(self): ) def testOptimize1(self): - mean_b_coeff = np.array( + mean_b_coef = np.array( [ [ 1.07338590e-01, @@ -354,7 +354,7 @@ def testOptimize1(self): ] ) - x = self.estimator.conj_grad(mean_b_coeff) + x = self.estimator.conj_grad(mean_b_coef) self.assertTrue( np.allclose( x, @@ -463,7 +463,7 @@ def testOptimize1(self): ) def testOptimize2(self): - mean_b_coeff = np.array( + mean_b_coef = np.array( [ [ 1.07338590e-01, @@ -568,7 +568,7 @@ def testOptimize2(self): ] ) - x = self.estimator_with_preconditioner.conj_grad(mean_b_coeff) + x = self.estimator_with_preconditioner.conj_grad(mean_b_coef) self.assertTrue( np.allclose( x, diff --git a/tests/test_orient_sdp.py b/tests/test_orient_sdp.py new file mode 100644 index 0000000000..a161d2fdd7 --- /dev/null +++ b/tests/test_orient_sdp.py @@ -0,0 +1,195 @@ +import numpy as np +import pytest + +from aspire.abinitio import CommonlineSDP +from aspire.nufft import backend_available +from aspire.source import Simulation +from aspire.utils import ( + Rotation, + get_aligned_rotations, + mean_aligned_angular_distance, + register_rotations, + rots_to_clmatrix, +) +from aspire.volume import AsymmetricVolume + +RESOLUTION = [ + 32, + 33, +] + +OFFSETS = [ + None, # Defaults to random offsets. + 0, +] + +DTYPES = [ + np.float32, + pytest.param(np.float64, marks=pytest.mark.expensive), +] + + +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}") +def resolution(request): + return request.param + + +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}") +def offsets(request): + return request.param + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +def dtype(request): + return request.param + + +@pytest.fixture +def src_orient_est_fixture(resolution, offsets, dtype): + """Fixture for simulation source and orientation estimation object.""" + src = Simulation( + n=50, + L=resolution, + vols=AsymmetricVolume(L=resolution, C=1, K=100, seed=0).generate(), + offsets=offsets, + amplitudes=1, + seed=0, + ) + + # Increase max_shift and set shift_step to be sub-pixel when using + # random offsets in the Simulation. This improves common-line detection. + max_shift = 0.20 + shift_step = 0.25 + + # Set max_shift 1 pixel and shift_step to 1 pixel when using 0 offsets. + if np.all(src.offsets == 0.0): + max_shift = 1 / src.L + shift_step = 1 + + orient_est = CommonlineSDP( + src, max_shift=max_shift, shift_step=shift_step, mask=False + ) + + return src, orient_est + + +def test_estimate_rotations(src_orient_est_fixture): + src, orient_est = src_orient_est_fixture + + if backend_available("cufinufft") and src.dtype == np.float32: + pytest.skip("CI on gpu fails for singles.") + + orient_est.estimate_rotations() + + # Register estimates to ground truth rotations and compute the + # angular distance between them (in degrees). + # Assert that mean aligned angular distance is less than 1 degrees. + mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1) + + +def test_construct_S(src_orient_est_fixture): + """Test properties of the common-line quadratic form matrix S.""" + src, orient_est = src_orient_est_fixture + + # Since we are using the ground truth cl_matrix there is no need to test with offsets. + if src.offsets.all() != 0: + pytest.skip("No need to test with offsets.") + + # Construct the matrix S using ground truth common-lines. + gt_cl_matrix = rots_to_clmatrix(src.rotations, orient_est.n_theta) + S = orient_est._construct_S(gt_cl_matrix) + + # Check that S is symmetric. + np.testing.assert_allclose(S, S.T) + + # For uniformly distributed rotations the top eigenvalue should have multiplicity 3. + # As such, we can expect that the top 3 eigenvalues will all be close in value to their mean. + eigs = np.linalg.eigvalsh(S) + eigs_mean = np.mean(eigs[:3]) + + # Check that the top 3 eigenvalues are all within 10% of the their mean. + np.testing.assert_array_less(abs((eigs[:3] - eigs_mean) / eigs_mean), 0.10) + + # Check that the next eigenvalue is not close to the top 3, ie. multiplicity is not greater than 3. + np.testing.assert_array_less(0.25, abs((eigs[4] - eigs_mean) / eigs_mean)) + + +def test_gram_matrix(src_orient_est_fixture): + """Test properties of the common-line Gram matrix.""" + src, orient_est = src_orient_est_fixture + + # Since we are using the ground truth cl_matrix there is no need to test with offsets. + if src.offsets.all() != 0: + pytest.skip("No need to test with offsets.") + + # Construct a ground truth S to pass into Gram computation. + gt_cl_matrix = rots_to_clmatrix(src.rotations, orient_est.n_theta) + S = orient_est._construct_S(gt_cl_matrix) + + # Estimate the Gram matrix + A, b = orient_est._sdp_prep() + gram = orient_est._compute_gram_matrix(S, A, b) + + # Construct the ground truth Gram matrix, G = R @ R.T, where R = [R1, R2] + # with R1 and R2 being the concatenation of the first and second columns + # of all ground truth rotation matrices, respectively. + rots = src.rotations + R1 = rots[:, :, 0] + R2 = rots[:, :, 1] + R = np.concatenate((R1, R2)) + gt_gram = R @ R.T + + # We'll check that the RMSE is within 10% of the mean value of gt_gram + rmse = np.sqrt(np.mean((gram - R @ R.T) ** 2)) + np.testing.assert_array_less(rmse / np.mean(gt_gram), 0.10) + + +def test_ATA_solver(): + # Generate some rotations. + seed = 42 + n_rots = 73 + dtype = np.float32 + rots = Rotation.generate_random_rotations(n=n_rots, seed=seed, dtype=dtype).matrices + + # Create a simple reference linear transformation A that is rank-3. + A_ref = np.diag([1, 2, 3]).astype(dtype, copy=False) + + # Create v1 and v2 such that A_ref*v1=R1 and A_ref*v2=R2, R1 and R2 are the first + # and second columns of all rotations. + R1 = rots[:, :, 0].T + R2 = rots[:, :, 1].T + v1 = np.linalg.inv(A_ref) @ R1 + v2 = np.linalg.inv(A_ref) @ R2 + + # Use ATA_solver to solve for A, given v1 and v2. + A = CommonlineSDP._ATA_solver(v1, v2) + + # Check that A is close to A_ref. + np.testing.assert_allclose(A, A_ref, atol=1e-7) + + +def test_deterministic_rounding(src_orient_est_fixture): + """Test deterministic rounding, which recovers rotations from a Gram matrix.""" + src, orient_est = src_orient_est_fixture + + # Since we are using the ground truth cl_matrix there is no need to test with offsets. + if src.offsets.all() != 0: + pytest.skip("No need to test with offsets.") + + # Construct the ground truth Gram matrix, G = R @ R.T, where R = [R1, R2] + # with R1 and R2 being the concatenation of the first and second columns + # of all ground truth rotation matrices, respectively. + gt_rots = src.rotations + R1 = gt_rots[:, :, 0] + R2 = gt_rots[:, :, 1] + R = np.concatenate((R1, R2)) + gt_gram = R @ R.T + + # Pass the Gram matrix into the deterministic rounding procedure to recover rotations. + est_rots = orient_est._deterministic_rounding(gt_gram) + + # Check that the estimated rotations are close to ground truth after global alignment. + Q_mat, flag = register_rotations(est_rots, gt_rots) + regrot = get_aligned_rotations(est_rots, Q_mat, flag) + + np.testing.assert_allclose(regrot, gt_rots) diff --git a/tests/test_orient_symmetric.py b/tests/test_orient_symmetric.py index 746213913c..b008769e72 100644 --- a/tests/test_orient_symmetric.py +++ b/tests/test_orient_symmetric.py @@ -6,10 +6,15 @@ from aspire.abinitio import CLSymmetryC2, CLSymmetryC3C4, CLSymmetryCn from aspire.abinitio.commonline_cn import MeanOuterProductEstimator from aspire.source import Simulation -from aspire.utils import Rotation, utest_tolerance -from aspire.utils.coor_trans import get_aligned_rotations, register_rotations -from aspire.utils.misc import J_conjugate, all_pairs, cyclic_rotations -from aspire.utils.random import randn +from aspire.utils import ( + J_conjugate, + Rotation, + all_pairs, + cyclic_rotations, + mean_aligned_angular_distance, + randn, + utest_tolerance, +) from aspire.volume import CnSymmetricVolume # A set of these parameters are marked expensive to reduce testing time. @@ -84,6 +89,7 @@ def source_orientation_objs(n_img, L, order, dtype): n_theta=360, max_shift=1 / L, seed=seed, + mask=False, ) if order in [3, 4]: @@ -116,14 +122,9 @@ def test_estimate_rotations(n_img, L, order, dtype): # g-synchronize ground truth rotations. rots_gt_sync = cl_symm.g_sync(rots_est, order, rots_gt) - # Register estimates to ground truth rotations and compute the - # angular distance between them (in degrees). - Q_mat, flag = register_rotations(rots_est, rots_gt_sync) - regrot = get_aligned_rotations(rots_est, Q_mat, flag) - mean_ang_dist = Rotation.mean_angular_distance(regrot, rots_gt_sync) * 180 / np.pi - - # Assert mean angular distance is reasonable. - assert mean_ang_dist < 3 + # Register estimates to ground truth rotations and check that the + # mean angular distance between them is less than 3 degrees. + mean_aligned_angular_distance(rots_est, rots_gt_sync, degree_tol=3) @pytest.mark.parametrize("n_img, L, order, dtype", param_list_c3_c4) diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 8b427984ef..31d6b20e94 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -8,13 +8,9 @@ from aspire.abinitio import CLOrient3D, CLSyncVoting from aspire.commands.orient3d import orient3d +from aspire.noise import WhiteNoiseAdder from aspire.source import Simulation -from aspire.utils import ( - Rotation, - get_aligned_rotations, - register_rotations, - rots_to_clmatrix, -) +from aspire.utils import mean_aligned_angular_distance, rots_to_clmatrix from aspire.volume import AsymmetricVolume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") @@ -68,7 +64,9 @@ def source_orientation_objs(resolution, offsets, dtype): if src.offsets.all() != 0: max_shift = 0.20 shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. - orient_est = CLSyncVoting(src, max_shift=max_shift, shift_step=shift_step) + orient_est = CLSyncVoting( + src, max_shift=max_shift, shift_step=shift_step, mask=False + ) return src, orient_est @@ -89,6 +87,9 @@ def test_build_clmatrix(source_orientation_objs): # Check that at least 98% of estimates are within 5 degrees. tol = 0.98 + if src.offsets.all() != 0: + # Set tolerance to 95% when using nonzero offsets. + tol = 0.95 assert within_5 / angle_diffs.size > tol @@ -98,14 +99,9 @@ def test_estimate_rotations(source_orientation_objs): orient_est.estimate_rotations() # Register estimates to ground truth rotations and compute the - # angular distance between them (in degrees). - Q_mat, flag = register_rotations(orient_est.rotations, src.rotations) - regrot = get_aligned_rotations(orient_est.rotations, Q_mat, flag) - mean_ang_dist = Rotation.mean_angular_distance(regrot, src.rotations) * 180 / np.pi - - # Assert that mean angular distance is less than 1 degree (5 degrees with shifts). - degree_tol = 1 - assert mean_ang_dist < degree_tol + # mean angular distance between them (in degrees). + # Assert that mean angular distance is less than 1 degree. + mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1) def test_estimate_shifts(source_orientation_objs): @@ -119,6 +115,45 @@ def test_estimate_shifts(source_orientation_objs): assert np.allclose(est_shifts, src.offsets) +def test_estimate_rotations_fuzzy_mask(): + noisy_src = Simulation( + n=35, + vols=AsymmetricVolume(L=128, C=1, K=400, seed=0).generate(), + offsets=0, + amplitudes=1, + noise_adder=WhiteNoiseAdder.from_snr(snr=2), + seed=0, + ) + + # Orientation estimation without fuzzy_mask. + max_shift = 1 / noisy_src.L + shift_step = 1 + orient_est = CLSyncVoting( + noisy_src, max_shift=max_shift, shift_step=shift_step, mask=False + ) + orient_est.estimate_rotations() + + # Orientation estimation with fuzzy mask. + orient_est_fuzzy = CLSyncVoting( + noisy_src, max_shift=max_shift, shift_step=shift_step + ) + orient_est_fuzzy.estimate_rotations() + + # Check that fuzzy_mask improves orientation estimation. + mean_angle_dist = mean_aligned_angular_distance( + orient_est.rotations, noisy_src.rotations + ) + mean_angle_dist_fuzzy = mean_aligned_angular_distance( + orient_est_fuzzy.rotations, noisy_src.rotations + ) + + # Check that the estimate is reasonable, ie. mean_angle_dist < 10 degrees. + np.testing.assert_array_less(mean_angle_dist, 10) + + # Check that fuzzy_mask improves the estimate. + np.testing.assert_array_less(mean_angle_dist_fuzzy, mean_angle_dist) + + def test_theta_error(): """ Test that CLSyncVoting when instantiated with odd value for `n_theta` diff --git a/tests/test_oriented_source.py b/tests/test_oriented_source.py index e91b36c0c7..54c34ea404 100644 --- a/tests/test_oriented_source.py +++ b/tests/test_oriented_source.py @@ -39,7 +39,7 @@ def src_fixture(request): # Generate an origianl source and an oriented source. og_src = Simulation(L=L, n=n, vols=vol, offsets=0) - orient_est = estimator(og_src, max_shift=1 / L, **estimator_kwargs) + orient_est = estimator(og_src, max_shift=1 / L, mask=False, **estimator_kwargs) oriented_src = OrientedSource(og_src, orient_est) return og_src, oriented_src diff --git a/tests/test_relion_interop.py b/tests/test_relion_interop.py new file mode 100644 index 0000000000..a1a2796675 --- /dev/null +++ b/tests/test_relion_interop.py @@ -0,0 +1,70 @@ +import os + +import numpy as np +import pytest + +from aspire.source import RelionSource, Simulation +from aspire.volume import Volume + +DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") + + +STARFILE = ["rln_proj_65.star", "rln_proj_64.star"] + + +@pytest.fixture(params=STARFILE, scope="module") +def sources(request): + starfile = os.path.join(DATA_DIR, request.param) + rln_src = RelionSource(starfile) + + # Generate Volume used for Relion projections. + # Note, `downsample` is a no-op for resolution 65. + vol_path = os.path.join(DATA_DIR, "clean70SRibosome_vol.npy") + vol = Volume(np.load(vol_path), dtype=rln_src.dtype).downsample(rln_src.L) + + # Create Simulation source using Volume and angles from Relion projections. + # Note, for odd resolution Relion projections are shifted by 1 pixel in x and y. + offsets = 0 + if rln_src.L % 2 == 1: + offsets = -np.ones((rln_src.n, 2), dtype=rln_src.dtype) + + sim_src = Simulation( + n=rln_src.n, + vols=vol, + offsets=offsets, + amplitudes=1, + angles=rln_src.angles, + dtype=rln_src.dtype, + ) + return rln_src, sim_src + + +def test_projections_relative_error(sources): + """Check the relative error between Relion and ASPIRE projection images.""" + rln_src, sim_src = sources + + # Work with numpy arrays. + rln_np = rln_src.images[:].asnumpy() + sim_np = sim_src.images[:].asnumpy() + + # Normalize images. + rln_np = (rln_np - np.mean(rln_np)) / np.std(rln_np) + sim_np = (sim_np - np.mean(sim_np)) / np.std(sim_np) + + # Check that relative error is less than 3%. + error = np.linalg.norm(rln_np - sim_np, axis=(1, 2)) / np.linalg.norm( + rln_np, axis=(1, 2) + ) + np.testing.assert_array_less(error, 0.03) + + +def test_projections_frc(sources): + """Compute the FRC between Relion and ASPIRE projection images.""" + rln_src, sim_src = sources + + # Compute the Fourier Ring Correlation. + res, corr = rln_src.images[:].frc(sim_src.images[:], cutoff=0.143) + + # Check that estimated resolution is high (< 2.5 pixels) and correlation is close to 1. + np.testing.assert_array_less(res, 2.5) + np.testing.assert_array_less(1 - corr[:, -2], 0.02) diff --git a/tests/test_rotation.py b/tests/test_rotation.py index caffb17f0c..9e0dba4ec6 100644 --- a/tests/test_rotation.py +++ b/tests/test_rotation.py @@ -1,5 +1,4 @@ import logging -from unittest import TestCase import numpy as np import pytest @@ -10,122 +9,157 @@ logger = logging.getLogger(__name__) -class UtilsTestCase(TestCase): - def setUp(self): - self.dtype = np.float32 - self.num_rots = 32 - self.rot_obj = Rotation.generate_random_rotations( - self.num_rots, seed=0, dtype=self.dtype - ) - self.angles = self.rot_obj.angles - self.matrices = self.rot_obj.matrices - - def testRotMatrices(self): - rot_ref = sp_rot.from_matrix(self.matrices) - matrices = rot_ref.as_matrix().astype(self.dtype) - self.assertTrue( - np.allclose(self.matrices, matrices, atol=utest_tolerance(self.dtype)) +# Parameters + +NUM_ROTS = 32 +SEED = 0 + +DTYPES = [ + np.float32, + np.float64, +] + + +# Fixtures + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def rot_obj(dtype): + return Rotation.generate_random_rotations(NUM_ROTS, seed=SEED, dtype=dtype) + + +# Rotation Class Tests + + +def test_matrices(rot_obj): + rot_ref = sp_rot.from_matrix(rot_obj.matrices) + matrices = rot_ref.as_matrix() + np.testing.assert_allclose( + rot_obj.matrices, matrices, atol=utest_tolerance(rot_obj.dtype) + ) + + +def test_as_angles(rot_obj): + rot_ref = sp_rot.from_euler("ZYZ", rot_obj.angles, degrees=False) + angles = rot_ref.as_euler("ZYZ", degrees=False) + np.testing.assert_allclose(rot_obj.angles, angles) + + +def test_from_matrix(rot_obj): + rot_ref = sp_rot.from_matrix(rot_obj.matrices) + angles = rot_ref.as_euler("ZYZ", degrees=False) + rot = Rotation.from_matrix(rot_obj.matrices) + np.testing.assert_allclose(rot.angles, angles) + + +def test_from_euler(rot_obj): + rot_ref = sp_rot.from_euler("ZYZ", rot_obj.angles, degrees=False) + matrices = rot_ref.as_matrix() + rot = Rotation.from_euler(rot_obj.angles, dtype=rot_obj.dtype) + np.testing.assert_allclose(rot._matrices, matrices) + + +def test_invert(rot_obj): + rot_mat = rot_obj.matrices + rot_mat_t = rot_obj.invert() + np.testing.assert_allclose(rot_mat_t, np.transpose(rot_mat, (0, 2, 1))) + + +def test_multiplication(rot_obj): + result = (rot_obj * rot_obj.invert()).matrices + for i in range(len(rot_obj)): + np.testing.assert_allclose( + np.eye(3), result[i], atol=utest_tolerance(rot_obj.dtype) ) - def testRotAngles(self): - rot_ref = sp_rot.from_euler("ZYZ", self.angles, degrees=False) - angles = rot_ref.as_euler("ZYZ", degrees=False).astype(self.dtype) - self.assertTrue(np.allclose(self.angles, angles)) - - def testFromMatrix(self): - rot_ref = sp_rot.from_matrix(self.matrices) - angles = rot_ref.as_euler("ZYZ", degrees=False).astype(self.dtype) - rot = Rotation.from_matrix(self.matrices, dtype=self.dtype) - self.assertTrue(np.allclose(rot.angles, angles)) - - def testFromEuler(self): - rot_ref = sp_rot.from_euler("ZYZ", self.angles, degrees=False) - matrices = rot_ref.as_matrix().astype(self.dtype) - rot = Rotation.from_euler(self.angles, dtype=self.dtype) - self.assertTrue(np.allclose(rot._matrices, matrices)) - - def testInvert(self): - rot_mat = self.rot_obj.matrices - rot_mat_t = self.rot_obj.invert() - self.assertTrue(np.allclose(rot_mat_t, np.transpose(rot_mat, (0, 2, 1)))) - - def testMultiplication(self): - result = (self.rot_obj * self.rot_obj.invert()).matrices - for i in range(len(self.rot_obj)): - self.assertTrue( - np.allclose(np.eye(3), result[i], atol=utest_tolerance(self.dtype)) - ) - - def testRegisterRots(self): - q_mat = Rotation.generate_random_rotations(1, dtype=self.dtype)[0] - for flag in [0, 1]: - regrots_ref = self.rot_obj.apply_registration(q_mat, flag) - q_mat_est, flag_est = self.rot_obj.find_registration(regrots_ref) - self.assertTrue( - np.allclose(flag_est, flag) - and np.allclose(q_mat_est, q_mat, atol=utest_tolerance(self.dtype)) - ) - - def testRegister(self): - # These will yield two more distinct sets of random rotations wrt self.rot_obj - set1 = Rotation.generate_random_rotations(self.num_rots, dtype=self.dtype) - set2 = Rotation.generate_random_rotations( - self.num_rots, dtype=self.dtype, seed=7 + +def test_register_rots(rot_obj): + q_mat = Rotation.generate_random_rotations(1, dtype=rot_obj.dtype)[0] + for flag in [0, 1]: + regrots_ref = rot_obj.apply_registration(q_mat, flag) + q_mat_est, flag_est = rot_obj.find_registration(regrots_ref) + np.testing.assert_allclose(flag_est, flag) + np.testing.assert_allclose( + q_mat_est, q_mat, atol=utest_tolerance(rot_obj.dtype) ) - # Align both sets of random rotations to rot_obj - aligned_rots1 = self.rot_obj.register(set1) - aligned_rots2 = self.rot_obj.register(set2) - self.assertTrue(aligned_rots1.mse(aligned_rots2) < utest_tolerance(self.dtype)) - self.assertTrue(aligned_rots2.mse(aligned_rots1) < utest_tolerance(self.dtype)) - - def testMSE(self): - q_ang = [np.random.random(3)] - q_mat = sp_rot.from_euler("ZYZ", q_ang, degrees=False).as_matrix()[0] - for flag in [0, 1]: - regrots_ref = self.rot_obj.apply_registration(q_mat, flag) - mse = self.rot_obj.mse(regrots_ref) - self.assertTrue(mse < utest_tolerance(self.dtype)) - - def testCommonLines(self): - ell_ij, ell_ji = self.rot_obj.common_lines(8, 11, 360) - self.assertTrue(ell_ij == 235 and ell_ji == 284) - - def testString(self): - logger.debug(str(self.rot_obj)) - - def testRepr(self): - logger.debug(repr(self.rot_obj)) - - def testLen(self): - self.assertTrue(len(self.rot_obj) == self.num_rots) - - def testSetterGetter(self): - # Excute set - tmp = np.arange(9).reshape((3, 3)) - self.rot_obj[13] = tmp - # Execute get - self.assertTrue(np.all(self.rot_obj[13] == tmp)) - - def testDtype(self): - self.assertTrue(self.dtype == self.rot_obj.dtype) - - def testFromRotvec(self): - # Build random rotation vectors. - axis = np.array([1, 0, 0], dtype=self.dtype) - angles = np.random.uniform(0, 2 * np.pi, 10) - rot_vecs = np.array([angle * axis for angle in angles], dtype=self.dtype) - - # Build rotations using from_rotvec and about_axis (as reference). - rotations = Rotation.from_rotvec(rot_vecs, dtype=self.dtype) - ref_rots = Rotation.about_axis("x", angles, dtype=self.dtype) - - self.assertTrue(isinstance(rotations, Rotation)) - self.assertTrue(rotations.matrices.dtype == self.dtype) - self.assertTrue(np.allclose(rotations.matrices, ref_rots.matrices)) - - -def test_angle_dist(): - dtype = np.float32 + + +def test_register(rot_obj): + # These will yield two more distinct sets of random rotations wrt rot_obj + set1 = Rotation.generate_random_rotations(NUM_ROTS, dtype=rot_obj.dtype) + set2 = Rotation.generate_random_rotations( + NUM_ROTS, dtype=rot_obj.dtype, seed=SEED + 7 + ) + # Align both sets of random rotations to rot_obj + aligned_rots1 = rot_obj.register(set1) + aligned_rots2 = rot_obj.register(set2) + tol = utest_tolerance(rot_obj.dtype) + np.testing.assert_array_less(aligned_rots1.mse(aligned_rots2), tol) + np.testing.assert_array_less(aligned_rots2.mse(aligned_rots1), tol) + + +def test_mse(rot_obj): + q_ang = [np.random.random(3)] + q_mat = sp_rot.from_euler("ZYZ", q_ang, degrees=False).as_matrix()[0] + for flag in [0, 1]: + regrots_ref = rot_obj.apply_registration(q_mat, flag) + mse = rot_obj.mse(regrots_ref) + np.testing.assert_array_less(mse, utest_tolerance(rot_obj.dtype)) + + +def test_common_lines(rot_obj): + ell_ij, ell_ji = rot_obj.common_lines(8, 11, 360) + np.testing.assert_equal([ell_ij, ell_ji], [235, 284]) + + +def test_string(rot_obj): + logger.debug(str(rot_obj)) + + +def test_repr(rot_obj): + logger.debug(repr(rot_obj)) + + +def test_len(rot_obj): + assert len(rot_obj) == NUM_ROTS + + +def test_setter_getter(rot_obj): + # Excute set + tmp = np.arange(9).reshape((3, 3)) + rot_obj[13] = tmp + # Execute get + np.testing.assert_equal(rot_obj[13], tmp) + + +def test_dtype(dtype, rot_obj): + assert dtype == rot_obj.dtype + + +def test_from_rotvec(rot_obj): + # Build random rotation vectors. + axis = np.array([1, 0, 0], dtype=rot_obj.dtype) + angles = np.random.uniform(0, 2 * np.pi, 10) + rot_vecs = np.array([angle * axis for angle in angles], dtype=rot_obj.dtype) + + # Build rotations using from_rotvec and about_axis (as reference). + rotations = Rotation.from_rotvec(rot_vecs, dtype=rot_obj.dtype) + ref_rots = Rotation.about_axis("x", angles, dtype=rot_obj.dtype) + + assert isinstance(rotations, Rotation) + assert rotations.matrices.dtype == rot_obj.dtype + np.testing.assert_allclose(rotations.matrices, ref_rots.matrices) + + +# Angular Distance Tests + + +def test_angle_dist(dtype): angles = np.array([i * np.pi / 360 for i in range(360)], dtype=dtype) rots = Rotation.about_axis("x", angles, dtype=dtype) @@ -140,11 +174,9 @@ def test_angle_dist(): _ = Rotation.angle_dist(rots[:3], rots[:5]) -def test_mean_angular_distance(): - rots_z = Rotation.about_axis( - "z", [0, np.pi / 4, np.pi / 2], dtype=np.float32 - ).matrices - rots_id = Rotation.about_axis("z", [0, 0, 0], dtype=np.float32).matrices +def test_mean_angular_distance(dtype): + rots_z = Rotation.about_axis("z", [0, np.pi / 4, np.pi / 2], dtype=dtype).matrices + rots_id = Rotation.about_axis("z", [0, 0, 0], dtype=dtype).matrices mean_ang_dist = Rotation.mean_angular_distance(rots_z, rots_id) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 11c7779245..659ce95603 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -8,9 +8,8 @@ from aspire.noise import WhiteNoiseAdder from aspire.operators import RadialCTFFilter -from aspire.source.relion import RelionSource -from aspire.source.simulation import Simulation -from aspire.utils.types import utest_tolerance +from aspire.source import RelionSource, Simulation, _LegacySimulation +from aspire.utils import utest_tolerance from aspire.volume import LegacyVolume, SymmetryGroup, Volume from .test_utils import matplotlib_dry_run @@ -113,7 +112,7 @@ def setUp(self): dtype=self.dtype, ).generate() - self.sim = Simulation( + self.sim = _LegacySimulation( n=self.n, L=self.L, vols=self.vols, @@ -135,7 +134,7 @@ def testGaussianBlob(self): def testSimulationRots(self): self.assertTrue( np.allclose( - self.sim.rotations[0, :, :], + self.sim.rots_zyx_to_legacy_aspire(self.sim.rotations[0, :, :]), np.array( [ [0.91675498, 0.2587233, 0.30433956], @@ -143,6 +142,7 @@ def testSimulationRots(self): [-0.00507853, 0.76938412, -0.63876622], ] ), + atol=utest_tolerance(self.dtype), ) ) @@ -158,7 +158,7 @@ def testSimulationImages(self): ) def testSimulationCached(self): - sim_cached = Simulation( + sim_cached = _LegacySimulation( n=self.n, L=self.L, vols=self.vols, @@ -515,7 +515,7 @@ def testSimulationEvalCoords(self): self.assertTrue( np.allclose( - result["err"][:10], + result["err"][0, :10], [ 1.58382394, 1.58382394, diff --git a/tests/test_steerable_bases_2d.py b/tests/test_steerable_bases_2d.py new file mode 100644 index 0000000000..9777c3d2ed --- /dev/null +++ b/tests/test_steerable_bases_2d.py @@ -0,0 +1,119 @@ +import logging + +import numpy as np +import PIL.Image as PILImage +import pytest + +from aspire.basis import FBBasis2D, FFBBasis2D, FLEBasis2D, FPSWFBasis2D, PSWFBasis2D +from aspire.image import Image +from aspire.utils import gaussian_2d + +logger = logging.getLogger(__name__) + + +# Parameters + +DTYPES = [ + np.float32, + pytest.param(np.float64, marks=pytest.mark.expensive), +] + +BASES = [ + FFBBasis2D, + FBBasis2D, + FLEBasis2D, + PSWFBasis2D, + FPSWFBasis2D, +] + +IMG_SIZES = [ + 31, + pytest.param(32, marks=pytest.mark.expensive), +] + +# Fixtures + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"img_size={x}", scope="module") +def img_size(request): + return request.param + + +@pytest.fixture(params=BASES, ids=lambda x: f"basis={x}", scope="module") +def basis(request, img_size, dtype): + cls = request.param + # Setup a Basis + basis = cls(img_size, dtype=dtype) + return basis + + +# Basis Rotations + + +def test_basis_rotation_2d(basis): + """ + Test steerable basis rotation performs similar operation to PIL real space image rotation. + + Checks both orientation and rough values. + """ + # Set a rotation amount + rot_radians = np.pi / 6 + + # Create an Image containing a smooth blob. + L = basis.nres + img = gaussian_2d(L, mu=(L // 4, 0), dtype=basis.dtype) + img = Image(img / np.linalg.norm(img)) # Normalize + + # Rotate with an ASPIRE steerable basis, returning to real space. + rot_img = basis.expand(img).rotate(rot_radians).evaluate() + + # Rotate image with PIL, returning to Numpy array. + pil_rot_img = np.asarray( + PILImage.fromarray(img.asnumpy()[0]).rotate( + rot_radians * 180 / np.pi, resample=PILImage.BICUBIC + ) + ) + + # Rough compare arrays. + np.testing.assert_allclose(rot_img.asnumpy()[0], pil_rot_img, atol=0.15) + + +def test_basis_reflection_2d(basis): + """ + Test steerable basis reflection performs similar operation to Numpy flips. + + Checks both orientation and rough values. + """ + + # Create an Image containing a smooth blob. + L = basis.nres + img = gaussian_2d(L, mu=(L // 4, L // 5), dtype=basis.dtype) + img = Image(img / np.linalg.norm(img)) # Normalize + + # Reflect with an ASPIRE steerable basis, returning to real space. + refl_img = basis.expand(img).rotate(0, refl=True).evaluate() + + # Reflect image with Numpy. + # Note for odd images we can accurately use Numpy, + # but evens have the expected offset issue + # when compared to a purely row/col based flip. + flip = np.flipud + if isinstance(basis, PSWFBasis2D): + # TODO, reconcile PSWF reflection axis + flip = np.fliplr + + refl_img_np = flip(img.asnumpy()[0]) + + # Rough compare arrays. + atol = 0.01 + if L % 2 == 0: + # Even images test is crude, + # but is enough ensure flipping without complicating test. + atol = 0.5 + + np.testing.assert_allclose(refl_img.asnumpy()[0], refl_img_np, atol=atol) diff --git a/tests/test_utils.py b/tests/test_utils.py index 3b9c26fa44..ffad5bc9f6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -350,8 +350,24 @@ def test_fuzzy_mask(): ], ] ) - fmask = fuzzy_mask((8, 8), 2, 2) - assert np.allclose(results, fmask, atol=1e-7) + fmask = fuzzy_mask((8, 8), results.dtype, r0=2, risetime=2) + np.testing.assert_allclose(results, fmask, atol=1e-7) + + # Smoke test for 1D, 2D, and 3D fuzzy_mask. + for dim in range(1, 4): + _ = fuzzy_mask((32,) * dim, np.float32) + + # Check that we raise an error for bad dimension. + with pytest.raises(RuntimeError, match=r"Only 1D, 2D, or 3D fuzzy_mask*"): + _ = fuzzy_mask((8,) * 4, np.float32) + + # Check we raise for bad 2D shape. + with pytest.raises(ValueError, match=r"A 2D fuzzy_mask must be square*"): + _ = fuzzy_mask((2, 3), np.float32) + + # Check we raise for bad 3D shape. + with pytest.raises(ValueError, match=r"A 3D fuzzy_mask must be cubic*"): + _ = fuzzy_mask((2, 3, 3), np.float32) def test_multiprocessing_utils(): @@ -379,7 +395,7 @@ def matplotlib_no_gui(): # Save and restore current warnings list. with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Matplotlib is currently using agg") + warnings.filterwarnings("ignore", r"Matplotlib is currently using agg.*") yield diff --git a/tests/test_volume.py b/tests/test_volume.py index aa3eaf336e..9d039c4c9f 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -8,9 +8,8 @@ from numpy import pi from pytest import raises, skip -from aspire.utils import Rotation, grid_2d, powerset -from aspire.utils.matrix import anorm -from aspire.utils.types import utest_tolerance +from aspire.source import _LegacySimulation +from aspire.utils import Rotation, anorm, grid_2d, powerset, utest_tolerance from aspire.volume import ( AsymmetricVolume, CnSymmetryGroup, @@ -31,7 +30,7 @@ def res_id(params): RES = [42, 43] -@pytest.fixture(params=RES, ids=res_id) +@pytest.fixture(params=RES, ids=res_id, scope="module") def res(request): return request.param @@ -43,7 +42,7 @@ def dtype_id(params): DTYPES = [np.float32, np.float64] -@pytest.fixture(params=DTYPES, ids=dtype_id) +@pytest.fixture(params=DTYPES, ids=dtype_id, scope="module") def dtype(request): return request.param @@ -82,6 +81,35 @@ def vols_12(data_12): return Volume(data_12) +@pytest.fixture +def asym_vols(res, dtype): + vols = AsymmetricVolume(L=res, C=N, dtype=dtype, seed=0).generate() + return vols + + +@pytest.fixture(scope="module") +def vols_hot_cold(res, dtype): + L = res + n_vols = 5 + + # Generate random locations for hot/cold spots, each at a distance of approximately + # L // 4 from (0, 0, 0). Note, these points are considered to be in (z, y, x) order. + hot_cold_locs = np.random.uniform(low=-1, high=1, size=(n_vols, 2, 3)) + hot_cold_locs = np.round( + (hot_cold_locs / np.linalg.norm(hot_cold_locs, axis=-1)[:, :, None]) * (L // 4) + ).astype("int") + + # Generate Volumes, each with one hot and one cold spot. + vols = np.zeros((n_vols, L, L, L), dtype=dtype) + vol_center = np.array((L // 2, L // 2, L // 2), dtype="int") + for i in range(n_vols): + vols[i][tuple(vol_center + hot_cold_locs[i, 0])] = 1 + vols[i][tuple(vol_center + hot_cold_locs[i, 1])] = -1 + vols = Volume(vols) + + return vols, hot_cold_locs, vol_center + + @pytest.fixture def random_data(res, dtype): return np.random.randn(res, res, res).astype(dtype) @@ -256,7 +284,43 @@ def test_save_load(vols_1): assert np.allclose(vols_1, vols_loaded_double) -def test_project(vols_1, dtype): +def test_project(vols_hot_cold): + """ + We project Volumes containing random hot/cold spots using random rotations and check that + hot/cold spots in the projections are in the expected locations. + """ + vols, hot_cold_locs, vol_center = vols_hot_cold + dtype = vols.dtype + L = vols.resolution + + # Generate random rotations. + rots = Rotation.generate_random_rotations(n=vols.n_vols, dtype=dtype) + + # To find the expected location of hot/cold spots in the projections we rotate the 3D + # vector of locations by the transpose, ie. rots.invert(), (since our projections are + # produced by rotating the underlying grid) and then project along the z-axis. + + # Expected location of hot/cold spots relative to (0, 0, 0) origin in (x, y, z) order. + # Note, we write the simpler `(x, y, z) @ rots` in place of `(rots.T @ (x, y, z).T).T` + expected_hot_cold = hot_cold_locs[..., ::-1] @ rots.matrices + + # Expected location of hot/cold spots relative to center (L/2, L/2, L/2) in (z, y, x) order. + # Then projected along z-axis by dropping the z component. + expected_locs = np.round(expected_hot_cold[..., ::-1] + vol_center)[..., 1:] + + # Generate projection images. + projections = vols.project(rots) + + # Check that new hot/cold spots are within 1 pixel of expectecd locations. + for i in range(vols.n_vols): + p = projections.asnumpy()[i] + new_hot_loc = np.unravel_index(np.argmax(p), (L, L)) + new_cold_loc = np.unravel_index(np.argmin(p), (L, L)) + np.testing.assert_allclose(new_hot_loc, expected_locs[i, 0], atol=1) + np.testing.assert_allclose(new_cold_loc, expected_locs[i, 1], atol=1) + + +def test_project_axes(vols_1, dtype): L = vols_1.resolution # first test with synthetic data # Create a stack of rotations to test. @@ -294,37 +358,40 @@ def test_project(vols_1, dtype): vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy"))) rots = np.load(os.path.join(DATA_DIR, "rand_rot_matrices32.npy")) rots = np.moveaxis(rots, 2, 0) + + # Note, transforming rotations to compensate for legacy grid convention used in saved data. + rots = _LegacySimulation.rots_zyx_to_legacy_aspire(rots) + imgs_clean = vols.project(rots).asnumpy() assert np.allclose(results, imgs_clean, atol=1e-7) -# Parameterize over even and odd resolutions -@pytest.mark.parametrize("L", RES) -def test_rotate(L, dtype): +def test_rotate_axes(res, dtype): # In this test we instantiate Volume instance `vol`, containing a single nonzero # voxel in the first octant, and rotate it by multiples of pi/2 about each axis. # We then compare to reference volumes containing appropriately located nonzero voxel. # Create a Volume instance to rotate. # This volume has a value of 1 in the first octant at (1, 1, 1) and zeros elsewhere. + L = res data = np.zeros((L, L, L), dtype=dtype) data[L // 2 + 1, L // 2 + 1, L // 2 + 1] = 1 vol = Volume(data) - # Create a dict with map from axis and angle of rotation to new location of nonzero voxel. + # Create a dict with map from axis and angle of rotation to new location (z, y, x) of nonzero voxel. ref_pts = { ("x", 0): (1, 1, 1), - ("x", pi / 2): (1, 1, -1), - ("x", pi): (1, -1, -1), - ("x", 3 * pi / 2): (1, -1, 1), + ("x", pi / 2): (1, -1, 1), + ("x", pi): (-1, -1, 1), + ("x", 3 * pi / 2): (-1, 1, 1), ("y", 0): (1, 1, 1), ("y", pi / 2): (-1, 1, 1), ("y", pi): (-1, 1, -1), ("y", 3 * pi / 2): (1, 1, -1), ("z", 0): (1, 1, 1), - ("z", pi / 2): (1, -1, 1), - ("z", pi): (-1, -1, 1), - ("z", 3 * pi / 2): (-1, 1, 1), + ("z", pi / 2): (1, 1, -1), + ("z", pi): (1, -1, -1), + ("z", 3 * pi / 2): (1, -1, 1), } center = np.array([L // 2] * 3) @@ -349,23 +416,52 @@ def test_rotate(L, dtype): assert np.allclose(ref_vol, rot_vol, atol=utest_tolerance(dtype)) -def test_rotate_broadcast_unicast(vols_1, dtype): +def test_rotate(vols_hot_cold): + """ + We rotate Volumes containing random hot/cold spots by random rotations and check that + hot/cold spots in the rotated Volumes are in the expected locations. + """ + vols, hot_cold_locs, vol_center = vols_hot_cold + dtype = vols.dtype + L = vols.resolution + + # Generate random rotations. + rots = Rotation.generate_random_rotations(n=vols.n_vols, dtype=dtype) + + # Expected location of hot/cold spots relative to (0, 0, 0) origin in (x, y, z) order. + # Note, we write the simpler `(x, y, z) @ rots.T` in place of `(rots @ (x, y, z).T).T` + expected_hot_cold = hot_cold_locs[..., ::-1] @ rots.invert().matrices + + # Expected location of hot/cold spots relative to Volume center (L/2, L/2, L/2) in (z, y, x) order. + expected_locs = np.round(expected_hot_cold[..., ::-1] + vol_center) + + # Rotate Volumes. + rotated_vols = vols.rotate(rots) + + # Check that new hot/cold spots are within 1 pixel of expectecd locations. + for i in range(vols.n_vols): + v = rotated_vols.asnumpy()[i] + new_hot_loc = np.unravel_index(np.argmax(v), (L, L, L)) + new_cold_loc = np.unravel_index(np.argmin(v), (L, L, L)) + np.testing.assert_allclose(new_hot_loc, expected_locs[i, 0], atol=1) + np.testing.assert_allclose(new_cold_loc, expected_locs[i, 1], atol=1) + + +def test_rotate_broadcast_unicast(asym_vols): # Build `Rotation` objects. A singleton for broadcasting and a stack for unicasting. # The stack consists of copies of the singleton. - angles = np.array([pi, pi / 2, 0], dtype=dtype) - angles = np.tile(angles, (3, 1)) - rot_mat = Rotation.from_euler(angles, dtype=dtype).matrices - rot = Rotation(rot_mat[0]) - rots = Rotation(rot_mat) + dtype = asym_vols.dtype + rot = Rotation.generate_random_rotations(n=1, seed=1234, dtype=dtype) + rots = Rotation(np.broadcast_to(rot.matrices, (asym_vols.n_vols, 3, 3))) # Broadcast the singleton `Rotation` across the `Volume` stack. - vols_broadcast = vols_1.rotate(rot) + vols_broadcast = asym_vols.rotate(rot) # Unicast the `Rotation` stack across the `Volume` stack. - vols_unicast = vols_1.rotate(rots) + vols_unicast = asym_vols.rotate(rots) - for i in range(N): - assert np.allclose(vols_broadcast[i], vols_unicast[i]) + # Tests that all volumes match. + assert np.allclose(vols_broadcast, vols_unicast, atol=utest_tolerance(dtype)) def to_vec(vols_1, vec): diff --git a/tests/test_weighted_mean_estimator.py b/tests/test_weighted_mean_estimator.py index e7be5cf77a..4d185ce142 100644 --- a/tests/test_weighted_mean_estimator.py +++ b/tests/test_weighted_mean_estimator.py @@ -7,7 +7,7 @@ from aspire.basis import FBBasis3D from aspire.operators import RadialCTFFilter from aspire.reconstruction import WeightedVolumesEstimator -from aspire.source.simulation import Simulation +from aspire.source import _LegacySimulation from aspire.volume import LegacyVolume logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ def setUp(self): self.n = 1024 self.r = 2 self.L = L = 8 - self.sim = Simulation( + self.sim = _LegacySimulation( vols=LegacyVolume(L, dtype=self.dtype).generate(), n=self.n, unique_filters=[ @@ -135,10 +135,10 @@ def testPositiveWeightedEstimates(self): self.assertTrue(np.allclose(a, b, atol=1e-5)) def testAdjoint(self): - mean_b_coeff = self.estimator.src_backward().squeeze() + mean_b_coef = self.estimator.src_backward().squeeze() self.assertTrue( np.allclose( - mean_b_coeff, + mean_b_coef, [ 1.07338590e-01, 1.23690941e-01, @@ -244,7 +244,7 @@ def testAdjoint(self): ) def testOptimize1(self): - mean_b_coeff = np.array( + mean_b_coef = np.array( [ [ 1.07338590e-01, @@ -351,7 +351,7 @@ def testOptimize1(self): ) # Given equal weighting we should get the same result for all self.r volumes. - x = self.estimator.conj_grad(mean_b_coeff) + x = self.estimator.conj_grad(mean_b_coef) ref = np.array( [ @@ -461,7 +461,7 @@ def testOptimize1(self): self.assertTrue(np.allclose(x.flatten(), ref, atol=1e-4)) def testOptimize2(self): - mean_b_coeff = np.array( + mean_b_coef = np.array( [ [ 1.07338590e-01, @@ -567,7 +567,7 @@ def testOptimize2(self): * self.r ) - x = self.estimator_with_preconditioner.conj_grad(mean_b_coeff) + x = self.estimator_with_preconditioner.conj_grad(mean_b_coef) self.assertTrue( np.allclose( x, diff --git a/tox.ini b/tox.ini index cd98bab956..3970e6e317 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,6 @@ minversion = 3.8.0 [testenv] changedir = tests deps = - parameterized pooch pytest pytest-cov @@ -56,6 +55,7 @@ commands = flake8 . isort --check-only --diff . black --check --diff . + python docs/check_docstrings.py src/aspire python -m json.tool .zenodo.json /dev/null check-manifest . python -m build @@ -69,9 +69,13 @@ per-file-ignores = __init__.py: F401 gallery/tutorials/aspire_introduction.py: T201, F401, E402 gallery/tutorials/configuration.py: T201, E402 + gallery/tutorials/pipeline_demo.py: T201 gallery/tutorials/turorials/data_downloader.py: E402 gallery/tutorials/tutorials/ctf.py: T201, E402 + gallery/tutorials/tutorials/image_class.py: T201 gallery/tutorials/tutorials/micrograph_source.py: T201, E402 + gallery/tutorials/tutorials/weighted_volume_estimation.py: T201, E402 + gallery/tutorials/tutorials/relion_projection_interop.py: T201 # Ignore Sphinx gallery builds docs/build/html/_downloads/*/*.py: T201, E402, F401, E265 docs/source/auto*/*.py: T201, E402, F401, E265