diff --git a/.bumpversion.cfg b/.bumpversion.cfg index f1257cb71a..af2e0c8e7c 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.12.3 +current_version = 0.13.0 commit = True tag = True diff --git a/.github/workflows/long_workflow.yml b/.github/workflows/long_workflow.yml index bf676b45fb..ec5714a29e 100644 --- a/.github/workflows/long_workflow.yml +++ b/.github/workflows/long_workflow.yml @@ -1,14 +1,14 @@ name: ASPIRE Python Long Running Test Suite on: - push: - branches: - - 'main' - - 'develop' + pull_request: + types: [opened, synchronize, reopened, ready_for_review] jobs: expensive_tests: runs-on: self-hosted + # Only run on review ready pull_requests + if: ${{ github.event_name == 'pull_request' && github.event.pull_request.draft == false }} timeout-minutes: 360 steps: - uses: actions/checkout@v4 @@ -33,8 +33,9 @@ jobs: cat ${WORK_DIR}/config.yaml - name: Run run: | + export OMP_NUM_THREADS=1 ASPIREDIR=${{ env.WORK_DIR }} python -c \ "import aspire; print(aspire.config['ray']['temp_dir'])" - ASPIREDIR=${{ env.WORK_DIR }} python -m pytest -m "expensive" --durations=0 + ASPIREDIR=${{ env.WORK_DIR }} python -m pytest -n8 -m "expensive" --durations=0 - name: Cleanup run: rm -rf ${{ env.WORK_DIR }} diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index f2d5472e52..33a974b214 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -55,6 +55,8 @@ jobs: run: tox --skip-missing-interpreters false -e py${{ matrix.python-version }}-${{ matrix.pyenv }} - name: Upload Coverage to CodeCov uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} conda-build: needs: check @@ -147,14 +149,17 @@ jobs: echo "Stash the WORK_DIR to GitHub env so we can clean it up later." echo "WORK_DIR=${WORK_DIR}" >> $GITHUB_ENV echo -e "ray:\n temp_dir: ${WORK_DIR}\n" > ${WORK_DIR}/config.yaml - echo -e "common:\n cache_dir: ${CI_CACHE_DIR}\n" >> ${WORK_DIR}/config.yaml + echo -e "common:" >> ${WORK_DIR}/config.yaml + echo -e " cache_dir: ${CI_CACHE_DIR}" >> ${WORK_DIR}/config.yaml + echo -e " numeric: cupy" >> ${WORK_DIR}/config.yaml + echo -e " fft: cupy\n" >> ${WORK_DIR}/config.yaml echo "Log the config: ${WORK_DIR}/config.yaml" cat ${WORK_DIR}/config.yaml - name: Run run: | ASPIREDIR=${{ env.WORK_DIR }} python -c \ "import aspire; print(aspire.config['ray']['temp_dir'])" - ASPIREDIR=${{ env.WORK_DIR }} python -m pytest --durations=50 + ASPIREDIR=${{ env.WORK_DIR }} PYTHONWARNINGS=error python -m pytest --durations=50 - name: Cache Data run: | ASPIREDIR=${{ env.WORK_DIR }} python -c \ @@ -219,29 +224,18 @@ jobs: retention-days: 7 osx_arm: - defaults: - run: - shell: bash -l {0} needs: check runs-on: macos-14 # Run on every code push, but only on review ready PRs if: ${{ github.event_name == 'push' || github.event.pull_request.draft == false }} steps: - uses: actions/checkout@v4 - - name: Set up Conda - uses: conda-incubator/setup-miniconda@v2.3.0 + - uses: actions/setup-python@v5 with: - miniconda-version: "latest" - auto-update-conda: true - python-version: '3.8' - activate-environment: aspire - environment-file: environment-accelerate.yml - auto-activate-base: false - - name: Complete Install and Log Environment ${{ matrix.os }} Python ${{ matrix.python-version }} + python-version: '3.11' + - name: Complete Install and Log Environment run: | - conda info - conda list - conda install pyshtools # debug depends issues + python --version pip install -e ".[dev]" # install aspire pip freeze - name: Test diff --git a/MANIFEST.in b/MANIFEST.in index 4477aa87c0..ecc7484b40 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -17,6 +17,7 @@ recursive-include docs *.rst recursive-include docs Makefile recursive-include docs *.sh recursive-include src *.conf +recursive-include src *.cu recursive-include src *.yaml prune docs/build prune docs/source diff --git a/README.md b/README.md index d3fe484579..fd1b16c81f 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5657281.svg)](https://doi.org/10.5281/zenodo.5657281) [![Downloads](https://static.pepy.tech/badge/aspire/month)](https://pepy.tech/project/aspire) -# ASPIRE - Algorithms for Single Particle Reconstruction - v0.12.3 +# ASPIRE - Algorithms for Single Particle Reconstruction - v0.13.0 The ASPIRE-Python project supersedes [Matlab ASPIRE](https://github.com/PrincetonUniversity/aspire). @@ -20,7 +20,7 @@ For more information about the project, algorithms, and related publications ple Please cite using the following DOI. This DOI represents all versions, and will always resolve to the latest one. ``` -ComputationalCryoEM/ASPIRE-Python: v0.12.3 https://doi.org/10.5281/zenodo.5657281 +ComputationalCryoEM/ASPIRE-Python: v0.13.0 https://doi.org/10.5281/zenodo.5657281 ``` @@ -37,11 +37,7 @@ install `aspire` safely in that environment. If you are unfamiliar with `conda`, the [Miniconda](https://docs.conda.io/en/latest/miniconda.html) -distribution for `x86_64` is recommended. For Apple silicon to use -the osx-arm platform, patching and building some dependencies from -source is currently required. The Intel `osx-64` install is still -preferred even for Apple silicon users, otherwise [notes are -provided.](https://github.com/ComputationalCryoEM/ASPIRE-Python/discussions/969) +distribution for `x86_64` is recommended. Assuming you have `conda` and a compatible system, the following steps will checkout current code release, create an environment, and install diff --git a/docs/source/conf.py b/docs/source/conf.py index 1ae8045688..3aa429641c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -86,7 +86,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -release = version = "0.12.3" +release = version = "0.13.0" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/source/index.rst b/docs/source/index.rst index f34b7a1f4c..dda90a41d7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,4 +1,4 @@ -Aspire v0.12.3 +Aspire v0.13.0 ============== Algorithms for Single Particle Reconstruction diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 4a48e3a505..1fca5a35dd 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -34,12 +34,6 @@ to view Conda's installation instructions. .. note:: If you're not sure which distribution is right for you, go with `Miniconda `__ -.. note:: For Apple silicon to use the osx-arm platform, patching and - building some dependencies from source is currently required. The - Intel ``osx-64`` install is still preferred even for Apple silicon - users, otherwise `notes are - provided. `_ - Getting Started - Installation ************************************ @@ -129,10 +123,10 @@ an M1 laptop: Installing GPU Extensions ************************* -ASPIRE does support GPUs, depending on several external packages. The -collection of GPU extensions can be installed using ``pip``. -Extensions are grouped based on CUDA versions. To find the CUDA -driver version, run ``nvidia-smi`` on the intended system. +ASPIRE does support using a GPU, depending on several external +packages. The collection of GPU extensions can be installed using +``pip``. Extensions are grouped based on CUDA versions. To find the +CUDA driver version, run ``nvidia-smi`` on the intended system. .. list-table:: CUDA GPU Extension Versions :widths: 25 25 @@ -140,14 +134,6 @@ driver version, run ``nvidia-smi`` on the intended system. * - CUDA Version - ASPIRE Extension - * - 10.2 - - gpu-102 - * - 11.0 - - gpu-110 - * - 11.1 - - gpu-111 - * - >=11.2 - - gpu-11x * - >=12 - gpu-12x @@ -164,12 +150,15 @@ the command below would install GPU packages required for ASPIRE. By default if the required GPU extensions are correctly installed, -ASPIRE should automatically begin using the GPU for select components -(such as those using ``nufft``). - -Because GPU extensions depend on several third party packages and -libraries, we can only offer limited support if one of the packages -has a problem on your system. +ASPIRE should automatically begin using the GPU calls to our ``nufft`` module. + +Using GPU in other areas of the code is still an experimental feature +and requires a minor configuration setting to enable ``cupy``. See the +:ref:`sphx_glr_auto_tutorials_configuration.py` for details. Because +GPU extensions depend on several third party softwares and machines +vary wildly, we can only offer limited support if one of the packages +has a problem on your system. We are currently expanding GPU code +coverage. Generating Documentation ************************ diff --git a/environment-accelerate.yml b/environment-accelerate.yml index 38dd49813d..cfe9631f3c 100644 --- a/environment-accelerate.yml +++ b/environment-accelerate.yml @@ -7,7 +7,6 @@ channels: dependencies: - pip - python=3.8 - - pyshtools - numpy=1.24.1 - scipy=1.10.1 - scikit-learn diff --git a/gallery/experiments/experimental_abinitio_pipeline_10081.py b/gallery/experiments/experimental_abinitio_pipeline_10081.py index be27bc6e43..838b2c2d5a 100644 --- a/gallery/experiments/experimental_abinitio_pipeline_10081.py +++ b/gallery/experiments/experimental_abinitio_pipeline_10081.py @@ -59,7 +59,11 @@ # Create a source object for the experimental images src = RelionSource( - starfile_in, pixel_size=pixel_size, max_rows=n_imgs, data_folder=data_folder + starfile_in, + pixel_size=pixel_size, + max_rows=n_imgs, + data_folder=data_folder, + symmetry_group="C4", ) # Downsample the images @@ -115,12 +119,13 @@ # Volume Reconstruction # ---------------------- # -# Using the oriented source, attempt to reconstruct a volume. -# Since this is a Cn symmetric molecule, as indicated by -# ``symmetry="C4"`` above, the ``avgs`` images set will be repeated -# for each of the 3 additional rotations during the back-projection -# step. This boosts the effective number of images used in the -# reconstruction from ``n_classes`` to ``4*n_classes``. +# Using the oriented source, attempt to reconstruct a volume. Since +# this is a Cn symmetric molecule, as specified by ``RelionSource(..., +# symmetry_group="C4", ...)``, the ``symmetry_group`` source attribute +# will flow through the pipeline to ``avgs``. Then each image will be +# repeated for each of the 3 additional rotations during +# back-projection. This boosts the effective number of images used in +# the reconstruction from ``n_classes`` to ``4*n_classes``. logger.info("Begin Volume reconstruction") diff --git a/gallery/experiments/simulated_abinitio_pipeline.py b/gallery/experiments/simulated_abinitio_pipeline.py index ca0e4fbac5..4638c30716 100644 --- a/gallery/experiments/simulated_abinitio_pipeline.py +++ b/gallery/experiments/simulated_abinitio_pipeline.py @@ -54,13 +54,9 @@ # --------------- # Start with the hi-res volume map EMDB-2660 sourced from EMDB, # https://www.ebi.ac.uk/emdb/EMD-2660, and dowloaded via ASPIRE's downloader utility. -og_v = emdb_2660() +og_v = emdb_2660().astype(np.float64) logger.info("Original volume map data" f" shape: {og_v.shape} dtype:{og_v.dtype}") -logger.info(f"Downsampling to {(img_size,)*3}") -v = og_v.downsample(img_size) -L = v.resolution - # Then create a filter based on that variance # This is an example of a custom noise profile @@ -70,7 +66,7 @@ def noise_function(x, y): # White f1 = noise_variance # Violet-ish - f2 = noise_variance * (x * x + y * y) / L * L + f2 = noise_variance * (x * x + y * y) / img_size * img_size return (alpha * f1 + beta * f2) / 2.0 @@ -78,7 +74,7 @@ def noise_function(x, y): logger.info("Initialize CTF filters.") # Create some CTF effects -pixel_size = 5 * 65 / img_size # Pixel size of the images (in angstroms) +pixel_size = og_v.pixel_size # Pixel size (in angstroms) voltage = 200 # Voltage (in KV) defocus_min = 1.5e4 # Minimum defocus value (in angstroms) defocus_max = 2.5e4 # Maximum defocus value (in angstroms) @@ -94,13 +90,16 @@ def noise_function(x, y): # Finally create the Simulation src = Simulation( - L=v.resolution, n=num_imgs, - vols=v, + vols=og_v, noise_adder=custom_noise, unique_filters=ctf_filters, - dtype=v.dtype, + dtype=np.float64, ) + +# Downsample +src = src.downsample(img_size).cache() + # Peek if interactive: src.images[:10].show() @@ -115,7 +114,7 @@ def noise_function(x, y): # Plot the noise profile for inspection if interactive: - plt.imshow(aiso_noise_estimator.filter.evaluate_grid(L)) + plt.imshow(aiso_noise_estimator.filter.evaluate_grid(img_size)) plt.show() # Peek, what do the whitened images look like... diff --git a/gallery/tutorials/aspire_introduction.py b/gallery/tutorials/aspire_introduction.py index 648750ac01..cffe6d544e 100644 --- a/gallery/tutorials/aspire_introduction.py +++ b/gallery/tutorials/aspire_introduction.py @@ -571,7 +571,7 @@ def noise_function(x, y): # Generate several CTFs. ctf_filters = [ - RadialCTFFilter(pixel_size=5, defocus=d) + RadialCTFFilter(pixel_size=vol_ds.pixel_size, defocus=d) for d in np.linspace(defocus_min, defocus_max, defocus_ct) ] diff --git a/gallery/tutorials/configuration.py b/gallery/tutorials/configuration.py index 819ff9b675..372d97df06 100644 --- a/gallery/tutorials/configuration.py +++ b/gallery/tutorials/configuration.py @@ -102,6 +102,36 @@ time.sleep(1) print("Done Loop 2\n") +# %% +# Enabling GPU Acceleration +# ------------------------- +# Enabling GPU acceleration requires installing supporting software +# packages and small config changes. Installing the supporting +# software is most easily accomplished by installing ASPIRE with one +# of the published GPU extensions, for example ``pip install +# "aspire[dev,gpu_12x]"``. Once the packages are installed users +# should find that the NUFFT calls are automatically running on the +# GPU. Additional acceleration is achieved by enabling `cupy` for +# `numeric` and `fft` components. +# +# .. code-block:: yaml +# +# common: +# # numeric module to use - one of numpy/cupy +# numeric: cupy +# # fft backend to use - one of pyfftw/scipy/cupy/mkl +# fft: cupy +# +# Alternatively, like other config options, this can be changed +# dynamically with code. +# +# .. code-block:: python +# +# from aspire import config +# +# config["common"]["numeric"] = "cupy" +# config["common"]["fft"] = "cupy" +# # %% # Resolution diff --git a/gallery/tutorials/pipeline_demo.py b/gallery/tutorials/pipeline_demo.py index 8910436de2..77d304b156 100644 --- a/gallery/tutorials/pipeline_demo.py +++ b/gallery/tutorials/pipeline_demo.py @@ -63,7 +63,7 @@ defocus_ct = 7 ctf_filters = [ - RadialCTFFilter(pixel_size=5, defocus=d) + RadialCTFFilter(pixel_size=vol.pixel_size, defocus=d) for d in np.linspace(defocus_min, defocus_max, defocus_ct) ] diff --git a/gallery/tutorials/tutorials/cov3d_simulation.py b/gallery/tutorials/tutorials/cov3d_simulation.py index 5fced70fbb..741a47de99 100644 --- a/gallery/tutorials/tutorials/cov3d_simulation.py +++ b/gallery/tutorials/tutorials/cov3d_simulation.py @@ -47,7 +47,9 @@ L=img_size, n=num_imgs, vols=vols, - unique_filters=[RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)], + unique_filters=[ + RadialCTFFilter(pixel_size=10, defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + ], dtype=dtype, ) diff --git a/pyproject.toml b/pyproject.toml index 3cd57981ef..364ef107f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "aspire" -version = "0.12.3" +version = "0.13.0" description = "Algorithms for Single Particle Reconstruction" readme = "README.md" # Optional requires-python = ">=3.8" @@ -31,9 +31,7 @@ dependencies = [ "click", "confuse >= 2.0.0", "cvxpy", - # finufft 2.2.0 doesn't seemt to run on GHA Windows CI... - "finufft==2.2.0 ; sys_platform != 'win32'", - "finufft==2.1.0 ; sys_platform == 'win32'", + "finufft==2.3.0", "gemmi >= 0.6.5", "grpcio >= 1.54.2", "joblib", @@ -45,7 +43,6 @@ dependencies = [ "pillow", "psutil", "pymanopt", - "pyshtools<=4.10.4", # 4.11.7 might have a packaging bug "PyWavelets", "ray >= 2.9.2", "scipy >= 1.10.0", @@ -61,11 +58,7 @@ dependencies = [ "Source" = "https://github.com/ComputationalCryoEM/ASPIRE-Python" [project.optional-dependencies] -gpu-102 = ["pycuda", "cupy-cuda102", "cufinufft==1.3"] -gpu-110 = ["pycuda", "cupy-cuda110", "cufinufft==1.3"] -gpu-111 = ["pycuda", "cupy-cuda111", "cufinufft==1.3"] -gpu-11x = ["pycuda", "cupy-cuda11x", "cufinufft==1.3"] -gpu-12x = ["pycuda", "cupy-cuda12x", "cufinufft==2.2.0"] +gpu-12x = ["cupy-cuda12x", "cufinufft==2.3.0"] dev = [ "black", "bumpversion", diff --git a/src/aspire/__init__.py b/src/aspire/__init__.py index 4a1419c7ed..e6b58dfbf5 100644 --- a/src/aspire/__init__.py +++ b/src/aspire/__init__.py @@ -12,7 +12,7 @@ from aspire.exceptions import handle_exception # version in maj.min.bld format -__version__ = "0.12.3" +__version__ = "0.13.0" # Setup `confuse` config diff --git a/src/aspire/abinitio/__init__.py b/src/aspire/abinitio/__init__.py index ff14cc2d45..e8115ea185 100644 --- a/src/aspire/abinitio/__init__.py +++ b/src/aspire/abinitio/__init__.py @@ -4,8 +4,10 @@ # isort: off from .commonline_sync import CLSyncVoting +from .commonline_sync3n import CLSync3N from .commonline_c3_c4 import CLSymmetryC3C4 from .commonline_cn import CLSymmetryCn from .commonline_c2 import CLSymmetryC2 +from .commonline_d2 import CLSymmetryD2 # isort: on diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index e3d8fa50db..c0c3718803 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -4,6 +4,7 @@ import numpy as np import scipy.sparse as sparse +from aspire.image import Image from aspire.operators import PolarFT from aspire.utils import common_line_from_rots, fuzzy_mask, tqdm from aspire.utils.random import choice @@ -22,6 +23,8 @@ def __init__( n_rad=None, n_theta=360, n_check=None, + hist_bin_width=3, + full_width=6, max_shift=0.15, shift_step=1, mask=True, @@ -41,6 +44,10 @@ def __init__( of the resolution. Default is 0.15. :param shift_step: Resolution of shift estimation in pixels. Default is 1 pixel. + :param hist_bin_width: Bin width in smoothing histogram (degrees). + :param full_width: Selection width around smoothed histogram peak (degrees). + `adaptive` will attempt to automatically find the smallest number of + `hist_bin_width`s required to find at least one valid image index. :param mask: Option to mask `src.images` with a fuzzy mask (boolean). Default, `True`, applies a mask. """ @@ -52,6 +59,8 @@ def __init__( self.n_rad = n_rad self.n_theta = n_theta self.n_check = n_check + self.hist_bin_width = hist_bin_width + self.full_width = full_width self.clmatrix = None self.max_shift = math.ceil(max_shift * self.n_res) self.shift_step = shift_step @@ -91,8 +100,14 @@ def _prepare_pf(self): imgs = self.src.images[:] if self.mask: - fuzz_mask = fuzzy_mask((self.n_res, self.n_res), self.dtype) + # For best results and to reproduce MATLAB: + # Set risetime=2 + # Always compute mask (erf) in doubles. + fuzz_mask = fuzzy_mask((self.n_res, self.n_res), np.float64, risetime=2) + # Apply mask in doubles (allow imgs to upcast as needed) imgs = imgs * fuzz_mask + # Cast to desired type + imgs = Image(imgs.asnumpy().astype(self.dtype, copy=False)) # Obtain coefficients of polar Fourier transform for input 2D images pft = PolarFT( @@ -259,8 +274,9 @@ def estimate_shifts(self, equations_factor=1, max_memory=4000): show = False if logging.getLogger().isEnabledFor(logging.DEBUG): show = True - # Negative sign comes from using -i conversion of Fourier transformation - est_shifts = sparse.linalg.lsqr(shift_equations, -shift_b, show=show)[0] + + # Estimate shifts. + est_shifts = sparse.linalg.lsqr(shift_equations, shift_b, show=show)[0] est_shifts = est_shifts.reshape((self.n_img, 2)) return est_shifts @@ -305,15 +321,16 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): n_equations = self._estimate_num_shift_equations( n_img, equations_factor, max_memory ) + # Allocate local variables for estimating 2D shifts based on the estimated number # of equations. The shift equations are represented using a sparse matrix, # since each row in the system contains four non-zeros (as it involves # exactly four unknowns). The variables below are used to construct # this sparse system. The k'th non-zero element of the equations matrix # is stored at index (shift_i(k),shift_j(k)). - shift_i = np.zeros(4 * n_equations, dtype=self.dtype) - shift_j = np.zeros(4 * n_equations, dtype=self.dtype) - shift_eq = np.zeros(4 * n_equations, dtype=self.dtype) + shift_i = np.zeros((n_equations, 4), dtype=self.dtype) + shift_j = np.zeros((n_equations, 4), dtype=self.dtype) + shift_eq = np.zeros((n_equations, 4), dtype=self.dtype) shift_b = np.zeros(n_equations, dtype=self.dtype) # Prepare the shift phases to try and generate filter for common-line detection @@ -373,33 +390,33 @@ def _get_shift_equations_approx(self, equations_factor=1, max_memory=4000): sidx = sidx1 if c1[sidx1] > c2[sidx2] else sidx2 dx = -max_shift + sidx * shift_step - # Create a shift equation for the image pair [i,j] - idx = np.arange(4 * shift_eq_idx, 4 * shift_eq_idx + 4) # angle of common ray in image i shift_alpha = c_ij * d_theta # Angle of common ray in image j. shift_beta = c_ji * d_theta # Row index to construct the sparse equations - shift_i[idx] = shift_eq_idx + shift_i[shift_eq_idx] = shift_eq_idx # Columns of the shift variables that correspond to the current pair [i, j] - shift_j[idx] = [2 * i, 2 * i + 1, 2 * j, 2 * j + 1] + shift_j[shift_eq_idx] = [2 * i, 2 * i + 1, 2 * j, 2 * j + 1] # Right hand side of the current equation shift_b[shift_eq_idx] = dx # Compute the coefficients of the current equation coefs = np.array( [ - np.sin(shift_alpha), np.cos(shift_alpha), - -np.sin(shift_beta), + np.sin(shift_alpha), -np.cos(shift_beta), + -np.sin(shift_beta), ] ) - shift_eq[idx] = -1 * coefs if is_pf_j_flipped else coefs + shift_eq[shift_eq_idx] = ( + [-1, -1, 0, 0] * coefs if is_pf_j_flipped else coefs + ) # create sparse matrix object only containing non-zero elements shift_equations = sparse.csr_matrix( - (shift_eq, (shift_i, shift_j)), + (shift_eq.flatten(), (shift_i.flatten(), shift_j.flatten())), shape=(n_equations, 2 * n_img), dtype=self.dtype, ) diff --git a/src/aspire/abinitio/commonline_c3_c4.py b/src/aspire/abinitio/commonline_c3_c4.py index 8e9652258a..670d314d36 100644 --- a/src/aspire/abinitio/commonline_c3_c4.py +++ b/src/aspire/abinitio/commonline_c3_c4.py @@ -47,7 +47,7 @@ def __init__( n_theta=None, max_shift=0.15, shift_step=1, - epsilon=1e-3, + epsilon=1e-2, max_iters=1000, degree_res=1, seed=None, @@ -561,7 +561,7 @@ def _syncmatrix_ij_vote_3n(self, clmatrix, i, j, k_list, n_theta): :param n_theta: The number of points in the theta direction (common lines) :return: The (i,j) rotation block of the synchronization matrix """ - good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list) + _, good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list) rots = self._rotratio_eulerangle_vec(clmatrix, i, j, good_k, n_theta) @@ -691,7 +691,8 @@ def _J_sync_power_method(self, vijs): ) while itr < max_iters and residual > epsilon: itr += 1 - vec_new = self._signs_times_v(vijs, vec) + # Note, this appears to need double precision for accuracy in the following division. + vec_new = self._signs_times_v(vijs, vec).astype(np.float64, copy=False) vec_new = vec_new / norm(vec_new) residual = norm(vec_new - vec) vec = vec_new @@ -849,7 +850,13 @@ def cl_angles_to_ind(cl_angles, n_theta): thetas = np.mod(thetas, 2 * np.pi) # linear scale from [0,2*pi) to [0,n_theta). - return np.mod(np.round(thetas / (2 * np.pi) * n_theta), n_theta).astype(int) + ind = np.mod(np.round(thetas / (2 * np.pi) * n_theta), n_theta).astype(int) + + # Return scalar for single value. + if ind.size == 1: + ind = ind.flat[0] + + return ind @staticmethod def g_sync(rots, order, rots_gt): diff --git a/src/aspire/abinitio/commonline_d2.py b/src/aspire/abinitio/commonline_d2.py new file mode 100644 index 0000000000..a8e951c642 --- /dev/null +++ b/src/aspire/abinitio/commonline_d2.py @@ -0,0 +1,1975 @@ +import logging + +import numpy as np +import scipy.sparse.linalg as la +from numpy.linalg import norm + +from aspire.abinitio import CLOrient3D +from aspire.operators import PolarFT +from aspire.utils import J_conjugate, Rotation, all_pairs, all_triplets, tqdm, trange +from aspire.utils.random import randn +from aspire.volume import DnSymmetryGroup + +logger = logging.getLogger(__name__) + + +class CLSymmetryD2(CLOrient3D): + """ + Define a class to estimate 3D orientations using common lines methods for + molecules with D2 (dihedral) symmetry. + + Corresponding publication: + E. Rosen and Y. Shkolnisky, + Common lines ab-initio reconstruction of D2-symmetric molecules, + SIAM Journal on Imaging Sciences, volume 13-4, p. 1898-1994, 2020 + """ + + def __init__( + self, + src, + n_rad=None, + n_theta=None, + max_shift=0.15, + shift_step=1, + grid_res=1200, + inplane_res=5, + eq_min_dist=7, + epsilon=0.01, + seed=None, + mask=True, + ): + """ + Initialize object for estimating 3D orientations for molecules with D2 symmetry. + + :param src: The source object of 2D denoised or class-averaged images with metadata + :param n_rad: The number of points in the radial direction of Fourier image. + :param n_theta: The number of points in the theta direction of Fourier image. + :param max_shift: Maximum range for shifts as a proportion of resolution. Default = 0.15. + :param shift_step: Resolution of shift estimation in pixels. Default = 1 pixel. + :param grid_res: Number of sampling points on sphere for projetion directions. + These are generated using the Saaf-Kuijlaars algorithm. Default value is 1200. + :param inplane_res: The sampling resolution of in-plane rotations for each + projection direction. Default value is 5 degrees. + :param eq_min_dist: Width of strip around equator projection directions from + which we do not sample directions. Default value is 7 degrees. + :param epsilon: Tolerance for J-synchronization power method. + :param seed: Optional seed for RNG. + :param mask: Option to mask `src.images` with a fuzzy mask (boolean). + Default, `True`, applies a mask. + """ + + super().__init__( + src, + n_rad=n_rad, + n_theta=n_theta, + max_shift=max_shift, + shift_step=shift_step, + mask=mask, + ) + + self.grid_res = grid_res + self.inplane_res = inplane_res + self.n_inplane_rots = int(360 / self.inplane_res) + self.eq_min_dist = eq_min_dist + self.seed = seed + self.epsilon = epsilon + + self.triplets = all_triplets(self.n_img) + self.pairs, self.pairs_to_linear = all_pairs(self.n_img, return_map=True) + self.n_pairs = len(self.pairs) + + # D2 symmetry group. + # Rearrange in order Identity, about_x, about_y, about_z. + # This ordering is necessary for reproducing MATLAB code results. + self.gs = DnSymmetryGroup(order=2, dtype=self.dtype).matrices[[0, 3, 2, 1]] + + def estimate_rotations(self): + """ + Estimate rotation matrices for molecules with D2 symmetry. Sets the attribute + self.rotations with an array of estimated rotation matrices, size src.nx3x3. + """ + # Pre-compute phase-shifted polar Fourier. + self._compute_shifted_pf() + + # Generate lookup data + self._generate_lookup_data() + self._generate_scl_lookup_data() + + # Compute self common-line scores. + self._compute_scl_scores() + + # Compute common-lines and estimate relative rotations Rijs. + self._compute_cl_scores() + + # Perform handedness synchronization. + self.Rijs_sync = self._global_J_sync(self.Rijs_est) + + # Synchronize colors. + self.colors, self.Rijs_rows = self._sync_colors(self.Rijs_sync) + + # Synchronize signs. + Ris = self._sync_signs(self.Rijs_rows, self.colors) + + # Assign rotations. + self.rotations = Ris + + ######################### + # Prepare Polar Fourier # + ######################### + + def _compute_shifted_pf(self): + """ + Pre-compute shifted and full polar Fourier transforms. + """ + logger.info("Preparing polar Fourier transform.") + pf = self.pf + + # Generate shift phases. + r_max = pf.shape[-1] + max_shift_1d = np.ceil(2 * np.sqrt(2) * self.max_shift) + shifts, shift_phases, _ = self._generate_shift_phase_and_filter( + r_max, max_shift_1d, self.shift_step + ) + self.n_shifts = len(shifts) + + # Reconstruct full polar Fourier for use in correlation. + pf[:, :, 0] = 0 # Matching matlab convention to zero out the lowest frequency. + pf /= norm(pf, axis=2)[..., np.newaxis] # Normalize each ray. + self.pf_full = PolarFT.half_to_full(pf) + + # Pre-compute shifted pf's. + pf_shifted = pf[:, None] * shift_phases[None, :, None] + self.pf_shifted = pf_shifted.reshape( + (self.n_img, self.n_shifts * (self.n_theta // 2), r_max) + ) + + ################################### + # Generate Commonline Lookup Data # + ################################### + + def _generate_lookup_data(self): + """ + Generate candidate relative rotations and corresponding common line indices. + """ + logger.info("Generating commonline lookup data.") + # Generate uniform grid on sphere with Saff-Kuijlaars and take one quarter + # of sphere because of D2 symmetry redundancy. + sphere_grid = self._saff_kuijlaars(self.grid_res) + octant1_mask = np.all(sphere_grid > 0, axis=1) + octant2_mask = ( + (sphere_grid[:, 0] > 0) & (sphere_grid[:, 1] > 0) & (sphere_grid[:, 2] < 0) + ) + sphere_grid1 = sphere_grid[octant1_mask] + sphere_grid2 = sphere_grid[octant2_mask] + + # Mark Equator Directions. + # Common lines between projection directions which are perpendicular to + # symmetry axes (equator images) have common line degeneracies. Two images + # taken from directions on the same great circle which is perpendicular to + # some symmetry axis only have 2 common lines instead of 4, and must be + # treated separately. + # We detect such directions by taking a strip of radius + # `eq_min_dist` about the 3 great circles perpendicular to the symmetry + # axes of D2 (i.e to X,Y and Z axes). + eq_class1 = self._mark_equators(sphere_grid1, self.eq_min_dist) + eq_class2 = self._mark_equators(sphere_grid2, self.eq_min_dist) + + # Mark Top View Directions. + # A Top view projection image is taken from the direction of one of the + # symmetry axes. Since all symmetry axes of D2 molecules are perpendicular + # this means that such an image is an equator with repect to both symmetry + # axes which are perpendicular to the direction of the symmetry axis from + # which the image was made, e.g. if the image was formed by projecting in + # the direction of the X (symmetry) axis, then it is an equator with + # respect to both Y and Z symmetry axes (it's direction is the + # interesection of 2 great circles perpendicular to Y and Z axes). + # Such images have severe degeneracies. A pair of Top View images (taken + # from different directions or a Top View and equator image only have a + # single common line. A top view and a regular non-equator image only have + # two common lines. + + # Remove top views from sphere grids and update equator indices and classes. + self.sphere_grid1 = sphere_grid1[eq_class1 < 4] + self.sphere_grid2 = sphere_grid2[eq_class2 < 4] + self.eq_class1 = eq_class1[eq_class1 < 4] + self.eq_class2 = eq_class2[eq_class2 < 4] + + # Generate in-plane rotations for each grid point on the sphere. + self.inplane_rotated_grid1 = self._generate_inplane_rots( + self.sphere_grid1, self.inplane_res + ) + self.inplane_rotated_grid2 = self._generate_inplane_rots( + self.sphere_grid2, self.inplane_res + ) + + # Generate commmonline angles induced by all relative rotation candidates. + cl_angles1, self.eq2eq_Rij_table_11 = self._generate_commonline_angles( + self.inplane_rotated_grid1, + self.inplane_rotated_grid1, + self.eq_class1, + self.eq_class1, + ) + cl_angles2, self.eq2eq_Rij_table_12 = self._generate_commonline_angles( + self.inplane_rotated_grid1, + self.inplane_rotated_grid2, + self.eq_class1, + self.eq_class2, + same_octant=False, + ) + + # Generate commonline indices. + self.cl_idx_1 = self._generate_commonline_indices(cl_angles1) + self.cl_idx_2 = self._generate_commonline_indices(cl_angles2) + self.cl_idx = np.hstack((self.cl_idx_1, self.cl_idx_2)) + + def _generate_commonline_angles( + self, + Ris, + Rjs, + Ri_eq_class, + Rj_eq_class, + same_octant=True, + ): + """ + Compute commonline angles induced by the 4 sets of relative rotations + Rij = Ri.T @ g_m @ Rj, m = 0,1,2,3, where g_m is the identity and rotations + about the three axes of symmetry of a D2 symmetric molecule. Note, we only + compute commonline angles between pairs of images which are not equator + images with respect to the same axis of symmetry. To do this we build a + table, `eq2eq_Rij_table`, which is `False` for pairs of images that are + equator images with respect to the same axis of symmetry and `True` otherwise. + + :param Ris: First set of candidate rotations. + :param Rjs: Second set of candidate rotation. + :param Ri_eq_class: Equator classification for Ris. + :param Rj_eq_class: Equator classification for Rjs. + :param same_octant: True if both sets of candidates are in the same octant. + + :return: Commonline angles induced by relative rotation candidates. + """ + n_rots_i = len(Ris) + n_theta = Ris.shape[1] # Same for Rjs, TODO: Don't call this n_theta + + # Generate upper triangular table of indicators of all pairs which are not + # equators with respect to the same symmetry axis (named unique_pairs). + eq_table = np.outer(Ri_eq_class > 0, Rj_eq_class > 0) + in_same_class = (Ri_eq_class[:, None] - Rj_eq_class.T[None]) == 0 + eq2eq_Rij_table = ~(eq_table * in_same_class) + + # For candidates in the same octant only need upper triangle of table. + if same_octant: + eq2eq_Rij_table = np.triu(eq2eq_Rij_table, 1) + + n_pairs = np.count_nonzero(eq2eq_Rij_table) + idx = 0 + cl_angles = np.zeros((2, n_pairs, n_theta, n_theta // 2, 4, 2)) + + for i in range(n_rots_i): + unique_pairs_i = np.nonzero(eq2eq_Rij_table[i])[0] + if len(unique_pairs_i) == 0: + continue + Ri = Ris[i] + for j in unique_pairs_i: + Rj = Rjs[j, : n_theta // 2] + + # Compute relative rotations candidates Rij = Ri.T @ gs @ Rj + Rijs = ( + np.transpose(Ri, axes=(0, 2, 1))[:, None, None] + @ self.gs + @ Rj[:, None] + ) + + # Common line indices induced by Rijs + cl_angles[0, idx, :, :, :, 0] = np.arctan2( + -Rijs[..., 0, 2], Rijs[..., 1, 2] + ) + cl_angles[0, idx, :, :, :, 1] = np.arctan2( + Rijs[..., 2, 0], -Rijs[..., 2, 1] + ) + cl_angles[1, idx, :, :, :, 0] = np.arctan2( + -Rijs[..., 2, 0], Rijs[..., 2, 1] + ) + cl_angles[1, idx, :, :, :, 1] = np.arctan2( + Rijs[..., 0, 2], -Rijs[..., 1, 2] + ) + + idx += 1 + + # Make all angles non-negative and convert to degrees. + cl_angles = (cl_angles + 2 * np.pi) % (2 * np.pi) + cl_angles = cl_angles * 180 / np.pi + + return cl_angles, eq2eq_Rij_table + + ######################################## + # Generate Self-Commonline Lookup Data # + ######################################## + + def _generate_scl_lookup_data(self): + """ + Generate lookup data for self-commonlines. + """ + logger.info("Generating self-commonline lookup data.") + # Get self-commonline angles. + self.scl_angles1 = self._generate_scl_angles( + self.inplane_rotated_grid1, + self.eq_class1, + ) + self.scl_angles2 = self._generate_scl_angles( + self.inplane_rotated_grid2, + self.eq_class2, + ) + + # Get self-commonline indices. + self.scl_idx_1, self.scl_eq_lin_idx_lists_1 = self._generate_scl_indices( + self.scl_angles1, self.eq_class1 + ) + self.scl_idx_2, self.scl_eq_lin_idx_lists_2 = self._generate_scl_indices( + self.scl_angles2, self.eq_class2 + ) + self.scl_idx_lists = np.concatenate( + (self.scl_eq_lin_idx_lists_1, self.scl_eq_lin_idx_lists_2), axis=1 + ) + + # Compute non-equator indices. + # Register non equator indices. Denote by C_ij the j'th in-plane rotation of + # the i'th ML candidate, and arrange all candidates in a list with their in-plane + # rotations in the order: C_11,...,C_1r,...,C_m1,...,C_mr where m is the + # number of candidates and r is the number of in plane rotations. Here we + # create a sub-list of only non equator candidates, i.e., if i_1,...,i_p are + # non equators then we have the sub list is + # C_(i_1)1,...,C(i_1)r,...C_(i_p)1,...,C_(i_p)r. + n_non_eq = np.count_nonzero(self.eq_class1 == 0) + np.count_nonzero( + self.eq_class2 == 0 + ) + non_eq_idx = np.zeros((n_non_eq, self.n_inplane_rots), dtype=int) + non_eq_idx[:, 0] = ( + np.hstack( + ( + np.nonzero(self.eq_class1 == 0)[0], + len(self.eq_class1) + np.nonzero(self.eq_class2 == 0)[0], + ) + ) + * self.n_inplane_rots + ) + non_eq_idx[:, 1:] = non_eq_idx[:, [0]] + np.arange(1, self.n_inplane_rots) + + self.non_eq_idx = non_eq_idx + + # Non-topview equator indices. + self.non_tv_eq_idx = np.concatenate( + ( + np.nonzero(self.eq_class1 > 0)[0], + len(self.eq_class1) + np.nonzero(self.eq_class2 > 0)[0], + ) + ) + + # Generate maps from scl indices to relative rotations. + self._generate_scl_scores_idx_map() + + def _generate_scl_angles(self, Ris, eq_class): + """ + Generate self-commonline angles. For each candidate rotation a pair of self-commonline + angles are generated for each of the 3 self-commonlines induced by D2 symmetry. + + :param Ris: Candidate rotation matrices, (n_sphere_grid, n_inplane_rots, 3, 3). + :param eq_idx: Equator index mask for Ris. + :param eq_class: Equator classification for Ris. + + :return: `scl_angles` of shape (n_sphere_grid, n_inplane_rots, 3, 2). + """ + + # For each candidate rotation Ri we generate the set of 3 self-commonlines. + scl_angles = np.zeros((*Ris.shape[:2], 3, 2), dtype=Ris.dtype) + n_rots = len(Ris) + for i in range(n_rots): + Ri = Ris[i] + for k, g in enumerate(self.gs[1:]): + g_Ri = g @ Ri + Riis = np.transpose(Ri, axes=(0, 2, 1)) @ g_Ri + + scl_angles[i, :, k, 0] = np.arctan2(Riis[:, 2, 0], -Riis[:, 2, 1]) + scl_angles[i, :, k, 1] = np.arctan2(-Riis[:, 0, 2], Riis[:, 1, 2]) + + # Prepare self commonline coordinates. + scl_angles = scl_angles % (2 * np.pi) + + # Deal with non top view equators + # A non-TV equator has only one self common line. However, we clasify an + # equator as an image whose projection direction is at radial distance < + # `eq_min_dist` from the great circle perpendicular to a symmetry axis, + # and not strictly zero distance. Thus in most cases we get 2 common lines + # differing by a small difference in degrees. Actually the calculation above + # gives us two NEARLY antipodal lines, so we first flip one of them by + # adding 180 degrees to it. Then we aggregate all the rays within the range + # between these two resulting lines to compute the score of this self common + # line for this candidate. The scoring part is done in the ML function itself. + # Furthermore, the line perpendicular to the self common line, though not + # really a self common line, has the property that all its values are real + # and both halves of the line (rays differing by pi, emanating from the + # origin) have the same values, and so it 'behaves like' a self common + # line which we also register here and exploit in the ML function. + # We put the 'real' self common line at 2 first coordinates, the + # candidate for perpendicular line is in 3rd coordinate. + + # If this is a self common line with respect to x-equator then the actual self + # common line(s) is given by the self relative rotations given by the y and z + # rotation (by 180 degrees) group members, i.e. Ri^TgyRj and Ri^TgzRj + scl_angles[eq_class == 1] = scl_angles[eq_class == 1][:, :, [1, 2, 0]] + scl_angles[eq_class == 1, :, 0] = scl_angles[eq_class == 1][:, :, 0, [1, 0]] + + # If this is a self common line with respect to y-equator then the actual self + # common line(s) is given by the self relative rotations given by the x and z + # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgzRj + scl_angles[eq_class == 2] = scl_angles[eq_class == 2][:, :, [0, 2, 1]] + scl_angles[eq_class == 2, :, 0] = scl_angles[eq_class == 2][:, :, 0, [1, 0]] + + # If this is a self common line with respect to z-equator then the actual self + # common line(s) is given by the self relative rotations given by the x and y + # rotation (by 180 degrees) group members, i.e. Ri^TgxRj and Ri^TgyRj + # No need to rearrange entries, the "real" common lines are already in + # indices 1 and 2, but flip one common line to antipodal. + scl_angles[eq_class == 3, :, 0] = scl_angles[eq_class == 3][:, :, 0, [1, 0]] + + # Make sure angle range is < 180 degrees. + # p1 marks "equator" self-commonlines where both entries of the first + # scl are greater than both entries of the second scl. + p1 = scl_angles[eq_class > 0, :, 0] > scl_angles[eq_class > 0, :, 1] + p1 = p1[:, :, 0] & p1[:, :, 1] + # p2 marks "equator" self-commonlines where the angle range between the + # first and second sets of self-commonlines is greater than 180. + p2 = scl_angles[eq_class > 0, :, 0] - scl_angles[eq_class > 0, :, 1] < -np.pi + p2 = p2[:, :, 0] | p2[:, :, 1] + p = p1 | p2 + + # Swap entries satisfying either of the above conditions. + scl_angles[eq_class > 0] = ( + scl_angles[eq_class > 0][:, :, [1, 0, 2]] * p[:, :, None, None] + + scl_angles[eq_class > 0] * ~p[:, :, None, None] + ) + + # Convert from radians [0,2*pi) to degrees [0, 360). + return np.round(scl_angles * 180 / np.pi) % 360 + + def _generate_scl_indices(self, scl_angles, eq_class): + """ + Generate self-commonline indices. This includes a set of linear indices for + all candidate rotations as well as lists of self-commonline index ranges for + equator candidates. + + :param scl_angles: Self-commonline angles, shape (n_sphere_grid, n_inplane_rots, 3, 2). + :param eq_class: Equator classification for the sphere_grid points represented + by the first axis of `scl_angles`. + + :returns: + - scl_indices, self-commonline linear indices. + - eq_lin_idx_lists, a list containing a range of self-commonline + indices for each equator candidate. + """ + L = self.n_theta + + # Convert from angles to indices. + scl_indices = self._generate_commonline_indices(scl_angles) + scl_angles = np.mod(np.round(scl_angles / (2 * np.pi) * L), L).astype(int) + + # Create candidate common line linear indices lists for equators. + # As indicated above for equator candidate, for each self common line we + # don't get a single coordinate but a range of them. Here we register a + # list of coordinates for each such self common line candidate. + non_top_view_eq_idx = np.nonzero(eq_class > 0)[0] + n_eq = len(non_top_view_eq_idx) + n_inplane_rots = scl_angles.shape[1] + count_eq = 0 + + # eq_lin_idx_lists[0,i,j] registers a list of linear indices of the j'th + # in-plane rotation of the range for the (only) self common line of the i'th + # candidate. eq_lin_idx_lists[1,i,j] registers the actual (integer) angle + # of the self common line in the 2D Fourier space. Note that we need only + # one number since each self common line has radial coordinates of the form + # (theta, theta+180). + eq_lin_idx_lists = np.empty((2, n_eq, n_inplane_rots), dtype=object) + for i in non_top_view_eq_idx.tolist(): + for j in range(n_inplane_rots): + idx1 = self._circ_seq(scl_angles[i, j, 0, 0], scl_angles[i, j, 1, 0], L) + idx2 = self._circ_seq(scl_angles[i, j, 0, 1], scl_angles[i, j, 1, 1], L) + + # Ensure idx1 and idx2 have same number of elements. + # Might be off by one due to n_theta discretization. + end = np.minimum(len(idx1), len(idx2)) + idx1, idx2 = idx1[:end], idx2[:end] + + # Adjust so idx1 is in [0, 180) range. + is_geq_than_pi = idx1 >= L // 2 + idx1[is_geq_than_pi] = idx1[is_geq_than_pi] - L // 2 + idx2[is_geq_than_pi] = (idx2[is_geq_than_pi] + L // 2) % L + + # register indices in list. + eq_lin_idx_lists[0, count_eq, j] = np.ravel_multi_index( + (idx1, idx2), (L // 2, L) + ) + eq_lin_idx_lists[1, count_eq, j] = idx1 + count_eq += 1 + + return scl_indices, eq_lin_idx_lists + + def _generate_scl_scores_idx_map(self): + """ + Generates lookup tables for maximum likelihood scheme to estimate commonlines + between images. + + This method creates two lookup tables (`oct1_ij_map` and `oct2_ij_map`) + for pairs of candidate rotations (i, j) under the following conditions: + + 1. Both rotations Ri and Rj are in octant 1. + 2. Ri is in octant 1 and Rj is in octant 2. + + For each pair of candidate rotations the tables give a map into the set of + self-commonlines induced by those rotations. This table will be used later + to incorporate a likelihood score for self-commonlines into the likelihood + score for common lines for each pair of images. + """ + # Calculate number of rotations in each octant. + n_rot_1 = len(self.scl_idx_1) // (3 * self.n_inplane_rots) + n_rot_2 = len(self.scl_idx_2) // (3 * self.n_inplane_rots) + + # First the map for i 0] + if len(unique_pairs_i) == 0: + continue + i_idx_plus_offset = i_idx + (i * self.n_inplane_rots) + + for j in unique_pairs_i: + j_idx_plus_offset = j_idx + (j * self.n_inplane_rots) + oct2_ij_map[idx] = np.column_stack( + (i_idx_plus_offset, j_idx_plus_offset) + ) + idx += 1 + + tmp1 = oct1_ij_map[:, :, 0].flatten() + tmp2 = oct1_ij_map[:, :, 1].flatten() + self.oct1_ij_map = np.column_stack((tmp1, tmp2)) + + tmp1 = oct2_ij_map[:, :, 0].flatten() + tmp2 = oct2_ij_map[:, :, 1].flatten() + self.oct2_ij_map = np.column_stack((tmp1, tmp2)) + + ############################################## + # Compute Self-Commonline Correlation Scores # + ############################################## + + def _compute_scl_scores(self): + """ + Compute correlations for self-commonline candidates. For each image i + we compute an auto-correlation table between all polar Fourier rays. + We then use that table to apply a score to each non-topview candidate + rotation which gives the likelihood that the self-commonlines induced + by that candidate belong to the image i.. + """ + logger.info("Computing self-commonline correlation scores.") + n_img = self.n_img + n_theta = self.n_theta + n_eq = len(self.non_tv_eq_idx) + n_inplane = self.n_inplane_rots + + # Prepare self-commonline indices. + scl_matrix = np.concatenate((self.scl_idx_1, self.scl_idx_2)) + M = len(scl_matrix) // 3 + scl_idx = scl_matrix.reshape(M, 3) + + # Get non-equator indices to use with corrs matrix. + non_eq_lin_idx = self.non_eq_idx.flatten() + n_non_eq = len(non_eq_lin_idx) + non_eq_idx = np.unravel_index( + scl_idx[non_eq_lin_idx].flatten(), (n_theta // 2, n_theta) + ) + + # Compute max correlation over all shifts. + corrs = np.real( + self.pf_shifted @ np.transpose(np.conj(self.pf_full), (0, 2, 1)) + ) + corrs = np.reshape(corrs, (self.n_img, self.n_shifts, n_theta // 2, n_theta)) + corrs = np.max(corrs, axis=1) + + # Map correlations to probabilities (in the spirit of Maximum Likelihood). + corrs = 0.5 * (corrs + 1) + + # Compute equator measures. + eq_measures = np.zeros((self.n_img, n_theta // 2), dtype=self.dtype) + for i in range(self.n_img): + eq_measures[i] = self._all_eq_measures(corrs[i]) + + # Handle the cases: Non-equator, Non-top-view equator images. + # 1. Non-equators: just take product of probabilities. + corrs_out = np.zeros((n_img, M), dtype=self.dtype) + prod_corrs = np.prod( + corrs[:, non_eq_idx[0], non_eq_idx[1]].reshape(self.n_img, n_non_eq, 3), + axis=2, + ) + corrs_out[:, non_eq_lin_idx] = prod_corrs + + # 2. Non-topview equators: adjust scores by eq_measures + for eq_idx in range(n_eq): + for j in range(n_inplane): + # Take the correlations for the self common line candidate of the + # "equator rotation" `eq_idx` with respect to image i, and + # multiply by all scores from the function eq_measures (see + # documentation inside the function ). Then take maximum over + # all the scores. + scl_idx_list = np.unravel_index( + self.scl_idx_lists[0, eq_idx, j], (n_theta // 2, n_theta) + ) + true_scls_corrs = corrs[:, scl_idx_list[0], scl_idx_list[1]] + scls_cand_idx = self.scl_idx_lists[1, eq_idx, j] + eq_measures_j = eq_measures[:, scls_cand_idx] + measures_agg = true_scls_corrs[:, :, None] * eq_measures_j[:, None, :] + k = self.non_tv_eq_idx[eq_idx] + corrs_out[:, k * n_inplane + j] = np.max(measures_agg, axis=(-2, -1)) + + self.scls_scores = corrs_out + + def _all_eq_measures(self, corrs): + """ + Compute a measure indicating how likely an image is an equator image. + + :param corrs: Correlation table of shape (n_theta // 2, n_theta). + + :return: (n_theta // 2) likelihood scores. + """ + # First compute the eq measure (corrs(scl-k,scl+k) for k=1:n_theta // 4) + # An equator image of a D2 molecule has the following property: If t_i is + # the angle of one of the rays of the self common line then all the pairs of + # rays of the form (t_i-k,t_i+k) for k=1:n_theta // 4 are identical. For each t_i we + # average over correlations between the lines (t_i-k,t_i+k) for k=1:n_theta // 4 + # to measure the likelihood that the image is an equator and the ray (line) + # with angle t_i is a self common line. + # (This first loop can be done once outside this function and then pass + # idx as an argument). + L = self.n_theta + L_half = L // 2 + + # Generate indices using broadcasting. + t_i = np.arange(L_half)[:, None, None] + k_vals = np.arange(1, L // 4 + 1)[None, :, None] + neg_pos_k = np.array([-1, 1])[None, None, :] + + # Calculate indices, shape: (L//2, L//4, 2). + idx = np.mod(t_i + k_vals * neg_pos_k, L) + + # Convert to Fourier ray indices. + idx_1 = idx[:, :, 0].flatten() + idx_2 = idx[:, :, 1].flatten() + + # Adjust idx_1 to be within [0, 180) and adjust idx_2 accordingly. + is_geq_than_pi = idx_1 >= L_half + idx_1[is_geq_than_pi] -= L_half + idx_2[is_geq_than_pi] = (idx_2[is_geq_than_pi] + L_half) % L + + # Compute correlations + eq_corrs = corrs[idx_1, idx_2].reshape(L_half, L // 4) + corrs_mean = np.mean(eq_corrs, axis=1) + + # Now compute correlations for normals to scls. + # An eqautor image of a D2 molecule has the additional following property: + # The normal line to a self common line in 2D Fourier plane is real valued + # and both of its rays have identical values. We use the correlation + # between one Fourier ray of the normal to a self common line candidate t_i + # with its anti-podal as an additional way to measure if the image is an + # equator and t_i+0.5*pi is the normal to its self common line. + r = np.ceil(2 * L / 360).astype( + int + ) # Search radius within 2 degrees of normal ray. + + # Generate indices for normal to scl index. + normal_2_scl_idx_0 = ( + L_half - np.arange(L_half // 2 - r, L_half // 2 + r + 1) + ) % L + normal_2_scl_idx = (normal_2_scl_idx_0 + np.arange(L_half).reshape(-1, 1)) % L + + # Adjust indices to be within [0, 180) range. + normal_2_scl_idx = np.where( + normal_2_scl_idx >= L_half, normal_2_scl_idx - L_half, normal_2_scl_idx + ) + + # Compute correlations for normals. + normal_corrs = corrs[normal_2_scl_idx, normal_2_scl_idx + L_half] + normal_corrs_max = np.max(normal_corrs, axis=1) + + return corrs_mean * normal_corrs_max + + ######################################### + # Compute Commonline Correlation Scores # + ######################################### + + def _compute_cl_scores(self): + """ + Run common lines Maximum likelihood procedure for a D2 molecule, to find + the set of rotations Ri^TgkRj, k=1,2,3,4 for each pair of images i and j. + """ + logger.info("Computing commonline correlation scores.") + L = self.n_theta + n_pairs = self.n_img * (self.n_img - 1) // 2 + + # Map the self common line scores of each 2 candidate rotations R_i, R_j + n_lookup_1 = len(self.scl_idx_1) // 3 + oct1_ij_map = np.vstack((self.oct1_ij_map, self.oct1_ij_map[:, [1, 0]])) + oct2_ij_map = self.oct2_ij_map + oct2_ij_map[:, 1] += n_lookup_1 + oct2_ij_map = np.vstack((oct2_ij_map, oct2_ij_map[:, [1, 0]])) + ij_map = np.vstack((oct1_ij_map, oct2_ij_map)) + + # Gather commonline indices and unravel to index into correlations. + cl_idx = np.unravel_index(self.cl_idx, (L // 2, L)) + + # Allocate output variables + corrs_idx = np.zeros(n_pairs, dtype=np.int64) + corrs_out = np.zeros(n_pairs, dtype=self.dtype) + + ij_idx = 0 + pbar = tqdm( + desc="Searching for commonlines between pairs of images", total=n_pairs + ) + + # For each i'th image compute the correlation with all j'th images, j > i. + for i in range(self.n_img - 1): + pf_i = self.pf_shifted[i] + scores_i = self.scls_scores[i] + + # Gather all pf_j in one array for vectorized computation + pf_js = self.pf_full[i + 1 : self.n_img] + n_pf_js = pf_js.shape[0] + + # Compute maximum correlation over all shifts for all pf_j + corrs = np.real(pf_i @ np.conj(pf_js.transpose(0, 2, 1))) + corrs = corrs.reshape(n_pf_js, self.n_shifts, L // 2, L) + corrs = np.max(corrs, axis=1) # Max over shifts + + # Take the product over symmetrically induced candidates. Eq. 4.5 in paper. + prod_corrs = corrs[:, cl_idx[0], cl_idx[1]] + prod_corrs = prod_corrs.reshape(n_pf_js, len(prod_corrs[0]) // 4, 4) + prod_corrs = np.prod(prod_corrs, axis=2) + + # Incorporate scores of individual rotations from self-commonlines + scores_js = self.scls_scores[i + 1 : self.n_img] + scores_ij = scores_i[ij_map[:, 0]] * scores_js[:, ij_map[:, 1]] + + # Find maximum correlations and update results + prod_corrs = prod_corrs * scores_ij + max_indices = np.argmax(prod_corrs, axis=1) + corrs_idx[ij_idx : ij_idx + len(max_indices)] = max_indices + corrs_out[ij_idx : ij_idx + len(max_indices)] = prod_corrs[ + np.arange(len(max_indices)), max_indices + ] + + ij_idx += len(max_indices) + pbar.update(len(max_indices)) + + pbar.close() + + # Get estimated relative viewing directions + self.corrs_idx = corrs_idx + self.Rijs_est = self._get_Rijs_from_lin_idx(corrs_idx) + + def _get_Rijs_from_lin_idx(self, lin_idx): + """ + Restore map results from maximum-likelihood over commonlines to corresponding + relative rotations. + + :param lin_idx: Set of linear indices corresponding to best estimate of Rijs. + + :return: Estimated Rijs. + """ + Rijs_est = np.zeros((len(lin_idx), 4, 3, 3), dtype=self.dtype) + n_cand_per_oct = len(self.cl_idx_1) // 4 + oct1_idx = lin_idx < n_cand_per_oct + n_est_in_oct1 = np.count_nonzero(oct1_idx) + if n_est_in_oct1 > 0: + Rijs_est[oct1_idx] = self._get_Rijs_from_oct(lin_idx[oct1_idx], octant=1) + if n_est_in_oct1 <= len(lin_idx): + Rijs_est[~oct1_idx] = self._get_Rijs_from_oct( + lin_idx[~oct1_idx] - n_cand_per_oct, octant=2 + ) + + return Rijs_est + + def _get_Rijs_from_oct(self, lin_idx, octant=1): + """ + Calculate estimated relative rotations Rijs from the linear indices of + common-lines estimates from the search table. Rijs are generated from the + rotation grids from which the common-lines table was generated. + + :param lin_idx: Set of linear indices corresponding to best estimate of Rijs. + :param octant: Octant of rotation grid from which the Rj rotation was selected + when generating the common-lines table. + :return: Estimated Rijs. + """ + if octant not in [1, 2]: + raise ValueError("`octant` must be 1 or 2.") + + # Get pairs lookup table. + if octant == 1: + unique_pairs = self.eq2eq_Rij_table_11 + else: + unique_pairs = self.eq2eq_Rij_table_12 + + n_theta = self.n_inplane_rots + n_lookup_pairs = np.count_nonzero(unique_pairs) + n_rots = len(self.sphere_grid1) + if octant == 1: + n_rots2 = n_rots + else: + n_rots2 = len(self.sphere_grid2) + + # Map linear indices of chosen pairs of rotation candidates from ML to regular indices. + p_idx, inplane_i, inplane_j = np.unravel_index( + lin_idx, (2 * n_lookup_pairs, n_theta, n_theta // 2) + ) + transpose_idx = p_idx >= n_lookup_pairs + p_idx[transpose_idx] -= n_lookup_pairs + s = self.inplane_rotated_grid1.shape + inplane_rotated_grid = np.reshape( + self.inplane_rotated_grid1, (np.prod(s[0:2]), 3, 3) + ) + if octant == 1: + s2 = s + inplane_rotated_grid2 = inplane_rotated_grid + else: + s2 = self.inplane_rotated_grid2.shape + inplane_rotated_grid2 = np.reshape( + self.inplane_rotated_grid2, (np.prod(s2[0:2]), 3, 3) + ) + + # Convert linear indices of unique table to linear indices of index pairs table. + idx_vec = np.arange(np.prod(unique_pairs.shape)) + unique_lin_idx = idx_vec[unique_pairs.flatten()] + I, J = np.unravel_index(unique_lin_idx, (n_rots, n_rots2)) + est_idx = np.vstack((I[p_idx], J[p_idx])) + + # Assemble relative rotations Ri^TgRj using linear indices, where g is a group member of D2. + Ris_lin_idx = np.ravel_multi_index((est_idx[0], inplane_i), s[:2]) + Rjs_lin_idx = np.ravel_multi_index((est_idx[1], inplane_j), s2[:2]) + Ris_t = np.transpose(inplane_rotated_grid[Ris_lin_idx], (0, 2, 1)) + Rjs = inplane_rotated_grid2[Rjs_lin_idx] + Rijs_est = Ris_t[:, None] @ self.gs @ Rjs[:, None] + + Rijs_est[transpose_idx] = np.transpose(Rijs_est[transpose_idx], (0, 1, 3, 2)) + + return Rijs_est + + #################################### + # Perform Global J Synchronization # + #################################### + + def _global_J_sync(self, Rijs): + """ + Global J-synchronization of all third row outer products. Given n_pairsx4x3x3 + matrices Rijs, each of which might contain a spurious J, ie. + Rij = J @ Ri.T @ gs @ Rj @ J instead of Rij = Ri.T @ gs @ Rj, we return Rijs + that all have either a spurious J or not. + + :param Rijs: An (n-choose-2)x4 x3x3 array where each 3x3 slice holds an estimate + for the corresponding outer-product Ri.T @ Rj. Each estimate might have a + spurious J independently of other estimates. + + :return: Rijs, all of which have a spurious J or not. + """ + logger.info("Performing global handedness synchronization.") + # Find best J_configuration. + J_list = self._J_configuration(Rijs) + + # Determine relative handedness of Rijs. + sign_ij_J = self._J_sync_power_method(J_list) + + # Synchronize Rijs + logger.info("Applying global handedness synchronization.") + mask = sign_ij_J == 1 + Rijs[mask] = J_conjugate(Rijs[mask]) + + return Rijs + + def _J_configuration(self, Rijs): + """ + For each triplet of indices (i, j, k), consider the relative rotations + tuples {Ri^TgmRj}, {Ri^TglRk} and {Rj^TgrRk}. Compute norms of the form + ||Ri^TgmRj*Rj^TglRk-Ri^TglRk||, ||J*Ri^TgmRj*J*Rj^TglRk-Ri^TglRk||, + ||Ri^TgmRj*J*Rj^TglRk*J-Ri^TglRk| and ||Ri^TgmRj*Rj^TglRk-J*Ri^TglRk*J|| + where gm,gl,gr are the varipus gorup members of Dn and J=diag([1,1-1]). + The correct "J-configuration" is given for the smallest of these 4 norms. + + :param Rijs: (n-choose-2)x3x3 array of relative rotations. + :return: List of n-choose-3 indices in {0,1,2,3} indicating + which J-configuration for each triplet of Rijs, i epsilon: + itr += 1 + vec_new = self._signs_times_v(J_list, vec) + vec_new = vec_new / norm(vec_new) + residual = norm(vec_new - vec) + vec = vec_new + logger.info( + f"Iteration {itr}, residual {round(residual, 5)} (target {epsilon})" + ) + + # We need only the signs of the eigenvector + J_sync = np.sign(vec) + J_sync = np.sign(J_sync[0]) * J_sync # Stabilize J_sync + + return J_sync + + def _signs_times_v(self, J_list, vec): + """ + Multiplication of the J-synchronization matrix by a candidate eigenvector. + + The J-synchronization matrix is a matrix representation of the handedness graph, + Gamma, whose set of nodes consists of the estimates Rijs and whose set of edges + consists of the undirected edges between all triplets of estimates Rij, Rjk, + and Rik, where i rel_perm[2] + ) + trip_idx += 1 + + colors_i = np.sum(colors_i, axis=1) + + return colors_i + + def _mult_cmat_by_vec(self, c_perms, v): + """ + Multiply color matrix by vector v "on the fly". + + :param c_perms: An (N over 3) vector. Each corresponds to a triplet of + indices i0 and -1->1 + + return sync_signs2 + + def _estimate_rows(self, sync_signs2, c_mat_5d): + """ + Construct 3N x 3N matrix of rank-1 3x3 blocks of sij*vi_m.T @ vj_m, + the leading eigenvectors of which correspond to estimates for the rows + of the rotations Ri, up to signs. + """ + c_mat_5d_mp = np.concatenate((c_mat_5d, -c_mat_5d), axis=1) + rows_arr = np.zeros((3, self.n_img, 3 * self.n_img), dtype=self.dtype) + svals = np.zeros((3, 2, self.n_img), dtype=self.dtype) + + logger.info("Constructing and decomposing N sign synchronization matrices...") + for c in range(3): + for r in range(self.n_img): + # Image r used for signs. + c_mat_eff = self._fill_sign_sync_matrix_c( + c_mat_5d_mp, sync_signs2, c, r + ) + + # Construct (3*N)x(3*N) rank 1 matrices from Qik + c_mat_for_svd = np.zeros( + (3 * self.n_img, 3 * self.n_img), dtype=self.dtype + ) + for i in range(self.n_img): + row_3Nx3 = c_mat_eff[i] + row_3Nx3 = row_3Nx3.reshape(3 * self.n_img, 3) + c_mat_for_svd[:, 3 * i : 3 * i + 3] = row_3Nx3 + + c_mat_for_svd = c_mat_for_svd + c_mat_for_svd.T + + # Extract leading eigenvector of rank 1 matrix. For each r and c + # this gives an estimate for the c'th row of the rotation Rr, up + # to sign +/-. + for i in range(self.n_img): + c_mat_for_svd[3 * i : 3 * i + 3, 3 * i : 3 * i + 3] = c_mat_eff[ + i, i + ] + U, S, _ = np.linalg.svd(c_mat_for_svd) + svals[c, :, r] = S[:2] + rows_arr[c, r] = U[:, 0] + + return rows_arr + + def _compute_signs_adjustment(self, rows_arr): + """ + Compute signs adjustment vector. + """ + # Sync signs according to results for each image. Dot products between + # signed row estimates are used to construct an (N over 2)x(N over 2) + # sign synchronization matrix S. If (v_i)k and (v_j)k are the i'th and + # j'th estimates for the c'th row of Rk, then the entry (i,k),(k,j) entry + # of S is <(v_i)k,(v_j)k>, where the rows and columns of S are indexed by + # double indexes (i,j), 1<=i 0: + ij_signs[zeros_idx] = 1 + + return np.sign(ij_signs) + + def _mult_smat_by_vec(self, v, sign_mat, pairs_map): + """ + Multiplies the signs sync matrix by a vector. + """ + v_out = np.zeros_like(v) + for i in range(self.n_img): + for j in range(i + 1, self.n_img): + ij = self.pairs_to_linear[i, j] + v_out[ij] = sign_mat[ij] @ v[pairs_map[ij]] + return v_out + + #################### + # Helper Functions # + #################### + + @staticmethod + def _circ_seq(n1, n2, L): + """ + For integers 0 <= n1, n2 < L, make a circular sequence of integers between + n1 and n2 modulo L. + + :param n1: First integer in sequence. + :param n2: Last integer in sequence. + :param L: Modulus of values in sequence. + :return: Circular sequence modulo L. + """ + if min(n1, n2) < 0 or max(n1, n2) >= L: + raise ValueError( + f"n1 and n2 must both be in [0, {L}). Found n1={n1}, n2={n2}." + ) + if n2 < n1: + n2 += L + if n1 == n2: + return np.array([n1]).astype(int) % L + + seq = np.arange(n1, n2 + 1).astype(int) % L + + return seq + + @staticmethod + def _saff_kuijlaars(N): + """ + Generates N vertices on the unit sphere that are approximately evenly distributed. + + This implements the recommended algorithm in spherical coordinates + (theta, phi) according to "Distributing many points on a sphere" + by E.B. Saff and A.B.J. Kuijlaars, Mathematical Intelligencer 19.1 + (1997) 5--11. + + :param N: Number of vertices to generate. + + :return: Nx3 array of vertices in cartesian coordinates. + """ + k = np.arange(1, N + 1) + h = -1 + 2 * (k - 1) / (N - 1) + theta = np.arccos(h) + phi = np.zeros(N) + + for i in range(1, N - 1): + phi[i] = (phi[i - 1] + 3.6 / (np.sqrt(N * (1 - h[i] ** 2)))) % (2 * np.pi) + + # Spherical coordinates + x = np.sin(theta) * np.cos(phi) + y = np.sin(theta) * np.sin(phi) + z = np.cos(theta) + + mesh = np.column_stack((x, y, z)) + + return mesh + + @staticmethod + def _mark_equators(sphere_grid, eq_filter_angle): + """ + This method categorizes a set of 3D unit vectors into equator and non-equator + vectors determined by the parameter `eq_filter_angle`, returned as `eq_idx`. + It further categorizes the vectors into the classes non_equator, z-equator, + y-equator, x-equator, z-top_view, y-top_view, and x-top_view, which are labeled + respectively with the values 0 - 6 and returned as `eq_class`. + + :param sphere_grid: Nx3 array of vertices in cartesian coordinates. + :param eq_filter_angle: Angular distance from equator to be marked as + an equator point. + + :return: eq_class, n_rots length array of values indicating equator class. + """ + # Project each vector onto xy, xz, yz planes and measure angular distance + # from each plane. + n_rots = len(sphere_grid) + angular_dists = np.zeros((n_rots, 3), dtype=sphere_grid.dtype) + + # For each grid point get the distance from the z, y, and x-axis equators. + for i in range(3): + proj_along_axis = sphere_grid.copy() + proj_along_axis[:, 2 - i] = 0 + proj_along_axis /= np.linalg.norm(proj_along_axis, axis=1)[:, None] + angular_dists[:, i] = np.sum(sphere_grid * proj_along_axis, axis=-1) + + # Mark all views close to an equator. + eq_min_dist = np.cos(eq_filter_angle * np.pi / 180) + n_eqs = np.count_nonzero(angular_dists > eq_min_dist, axis=1) + + # Classify equators. + # 0 -> non-equator view + # 1 -> z equator + # 2 -> y equator + # 3 -> x equator + # 4 -> z top view + # 5 -> y top view + # 6 -> x top view + eq_class = np.zeros(n_rots) + + # Grid points which are equator points with respect to 2 equators are considered top views. + # For example, a grid point that is close to both the x and y equator is a z top view. + top_view_idx = n_eqs > 1 + top_view_class = np.argmin(angular_dists[top_view_idx] > eq_min_dist, axis=1) + eq_class[top_view_idx] = top_view_class + 4 + + # Assign grid points which are equator points with respect to only 1 equator. + eq_view_idx = n_eqs == 1 + eq_view_class = np.argmax(angular_dists[eq_view_idx] > eq_min_dist, axis=1) + eq_class[eq_view_idx] = eq_view_class + 1 + + return eq_class + + @staticmethod + def _generate_inplane_rots(sphere_grid, d_theta): + """ + This function takes projection directions (points on the 2-sphere) and + generates rotation matrices in SO(3). The projection direction + is the 3rd column and columns 1 and 2 span the perpendicular plane. + To properly discretize SO(3), for each projection direction we generate + [2*pi/dtheta] "in-plane" rotations, of the plane + perpendicular to this direction. This is done by generating one rotation + for each direction and then multiplying on the right by a rotation about + the Z-axis by k*dtheta degrees, k=0...2*pi/dtheta-1. + + :param sphere_grid: A set of points on the 2-sphere. + :param d_theta: Resolution for in-plane rotations (in degrees) + :returns: 4D array of rotations of size len(sphere_grid) x n_inplane_rots x 3 x 3. + """ + dtype = sphere_grid.dtype + # Generate one rotation for each point on the sphere. + n_rots = len(sphere_grid) + Ri2 = np.column_stack((-sphere_grid[:, 1], sphere_grid[:, 0], np.zeros(n_rots))) + Ri2 /= np.linalg.norm(Ri2, axis=1)[:, None] + Ri1 = np.cross(Ri2, sphere_grid) + Ri1 /= np.linalg.norm(Ri1, axis=1)[:, None] + + rots_grid = np.zeros((n_rots, 3, 3), dtype=dtype) + rots_grid[:, :, 0] = Ri1 + rots_grid[:, :, 1] = Ri2 + rots_grid[:, :, 2] = sphere_grid + + # Generate in-plane rotations. + d_theta *= np.pi / 180 + # Negative signs to match matlab. + inplane_rots = Rotation.about_axis( + "z", np.arange(0, -2 * np.pi, -d_theta), dtype=dtype + ).matrices + n_inplane_rots = len(inplane_rots) + + # Generate in-plane rotations of rots_grid. + inplane_rotated_grid = np.zeros((n_rots, n_inplane_rots, 3, 3), dtype=dtype) + for i in range(n_rots): + inplane_rotated_grid[i] = rots_grid[i] @ inplane_rots + + return inplane_rotated_grid + + def _generate_commonline_indices(self, cl_angles): + """ + Converts a multi-dimensional stack of pairs of commonline angles in [0, 360) degrees + into a flattened stack of polar Fourier linear indices, with the convention that + each linear index corresponds to an unraveled index in [0, n_theta // 2) x [0, n_theta). + + :param cl_angles: A multi-dimensional stack of commonline angles in degrees, shape (..., 2). + :return: cl_idx, a 1D array of linear indices. + """ + L = self.n_theta + + # Flatten the stack + og_shape = cl_angles.shape + cl_angles = np.reshape(cl_angles, (np.prod(og_shape[:-1]), 2)) + + # Fourier ray index + row_sub = np.round(cl_angles[:, 0] * L / 360).astype("int") % L + col_sub = np.round(cl_angles[:, 1] * L / 360).astype("int") % L + + # Restrict Ri in-plane coordinates to <180 degrees. + is_geq_than_pi = row_sub >= L // 2 + row_sub[is_geq_than_pi] = row_sub[is_geq_than_pi] - L // 2 + col_sub[is_geq_than_pi] = (col_sub[is_geq_than_pi] + (L // 2)) % L + + # Convert to linear indices in 180x360 correlation matrix. + cl_idx = np.ravel_multi_index((row_sub, col_sub), dims=(L // 2, L)) + + return cl_idx diff --git a/src/aspire/abinitio/commonline_sync.py b/src/aspire/abinitio/commonline_sync.py index 5e07181f4a..aae9e4b3e0 100644 --- a/src/aspire/abinitio/commonline_sync.py +++ b/src/aspire/abinitio/commonline_sync.py @@ -24,7 +24,15 @@ class CLSyncVoting(CLOrient3D, SyncVotingMixin): """ def __init__( - self, src, n_rad=None, n_theta=360, max_shift=0.15, shift_step=1, mask=True + self, + src, + n_rad=None, + n_theta=360, + max_shift=0.15, + shift_step=1, + hist_bin_width=3, + full_width=6, + mask=True, ): """ Initialize an object for estimating 3D orientations using synchronization matrix @@ -36,6 +44,10 @@ def __init__( :param max_shift: Determines maximum range for shifts as a proportion of the resolution. Default is 0.15. :param shift_step: Resolution for shift estimation in pixels. Default is 1 pixel. + :param hist_bin_width: Bin width in smoothing histogram (degrees). + :param full_width: Selection width around smoothed histogram peak (degrees). + `adaptive` will attempt to automatically find the smallest number of + `hist_bin_width`s required to find at least one valid image index. :param mask: Option to mask `src.images` with a fuzzy mask (boolean). Default, `True`, applies a mask. """ @@ -45,6 +57,8 @@ def __init__( n_theta=n_theta, max_shift=max_shift, shift_step=shift_step, + hist_bin_width=hist_bin_width, + full_width=full_width, mask=mask, ) self.syncmatrix = None @@ -189,7 +203,7 @@ def _syncmatrix_ij_vote(self, clmatrix, i, j, k_list, n_theta): :return: The (i,j) rotation block of the synchronization matrix """ - good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list) + _, good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list) rots = self._rotratio_eulerangle_vec(clmatrix, i, j, good_k, n_theta) diff --git a/src/aspire/abinitio/commonline_sync3n.cu b/src/aspire/abinitio/commonline_sync3n.cu new file mode 100644 index 0000000000..eeaee723b9 --- /dev/null +++ b/src/aspire/abinitio/commonline_sync3n.cu @@ -0,0 +1,415 @@ + +/* from i,j indices to the common index in the N-choose-2 sized array */ +#define PAIR_IDX(N,I,J) ((2*N-I-1)*I/2 + J-I-1) + + +inline void mult_3x3(double *out, double *R1, double *R2) { + /* 3X3 matrices multiplication: out = R1*R2 + * Note, this differs from the MATLAB mult_3x3. + */ + + int i,j,k; + + for(i=0; i<3; i++){ + for(j=0; j<3; j++){ + out[i*3 + j] = 0; + for (k=0; k<3; k++){ + out[i*3 + j] += R1[i*3+k] * R2[k*3+j]; + } + } + } +} + +inline void JRJ(double *R, double *A) { + /* multiple 3X3 matrix by J from both sizes: A = JRJ */ + A[0]=R[0]; + A[1]=R[1]; + A[2]=-R[2]; + A[3]=R[3]; + A[4]=R[4]; + A[5]=-R[5]; + A[6]=-R[6]; + A[7]=-R[7]; + A[8]=R[8]; +} + +inline double diff_norm_3x3(const double *R1, const double *R2) { + /* difference 2 matrices and return squared norm: ||R1-R2||^2 */ + int i; + double norm = 0; + for (i=0; i<9; i++) {norm += (R1[i]-R2[i])*(R1[i]-R2[i]);} + return norm; +} + + +extern "C" __global__ +void signs_times_v(int n, double* Rijs, const double* vec, double* new_vec, bool J_weighting) +{ + /* thread index (1d), represents "i" index */ + unsigned int i = blockDim.x * blockIdx.x + threadIdx.x; + + /* no-op when out of bounds */ + if(i >= n) return; + + double c[4]; + unsigned int j; + unsigned int k; + for(k=0;k<4;k++){c[k]=0;} + unsigned long ij, jk, ik; + int best_i; + double best_val; + double s_ij_jk, s_ik_jk, s_ij_ik; + double alt_ij_jk, alt_ij_ik, alt_ik_jk; + + double *Rij, *Rjk, *Rik; + double JRijJ[9], JRjkJ[9], JRikJ[9]; + double tmp[9]; + + int signs_confs[4][3]; + for(int a=0; a<4; a++) { for(k=0; k<3; k++) { signs_confs[a][k]=1; } } + signs_confs[1][0]=-1; signs_confs[1][2]=-1; + signs_confs[2][0]=-1; signs_confs[2][1]=-1; + signs_confs[3][1]=-1; signs_confs[3][2]=-1; + + /* initialize alternatives */ + /* when we find the best J-configuration, we also compare it to the alternative 2nd best one. + * this comparison is done for every pair in the triplete independently. to make sure that the + * alternative is indeed different in relation to the pair, we document the differences between + * the configurations in advance: + * ALTS(:,best_conf,pair) = the two configurations in which J-sync differs from + * best_conf in relation to pair */ + + int ALTS[2][4][3]; + ALTS[0][0][0]=1; ALTS[0][1][0]=0; ALTS[0][2][0]=0; ALTS[0][3][0]=1; + ALTS[1][0][0]=2; ALTS[1][1][0]=3; ALTS[1][2][0]=3; ALTS[1][3][0]=2; + ALTS[0][0][1]=2; ALTS[0][1][1]=2; ALTS[0][2][1]=0; ALTS[0][3][1]=0; + ALTS[1][0][1]=3; ALTS[1][1][1]=3; ALTS[1][2][1]=1; ALTS[1][3][1]=1; + ALTS[0][0][2]=1; ALTS[0][1][2]=0; ALTS[0][2][2]=1; ALTS[0][3][2]=0; + ALTS[1][0][2]=3; ALTS[1][1][2]=2; ALTS[1][2][2]=3; ALTS[1][3][2]=2; + + + for(j=i+1; j< (n - 1); j++){ + ij = PAIR_IDX(n, i, j); + for(k=j+1; k< n; k++){ + ik = PAIR_IDX(n, i, k); + jk = PAIR_IDX(n, j, k); + + /* compute configurations matches scores */ + Rij = Rijs + 9*ij; + Rjk = Rijs + 9*jk; + Rik = Rijs + 9*ik; + + JRJ(Rij, JRijJ); + JRJ(Rjk, JRjkJ); + JRJ(Rik, JRikJ); + + mult_3x3(tmp, Rij, Rjk); + c[0] = diff_norm_3x3(tmp, Rik); + + mult_3x3(tmp, JRijJ, Rjk); + c[1] = diff_norm_3x3(tmp, Rik); + + mult_3x3(tmp, Rij, JRjkJ); + c[2] = diff_norm_3x3(tmp, Rik); + + mult_3x3(tmp, Rij, Rjk); + c[3] = diff_norm_3x3(tmp, JRikJ); + + /* find best match */ + best_i=0; best_val=c[0]; + if (c[1]= n) return; + + double c[4]; + unsigned int j; + unsigned int k; + for(k=0;k<4;k++){c[k]=0;} + unsigned long ij, jk, ik; + int best_i; + double best_val; + double s_ij_jk, s_ik_jk, s_ij_ik; + double alt_ij_jk, alt_ij_ik, alt_ik_jk; + double f_ij_jk, f_ik_jk, f_ij_ik; + + + double *Rij, *Rjk, *Rik; + double JRijJ[9], JRjkJ[9], JRikJ[9]; + double tmp[9]; + + int signs_confs[4][3]; + for(int a=0; a<4; a++) { for(k=0; k<3; k++) { signs_confs[a][k]=1; } } + signs_confs[1][0]=-1; signs_confs[1][2]=-1; + signs_confs[2][0]=-1; signs_confs[2][1]=-1; + signs_confs[3][1]=-1; signs_confs[3][2]=-1; + + /* initialize alternatives */ + /* when we find the best J-configuration, we also compare it to the alternative 2nd best one. + * this comparison is done for every pair in the triplete independently. to make sure that the + * alternative is indeed different in relation to the pair, we document the differences between + * the configurations in advance: + * ALTS(:,best_conf,pair) = the two configurations in which J-sync differs from + * best_conf in relation to pair */ + + int ALTS[2][4][3]; + ALTS[0][0][0]=1; ALTS[0][1][0]=0; ALTS[0][2][0]=0; ALTS[0][3][0]=1; + ALTS[1][0][0]=2; ALTS[1][1][0]=3; ALTS[1][2][0]=3; ALTS[1][3][0]=2; + ALTS[0][0][1]=2; ALTS[0][1][1]=2; ALTS[0][2][1]=0; ALTS[0][3][1]=0; + ALTS[1][0][1]=3; ALTS[1][1][1]=3; ALTS[1][2][1]=1; ALTS[1][3][1]=1; + ALTS[0][0][2]=1; ALTS[0][1][2]=0; ALTS[0][2][2]=1; ALTS[0][3][2]=0; + ALTS[1][0][2]=3; ALTS[1][1][2]=2; ALTS[1][2][2]=3; ALTS[1][3][2]=2; + + + for(j=i+1; j< (n - 1); j++){ + ij = PAIR_IDX(n, i, j); + for(k=j+1; k< n; k++){ + ik = PAIR_IDX(n, i, k); + jk = PAIR_IDX(n, j, k); + + /* compute configurations matches scores */ + Rij = Rijs + 9*ij; + Rjk = Rijs + 9*jk; + Rik = Rijs + 9*ik; + + JRJ(Rij, JRijJ); + JRJ(Rjk, JRjkJ); + JRJ(Rik, JRikJ); + + mult_3x3(tmp, Rij, Rjk); + c[0] = diff_norm_3x3(tmp, Rik); + + mult_3x3(tmp, JRijJ, Rjk); + c[1] = diff_norm_3x3(tmp, Rik); + + mult_3x3(tmp, Rij, JRjkJ); + c[2] = diff_norm_3x3(tmp, Rik); + + mult_3x3(tmp, Rij, Rjk); + c[3] = diff_norm_3x3(tmp, JRikJ); + + /* find best match */ + best_i=0; best_val=c[0]; + if (c[1]= n) return; + + double c[4]; + unsigned int j; + unsigned int k; + for(k=0;k<4;k++){c[k]=0;} + unsigned long ij, jk, ik; + int best_i; + double best_val; + double s_ij_jk, s_ik_jk, s_ij_ik; + double alt_ij_jk, alt_ij_ik, alt_ik_jk; + unsigned int l1,l2,l3; + double threshold; + double h = 1. / n_intervals; + + double *Rij, *Rjk, *Rik; + double JRijJ[9], JRjkJ[9], JRikJ[9]; + double tmp[9]; + + /* initialize alternatives */ + /* when we find the best J-configuration, we also compare it to the alternative 2nd best one. + * this comparison is done for every pair in the triplete independently. to make sure that the + * alternative is indeed different in relation to the pair, we document the differences between + * the configurations in advance: + * ALTS(:,best_conf,pair) = the two configurations in which J-sync differs from + * best_conf in relation to pair */ + + int ALTS[2][4][3]; + ALTS[0][0][0]=1; ALTS[0][1][0]=0; ALTS[0][2][0]=0; ALTS[0][3][0]=1; + ALTS[1][0][0]=2; ALTS[1][1][0]=3; ALTS[1][2][0]=3; ALTS[1][3][0]=2; + ALTS[0][0][1]=2; ALTS[0][1][1]=2; ALTS[0][2][1]=0; ALTS[0][3][1]=0; + ALTS[1][0][1]=3; ALTS[1][1][1]=3; ALTS[1][2][1]=1; ALTS[1][3][1]=1; + ALTS[0][0][2]=1; ALTS[0][1][2]=0; ALTS[0][2][2]=1; ALTS[0][3][2]=0; + ALTS[1][0][2]=3; ALTS[1][1][2]=2; ALTS[1][2][2]=3; ALTS[1][3][2]=2; + + + for(j=i+1; j< (n - 1); j++){ + ij = PAIR_IDX(n, i, j); + for(k=j+1; k< n; k++){ + ik = PAIR_IDX(n, i, k); + jk = PAIR_IDX(n, j, k); + + /* compute configurations matches scores */ + Rij = Rijs + 9*ij; + Rjk = Rijs + 9*jk; + Rik = Rijs + 9*ik; + + JRJ(Rij, JRijJ); + JRJ(Rjk, JRjkJ); + JRJ(Rik, JRikJ); + + mult_3x3(tmp, Rij, Rjk); + c[0] = diff_norm_3x3(tmp, Rik); + + mult_3x3(tmp, JRijJ, Rjk); + c[1] = diff_norm_3x3(tmp, Rik); + + mult_3x3(tmp, Rij, JRjkJ); + c[2] = diff_norm_3x3(tmp, Rik); + + mult_3x3(tmp, Rij, Rjk); + c[3] = diff_norm_3x3(tmp, JRikJ); + + /* find best match */ + best_i=0; best_val=c[0]; + if (c[1] src.n {src.n}." + " Consider reducing if curve fitting is infeasable." + ) + + # Auto configure GPU + self.__gpu_module = None + if not disable_gpu: + try: + import cupy as cp + + if cp.cuda.runtime.getDeviceCount() >= 1: + gpu_id = cp.cuda.runtime.getDevice() + logger.info( + f"cupy and GPU {gpu_id} found by cuda runtime; enabling cupy." + ) + self.__gpu_module = self.__init_cupy_module() + else: + logger.info("GPU not found, defaulting to numpy.") + + except ModuleNotFoundError: + logger.info("cupy not found, defaulting to numpy.") + + ########################################### + # High level algorithm steps # + ########################################### + def estimate_rotations(self): + """ + Estimate rotation matrices. + + :return: Array of rotation matrices, size n_imgx3x3. + """ + + logger.info(f"Estimating relative viewing directions for {self.n_img} images.") + + # Detect a single pair of common-lines between each pair of images + self.build_clmatrix() + + # Initial estimate of viewing directions + # Calculate relative rotations + Rijs0 = self._estimate_all_Rijs(self.clmatrix) + + # Compute and apply global handedness + Rijs = self._global_J_sync(Rijs0) + + # Build sync3n matrix + S = self._construct_sync3n_matrix(Rijs) + + # Optionally compute S weights + W = None + if self.S_weighting is True: + W = self._syncmatrix_weights(Rijs) + + # Yield rotations from S + self.rotations = self._sync3n_S_to_rot(S, W) + + ####################### + # Main Sync3N Methods # + ####################### + def _sync3n_S_to_rot(self, S, W=None, n_eigs=4): + """ + Use eigen decomposition of S to estimate transforms, + then project transforms to nearest rotations. + + :param S: Numpy array representing Synchronization matrix. + :param W: Optional weights array, default `None` is equal weighting of `S`. + :param n_eigs: Optional, number of eigenvalues to compute (min 3). + """ + + # Critical this occurs in double precision + S = S.astype(np.float64, copy=False) + + if n_eigs < 3: + raise ValueError( + f"n_eigs must be greater than 3, default is 4. Invoked with {n_eigs}" + ) + + if W is not None: + logger.info("Applying weights to synchronization matrix.") + if not W.shape == (self.n_img, self.n_img): + raise RuntimeError( + f"Shape of W should be {(self.n_img, self.n_img)}." + f" Received {W.shape}." + ) + # Initialize D + # Critical this occurs in double precision + W = W.astype(np.float64, copy=False) + D = np.mean(W, axis=1) + + Dhalf = D + # Compute mask of trouble D values + nulls = np.abs(D) < self._D_null + # Avoid trouble values when exponentiating + Dhalf[~nulls] = Dhalf[~nulls] ** (-0.5) + # Flush trouble values to zero + Dhalf[nulls] = 0 + # expand diagonal + Dhalf = np.diag(Dhalf) + + # Report W Diagnostic + W_normalized = Dhalf**2 @ W + nzidx = np.sum(W_normalized, axis=1) != 0 + err = np.linalg.norm(np.sum(W_normalized[nzidx], axis=1) - self.n_img) + if err > 1e-10: + logger.warning(f"Large Weights Matrix Normalization Error: {err}") + + # Make W of size 3Nx3N + W = np.kron(W, np.ones((3, 3), dtype=self.dtype)) + + # Make Dhalf of size 3Nx3N + Dhalf = np.diag(np.kron(np.diag(Dhalf), np.ones(3, dtype=np.float64))) + + # Apply weights to S + S = Dhalf @ (W * S) @ Dhalf + + # Extract three eigenvectors corresponding to non-zero eigenvalues. + d, v = stable_eigsh(S, n_eigs, which="LM") + + sort_idx = np.argsort(-d) + logger.info( + f"Top {n_eigs} eigenvalues from synchronization voting matrix: {d[sort_idx]}" + ) + + # Only need the top 3 eigen-vectors. + v = v[:, sort_idx[:3]] + + # Cancel symmetrization when using weights W + if W is not None: + # Untill now we used a symmetrized variant of the weighted Sync matrix, + # thus we didn't get the right eigenvectors. to fix that we just need + # to multiply: + v = Dhalf @ v + + # Yield estimated rotations from the eigen-vectors + rotations = v.reshape(self.n_img, 3, 3).transpose(0, 2, 1) + + # Enforce we are returning actual rotations + rotations = nearest_rotations(rotations, allow_reflection=True) + + return rotations.astype(self.dtype) + + def _construct_sync3n_matrix(self, Rij): + """ + Construct sync3n matrix from estimated rotations Rij. + + :param Rij: Numpy array of estimated rotations (all pairs). + :return: Synchronization matrix S, (3*N, 3*N). + """ + + # Initialize S with diag identity blocks + n = self.n_img + S = np.eye(3 * n, dtype=self.dtype).reshape(n, 3, n, 3) + + idx = 0 + for i in range(n): + for j in range(i + 1, n): + # S( (3*i-2):(3*i) , (3*j-2):(3*j) ) = Rij(:,:,idx); % Rij + S[i, :, j, :] = Rij[idx] + # S( (3*j-2):(3*j) , (3*i-2):(3*i) ) = Rij(:,:,idx)'; % Rji = Rij' + S[j, :, i, :] = Rij[idx].T + idx += 1 + + # Convert S shape to 3Nx3N + S = S.reshape(3 * n, 3 * n) + + return S + + def _syncmatrix_weights( + self, + Rijs, + permitted_inconsistency=1.5, + p_domain_limit=0.7, + max_iterations=12, + min_p_permitted=0.04, + ): + """ + Given relative rotations matrix `Rij`, + compute and return probability weights `W` for S. + + Default parameters here were taken from those in the MATLAB + code, with the original author noting they were found + empirically. + + :param permitted_inconsistency: Consistency condition is + `mean(Pij)/permitted_inconsistency < P < + mean(Pij)*permitted_inconsistency`. + :param p_domain_limit: Domain of P is [Pmin,Pmax], with + Pmin=p_domain_limit*Pmax + :param max_iterations: Maximum iterations for P estimation. + :param min_p_permitted: Small value at which to stop + attempting to synchronize P. + :return: Synchronization matrix weights `W`. + """ + logger.info("Computing synchronization matrix weights.") + + def _body(prev_too_low, Pmin, Pmax, hist, p_domain_limit=p_domain_limit): + """ + Helper function to run and test triangle_scores. + """ + # Get inistial estimate for Pij + P, sigma, Pij, hist = self._triangle_scores(Rijs, hist, Pmin, Pmax) + + # Check if P and Pij are consistent + mean_Pij = np.mean(Pij) + too_low = P < mean_Pij / permitted_inconsistency + too_high = P > mean_Pij * permitted_inconsistency + inconsistent = too_low | too_high + + # Check trend + if prev_too_low is not None and too_low != prev_too_low: + p_domain_limit = np.sqrt(p_domain_limit) + + # define limits for next P estimation + if too_high: + if P < min_p_permitted: + logger.error( + "Triangles Scores are poorly distributed, whatever small P we force." + ) + + if Pmax is not None: + Pmax = Pmax * p_domain_limit + else: + Pmax = P + + Pmin = Pmax * p_domain_limit + else: # too low + if Pmin is not None: + Pmin = Pmin / p_domain_limit + else: + Pmin = P + + Pmax = Pmin / p_domain_limit + + return inconsistent, Pij, (too_low, Pmin, Pmax, hist) + + # Repeat iteratively until estimations of P & Pij are consistent + i = 0 + res = (None,) * 4 + inconsistent = True + while inconsistent and i < max_iterations: + inconsistent, Pij, res = _body(*res) + i += 1 + + # Pack W + W = np.zeros((self.n_img, self.n_img), dtype=self.dtype) + idx = 0 + for i in range(self.n_img): + for j in range(i + 1, self.n_img): + W[i, j] = Pij[idx] + W[j, i] = Pij[idx] + idx += 1 + + return W + + def _triangle_scores_inner(self, Rijs): + """ + Computes histogram of `triangle scores`. + + Wrapper for cpu/gpu dispatch. + + :param Rijs: nchoose2 by 3 by 3 array of rotations. + :return: Histogram of triangle scores. + """ + + # host/gpu dispatch + if self.__gpu_module: + scores_hist = self._triangle_scores_inner_cupy(Rijs) + else: + scores_hist = self._triangle_scores_inner_host(Rijs) + + return scores_hist + + def _triangle_scores_inner_host(self, Rijs): + """ + See _triangle_scores_inner. + + CPU implementation. + """ + + # The following is adopted from Matlab triangle_scores_mex.c + + # Initialize probability result arrays + scores_hist = np.zeros(self.hist_intervals, dtype=np.uint32) + + c = np.empty((4), dtype=Rijs.dtype) + s = np.empty((3), dtype=Rijs.dtype) + for i in trange(self.n_img - 2, desc="Computing triangle scores"): + for j in range( + i + 1, self.n_img - 1 + ): # check bound (taken from MATLAB mex) + ij = self._pairs_to_linear[i, j] + Rij = Rijs[ij] + for k in range(j + 1, self.n_img): + ik = self._pairs_to_linear[i, k] + jk = self._pairs_to_linear[j, k] + Rik = Rijs[ik] + Rjk = Rijs[jk] + + # Compute conjugated rotats + Rij_J = J_conjugate(Rij) + Rik_J = J_conjugate(Rik) + Rjk_J = J_conjugate(Rjk) + + # Compute R muls and norms + c[0] = np.sum(((Rij @ Rjk) - Rik) ** 2) + c[1] = np.sum(((Rij_J @ Rjk) - Rik) ** 2) + c[2] = np.sum(((Rij @ Rjk_J) - Rik) ** 2) + c[3] = np.sum(((Rij @ Rjk) - Rik_J) ** 2) + + # Find best match + best_i = np.argmin(c) + best_val = c[best_i] + + # For each triangle side, find the best alternative + alt_ij_jk = c[self._ALTS[0][best_i][0]] + if c[self._ALTS[1][best_i][0]] < alt_ij_jk: + alt_ij_jk = c[self._ALTS[1][best_i][0]] + + alt_ik_jk = c[self._ALTS[0][best_i][1]] + if c[self._ALTS[1][best_i][1]] < alt_ik_jk: + alt_ik_jk = c[self._ALTS[1][best_i][1]] + + alt_ij_ik = c[self._ALTS[0][best_i][2]] + if c[self._ALTS[1][best_i][2]] < alt_ij_ik: + alt_ij_ik = c[self._ALTS[1][best_i][2]] + + # Compute scores + s[0] = 1 - np.sqrt(best_val / alt_ij_jk) # s_ij_jk + s[1] = 1 - np.sqrt(best_val / alt_ik_jk) # s_ik_jk + s[2] = 1 - np.sqrt(best_val / alt_ij_ik) # s_ij_ik + + # Update histogram + # Find integer bin [0,self.hist_intervals) + _l1, _l2, _l3 = np.maximum( + np.minimum( + (self.hist_intervals * s).astype(int), # implicit floor + self.hist_intervals - 1, # clamp upper bound + ), + 0, # clamp lower bound + ) + + scores_hist[_l1] += 1 + scores_hist[_l2] += 1 + scores_hist[_l3] += 1 + + return scores_hist + + def _triangle_scores_inner_cupy(self, Rijs): + """ + See _triangle_scores_inner. + + GPU implementation. + """ + + import cupy as cp + + triangle_scores = self.__gpu_module.get_function("triangle_scores_inner") + + Rijs_dev = cp.array(Rijs, dtype=np.float64) + + # This holds integer counts + scores_hist_dev = cp.zeros((self.hist_intervals), dtype=np.uint32) + + # call the kernel + blkszx = 512 + nblkx = (self.n_img + blkszx - 1) // blkszx + triangle_scores( + (nblkx,), + (blkszx,), + ( + self.n_img, + Rijs_dev, + self.hist_intervals, + scores_hist_dev, + ), + ) + + # d2h + scores_hist = scores_hist_dev.get() + + return scores_hist + + def _pairs_probabilities(self, Rijs, P2, A, a, B, b, x0): + """ + This function computes the probability of a pair `ij` having + an observed value of triangles score under two priors. Once + given it has an indicative common line, and again once given + it has an arbitrary common line. + + The probability of the common line to be indicative can then + be derived by Bayes Theorem. + + Wrapper for cpu/gpu dispatch. + + :param Rijs: nchoose2 by 3 by 3 array of rotations. + :param P2: distribution parameter + :param A: distribution parameter + :param a: distribution parameter + :param B: distribution parameter + :param b: distribution parameter + :param x0: Initial guess + :return: (log indicative probabilities, log arbitrary probabilities) + """ + # These param values are passed to C, force doubles. + params = np.array([P2, A, a, B, b, x0], dtype=np.float64) + + # host/gpu dispatch + if self.__gpu_module: + ln_f_ind, ln_f_arb = self._pairs_probabilities_cupy(Rijs, *params) + else: + ln_f_ind, ln_f_arb = self._pairs_probabilities_host(Rijs, *params) + + return ln_f_ind, ln_f_arb + + def _pairs_probabilities_host(self, Rijs, P2, A, a, B, b, x0): + """ + See _pairs_probabilities. + + CPU implementation. + """ + # The following is adopted from Matlab pairs_probabilities_mex.c `looper` + + # Initialize probability result arrays + ln_f_ind = np.zeros(len(Rijs), dtype=Rijs.dtype) + ln_f_arb = np.zeros(len(Rijs), dtype=Rijs.dtype) + + c = np.empty((4), dtype=Rijs.dtype) + for i in trange(self.n_img - 2, desc="Computing pair probabilities"): + for j in range(i + 1, self.n_img - 1): + ij = self._pairs_to_linear[i, j] + Rij = Rijs[ij] + for k in range(j + 1, self.n_img): + ik = self._pairs_to_linear[i, k] + jk = self._pairs_to_linear[j, k] + Rik = Rijs[ik] + Rjk = Rijs[jk] + + # Compute conjugated rotats + Rij_J = J_conjugate(Rij) + Rik_J = J_conjugate(Rik) + Rjk_J = J_conjugate(Rjk) + + # Compute R muls and norms + c[0] = np.sum(((Rij @ Rjk) - Rik) ** 2) + c[1] = np.sum(((Rij_J @ Rjk) - Rik) ** 2) + c[2] = np.sum(((Rij @ Rjk_J) - Rik) ** 2) + c[3] = np.sum(((Rij @ Rjk) - Rik_J) ** 2) + + # Find best match + best_i = np.argmin(c) + best_val = c[best_i] + + # For each triangle side, find the best alternative + alt_ij_jk = c[self._ALTS[0][best_i][0]] + if c[self._ALTS[1][best_i][0]] < alt_ij_jk: + alt_ij_jk = c[self._ALTS[1][best_i][0]] + alt_ik_jk = c[self._ALTS[0][best_i][1]] + if c[self._ALTS[1][best_i][1]] < alt_ik_jk: + alt_ik_jk = c[self._ALTS[1][best_i][1]] + alt_ij_ik = c[self._ALTS[0][best_i][2]] + if c[self._ALTS[1][best_i][2]] < alt_ij_ik: + alt_ij_ik = c[self._ALTS[1][best_i][2]] + + # Compute scores + s_ij_jk = 1 - np.sqrt(best_val / alt_ij_jk) + s_ik_jk = 1 - np.sqrt(best_val / alt_ik_jk) + s_ij_ik = 1 - np.sqrt(best_val / alt_ij_ik) + + # Update probabilities + # # Probability of pair ij having score given indicicative common line + # P2, B, b, x0, A, a + f_ij_jk = np.log( + P2 + * ( + B + * np.power(1 - s_ij_jk, b) + * np.exp(-b / (1 - x0) * (1 - s_ij_jk)) + ) + + (1 - P2) * A * np.power((1 - s_ij_jk), a) + ) + f_ik_jk = np.log( + P2 + * ( + B + * np.power(1 - s_ik_jk, b) + * np.exp(-b / (1 - x0) * (1 - s_ik_jk)) + ) + + (1 - P2) * A * np.power((1 - s_ik_jk), a) + ) + f_ij_ik = np.log( + P2 + * ( + B + * np.power(1 - s_ij_ik, b) + * np.exp(-b / (1 - x0) * (1 - s_ij_ik)) + ) + + (1 - P2) * A * np.power((1 - s_ij_ik), a) + ) + ln_f_ind[ij] += f_ij_jk + f_ij_ik + ln_f_ind[jk] += f_ij_jk + f_ik_jk + ln_f_ind[ik] += f_ik_jk + f_ij_ik + + # # Probability of pair ij having score given arbitrary common line + f_ij_jk = np.log(A * np.power((1 - s_ij_jk), a)) + f_ik_jk = np.log(A * np.power((1 - s_ik_jk), a)) + f_ij_ik = np.log(A * np.power((1 - s_ij_ik), a)) + ln_f_arb[ij] += f_ij_jk + f_ij_ik + ln_f_arb[jk] += f_ij_jk + f_ik_jk + ln_f_arb[ik] += f_ik_jk + f_ij_ik + + return ln_f_ind, ln_f_arb + + def _pairs_probabilities_cupy(self, Rijs, P2, A, a, B, b, x0): + """ + See _pairs_probabilities. + + GPU implementation. + """ + + import cupy as cp + + pairs_probabilities = self.__gpu_module.get_function("pairs_probabilities") + + Rijs_dev = cp.array(Rijs, dtype=np.float64) + ln_f_ind_dev = cp.zeros((self.n_img * (self.n_img - 1) // 2), dtype=np.float64) + ln_f_arb_dev = cp.zeros((self.n_img * (self.n_img - 1) // 2), dtype=np.float64) + + # call the kernel + blkszx = 512 + nblkx = (self.n_img + blkszx - 1) // blkszx + pairs_probabilities( + (nblkx,), + (blkszx,), + (self.n_img, Rijs_dev, P2, A, a, B, b, x0, ln_f_ind_dev, ln_f_arb_dev), + ) + + # accumulate over thread results + ln_f_arb = ln_f_arb_dev.get().astype(self.dtype, copy=False) + ln_f_ind = ln_f_ind_dev.get().astype(self.dtype, copy=False) + + return ln_f_ind, ln_f_arb + + def _triangle_scores( + self, + Rijs, + scores_hist, + Pmin, + Pmax, + a=2.2, + peak2sigma=2.43e-2, + P=0.5, + b=2.5, + x0=0.78, + ): + """ + Computes `triangle_scores`, attempts to fit curve to + distribution, and uses estimated distribution to compute + `pairs_probabilities`. + + Default parameters here were taken from those in the MATLAB + code, with the original author noting they were found + empirically. + + :param a: distribution parameter + :param peak2sigma: empirical relation between the location of + the peak of the histigram, and the mean error in the + common lines estimations. + :param P: distribution parameter + :param b: distribution parameter + :param x0: Initial guess + :return: Tuple of pairs probabilty Pij and related terms + (P, sigma, Pij, scores_hist) + """ + + Pmin = Pmin or 0 + Pmin = max(Pmin, 0) # Clamp probability to [0,1] + Pmax = Pmax or 1 + Pmax = min(Pmax, 1) # Clamp probability to [0,1] + + if scores_hist is None: + scores_hist = self._triangle_scores_inner(Rijs) + + # Histogram decomposition: P & sigma evaluation + h = 1 / self.hist_intervals + hist_x = np.arange(h / 2, 1, h) + # normalization factor of one component of the histogram + A = ( + (self.n_img * (self.n_img - 1) * (self.n_img - 2) / 2) + / self.hist_intervals + * (a + 1) + ) + # normalization of 2nd component: B = P*N_delta/sum(f), where f is the component formula + # B0 = ( + # P + # * (self.n_img * (self.n_img - 1) * (self.n_img - 2) / 2) + # / np.sum(((1 - hist_x) ** b) * np.exp(-b / (1 - x0) * (1 - hist_x))) + # ) + # P must be in lower and upper bounds or `curve_fit` will error + # This was not the case for MATLAB... + # P0 = np.clip(P, Pmin**3, Pmax**3) + # Note, MATLAB suggests the following, but I feel it is a bug. + # Will discuss with Yoel about the original code's intent. + # np.array([B0, P0, b, x0], dtype=np.float64) + start_values = None + lower_bounds = np.array([0, Pmin**3, 2, 0], dtype=np.float64) + upper_bounds = np.array([np.inf, Pmax**3, np.inf, 1], dtype=np.float64) + + with np.printoptions(precision=2): + logger.info(f"curve_fit lower_bounds:{lower_bounds}") + logger.info(f"curve_fit start_values:{start_values}") + logger.info(f"curve_fit upper_bounds:{upper_bounds}") + + # Fit distribution + def fun(x, B, P, b, x0, A=A, a=a): + """Function to fit. x is data vector.""" + return (1 - P) * A * (1 - x) ** a + P * B * (1 - x) ** b * np.exp( + -b / (1 - x0) * (1 - x) + ) + + popt, pcov = curve_fit( + fun, + hist_x.astype(np.float64, copy=False), + scores_hist.astype(np.float64, copy=False), + p0=start_values, + bounds=(lower_bounds, upper_bounds), + method="trf", # MATLAB used method "LAR" with algo "Trust-Region" + ) + B, P, b, x0 = popt + + # Derive P and sigma + P = P ** (1 / 3) + sigma = (1 - x0) / peak2sigma + + logger.info(f"Estimated CL Errors P,STD:\t{100*P:.2f}%\t{sigma:.2f}") + + # Initialize probability computations + # Local histograms analysis + A = a + 1 # distribution 1st component normalization factor + # distribution 2nd component normalization factor + B = B / ( + (self.n_img * (self.n_img - 1) * (self.n_img - 2) / 2) / self.hist_intervals + ) + + # Calculate probabilities + ln_f_ind, ln_f_arb = self._pairs_probabilities(Rijs, P**2, A, a, B, b, x0) + + with warnings.catch_warnings(): + # For large values of (ln_f_arb - ln_f_ind), numpy exponential will overflow. We still + # get the intended result of Pij = 0, so we capture and ignore the overflow warning. + warnings.filterwarnings("ignore", r".*overflow encountered in exp.*") + + Pij = 1 / (1 + (1 - P) / P * np.exp(ln_f_arb - ln_f_ind)) + + # Fix singular output + num_nan = np.sum(np.isnan(Pij)) + if num_nan > 0: + logger.error( + f"NaN probabilities occurred {num_nan} times out of {np.size(Pij)}. Setting NaNs to zero." + ) + Pij = np.nan_to_num(Pij) + + logger.info( + f"Common lines probabilities to be indicative Pij={100*np.mean(Pij):.2f}%" + ) + + return P, sigma, Pij, scores_hist + + ########################################### + # Primary Methods # + ########################################### + + def _global_J_sync(self, Rijs): + """ + Apply global J-synchronization. + + Given all pairs of estimated rotation matrices `Rijs` with + arbitrary handedness (J conjugation), attempt to detect and + conjugate entries of `Rijs` such that all rotations have same + handedness. + + :param Rijs: Array of all pairs of rotation matrices + :return: Array of all pairs of J synchronized rotation matrices + """ + + # Determine relative handedness of Rijs. + sign_ij_J = self._J_sync_power_method(Rijs) + + # Synchronize Rijs + logger.info("Applying global handedness synchronization.") + mask = sign_ij_J == -1 + Rijs[mask] = J_conjugate(Rijs[mask]) + + return Rijs + + def _estimate_all_Rijs(self, clmatrix): + """ + Estimate Rijs using the voting method. + + :param clmatrix: Common lines matrix + :return: Estimated rotations + """ + n_img = self.n_img + n_theta = self.n_theta + Rijs = np.zeros((len(self._pairs), 3, 3)) + + for idx, (i, j) in enumerate(tqdm(self._pairs, desc="Estimate Rijs")): + Rijs[idx] = self._syncmatrix_ij_vote_3n( + clmatrix, i, j, np.arange(n_img), n_theta + ) + + return Rijs + + def _syncmatrix_ij_vote_3n(self, clmatrix, i, j, k_list, n_theta): + """ + Compute the (i,j) rotation block of the synchronization matrix using voting method + + Given the common lines matrix `clmatrix`, a list of images specified in k_list + and the number of common lines n_theta, find the (i, j) rotation block Rij. + + :param clmatrix: The common lines matrix + :param i: The i image + :param j: The j image + :param k_list: The list of images for the third image for voting algorithm + :param n_theta: The number of points in the theta direction (common lines) + :return: The (i,j) rotation block of the synchronization matrix + """ + alphas, good_k = self._vote_ij(clmatrix, n_theta, i, j, k_list, sync=True) + + angles = np.zeros(3) + + if alphas is not None: + angles[0] = clmatrix[i, j] * 2 * np.pi / n_theta + np.pi / 2 + angles[1] = np.mean(alphas) + angles[2] = -np.pi / 2 - clmatrix[j, i] * 2 * np.pi / n_theta + rot = Rotation.from_euler(angles).matrices + + else: + # This is for the case that images i and j correspond to the same + # viewing direction and differ only by in-plane rotation. + # We set to zero as in the Matlab code. + rot = np.zeros((3, 3)) + + return rot + + ####################################### + # Secondary Methods for Global J Sync # + ####################################### + + def _J_sync_power_method(self, Rijs): + """ + Calculate the leading eigenvector of the J-synchronization matrix + using the power method. + + As the J-synchronization matrix is of size (n-choose-2)x(n-choose-2), we + use the power method to compute the eigenvalues and eigenvectors, + while constructing the matrix on-the-fly. + + :param Rijs: (n-choose-2)x3x3 array of estimates of relative orientation matrices. + + :return: An array of length n-choose-2 consisting of 1 or -1, where the sign of the + i'th entry indicates whether the i'th relative orientation matrix will be J-conjugated. + """ + + logger.info( + "Initiating power method to estimate J-synchronization matrix eigenvector." + ) + # Set power method tolerance and maximum iterations. + epsilon = self.epsilon + max_iters = self.max_iters + + # Initialize candidate eigenvectors + n_Rijs = Rijs.shape[0] + vec = rand(n_Rijs, seed=self.seed) + vec = vec / norm(vec) + residual = 1 + itr = 0 + + # Todo + # I don't like that epsilon>1 (residual) returns signs of random vector + # maybe force to run once? or return vec as zeros in that case? + # Seems unintended, but easy to do. + + # Power method iterations + while itr < max_iters and residual > epsilon: + itr += 1 + # Todo, this code code actually needs double precision for accuracy... forcing. + vec_new = self._signs_times_v(Rijs, vec).astype(np.float64, copy=False) + vec_new = vec_new / norm(vec_new) + residual = norm(vec_new - vec) + vec = vec_new + logger.info( + f"Iteration {itr}, residual {round(residual, 5)} (target {epsilon})" + ) + + # We need only the signs of the eigenvector + J_sync = np.sign(vec) + J_sync = np.sign(J_sync[0]) * J_sync # Stabilize J_sync + + return J_sync + + def _signs_times_v(self, Rijs, vec): + """ + Multiplication of the J-synchronization matrix by a candidate eigenvector `vec` + + Wrapper for cpu/gpu dispatch. + + :param Rijs: An n-choose-2x3x3 array of estimates of relative rotations + :param vec: The current candidate eigenvector of length n-choose-2 from the power method. + :return: New candidate eigenvector. + """ + # host/gpu dispatch + if self.__gpu_module: + new_vec = self._signs_times_v_cupy(Rijs, vec) + else: + new_vec = self._signs_times_v_host(Rijs, vec) + + return new_vec.astype(vec.dtype, copy=False) + + def _signs_times_v_host(self, Rijs, vec): + """ + See `_signs_times_v`. + + CPU implementation. + """ + + new_vec = np.zeros_like(vec) + + _signs_confs = np.array( + [[1, 1, 1], [-1, 1, -1], [-1, -1, 1], [1, -1, -1]], dtype=int + ) + + c = np.empty((4)) + desc = "Computing signs_times_v" + if self.J_weighting: + desc += " with J_weighting" + for i in trange(self.n_img - 2, desc=desc): + for j in range( + i + 1, self.n_img - 1 + ): # check bound (taken from MATLAB mex) + ij = self._pairs_to_linear[i, j] + Rij = Rijs[ij] + for k in range(j + 1, self.n_img): + ik = self._pairs_to_linear[i, k] + jk = self._pairs_to_linear[j, k] + Rik = Rijs[ik] + Rjk = Rijs[jk] + + # Compute conjugated rotats + Rij_J = J_conjugate(Rij) + Rik_J = J_conjugate(Rik) + Rjk_J = J_conjugate(Rjk) + + # Compute R muls and norms + c[0] = np.sum(((Rij @ Rjk) - Rik) ** 2) + c[1] = np.sum(((Rij_J @ Rjk) - Rik) ** 2) + c[2] = np.sum(((Rij @ Rjk_J) - Rik) ** 2) + c[3] = np.sum(((Rij @ Rjk) - Rik_J) ** 2) + + # Find best match + best_i = np.argmin(c) + best_val = c[best_i] + + # MATLAB: scores_as_entries == 0 + s_ij_jk = _signs_confs[best_i][0] + s_ik_jk = _signs_confs[best_i][1] + s_ij_ik = _signs_confs[best_i][2] + + # Note there was a third J_weighting option (2) in MATLAB, + # but it was not exposed at top level. + if self.J_weighting: + # MATLAB: scores_as_entries == 1 + # For each triangle side, find the best alternative + alt_ij_jk = c[self._ALTS[0][best_i][0]] + if c[self._ALTS[1][best_i][0]] < alt_ij_jk: + alt_ij_jk = c[self._ALTS[1][best_i][0]] + + alt_ik_jk = c[self._ALTS[0][best_i][1]] + if c[self._ALTS[1][best_i][1]] < alt_ik_jk: + alt_ik_jk = c[self._ALTS[1][best_i][1]] + + alt_ij_ik = c[self._ALTS[0][best_i][2]] + if c[self._ALTS[1][best_i][2]] < alt_ij_ik: + alt_ij_ik = c[self._ALTS[1][best_i][2]] + + # Compute scores + s_ij_jk *= 1 - np.sqrt(best_val / alt_ij_jk) + s_ik_jk *= 1 - np.sqrt(best_val / alt_ik_jk) + s_ij_ik *= 1 - np.sqrt(best_val / alt_ij_ik) + + # Update vector entries + new_vec[ij] += s_ij_jk * vec[jk] + s_ij_ik * vec[ik] + new_vec[jk] += s_ij_jk * vec[ij] + s_ik_jk * vec[ik] + new_vec[ik] += s_ij_ik * vec[ij] + s_ik_jk * vec[jk] + + return new_vec + + def _signs_times_v_cupy(self, Rijs, vec): + """ + See `_signs_times_v`. + + CPU implementation. + """ + import cupy as cp + + signs_times_v = self.__gpu_module.get_function("signs_times_v") + + Rijs_dev = cp.array(Rijs, dtype=np.float64) + vec_dev = cp.array(vec, dtype=np.float64) + new_vec_dev = cp.zeros((vec.shape[0]), dtype=np.float64) + + # call the kernel + blkszx = 512 + nblkx = (self.n_img + blkszx - 1) // blkszx + signs_times_v( + (nblkx,), + (blkszx,), + (self.n_img, Rijs_dev, vec_dev, new_vec_dev, self.J_weighting), + ) + + # dtoh + new_vec = new_vec_dev.get().astype(vec.dtype, copy=False) + + return new_vec + + @staticmethod + def __init_cupy_module(): + """ + Private utility method to read in CUDA source and return as + compiled CUPY module. + """ + + import cupy as cp + + # Read in contents of file + fp = os.path.join(os.path.dirname(__file__), "commonline_sync3n.cu") + with open(fp, "r") as fh: + module_code = fh.read() + + # CUPY compile the CUDA code + return cp.RawModule(code=module_code) diff --git a/src/aspire/abinitio/sync_voting.py b/src/aspire/abinitio/sync_voting.py index abc11ef6e1..fb626a8a91 100644 --- a/src/aspire/abinitio/sync_voting.py +++ b/src/aspire/abinitio/sync_voting.py @@ -37,11 +37,14 @@ def _rotratio_eulerangle_vec(self, clmatrix, i, j, good_k, n_theta): # cl_diff2 is for the angle on C2 created by its intersection with C1 and C3. # cl_diff3 is for the angle on C3 created by its intersection with C2 and C1. cl_diff1 = clmatrix[i, good_k] - clmatrix[i, j] # for theta1 - cl_diff2 = clmatrix[j, good_k] - clmatrix[j, i] # for - theta2 + cl_diff2 = clmatrix[j, good_k] - clmatrix[j, i] # for theta2 cl_diff3 = clmatrix[good_k, j] - clmatrix[good_k, i] # for theta3 # Calculate the cos values of rotation angles between i an j images for good k images - c_alpha, good_idx = self._get_cos_phis(cl_diff1, cl_diff2, cl_diff3, n_theta) + c_alpha, good_idx = self._get_cos_phis( + cl_diff1, cl_diff2, cl_diff3, n_theta, sync=False + ) + if len(c_alpha) == 0: return None alpha = np.arccos(c_alpha) @@ -55,7 +58,7 @@ def _rotratio_eulerangle_vec(self, clmatrix, i, j, good_k, n_theta): return r[good_idx, :, :] - def _vote_ij(self, clmatrix, n_theta, i, j, k_list): + def _vote_ij(self, clmatrix, n_theta, i, j, k_list, sync=False): """ Apply the voting algorithm for images i and j. @@ -68,12 +71,14 @@ def _vote_ij(self, clmatrix, n_theta, i, j, k_list): :param i: The i image :param j: The j image :param k_list: The list of images for the third image for voting algorithm - :return: good_k, the list of all third images in the peak of the histogram - corresponding to the pair of images (i,j) + :param sync: Perform 180 degree ambiguity synchronization. + :return: (alpha, good_k), angles and list of all third images + in the peak of the histogram corresponding to the pair of + images (i,j) """ if i == j or clmatrix[i, j] == -1: - return [] + return None, [] # Some of the entries in clmatrix may be zero if we cleared # them due to small correlation, or if for each image @@ -102,10 +107,13 @@ def _vote_ij(self, clmatrix, n_theta, i, j, k_list): # cl_diff2 is for the angle on C2 created by its intersection with C1 and C3. # cl_diff3 is for the angle on C3 created by its intersection with C2 and C1. cl_diff1 = cl_idx13 - cl_idx12 - cl_diff2 = cl_idx21 - cl_idx23 + cl_diff2 = cl_idx23 - cl_idx21 cl_diff3 = cl_idx32 - cl_idx31 + # Calculate the cos values of rotation angles between i an j images for good k images - cos_phi2, good_idx = self._get_cos_phis(cl_diff1, cl_diff2, cl_diff3, n_theta) + cos_phi2, good_idx = self._get_cos_phis( + cl_diff1, cl_diff2, cl_diff3, n_theta, sync=sync + ) if np.any(np.abs(cos_phi2) - 1 > 1e-12): logger.warning( @@ -121,13 +129,15 @@ def _vote_ij(self, clmatrix, n_theta, i, j, k_list): inds = k_list[good_idx] if phis.shape[0] == 0: - return [] + return None, [] # Parameters used to compute the smoothed angle histogram. - ntics = 60 - angles_grid = np.linspace(0, 180, ntics, True) + ntics = int(180 / self.hist_bin_width) + angles_grid = np.linspace(0, 180, ntics + 1, True) + # Get angles between images i and j for computing the histogram angles = np.arccos(phis[:]) * 180 / np.pi + # Angles that are up to 10 degrees apart are considered # similar. This sigma ensures that the width of the density # estimation kernel is roughly 10 degrees. For 15 degrees, the @@ -135,14 +145,8 @@ def _vote_ij(self, clmatrix, n_theta, i, j, k_list): sigma = 3.0 # Compute the histogram of the angles between images i and j - squared_values = np.add.outer(np.square(angles), np.square(angles_grid)) - angles_hist = np.sum( - np.exp( - (2 * np.multiply.outer(angles, angles_grid) - squared_values) - / (2 * sigma**2) - ), - 0, - ) + angles_distances = angles_grid[None, :] - angles[:, None] + angles_hist = np.sum(np.exp(-(angles_distances**2) / (2 * sigma**2)), axis=0) # We assume that at the location of the peak we get the true angle # between images i and j. Find all third images k, that induce an @@ -151,11 +155,29 @@ def _vote_ij(self, clmatrix, n_theta, i, j, k_list): # tics, since the peak might move a little bit due to wrong k images # that accidentally fall near the peak. peak_idx = angles_hist.argmax() - idx = np.abs(angles - angles_grid[peak_idx]) < 360 / ntics + + if str(self.full_width).lower() == "adaptive": + # Adaptive width (MATLAB) + # Look for the estimations in the peak of the histogram + w_theta_needed = 0 + idx = [] + while sum(idx) == 0: + w_theta_needed += self.hist_bin_width # widen peak as needed + idx = np.abs(angles - angles_grid[peak_idx]) < w_theta_needed + if w_theta_needed > self.hist_bin_width: + logger.info( + f"Adaptive width {w_theta_needed} required for ({i},{j}), found {sum(idx)} indices." + ) + else: + # Fixed width + idx = np.abs(angles - angles_grid[peak_idx]) < self.full_width + good_k = inds[idx] - return good_k.astype("int") + alpha = np.arccos(phis[idx]) - def _get_cos_phis(self, cl_diff1, cl_diff2, cl_diff3, n_theta): + return alpha, good_k.astype("int") + + def _get_cos_phis(self, cl_diff1, cl_diff2, cl_diff3, n_theta, sync=False): """ Calculate cos values of rotation angles between i and j images @@ -179,6 +201,7 @@ def _get_cos_phis(self, cl_diff1, cl_diff2, cl_diff3, n_theta): :param cl_diff3: Difference of common line indices on C3 created by its intersection with C2 and C1 :param n_theta: The number of points in the theta direction (common lines) + :param sync: Perform 180 degree ambiguity synchronization. :return: cos values of rotation angles between i and j images and indices for good k """ @@ -224,7 +247,38 @@ def _get_cos_phis(self, cl_diff1, cl_diff2, cl_diff3, n_theta): good_idx = np.nonzero(cond > 1e-5)[0] # Calculated cos values of angle between i and j images - cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( - np.sin(theta1[good_idx]) * np.sin(theta2[good_idx]) - ) + if sync: + # MATLAB + cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( + np.sqrt(1 - c1[good_idx] ** 2) * np.sqrt(1 - c2[good_idx] ** 2) + ) + + # Some synchronization must be applied when common line is + # out by 180 degrees. + # Here fix the angles between c_ij(c_ji) and c_ik(c_jk) to be smaller than pi/2, + # otherwise there will be an ambiguity between alpha and pi-alpha. + TOL_idx = 1e-12 + + # Select only good_idx + theta1 = theta1[good_idx] + theta2 = theta2[good_idx] + theta3 = theta3[good_idx] + + # Check sync conditions + ind1 = (theta1 > (np.pi + TOL_idx)) | ( + (theta1 < -TOL_idx) & (theta1 > -np.pi) + ) + ind2 = (theta2 > (np.pi + TOL_idx)) | ( + (theta2 < -TOL_idx) & (theta2 > -np.pi) + ) + align180 = (ind1 & ~ind2) | (~ind1 & ind2) + + # Apply sync + cos_phi2[align180] = -cos_phi2[align180] + else: + # Python + cos_phi2 = (c3[good_idx] - c1[good_idx] * c2[good_idx]) / ( + np.sin(theta1[good_idx]) * np.sin(theta2[good_idx]) + ) + return cos_phi2, good_idx diff --git a/src/aspire/basis/basis_utils.py b/src/aspire/basis/basis_utils.py index fe599e9fdc..c32e73fbf3 100644 --- a/src/aspire/basis/basis_utils.py +++ b/src/aspire/basis/basis_utils.py @@ -4,13 +4,11 @@ """ import logging -import warnings import numpy as np from numpy import diff, exp, log, pi from numpy.polynomial.legendre import leggauss -from pyshtools.expand import spharm_lm -from scipy.special import jn, jv, sph_harm +from scipy.special import jn, jv from aspire.utils import grid_2d, grid_3d @@ -158,6 +156,33 @@ def norm_assoc_legendre(j, m, x): return px +def sph_harm(j, m, theta, phi): + """ + Compute spherical harmonics. + + Note call signature convention may be different from other packages. + + :param m: Order |m| <= j + :param j: Harmonic degree, j>=0 + :param theta: latitude coordinate [0, pi] + :param phi: longitude coordinate [0, 2*pi] + :return: Complex array of evaluated spherical harmonics. + """ + + # Compute sph_harm for positive `abs(m)` + y = ( + norm_assoc_legendre(j, abs(m), np.cos(theta)) + * np.exp(1j * abs(m) * phi) + * np.sqrt(0.5 / np.pi) + ) + + # Use identity for negative `m` + if m < 0: + y = (-1) ** (m % 2) * np.conj(y) + + return y + + def real_sph_harmonic(j, m, theta, phi): """ Evaluate a real spherical harmonic @@ -172,28 +197,8 @@ def real_sph_harmonic(j, m, theta, phi): """ abs_m = abs(m) - # The `scipy` sph_harm implementation is much faster, - # but incorrectly returns NaN for high orders. - # For higher order use `pyshtools`. - if j < 86: - y = sph_harm(abs_m, j, phi, theta) - else: - warnings.warn( - "Computing higher order spherical harmonics is slow." - " Consider using `FFBBasis3D` or decreasing volume size.", - stacklevel=1, - ) - - y = spharm_lm( - j, - abs_m, - theta, - phi, - kind="complex", - degrees=False, - csphase=-1, - normalization="ortho", - ) + # Note the calling convention here may not match other `sph_harm` packages + y = sph_harm(j, abs_m, theta, phi) if m < 0: y = np.sqrt(2) * np.imag(y) diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 5a5c7c3f27..8d46e8419c 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -58,6 +58,16 @@ def _build(self): # precompute the basis functions in 2D grids self._precomp = self._precomp() + # include the normalization factor of angular part into radial part + self.radial_norm = xp.asarray(self._precomp["radial"]) / xp.asarray( + np.expand_dims(self.angular_norms, 1) + ) + + # precompute weighted nodes + self.gl_weighted_nodes = xp.asarray(self._precomp["gl_weights"]) * xp.asarray( + self._precomp["gl_nodes"] + ) + def _precomp(self): """ Precomute the basis functions on a polar Fourier grid @@ -105,6 +115,7 @@ def _evaluate(self, v): coordinate basis. This is Image instance with resolution of `self.sz` and the first dimension correspond to remaining dimension of `v`. """ + v = xp.asarray(v) sz_roll = v.shape[:-1] v = v.reshape(-1, self.count) @@ -112,25 +123,23 @@ def _evaluate(self, v): n_data = v.shape[0] # get information on polar grids from precomputed data - n_theta = np.size(self._precomp["freqs"], 2) - n_r = np.size(self._precomp["freqs"], 1) + n_theta = self._precomp["freqs"].shape[2] + n_r = self._precomp["freqs"].shape[1] # go through each basis function and find corresponding coefficient - pf = np.zeros((n_data, 2 * n_theta, n_r), dtype=complex_type(self.dtype)) + pf = xp.zeros((n_data, 2 * n_theta, n_r), dtype=complex_type(self.dtype)) ind = 0 idx = ind + np.arange(self.k_max[0], dtype=int) - # include the normalization factor of angular part into radial part - radial_norm = self._precomp["radial"] / np.expand_dims(self.angular_norms, 1) - pf[:, 0, :] = v[:, self._zero_angular_inds] @ radial_norm[idx] - ind = ind + np.size(idx) + pf[:, 0, :] = v[:, self._zero_angular_inds] @ self.radial_norm[idx] + ind = ind + idx.size ind_pos = ind for ell in range(1, self.ell_max + 1): - idx = ind + np.arange(self.k_max[ell], dtype=int) + idx = ind + xp.arange(self.k_max[ell], dtype=int) idx_pos = ind_pos + np.arange(self.k_max[ell], dtype=int) idx_neg = idx_pos + self.k_max[ell] @@ -139,7 +148,7 @@ def _evaluate(self, v): if np.mod(ell, 2) == 1: v_ell = 1j * v_ell - pf_ell = v_ell @ radial_norm[idx] + pf_ell = v_ell @ self.radial_norm[idx] pf[:, ell, :] = pf_ell if np.mod(ell, 2) == 0: @@ -147,22 +156,17 @@ def _evaluate(self, v): else: pf[:, 2 * n_theta - ell, :] = -pf_ell.conjugate() - ind = ind + np.size(idx) + ind = ind + idx.size ind_pos = ind_pos + 2 * self.k_max[ell] # 1D inverse FFT in the degree of polar angle - pf = 2 * pi * xp.asnumpy(fft.ifft(xp.asarray(pf), axis=1)) + pf = 2 * xp.pi * fft.ifft(pf, axis=1) # Only need "positive" frequencies. - hsize = int(np.size(pf, 1) / 2) + hsize = int(pf.shape[1] / 2) pf = pf[:, 0:hsize, :] - - for i_r in range(0, n_r): - pf[..., i_r] = pf[..., i_r] * ( - self._precomp["gl_weights"][i_r] * self._precomp["gl_nodes"][i_r] - ) - - pf = np.reshape(pf, (n_data, n_r * n_theta)) + pf *= self.gl_weighted_nodes[None, None, :] + pf = pf.reshape(n_data, n_r * n_theta) # perform inverse non-uniformly FFT transform back to 2D coordinate basis freqs = m_reshape(self._precomp["freqs"], (2, n_r * n_theta)) @@ -172,7 +176,7 @@ def _evaluate(self, v): # Return X as Image instance with the last two dimensions as *self.sz x = x.reshape((*sz_roll, *self.sz)) - return x + return xp.asnumpy(x) def _evaluate_t(self, x): """ @@ -193,56 +197,51 @@ def _evaluate_t(self, x): n_images = x.shape[0] # resamping x in a polar Fourier gird using nonuniform discrete Fourier transform - pf = nufft(x, 2 * pi * freqs) - pf = np.reshape(pf, (n_images, n_r, n_theta)) + pf = nufft(xp.asarray(x), 2 * pi * freqs) + pf = pf.reshape(n_images, n_r, n_theta) # Recover "negative" frequencies from "positive" half plane. - pf = np.concatenate((pf, pf.conjugate()), axis=2) + pf = xp.concatenate((pf, pf.conjugate()), axis=2) # evaluate radial integral using the Gauss-Legendre quadrature rule - for i_r in range(0, n_r): - pf[:, i_r, :] = pf[:, i_r, :] * ( - self._precomp["gl_weights"][i_r] * self._precomp["gl_nodes"][i_r] - ) + pf = pf * self.gl_weighted_nodes[None, :, None] # 1D FFT on the angular dimension for each concentric circle - pf = 2 * pi / (2 * n_theta) * xp.asnumpy(fft.fft(xp.asarray(pf))) + pf = 2 * xp.pi / (2 * n_theta) * fft.fft(pf) # This only makes it easier to slice the array later. - v = np.zeros((n_images, self.count), dtype=x.dtype) + v = xp.zeros((n_images, self.count), dtype=x.dtype) # go through each basis function and find the corresponding coefficient ind = 0 - idx = ind + np.arange(self.k_max[0]) + idx = ind + xp.arange(self.k_max[0]) - # include the normalization factor of angular part into radial part - radial_norm = self._precomp["radial"] / np.expand_dims(self.angular_norms, 1) - v[:, self._zero_angular_inds] = pf[:, :, 0].real @ radial_norm[idx].T - ind = ind + np.size(idx) + v[:, self._zero_angular_inds] = pf[:, :, 0].real @ self.radial_norm[idx].T + ind = ind + idx.size ind_pos = ind for ell in range(1, self.ell_max + 1): - idx = ind + np.arange(self.k_max[ell]) - idx_pos = ind_pos + np.arange(self.k_max[ell]) + idx = ind + xp.arange(self.k_max[ell]) + idx_pos = ind_pos + xp.arange(self.k_max[ell]) idx_neg = idx_pos + self.k_max[ell] - v_ell = pf[:, :, ell] @ radial_norm[idx].T + v_ell = pf[:, :, ell] @ self.radial_norm[idx].T if np.mod(ell, 2) == 0: - v_pos = np.real(v_ell) - v_neg = -np.imag(v_ell) + v_pos = v_ell.real + v_neg = -v_ell.imag else: - v_pos = np.imag(v_ell) - v_neg = np.real(v_ell) + v_pos = v_ell.imag + v_neg = v_ell.real v[:, idx_pos] = v_pos v[:, idx_neg] = v_neg - ind = ind + np.size(idx) + ind = ind + idx.size ind_pos = ind_pos + 2 * self.k_max[ell] - return v + return xp.asnumpy(v) def filter_to_basis_mat(self, f, **kwargs): """ diff --git a/src/aspire/basis/ffb_3d.py b/src/aspire/basis/ffb_3d.py index 6362a9a703..7f0821b99a 100644 --- a/src/aspire/basis/ffb_3d.py +++ b/src/aspire/basis/ffb_3d.py @@ -1,11 +1,11 @@ import logging import numpy as np -from numpy import pi from aspire.basis import FBBasis3D from aspire.basis.basis_utils import lgwt, norm_assoc_legendre, sph_bessel from aspire.nufft import anufft, nufft +from aspire.numeric import xp from aspire.utils.matlab_compat import m_flatten, m_reshape logger = logging.getLogger(__name__) @@ -60,26 +60,29 @@ def _precomp(self): r, wt_r = lgwt(n_r, 0.0, self.kcut, dtype=self.dtype) z, wt_z = lgwt(n_phi, -1, 1, dtype=self.dtype) - r = m_reshape(r, (n_r, 1)) - wt_r = m_reshape(wt_r, (n_r, 1)) - z = m_reshape(z, (n_phi, 1)) - wt_z = m_reshape(wt_z, (n_phi, 1)) - phi = np.arccos(z) + r = m_reshape(xp.asarray(r), (n_r, 1)) + rh = xp.asnumpy(r) + wt_r = m_reshape(xp.asarray(wt_r), (n_r, 1)) + z = m_reshape(xp.asarray(z), (n_phi, 1)) + wt_z = m_reshape(xp.asarray(wt_z), (n_phi, 1)) + phi = xp.arccos(z) wt_phi = wt_z - theta = 2 * pi * np.arange(n_theta, dtype=self.dtype).T / (2 * n_theta) + theta = 2 * xp.pi * xp.arange(n_theta, dtype=self.dtype).T / (2 * n_theta) theta = m_reshape(theta, (n_theta, 1)) # evaluate basis function in the radial dimension - radial_wtd = np.zeros( + radial_wtd = xp.zeros( shape=(n_r, np.max(self.k_max), self.ell_max + 1), dtype=self.dtype ) for ell in range(0, self.ell_max + 1): k_max_ell = self.k_max[ell] - rmat = r * self.r0[ell][0:k_max_ell].T / self.kcut - radial_ell = np.zeros_like(rmat) + rmat = rh * self.r0[ell][0:k_max_ell].T / self.kcut # host + radial_ell = xp.zeros_like(rmat) for ik in range(0, k_max_ell): - radial_ell[:, ik] = sph_bessel(ell, rmat[:, ik]) - nrm = np.abs(sph_bessel(ell + 1, self.r0[ell][0:k_max_ell].T) / 4) + radial_ell[:, ik] = xp.asarray(sph_bessel(ell, rmat[:, ik])) + nrm = xp.abs( + xp.asarray(sph_bessel(ell + 1, self.r0[ell][0:k_max_ell].T)) / 4 + ) radial_ell = radial_ell / nrm radial_ell_wtd = r**2 * wt_r * radial_ell radial_wtd[:, 0:k_max_ell, ell] = radial_ell_wtd @@ -94,14 +97,14 @@ def _precomp(self): - np.mod(self.ell_max, 2) * np.mod(m, 2) ) n_odd_ell = int(self.ell_max - m + 1 - n_even_ell) - phi_wtd_m_even = np.zeros((n_phi, n_even_ell), dtype=phi.dtype) - phi_wtd_m_odd = np.zeros((n_phi, n_odd_ell), dtype=phi.dtype) + phi_wtd_m_even = xp.zeros((n_phi, n_even_ell), dtype=phi.dtype) + phi_wtd_m_odd = xp.zeros((n_phi, n_odd_ell), dtype=phi.dtype) ind_even = 0 ind_odd = 0 for ell in range(m, self.ell_max + 1): - phi_m_ell = norm_assoc_legendre(ell, m, z) - nrm_inv = np.sqrt(0.5 / pi) + phi_m_ell = xp.asarray(norm_assoc_legendre(ell, m, z)) + nrm_inv = np.sqrt(0.5 / np.pi) phi_m_ell = nrm_inv * phi_m_ell phi_wtd_m_ell = wt_phi * phi_m_ell if np.mod(ell, 2) == 0: @@ -115,32 +118,32 @@ def _precomp(self): ang_phi_wtd_odd.append(phi_wtd_m_odd) # evaluate basis function in the theta dimension - ang_theta = np.zeros((n_theta, 2 * self.ell_max + 1), dtype=theta.dtype) + ang_theta = xp.zeros((n_theta, 2 * self.ell_max + 1), dtype=theta.dtype) - ang_theta[:, 0 : self.ell_max] = np.sqrt(2) * np.sin( - theta @ m_reshape(np.arange(self.ell_max, 0, -1), (1, self.ell_max)) + ang_theta[:, 0 : self.ell_max] = np.sqrt(2) * xp.sin( + theta @ m_reshape(xp.arange(self.ell_max, 0, -1), (1, self.ell_max)) ) - ang_theta[:, self.ell_max] = np.ones(n_theta, dtype=theta.dtype) - ang_theta[:, self.ell_max + 1 : 2 * self.ell_max + 1] = np.sqrt(2) * np.cos( - theta @ m_reshape(np.arange(1, self.ell_max + 1), (1, self.ell_max)) + ang_theta[:, self.ell_max] = xp.ones(n_theta, dtype=theta.dtype) + ang_theta[:, self.ell_max + 1 : 2 * self.ell_max + 1] = np.sqrt(2) * xp.cos( + theta @ m_reshape(xp.arange(1, self.ell_max + 1), (1, self.ell_max)) ) - ang_theta_wtd = (2 * pi / n_theta) * ang_theta + ang_theta_wtd = (2 * np.pi / n_theta) * ang_theta - theta_grid, phi_grid, r_grid = np.meshgrid( - theta, phi, r, sparse=False, indexing="ij" + theta_grid, phi_grid, r_grid = xp.meshgrid( + theta.flatten(), phi.flatten(), r.flatten(), sparse=False, indexing="ij" ) - fourier_x = m_flatten(r_grid * np.cos(theta_grid) * np.sin(phi_grid)) - fourier_y = m_flatten(r_grid * np.sin(theta_grid) * np.sin(phi_grid)) - fourier_z = m_flatten(r_grid * np.cos(phi_grid)) + fourier_x = m_flatten(r_grid * xp.cos(theta_grid) * xp.sin(phi_grid)) + fourier_y = m_flatten(r_grid * xp.sin(theta_grid) * xp.sin(phi_grid)) + fourier_z = m_flatten(r_grid * xp.cos(phi_grid)) fourier_pts = ( 2 - * pi - * np.vstack( + * xp.pi + * xp.vstack( ( - fourier_z[np.newaxis, ...], - fourier_y[np.newaxis, ...], - fourier_x[np.newaxis, ...], + fourier_z[None, ...], + fourier_y[None, ...], + fourier_x[None, ...], ) ) ) @@ -163,6 +166,7 @@ def _evaluate(self, v): coordinate basis. This is an array whose last three dimensions equal `self.sz` and the remaining dimensions correspond to `v`. """ + v = xp.asarray(v) # roll dimensions of v sz_roll = v.shape[:-1] v = v.reshape((-1, self.count)) @@ -175,7 +179,7 @@ def _evaluate(self, v): # number of 3D image samples n_data = v.shape[0] - u_even = np.zeros( + u_even = xp.zeros( ( n_r, int(2 * self.ell_max + 1), @@ -184,7 +188,7 @@ def _evaluate(self, v): ), dtype=v.dtype, ) - u_odd = np.zeros( + u_odd = xp.zeros( (n_r, int(2 * self.ell_max + 1), n_data, int(np.ceil(self.ell_max / 2))), dtype=v.dtype, ) @@ -216,10 +220,10 @@ def _evaluate(self, v): int((ell - 1) / 2), ] = v_ell - u_even = np.transpose(u_even, (3, 0, 1, 2)) - u_odd = np.transpose(u_odd, (3, 0, 1, 2)) - w_even = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) - w_odd = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) + u_even = u_even.transpose((3, 0, 1, 2)) + u_odd = u_odd.transpose((3, 0, 1, 2)) + w_even = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) + w_odd = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) # evaluate the phi parts for m in range(0, self.ell_max + 1): @@ -252,8 +256,8 @@ def _evaluate(self, v): w_even[:, :, :, self.ell_max + sgn * m] = w_m_even w_odd[:, :, :, self.ell_max + sgn * m] = w_m_odd - w_even = np.transpose(w_even, (3, 0, 1, 2)) - w_odd = np.transpose(w_odd, (3, 0, 1, 2)) + w_even = w_even.transpose((3, 0, 1, 2)) + w_odd = w_odd.transpose((3, 0, 1, 2)) u_even = w_even u_odd = w_odd @@ -266,7 +270,7 @@ def _evaluate(self, v): pf = w_even + 1j * w_odd pf = m_reshape(pf, (n_theta * n_phi * n_r, n_data)) - pf = np.moveaxis(pf, 0, -1) + pf = xp.moveaxis(pf, 0, -1) # perform inverse non-uniformly FFT transformation back to 3D rectangular coordinates freqs = m_reshape(self._precomp["fourier_pts"], (3, n_r * n_theta * n_phi)) @@ -275,7 +279,7 @@ def _evaluate(self, v): # Roll, return the x with the last three dimensions as self.sz # Higher dimensions should be like v. x = x.reshape((*sz_roll, *self.sz)) - return x + return xp.asnumpy(x) def _evaluate_t(self, x): """ @@ -288,6 +292,7 @@ def _evaluate_t(self, x): `self.count` and whose remaining dimensions correspond to higher dimensions of `x`. """ + x = xp.asarray(x) # roll dimensions sz_roll = x.shape[:-3] x = x.reshape((-1, *self.sz)) @@ -303,20 +308,21 @@ def _evaluate_t(self, x): pf = m_reshape(pf.T, (n_theta, n_phi * n_r * n_data)) # evaluate the theta parts - u_even = self._precomp["ang_theta_wtd"].T @ np.real(pf) - u_odd = self._precomp["ang_theta_wtd"].T @ np.imag(pf) + ang_theta_wtd_trans = self._precomp["ang_theta_wtd"].T + u_even = ang_theta_wtd_trans @ pf.real + u_odd = ang_theta_wtd_trans @ pf.imag u_even = m_reshape(u_even, (2 * self.ell_max + 1, n_phi, n_r, n_data)) u_odd = m_reshape(u_odd, (2 * self.ell_max + 1, n_phi, n_r, n_data)) - u_even = np.transpose(u_even, (1, 2, 3, 0)) - u_odd = np.transpose(u_odd, (1, 2, 3, 0)) + u_even = u_even.transpose((1, 2, 3, 0)) + u_odd = u_odd.transpose((1, 2, 3, 0)) - w_even = np.zeros( + w_even = xp.zeros( (int(np.floor(self.ell_max / 2) + 1), n_r, 2 * self.ell_max + 1, n_data), dtype=x.dtype, ) - w_odd = np.zeros( + w_odd = xp.zeros( (int(np.ceil(self.ell_max / 2)), n_r, 2 * self.ell_max + 1, n_data), dtype=x.dtype, ) @@ -351,11 +357,11 @@ def _evaluate_t(self, x): end = np.size(w_odd, 0) w_odd[end - n_odd_ell : end, :, self.ell_max + sgn * m, :] = w_m_odd - w_even = np.transpose(w_even, (1, 2, 3, 0)) - w_odd = np.transpose(w_odd, (1, 2, 3, 0)) + w_even = w_even.transpose((1, 2, 3, 0)) + w_odd = w_odd.transpose((1, 2, 3, 0)) # evaluate the radial parts - v = np.zeros((n_data, self.count), dtype=x.dtype) + v = xp.zeros((n_data, self.count), dtype=x.dtype) for ell in range(0, self.ell_max + 1): k_max_ell = self.k_max[ell] radial_wtd = self._precomp["radial_wtd"][:, 0:k_max_ell, ell] @@ -388,4 +394,4 @@ def _evaluate_t(self, x): # Roll dimensions, last dimension should be self.count, # Higher dimensions like x. v = v.reshape((*sz_roll, self.count)) - return v + return xp.asnumpy(v) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 423d37c093..76330e6fba 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -1,8 +1,6 @@ import logging import numpy as np -import scipy.sparse as sparse -from scipy.fft import dct, idct from scipy.special import jv from aspire.basis import Coef, FBBasisMixin, SteerableBasis2D @@ -13,13 +11,28 @@ transform_complex_to_real, ) from aspire.nufft import anufft, nufft -from aspire.numeric import fft +from aspire.numeric import fft, sparse, xp from aspire.operators import DiagMatrix from aspire.utils import complex_type, grid_2d logger = logging.getLogger(__name__) +def _cleanup(): + """ + Utility for informing cupy to cleanup memory held by old vars. + + This method is designed to be safely called even when `CuPy` is + not installed, in which case it is a no-op. + """ + try: + import cupy + + cupy.get_default_memory_pool().free_all_blocks() + except ModuleNotFoundError: + pass + + class FLEBasis2D(SteerableBasis2D, FBBasisMixin): """ Define a derived class for Fast Fourier Bessel 2D expansion using interpolation @@ -278,10 +291,10 @@ def _compute_nufft_points(self): self.num_angular_nodes = num_angular_nodes # create gridpoints - nodes = 1 - (2 * np.arange(self.num_radial_nodes, dtype=self.dtype) + 1) / ( + nodes = 1 - (2 * xp.arange(self.num_radial_nodes, dtype=self.dtype) + 1) / ( 2 * self.num_radial_nodes ) - nodes = (np.cos(np.pi * nodes) + 1) / 2 + nodes = (xp.cos(np.pi * nodes) + 1) / 2 nodes = ( self.greatest_lambda - self.smallest_lambda ) * nodes + self.smallest_lambda @@ -292,16 +305,17 @@ def _compute_nufft_points(self): phi = ( 2 - * np.pi - * np.arange(self.num_angular_nodes // 2, dtype=self.dtype) + * xp.pi + * xp.arange(self.num_angular_nodes // 2, dtype=self.dtype) / self.num_angular_nodes ) - x = np.cos(phi).reshape(1, self.num_angular_nodes // 2) - y = np.sin(phi).reshape(1, self.num_angular_nodes // 2) - x = x * nodes * h - y = y * nodes * h - self.grid_x = x.flatten() - self.grid_y = y.flatten() + grid_xy = xp.empty( + (2, self.num_radial_nodes, self.num_angular_nodes // 2), dtype=self.dtype + ) + grid_xy[0] = xp.cos(phi) # x + grid_xy[1] = xp.sin(phi) # y + grid_xy = grid_xy * nodes * h + self.grid_xy = grid_xy.reshape(2, -1) def _build_interpolation_matrix(self): """ @@ -469,7 +483,7 @@ def _create_basis_functions(self): norm_constants[i] = c - self.norm_constants = norm_constants + self.norm_constants = xp.asarray(norm_constants) self.basis_functions = basis_functions def _evaluate(self, coefs): @@ -498,34 +512,39 @@ def _evaluate_t(self, imgs): coefficients. """ # See Section 3.5 - imgs = imgs.copy() + imgs = xp.array(imgs) # Intentionally copying here, mutating. imgs[:, self.radial_mask] = 0 z = self._step1_t(imgs) + del imgs # inform python we're done with imgs + _cleanup() + b = self._step2_t(z) + del z # inform python we're done with z + _cleanup() + coefs = self._step3_t(b) + del b # inform python we're done with b + _cleanup() # return in FB order coefs = coefs[..., self._fle_to_fb_indices] - return coefs.astype(self.coefficient_dtype, copy=False) + return xp.asnumpy(coefs.astype(self.coefficient_dtype)) def _step1_t(self, im): """ Step 1 of the adjoint transformation (images to coefficients). - Calculates the NUFFT of the image on gridpoints `self.grid_x` and `self.grid_y`. + Calculates the NUFFT of the image on gridpoints `grid_xy`. """ im = im.reshape(-1, self.nres, self.nres).astype(complex_type(self.dtype)) num_img = im.shape[0] - z = np.zeros( + z = xp.zeros( (num_img, self.num_radial_nodes, self.num_angular_nodes), dtype=complex_type(self.dtype), ) - _z = ( - nufft(im, np.stack((self.grid_x, self.grid_y)), epsilon=self.epsilon) - * self.h**2 - ) + _z = nufft(im, self.grid_xy, epsilon=self.epsilon) * self.h**2 _z = _z.reshape(num_img, self.num_radial_nodes, self.num_angular_nodes // 2) z[:, :, : self.num_angular_nodes // 2] = _z - z[:, :, self.num_angular_nodes // 2 :] = np.conj(_z) + z[:, :, self.num_angular_nodes // 2 :] = _z.conj() return z def _step2_t(self, z): @@ -538,12 +557,12 @@ def _step2_t(self, z): # Compute FFT along angular nodes betas = fft.fft(z, axis=2) / self.num_angular_nodes betas = betas[:, :, self.nus] - betas = np.conj(betas) - betas = np.swapaxes(betas, 0, 2) + betas = betas.conj() + betas = betas.swapaxes(0, 2) betas = betas.reshape(-1, self.num_radial_nodes * num_img) betas = self.c2r_nus @ betas betas = betas.reshape(-1, self.num_radial_nodes, num_img) - betas = np.real(np.swapaxes(betas, 0, 2)) + betas = betas.swapaxes(0, 2).real return betas def _step3_t(self, betas): @@ -554,13 +573,12 @@ def _step3_t(self, betas): """ num_img = betas.shape[0] if self.num_interp > self.num_radial_nodes: - betas = dct(betas, axis=1, type=2) / (2 * self.num_radial_nodes) - zeros = np.zeros(betas.shape) - betas = np.concatenate((betas, zeros), axis=1) - betas = idct(betas, axis=1, type=2) * 2 * betas.shape[1] - betas = np.moveaxis(betas, 0, -1) + betas = fft.dct(betas, axis=1, type=2) / (2 * self.num_radial_nodes) + betas = xp.concatenate((betas, xp.zeros(betas.shape)), axis=1) + betas = fft.idct(betas, axis=1, type=2) * 2 * betas.shape[1] + betas = xp.moveaxis(betas, 0, -1) - coefs = np.zeros((self.count, num_img), dtype=np.float64) + coefs = xp.zeros((self.count, num_img), dtype=np.float64) for i in range(self.ell_p_max + 1): coefs[self.idx_list[i]] = self.A3[i] @ betas[:, i, :] coefs = coefs.T @@ -574,22 +592,22 @@ def _step3(self, coefs): Uses barycenteric interpolation in reverse to compute values of Betas at Chebyshev nodes, given an array of FLE coefficients. """ - coefs = coefs.copy().reshape(-1, self.count) + coefs = xp.asarray(coefs.reshape(-1, self.count)) num_img = coefs.shape[0] coefs *= self.h * self.norm_constants coefs = coefs.T - out = np.zeros( + out = xp.zeros( (self.num_interp, 2 * self.max_ell + 1, num_img), dtype=np.float64, ) for i in range(self.ell_p_max + 1): out[:, i, :] = self.A3_T[i] @ coefs[self.idx_list[i]] - out = np.moveaxis(out, -1, 0) + out = xp.moveaxis(out, -1, 0) if self.num_interp > self.num_radial_nodes: - out = dct(out, axis=1, type=2) + out = fft.dct(out, axis=1, type=2) out = out[:, : self.num_radial_nodes, :] - out = idct(out, axis=1, type=2) + out = fft.idct(out, axis=1, type=2) return out @@ -600,18 +618,18 @@ def _step2(self, betas): Uses the IFFT to convert Beta values into Fourier-space images. """ num_img = betas.shape[0] - tmp = np.zeros( + tmp = xp.zeros( (num_img, self.num_radial_nodes, self.num_angular_nodes), dtype=np.complex128, ) - betas = np.swapaxes(betas, 0, 2) + betas = betas.swapaxes(0, 2) betas = betas.reshape(-1, self.num_radial_nodes * num_img) betas = self.r2c_nus @ betas betas = betas.reshape(-1, self.num_radial_nodes, num_img) - betas = np.swapaxes(betas, 0, 2) + betas = betas.swapaxes(0, 2) - tmp[:, :, self.nus] = np.conj(betas) + tmp[:, :, self.nus] = betas.conj() z = fft.ifft(tmp, axis=2) return z @@ -625,17 +643,17 @@ def _step1(self, z): num_img = z.shape[0] z = z[:, :, : self.num_angular_nodes // 2].reshape(num_img, -1) im = anufft( - z.astype(complex_type(self.dtype)), - np.stack((self.grid_x, self.grid_y)), + z.astype(complex_type(self.dtype), copy=False), + self.grid_xy, (self.nres, self.nres), epsilon=self.epsilon, ) - im = im + np.conj(im) - im = np.real(im) + im = im + im.conj() + im = im.real im = im.reshape(num_img, self.nres, self.nres) im[:, self.radial_mask] = 0 - return im + return xp.asnumpy(im) def _create_dense_matrix(self): """ @@ -702,10 +720,12 @@ def radial_convolve(self, coefs, radial_img): "`radial_convolve` currently only implemented for 1D stacks." ) - coefs = coefs.asnumpy() + # Potentially migrate to GPU + coefs = xp.asarray(coefs.asnumpy()) + radial_img = xp.asarray(radial_img) num_img = coefs.shape[0] - coefs_conv = np.zeros(coefs.shape) + coefs_conv = xp.zeros(coefs.shape) # Convert to internal FLE indices ordering coefs = coefs[..., self._fb_to_fle_indices] @@ -717,25 +737,26 @@ def radial_convolve(self, coefs, radial_img): weights = self._radial_convolve_weights(b) b = weights / (self.h**2) b = b.reshape(self.count) - coefs_conv[k, :] = np.real(self.c2r @ (b * (self.r2c @ _coefs).flatten())) + coefs_conv[k, :] = (self.c2r @ (b * (self.r2c @ _coefs).flatten())).real # Convert from internal FLE ordering to FB convention coefs_conv = coefs_conv[..., self._fle_to_fb_indices] - return Coef(self, coefs_conv) + # Return as Coef on host + return Coef(self, xp.asnumpy(coefs_conv)) def _radial_convolve_weights(self, b): """ Helper function for step 3 of convolving with a radial function. """ - b = np.squeeze(b) - b = np.array(b) + b = xp.squeeze(b) + b = xp.array(b) # implies copy if self.num_interp > self.num_radial_nodes: - b = dct(b, axis=0, type=2) / (2 * self.num_radial_nodes) - bz = np.zeros(b.shape) - b = np.concatenate((b, bz), axis=0) - b = idct(b, axis=0, type=2) * 2 * b.shape[0] - a = np.zeros(self.count, dtype=np.float64) + b = fft.dct(b, axis=0, type=2) / (2 * self.num_radial_nodes) + bz = xp.zeros(b.shape) + b = xp.concatenate((b, bz), axis=0) + b = fft.idct(b, axis=0, type=2) * 2 * b.shape[0] + a = xp.zeros(self.count, dtype=np.float64) y = [None] * (self.ell_p_max + 1) for i in range(self.ell_p_max + 1): y[i] = (self.A3[i] @ b[:, 0]).flatten() @@ -764,20 +785,26 @@ def filter_to_basis_mat(self, f, **kwargs): # get 2D grid in polar coordinate k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) - k, theta = np.meshgrid( - k_vals, np.arange(n_theta) * 2 * np.pi / (2 * n_theta), indexing="ij" + k, theta = xp.meshgrid( + xp.asarray(k_vals), + xp.arange(n_theta) * 2 * np.pi / (2 * n_theta), + indexing="ij", ) # Get function values in polar 2D grid and average out angle contribution # NOTE: should probably just let the ctf objects handle this... - omegax = k * np.cos(theta) - omegay = k * np.sin(theta) - omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C"))) - - h_vals2d = h_fun(omega).reshape(n_k, n_theta).astype(self.dtype) - h_vals = np.sum(h_vals2d, axis=1) / n_theta + omegax = k * xp.cos(theta) + omegay = k * xp.sin(theta) + omega = 2 * xp.pi * xp.vstack((omegax.flatten("C"), omegay.flatten("C"))) + + h_vals2d = ( + xp.asarray(h_fun(omega)) + .reshape(n_k, n_theta) + .astype(self.dtype, copy=False) + ) + h_vals = xp.sum(h_vals2d, axis=1) / n_theta - h_basis = np.zeros(self.count, dtype=self.dtype) + h_basis = xp.zeros(self.count, dtype=self.dtype) # For now we just need to handle 1D (stack of one ctf) for j in range(self.ell_p_max + 1): h_basis[self.idx_list[j]] = self.A3[j] @ h_vals @@ -785,4 +812,4 @@ def filter_to_basis_mat(self, f, **kwargs): # Convert from internal FLE ordering to FB convention h_basis = h_basis[self._fle_to_fb_indices] - return DiagMatrix(h_basis) + return DiagMatrix(xp.asnumpy(h_basis)) diff --git a/src/aspire/basis/fle_2d_utils.py b/src/aspire/basis/fle_2d_utils.py index cde0cd11bf..ea459988b0 100644 --- a/src/aspire/basis/fle_2d_utils.py +++ b/src/aspire/basis/fle_2d_utils.py @@ -1,5 +1,6 @@ import numpy as np -import scipy.sparse as sparse + +from aspire.numeric import sparse, xp def transform_complex_to_real(B, ells): @@ -85,7 +86,11 @@ def precomp_transform_complex_to_real(ells): jdx[k] = i + 1 k = k + 1 - A = sparse.csr_matrix((vals, (idx, jdx)), shape=(count, count), dtype=np.complex128) + A = sparse.csr_matrix( + (xp.asarray(vals), (xp.asarray(idx), xp.asarray(jdx))), + shape=(count, count), + dtype=np.complex128, + ) return A.conjugate() @@ -190,9 +195,9 @@ def barycentric_interp_sparse(target_points, known_points, numsparse): # note that const cancels in numerator and denominator vals = vals / denom.reshape(-1, 1) - vals = vals.flatten() - idx = idx.flatten() - jdx = jdx.flatten() + vals = xp.array(vals.flatten()) + idx = xp.array(idx.flatten()) + jdx = xp.array(jdx.flatten()) # A is the linear operator mapping the function values from the fixed source # points to the fixed target points. # A(i,j) = \ell(x[i] ) w_j/(x[i] - xs[j]), with the notation in Eq. 3.3 diff --git a/src/aspire/config_default.yaml b/src/aspire/config_default.yaml index def78983c0..9af8732dd1 100644 --- a/src/aspire/config_default.yaml +++ b/src/aspire/config_default.yaml @@ -1,4 +1,4 @@ -version: 0.12.3 +version: 0.13.0 common: # numeric module to use - one of numpy/cupy numeric: numpy diff --git a/src/aspire/denoising/class_avg.py b/src/aspire/denoising/class_avg.py index 586e0a08a9..20c3694dbf 100644 --- a/src/aspire/denoising/class_avg.py +++ b/src/aspire/denoising/class_avg.py @@ -76,6 +76,7 @@ def __init__( L=self.averager.src.L, n=self.averager.src.n, dtype=self.averager.src.dtype, + symmetry_group=self.src.symmetry_group, ) # Any further operations should not mutate this instance. diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index ff2353d333..57b32041cd 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -8,8 +8,9 @@ from PIL import Image as PILImage from scipy.linalg import lstsq +import aspire.sinogram import aspire.volume -from aspire.nufft import anufft +from aspire.nufft import anufft, nufft from aspire.numeric import fft, xp from aspire.utils import FourierRingCorrelation, anorm, crop_pad_2d, grid_2d from aspire.volume import SymmetryGroup @@ -80,7 +81,7 @@ def load_mrc(filepath): Load raw data from `.mrc` into an array. :param filepath: File path (string). - :return: numpy array of image data. + :return: (numpy array of image data, pixel_size) """ # mrcfile tends to yield many warnings about EMPIAR datasets being corrupt @@ -92,6 +93,7 @@ def load_mrc(filepath): with mrcfile.open(filepath, mode="r", permissive=True) as mrc: im = mrc.data + pixel_size = Image._vx_array_to_size(mrc.voxel_size) # Log each mrcfile warning to debug log, noting the associated file for w in ws: @@ -110,19 +112,29 @@ def load_mrc(filepath): f" Will attempt to continue processing {filepath}" ) - return im + return im, pixel_size def load_tiff(filepath): """ Load raw data from `.tiff` into an array. + Note, TIFF does not natively provide equivalent to pixel/voxel_size, + so users of TIFF files may need to manually assign `pixel_size` to + `Image` instances when required. Defaults to `pixel_size=None`. + :param filepath: File path (string). - :return: numpy array of image data. + :return: (numpy array of image data, pixel_size=None) """ + # Use PIL to open `filepath` + img = PILImage.open(filepath) + + # Future todo, extract `voxel_size` if available in TIFF tags (custom tag?) + # For now, default to `None`. + pixel_size = None - # Use PIL to open `filepath` and cast to numpy array. - return np.array(PILImage.open(filepath)) + # Cast image data as numpy array + return np.array(img), pixel_size class Image: @@ -133,7 +145,7 @@ class Image: ".tiff": load_tiff, } - def __init__(self, data, dtype=None): + def __init__(self, data, pixel_size=None, dtype=None): """ A stack of one or more images. @@ -149,6 +161,10 @@ def __init__(self, data, dtype=None): :param data: Numpy array containing image data with shape `(..., resolution, resolution)`. + :param pixel_size: Optional pixel size in angstroms. + When provided will be saved with `mrc` metadata. + Default of `None` will not write to file, + but will be considered unit pixels (1) for FSC. :param dtype: Optionally cast `data` to this dtype. Defaults to `data.dtype`. @@ -180,12 +196,52 @@ def __init__(self, data, dtype=None): self.stack_shape = self._data.shape[:-2] self.n_images = np.prod(self.stack_shape) self.resolution = self._data.shape[-1] + self.pixel_size = None + if pixel_size is not None: + self.pixel_size = float(pixel_size) # Numpy interop # https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol self.__array_interface__ = self._data.__array_interface__ self.__array__ = self._data + def project(self, angles): + """ + Computes the Radon Transform on an Image Stack using + Non-Uniform Fast Fourier Transforms. This method projects the + Image stack along different angles and returns the Radon + Transform. + + :param angles: A 1-D Numpy Array of angles in Radians. + This is used to compute the Radon Transform at different angles. + :return: Radon transform of the Image Stack. + :rtype: Ndarray (stack size, number of angles, image resolution) + """ + # number of points to sample on radial line in polar grid + n_points = self.resolution + original_stack = self.stack_shape + + # 2-D grid + radial_idx = fft.rfftfreq(n_points) * xp.pi * 2 + n_real_points = len(radial_idx) + n_angles = len(angles) + angles = xp.asarray(angles) + + pts = xp.empty((2, n_angles, n_real_points), dtype=self.dtype) + pts[0] = radial_idx[xp.newaxis, :] * xp.sin(angles)[:, xp.newaxis] + pts[1] = radial_idx[xp.newaxis, :] * xp.cos(angles)[:, xp.newaxis] + pts = pts.reshape(2, n_real_points * n_angles) + + # compute the polar nufft (NUFFT) + image_ft = nufft(xp.asarray(self.stack_reshape(-1)._data), pts).reshape( + self.n_images, n_angles, n_real_points + ) + + # Radon transform, output: (stack size, angles, points) + image_rt = fft.fftshift(fft.irfft(image_ft, n=n_points, axis=-1), axes=-1) + image_rt = image_rt.reshape(*original_stack, n_angles, n_points) + return aspire.sinogram.Sinogram(xp.asnumpy(image_rt)) + @property def res(self): warn( @@ -202,7 +258,7 @@ def _check_key_dims(self, key): def __getitem__(self, key): self._check_key_dims(key) - return self.__class__(self._data[key]) + return self.__class__(self._data[key], pixel_size=self.pixel_size) def __setitem__(self, key, value): self._check_key_dims(key) @@ -230,31 +286,34 @@ def stack_reshape(self, *args): f"Number of images {self.n_images} cannot be reshaped to {shape}." ) - return self.__class__(self._data.reshape(*shape, *self._data.shape[-2:])) + return self.__class__( + self._data.reshape(*shape, *self._data.shape[-2:]), + pixel_size=self.pixel_size, + ) def __add__(self, other): if isinstance(other, Image): other = other._data - return self.__class__(self._data + other) + return self.__class__(self._data + other, pixel_size=self.pixel_size) def __sub__(self, other): if isinstance(other, Image): other = other._data - return self.__class__(self._data - other) + return self.__class__(self._data - other, pixel_size=self.pixel_size) def __mul__(self, other): if isinstance(other, Image): other = other._data - return self.__class__(self._data * other) + return self.__class__(self._data * other, pixel_size=self.pixel_size) def __neg__(self): - return self.__class__(-self._data) + return self.__class__(-self._data, pixel_size=self.pixel_size) def sqrt(self): - return self.__class__(np.sqrt(self._data)) + return self.__class__(np.sqrt(self._data), pixel_size=self.pixel_size) @property def T(self): @@ -276,7 +335,9 @@ def transpose(self): im = self.stack_reshape(-1) imt = np.transpose(im._data, (0, -1, -2)) - return self.__class__(imt).stack_reshape(original_stack_shape) + return self.__class__(imt, pixel_size=self.pixel_size).stack_reshape( + original_stack_shape + ) def flip(self, axis=-2): """ @@ -299,11 +360,15 @@ def flip(self, axis=-2): f"Cannot flip axis {ax}: stack axis. Did you mean {ax-3}?" ) - return self.__class__(np.flip(self._data, axis)) + return self.__class__(np.flip(self._data, axis), pixel_size=self.pixel_size) def __repr__(self): + px_msg = "." + if self.pixel_size is not None: + px_msg = f" with pixel_size={self.pixel_size} angstroms." + msg = f"{self.n_images} {self.dtype} images arranged as a {self.stack_shape} stack" - msg += f" each of size {self.resolution}x{self.resolution}." + msg += f" each of size {self.resolution}x{self.resolution}{px_msg}" return msg def asnumpy(self): @@ -319,7 +384,7 @@ def asnumpy(self): return view def copy(self): - return self.__class__(self._data.copy()) + return self.__class__(self._data.copy(), pixel_size=self.pixel_size) def shift(self, shifts): """ @@ -344,26 +409,46 @@ def shift(self, shifts): return self._im_translate(shifts) - def downsample(self, ds_res): + def downsample(self, ds_res, zero_nyquist=True): """ Downsample Image to a specific resolution. This method returns a new Image. :param ds_res: int - new resolution, should be <= the current resolution of this Image + :param zero_nyquist: Option to keep or remove Nyquist frequency for even resolution. + Defaults to zero_nyquist=True, removing the Nyquist frequency. :return: The downsampled Image object. """ original_stack_shape = self.stack_shape im = self.stack_reshape(-1) + # Note image data is intentionally migrated via `xp.asarray` + # because all of the subsequent calls until `asnumpy` are GPU + # when xp and fft in `cupy` mode. + # compute FT with centered 0-frequency - fx = fft.centered_fft2(im._data) + fx = fft.centered_fft2(xp.asarray(im._data)) # crop 2D Fourier transform for each image - crop_fx = np.array([crop_pad_2d(fx[i], ds_res) for i in range(self.n_images)]) + crop_fx = crop_pad_2d(fx, ds_res) + + # If downsampled resolution is even, optionally zero out the nyquist frequency. + if ds_res % 2 == 0 and zero_nyquist is True: + crop_fx[:, 0, :] = 0 + crop_fx[:, :, 0] = 0 + # take back to real space, discard complex part, and scale - out = np.real(fft.centered_ifft2(crop_fx)) * (ds_res**2 / self.resolution**2) + out = fft.centered_ifft2(crop_fx).real * (ds_res**2 / self.resolution**2) + out = xp.asnumpy(out) + + # Optionally scale pixel size + ds_pixel_size = self.pixel_size + if ds_pixel_size is not None: + ds_pixel_size *= self.resolution / ds_res - return self.__class__(out).stack_reshape(original_stack_shape) + return self.__class__(out, pixel_size=ds_pixel_size).stack_reshape( + original_stack_shape + ) def filter(self, filter): """ @@ -376,19 +461,25 @@ def filter(self, filter): im = self.stack_reshape(-1) - filter_values = filter.evaluate_grid(self.resolution) + # Note image and filter data is intentionally migrated via + # `xp.asarray` because all of the subsequent calls until + # `asnumpy` are GPU when xp and fft in `cupy` mode. + # + # Second note, filter dtype may not match image dtype. + filter_values = xp.asarray( + filter.evaluate_grid(self.resolution), dtype=self.dtype + ) - im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(im._data))) + # Convolve + im_f = fft.centered_fft2(xp.asarray(im._data)) + im_f = filter_values * im_f + im = fft.centered_ifft2(im_f) - # TODO: why are these different? Doesn't the broadcast work? - if im_f.ndim > filter_values.ndim: - im_f *= filter_values - else: - im_f = filter_values * im_f - im = xp.asnumpy(fft.centered_ifft2(xp.asarray(im_f))) - im = np.real(im) + im = xp.asnumpy(im.real) - return self.__class__(im).stack_reshape(original_stack_shape) + return self.__class__(im, pixel_size=self.pixel_size).stack_reshape( + original_stack_shape + ) def rotate(self): raise NotImplementedError @@ -400,6 +491,9 @@ def save(self, mrcs_filepath, overwrite=False): with mrcfile.new(mrcs_filepath, overwrite=overwrite) as mrc: # original input format (the image index first) mrc.set_data(self._data.astype(np.float32)) + # Note assigning voxel_size must come after `set_data` + if self.pixel_size is not None: + mrc.voxel_size = self.pixel_size @staticmethod def load(filepath, dtype=None): @@ -424,14 +518,14 @@ def load(filepath, dtype=None): ) # Call the appropriate file reader - im = Image.extensions[ext](filepath) + im, pixel_size = Image.extensions[ext](filepath) # Attempt casting when user provides dtype if dtype is not None: im = im.astype(dtype, copy=False) # Return as Image instance - return Image(im) + return Image(im, pixel_size=pixel_size) def _im_translate(self, shifts): """ @@ -457,15 +551,17 @@ def _im_translate(self, shifts): n_shifts == 1 or n_shifts == self.n_images ), "number of shifts must be 1 or match the number of images" # Cast shifts to this instance's internal dtype - shifts = shifts.astype(self.dtype) + shifts = xp.asarray(shifts, dtype=self.dtype) L = self.resolution - im_f = xp.asnumpy(fft.fft2(xp.asarray(im))) + im_f = fft.fft2(xp.asarray(im)) grid_shifted = fft.ifftshift( - xp.asarray(np.ceil(np.arange(-L / 2, L / 2, dtype=self.dtype))) + xp.ceil(xp.arange(-L / 2, L / 2, dtype=self.dtype)) ) - grid_1d = xp.asnumpy(grid_shifted) * 2 * np.pi / L - om_x, om_y = np.meshgrid(grid_1d, grid_1d, indexing="ij") + grid_1d = grid_shifted * 2 * xp.pi / L + + # Grid indexing changed to "xy" to match Relion shift conventions. + om_x, om_y = xp.meshgrid(grid_1d, grid_1d, indexing="xy") phase_shifts_x = -shifts[:, 0].reshape((n_shifts, 1, 1)) phase_shifts_y = -shifts[:, 1].reshape((n_shifts, 1, 1)) @@ -474,13 +570,15 @@ def _im_translate(self, shifts): om_x[np.newaxis, :, :] * phase_shifts_x + om_y[np.newaxis, :, :] * phase_shifts_y ) - mult_f = np.exp(-1j * phase_shifts) + mult_f = xp.exp(-1j * phase_shifts) im_translated_f = im_f * mult_f - im_translated = xp.asnumpy(fft.ifft2(xp.asarray(im_translated_f))) - im_translated = np.real(im_translated) + im_translated = fft.ifft2(im_translated_f) + im_translated = xp.asnumpy(im_translated.real) # Reshape to stack shape - return self.__class__(im_translated).stack_reshape(stack_shape) + return self.__class__(im_translated, pixel_size=self.pixel_size).stack_reshape( + stack_shape + ) def norm(self): return anorm(self._data) @@ -490,7 +588,7 @@ def size(self): # probably not needed, transition return np.size(self._data) - def backproject(self, rot_matrices, symmetry_group=None): + def backproject(self, rot_matrices, symmetry_group=None, zero_nyquist=True): """ Backproject images along rotations. If a symmetry group is provided, images used in back-projection are duplicated (boosted) for symmetric viewing directions. @@ -500,6 +598,8 @@ def backproject(self, rot_matrices, symmetry_group=None): corresponding to viewing directions. :param symmetry_group: A SymmetryGroup instance or string indicating symmetry, ie. "C3". If supplied, uses symmetry to increase number of images used in back-projection. + :param zero_nyquist: Option to keep or remove Nyquist frequency for even resolution. + Defaults to zero_nyquist=True, removing the Nyquist frequency. :return: Volume instance corresonding to the backprojected images. """ @@ -517,10 +617,14 @@ def backproject(self, rot_matrices, symmetry_group=None): # Get symmetry rotations from SymmetryGroup. symmetry_rots = SymmetryGroup.parse(symmetry_group, dtype=self.dtype).matrices + if len(symmetry_rots) > 1: + logger.info(f"Boosting with {len(symmetry_rots)} rotational symmetries.") # Compute Fourier transform of images. im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self._data))) / (L**2) - if L % 2 == 0: + + # If resolution is even, optionally zero out the nyquist frequency. + if L % 2 == 0 and zero_nyquist is True: im_f[:, 0, :] = 0 im_f[:, :, 0] = 0 @@ -541,7 +645,9 @@ def backproject(self, rot_matrices, symmetry_group=None): vol /= L - return aspire.volume.Volume(vol, symmetry_group=symmetry_group) + return aspire.volume.Volume( + vol, pixel_size=self.pixel_size, symmetry_group=symmetry_group + ) def show(self, columns=5, figsize=(20, 10), colorbar=True): """ @@ -584,7 +690,7 @@ def show(self, columns=5, figsize=(20, 10), colorbar=True): plt.show() - def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): + def frc(self, other, cutoff=None, method="fft", plot=False): r""" Compute the Fourier ring correlation between two images. @@ -602,8 +708,6 @@ def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): Default `None` implies `cutoff=1` and excludes plotting cutoff line. - :param pixel_size: Pixel size in angstrom. Default `None` - implies unit in pixels, equivalent to pixel_size=1. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. :param plot: Optionally plot to screen or file. @@ -623,7 +727,7 @@ def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): frc = FourierRingCorrelation( a=self.asnumpy(), b=other.asnumpy(), - pixel_size=pixel_size, + pixel_size=self.pixel_size, method=method, ) @@ -634,6 +738,32 @@ def frc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): return frc.analyze_correlations(cutoff), frc.correlations + @staticmethod + def _vx_array_to_size(vx): + """ + Utility to convert from several possible `mrcfile.voxel_size` + representations to a single (float) value or None. + """ + + # Convert from recarray to single values, + # checks uniformity. + if isinstance(vx, np.recarray): + if vx.x != vx.y: + raise ValueError(f"Voxel sizes are not uniform: {vx}") + vx = vx.x + + # Convert `0` to `None` + if ( + isinstance(vx, int) or isinstance(vx, float) or isinstance(vx, np.ndarray) + ) and vx == 0: + vx = None + + # Consistently return a `float` when not None + if vx is not None: + vx = float(vx) + + return vx + class CartesianImage(Image): def expand(self, basis): diff --git a/src/aspire/nufft/__init__.py b/src/aspire/nufft/__init__.py index aa7c3a4adf..23ebe2c115 100644 --- a/src/aspire/nufft/__init__.py +++ b/src/aspire/nufft/__init__.py @@ -3,8 +3,16 @@ import numpy as np from aspire import config +from aspire.numeric import xp from aspire.utils import LogFilterByCount, complex_type, real_type +cp = None +try: + import cupy as cp +except ModuleNotFoundError: + pass + + logger = logging.getLogger(__name__) # Cached Plan Class objects, indexed by backend string identifier, and ordered by preference (highest first) @@ -152,6 +160,9 @@ def anufft(sig_f, fourier_pts, sz, real=False, epsilon=1e-8): Selects best available package from `nfft` `backends` configuration list. + When sig_f is provided as a CuPy GPU array with a cufinufft + backend, result is maintained on GPU. + :param sig_f: Array representing the signal(s) in Fourier space to be transformed. \ sig_f either matches length of fourier_pts or sig_f.shape is stack of (`ntransforms`, ...). :param fourier_pts: The points in Fourier space where the Fourier transform is to be calculated, @@ -162,6 +173,10 @@ def anufft(sig_f, fourier_pts, sz, real=False, epsilon=1e-8): """ + _keep_on_gpu = False + if cp and isinstance(sig_f, cp.ndarray): + _keep_on_gpu = True + if fourier_pts.dtype != real_type(sig_f.dtype): raise RuntimeError( "anufft passed inconsistent dtypes." @@ -181,7 +196,13 @@ def anufft(sig_f, fourier_pts, sz, real=False, epsilon=1e-8): sz=sz, fourier_pts=fourier_pts, ntransforms=ntransforms, epsilon=epsilon ) adjoint = plan.adjoint(sig_f) - return np.real(adjoint) if real else adjoint + + adjoint = adjoint.real if real else adjoint + + if not _keep_on_gpu: + adjoint = xp.asnumpy(adjoint) + + return adjoint def nufft(sig_f, fourier_pts, real=False, epsilon=1e-8): @@ -191,6 +212,9 @@ def nufft(sig_f, fourier_pts, real=False, epsilon=1e-8): Selects best available package from `nfft` `backends` configuration list. + When sig_f is provided as a CuPy GPU array with a cufinufft + backend, result is maintained on GPU. + :param sig_f: Array representing the signal(s) in real space to be transformed. \ sig_f either matches `sz` or sig_f.shape is stack of (..., `ntransforms`). :param fourier_pts: The points in Fourier space where the Fourier transform is to be calculated, @@ -200,6 +224,10 @@ def nufft(sig_f, fourier_pts, real=False, epsilon=1e-8): """ + _keep_on_gpu = False + if cp and isinstance(sig_f, cp.ndarray): + _keep_on_gpu = True + if fourier_pts.dtype != real_type(sig_f.dtype): raise RuntimeError( "nufft passed inconsistent dtypes." @@ -229,4 +257,10 @@ def nufft(sig_f, fourier_pts, real=False, epsilon=1e-8): sz=sz, fourier_pts=fourier_pts, ntransforms=ntransforms, epsilon=epsilon ) transform = plan.transform(sig_f) - return np.real(transform) if real else transform + + transform = transform.real if real else transform + + if not _keep_on_gpu: + transform = xp.asnumpy(transform) + + return transform diff --git a/src/aspire/nufft/cufinufft.py b/src/aspire/nufft/cufinufft.py index 465c0b23f9..fd869aacfd 100644 --- a/src/aspire/nufft/cufinufft.py +++ b/src/aspire/nufft/cufinufft.py @@ -1,9 +1,7 @@ import logging +import cupy as cp import numpy as np -import pycuda.autoinit # noqa: F401 -import pycuda.driver as cuda # noqa: F401 -import pycuda.gpuarray as gpuarray # noqa: F401 from cufinufft import Plan as cufPlan from aspire.nufft import Plan @@ -53,11 +51,11 @@ def __init__(self, sz, fourier_pts, epsilon=1e-8, ntransforms=1, **kwargs): "cufinufft has caught a non C_CONTIGUOUS array," " `fourier_pts` will be copied to C_CONTIGUOUS." ) - self.fourier_pts = np.ascontiguousarray( - np.mod(fourier_pts + np.pi, 2 * np.pi) - np.pi, dtype=self.dtype + self.fourier_pts = cp.ascontiguousarray( + cp.mod(cp.asarray(fourier_pts, dtype=self.dtype) + cp.pi, 2 * cp.pi) - cp.pi ) - self.num_pts = fourier_pts.shape[1] + self.num_pts = self.fourier_pts.shape[1] self.epsilon = max(epsilon, np.finfo(self.dtype).eps) self._transform_plan = cufPlan( @@ -83,12 +81,8 @@ def __init__(self, sz, fourier_pts, epsilon=1e-8, ntransforms=1, **kwargs): **self.adjoint_opts, ) - # Note, I store self.fourier_pts_gpu so the GPUArrray life - # is tied to instance, instead of this method. - self.fourier_pts_gpu = gpuarray.to_gpu(self.fourier_pts) - - self._transform_plan.setpts(*self.fourier_pts_gpu) - self._adjoint_plan.setpts(*self.fourier_pts_gpu) + self._transform_plan.setpts(*self.fourier_pts) + self._adjoint_plan.setpts(*self.fourier_pts) def transform(self, signal): """ @@ -99,7 +93,7 @@ def transform(self, signal): For a batch, signal should have shape `(*sz, ntransforms)`. :returns: Transformed signal of shape `num_pts` or - `(ntransforms, num_pts)`. + `(ntransforms, num_pts)` as CuPy array. """ # Check we're not forcing a dtype workaround for ASPIRE-Python/703, @@ -113,6 +107,9 @@ def transform(self, signal): " In the future this will be an error." ) + # Note, if not C order, cuFINUFFT will copy-cast anyway. + signal = cp.asarray(signal, order="C", dtype=self.complex_dtype) + sig_shape = signal.shape res_shape = self.num_pts # Note, there is a corner case for ntransforms == 1. @@ -134,17 +131,16 @@ def transform(self, signal): sig_shape == self.sz ), f"Signal frame to be transformed must have shape {self.sz}" - signal_gpu = gpuarray.to_gpu( - np.ascontiguousarray(signal, dtype=self.complex_dtype) - ) + result = cp.empty(res_shape, dtype=self.complex_dtype) - result_gpu = gpuarray.GPUArray(res_shape, dtype=self.complex_dtype) + if signal.dtype != self.complex_dtype: + signal = signal.astype(self.complex_dtype) - self._transform_plan.execute(signal_gpu, out=result_gpu) + self._transform_plan.execute(signal, out=result) - result = result_gpu.get() # ASPIRE-Python/703 - result = result.astype(complex_type(self._original_dtype), copy=False) + if result.dtype != complex_type(self._original_dtype): + result = result.astype(complex_type(self._original_dtype)) return result @@ -156,7 +152,7 @@ def adjoint(self, signal): this should be a a 1D array of len `num_pts`. For a batch, signal should have shape `(ntransforms, num_pts)`. - :returns: Transformed signal `(sz)` or `(sz, ntransforms)`. + :returns: Transformed signal `(sz)` or `(sz, ntransforms)` as CuPy array. """ # Check we're not forcing a dtype workaround for ASPIRE-Python/703, @@ -170,6 +166,9 @@ def adjoint(self, signal): " In the future this will be an error." ) + # Note, if not C order, cuFINUFFT will copy-cast anyway. + signal = cp.asarray(signal, order="C", dtype=self.complex_dtype) + res_shape = self.sz # Note, there is a corner case for ntransforms == 1. if self.ntransforms > 1 or (self.ntransforms == 1 and len(signal.shape) == 2): @@ -181,16 +180,15 @@ def adjoint(self, signal): ), "For multiple transforms, signal stack length should match ntransforms {self.ntransforms}." res_shape = (self.ntransforms, *self.sz) - signal_gpu = gpuarray.to_gpu( - np.ascontiguousarray(signal, dtype=self.complex_dtype) - ) + result = cp.empty(res_shape, dtype=self.complex_dtype) - result_gpu = gpuarray.GPUArray(res_shape, dtype=self.complex_dtype) + if signal.dtype != self.complex_dtype: + signal = signal.astype(self.complex_dtype) - self._adjoint_plan.execute(signal_gpu, out=result_gpu) + self._adjoint_plan.execute(signal, out=result) - result = result_gpu.get() # ASPIRE-Python/703 - result = result.astype(complex_type(self._original_dtype), copy=False) + if result.dtype != complex_type(self._original_dtype): + result = result.astype(complex_type(self._original_dtype)) return result diff --git a/src/aspire/numeric/__init__.py b/src/aspire/numeric/__init__.py index d298f131e4..be88775498 100644 --- a/src/aspire/numeric/__init__.py +++ b/src/aspire/numeric/__init__.py @@ -35,3 +35,37 @@ def fft_object(which): fft = fft_object(config["common"]["fft"].as_str()) + +# Sanity check. +if (config["common"]["numeric"].as_str() == "cupy") and ( + config["common"]["fft"].as_str() != "cupy" +): + raise RuntimeError( + "Using `cupy` numeric backend without `cupy` fft is unsupported." + ) + +if (config["common"]["fft"].as_str() == "cupy") and ( + config["common"]["numeric"].as_str() != "cupy" +): + raise RuntimeError( + "Using `cupy` fft without `cupy` numeric backend is unsupported." + ) + + +# Configure `sparse` in tandem with `numeric` as the arrays generally will need to interoperate. +def sparse_object(which): + if which == "cupy": + from cupyx.scipy import sparse as SparseClass + + # CuPy imports don't work the same as scipy + from cupyx.scipy.sparse.linalg import eigsh + + SparseClass.linalg.eigsh = eigsh + elif which == "numpy": + from scipy import sparse as SparseClass + else: + raise RuntimeError(f"Invalid selection for sparse module: {which}") + return SparseClass + + +sparse = sparse_object(config["common"]["numeric"].as_str()) diff --git a/src/aspire/numeric/cupy_fft.py b/src/aspire/numeric/cupy_fft.py index 4f45f92117..b491a0dcd1 100644 --- a/src/aspire/numeric/cupy_fft.py +++ b/src/aspire/numeric/cupy_fft.py @@ -1,8 +1,59 @@ +import functools + import cupy as cp +import cupyx.scipy.fft as cufft +import numpy as np from aspire.numeric.base_fft import FFT +# This improves the flexibility of our FFT wrappers by allowing for +# incremental code changes and testing. +def _preserve_host(func): + """ + Method decorator that returns a numpy/cupy array result when + passed a numpy/cupy array input respectively. + + At the time of writing this wrapper will also upcast cupy FFT + operations to doubles as the precision in singles can cause + accuracy issues. + """ + + @functools.wraps(func) # Pass metadata (eg name and doctrings) from `func` + def wrapper(self, x, *args, **kwargs): + + # CuPy's single precision FFT appears to be too inaccurate for + # many of our unit tests, so the signal is upcast and recast + # on return. + _singles = False + if x.dtype == np.float32: + _singles = True + x = x.astype(np.float64) + elif x.dtype == np.complex64: + _singles = True + x = x.astype(np.complex128) + + _host = False + if not isinstance(x, cp.ndarray): + _host = True + x = cp.asarray(x) + + res = func(self, x, *args, **kwargs) + + if _host: + res = res.get() + + # Recast if needed. + if _singles and res.dtype == np.float64: + res = res.astype(np.float32) + elif _singles and res.dtype == np.complex128: + res = res.astype(np.complex64) + + return res + + return wrapper + + class CupyFFT(FFT): """ Define a unified wrapper class for Cupy FFT functions @@ -10,26 +61,53 @@ class CupyFFT(FFT): To be consistent with Scipy and Pyfftw, not all arguments are included. """ + @_preserve_host def fft(self, x, axis=-1, workers=-1): return cp.fft.fft(x, axis=axis) + @_preserve_host def ifft(self, x, axis=-1, workers=-1): return cp.fft.ifft(x, axis=axis) + @_preserve_host def fft2(self, x, axes=(-2, -1), workers=-1): return cp.fft.fft2(x, axes=axes) + @_preserve_host def ifft2(self, x, axes=(-2, -1), workers=-1): return cp.fft.ifft2(x, axes=axes) + @_preserve_host def fftn(self, x, axes=None, workers=-1): return cp.fft.fftn(x, axes=axes) + @_preserve_host def ifftn(self, x, axes=None, workers=-1): return cp.fft.ifftn(x, axes=axes) + @_preserve_host def fftshift(self, x, axes=None): return cp.fft.fftshift(x, axes=axes) + @_preserve_host def ifftshift(self, x, axes=None): return cp.fft.ifftshift(x, axes=axes) + + @_preserve_host + def dct(self, x, **kwargs): + return cufft.dct(x, **kwargs) + + @_preserve_host + def idct(self, x, **kwargs): + return cufft.idct(x, **kwargs) + + def rfftfreq(self, n, **kwargs): + return cufft.rfftfreq(n, **kwargs) + + @_preserve_host + def irfft(self, x, **kwargs): + return cufft.irfft(x, **kwargs) + + @_preserve_host + def rfft(self, x, **kwargs): + return cufft.rfft(x, **kwargs) diff --git a/src/aspire/numeric/numpy.py b/src/aspire/numeric/numpy.py index 3237c2c3ad..ddc8355816 100644 --- a/src/aspire/numeric/numpy.py +++ b/src/aspire/numeric/numpy.py @@ -1,8 +1,22 @@ import numpy as np +cp = None +try: + import cupy as cp +except ModuleNotFoundError: + pass + class Numpy: - asnumpy = staticmethod(lambda x: x) + # This can be required when mixing nufft/fft/numpy backend combinations. + @staticmethod + def asnumpy(x): + """ + Ensure `asnumpy` is always available and returns a numpy array. + """ + if cp and isinstance(x, cp.ndarray): + x = x.get() + return x def __getattr__(self, item): """ diff --git a/src/aspire/numeric/pyfftw_fft.py b/src/aspire/numeric/pyfftw_fft.py index 9cfdd45210..95a8ea80f7 100644 --- a/src/aspire/numeric/pyfftw_fft.py +++ b/src/aspire/numeric/pyfftw_fft.py @@ -159,3 +159,9 @@ def fftshift(self, a, axes=None): def ifftshift(self, a, axes=None): return scipy_fft.ifftshift(a, axes=axes) + + def dct(self, x, **kwargs): + return scipy_fft.dct(x, **kwargs) + + def idct(self, x, **kwargs): + return scipy_fft.idct(x, **kwargs) diff --git a/src/aspire/numeric/scipy.py b/src/aspire/numeric/scipy.py index 8f2e7d86d8..c913e917a3 100644 --- a/src/aspire/numeric/scipy.py +++ b/src/aspire/numeric/scipy.py @@ -8,10 +8,11 @@ def cg(*args, **kwargs): """ - Supports scipy cg before and after 1.14.0. + Supports scipy cg before and after 1.12.0. """ - # older scipy cg interface uses `tol` instead of `rtol` - if Version(scipy.__version__) < Version("1.14.0"): + # older (<1.12.0) scipy cg interface uses `tol` instead of `rtol`. + # `tol` will be removed in scipy 1.14.0. + if Version(scipy.__version__) < Version("1.12.0"): kwargs["tol"] = kwargs.pop("rtol", None) return scipy.sparse.linalg.cg(*args, **kwargs) diff --git a/src/aspire/numeric/scipy_fft.py b/src/aspire/numeric/scipy_fft.py index c5a392f96b..0ef5c95f16 100644 --- a/src/aspire/numeric/scipy_fft.py +++ b/src/aspire/numeric/scipy_fft.py @@ -33,3 +33,18 @@ def fftshift(self, x, axes=None): def ifftshift(self, x, axes=None): return sp.fft.ifftshift(x, axes=axes) + + def dct(self, x, **kwargs): + return sp.fft.dct(x, **kwargs) + + def idct(self, x, **kwargs): + return sp.fft.idct(x, **kwargs) + + def rfftfreq(self, x, **kwargs): + return sp.fft.rfftfreq(x, **kwargs) + + def irfft(self, x, **kwargs): + return sp.fft.irfft(x, **kwargs) + + def rfft(self, x, **kwargs): + return sp.fft.rfft(x, **kwargs) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 9b910a8fe0..e75187fb4a 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -184,9 +184,19 @@ class PowerFilter(Filter): A Filter object that is composed of a regular `Filter` object, but evaluates it to a specified power. """ - def __init__(self, filter, power=1): + def __init__(self, filter, power=1, epsilon=None): + """ + Initialize PowerFilter instance. + + :param filter: A Filter instance. + :param power: Exponent to raise filter values. + :param epsilon: Threshold on filter values that get raised to a negative power. + `filter` values below this threshold will be set to zero during evaluation. + Default uses machine epsilon for filter.dtype. + """ self._filter = filter self._power = power + self._epsilon = epsilon super().__init__(dim=filter.dim, radial=filter.radial) def _evaluate(self, omega): @@ -204,7 +214,9 @@ def evaluate_grid(self, L, *args, dtype=np.float32, **kwargs): # Place safeguard on values below machine epsilon for negative powers. if self._power < 0: - eps = np.finfo(filter_vals.dtype).eps + eps = self._epsilon + if eps is None: + eps = np.finfo(filter_vals.dtype).eps condition = abs(filter_vals) < eps num_less_eps = np.count_nonzero(condition) if num_less_eps > 0: @@ -391,7 +403,7 @@ def __init__(self, dim=None): class CTFFilter(Filter): def __init__( self, - pixel_size=10, + pixel_size=1, voltage=200, defocus_u=15000, defocus_v=15000, @@ -403,7 +415,7 @@ def __init__( """ A CTF (Contrast Transfer Function) Filter - :param pixel_size: Pixel size in angstrom + :param pixel_size: Pixel size in angstrom, default 1. :param voltage: Electron voltage in kV :param defocus_u: Defocus depth along the u-axis in angstrom :param defocus_v: Defocus depth along the v-axis in angstrom @@ -413,7 +425,7 @@ def __init__( :param B: Envelope decay in inverse square angstrom (default 0) """ super().__init__(dim=2, radial=defocus_u == defocus_v) - self.pixel_size = pixel_size + self.pixel_size = float(pixel_size) self.voltage = voltage self.wavelength = voltage_to_wavelength(self.voltage) self.defocus_u = defocus_u @@ -470,7 +482,7 @@ def scale(self, c=1): class RadialCTFFilter(CTFFilter): def __init__( - self, pixel_size=10, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0 + self, pixel_size=1, voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0 ): super().__init__( pixel_size=pixel_size, diff --git a/src/aspire/operators/wemd.py b/src/aspire/operators/wemd.py index 45db7203bb..09fd1db563 100644 --- a/src/aspire/operators/wemd.py +++ b/src/aspire/operators/wemd.py @@ -46,7 +46,7 @@ def wemd_embed(arr, wavelet="coif3", level=None): message="Level value of .* is too high:" " all coefficients will experience boundary effects.", ) - arrdwt = pywt.wavedecn(arr, wavelet, mode="zero", level=level) + arrdwt = pywt.wavedecn(arr - arr.mean(), wavelet, mode="zero", level=level) detail_coefs = arrdwt[1:] assert len(detail_coefs) == level diff --git a/src/aspire/reconstruction/estimator.py b/src/aspire/reconstruction/estimator.py index 9d62a0b765..2bf0ec3366 100644 --- a/src/aspire/reconstruction/estimator.py +++ b/src/aspire/reconstruction/estimator.py @@ -17,7 +17,7 @@ def __init__( preconditioner="circulant", checkpoint_iterations=10, checkpoint_prefix="volume_checkpoint", - maxiter=100, + maxiter=50, boost=True, ): """ @@ -34,9 +34,12 @@ def __init__( `src` during back projection and kernel estimation steps. :param preconditioner: Optional kernel preconditioner (`string`). Currently supported options are "circulant" or None. - :param checkpoint_iterations: Optionally save `cg` estimated `Volume` - instance periodically each `checkpoint_iterations`. - Setting to None disables, otherwise checks for positive integer. + :param checkpoint_iterations: Optionally save `cg` estimated + `basis` coefficients periodically each + `checkpoint_iterations`. Setting to `None` disables, + otherwise checks for positive integer. Note, when + `maxiter` is not `None` and `cg` fails to converge a final + checkpoint will still be written. :param checkpoint_prefix: Optional path prefix for `cg` checkpoint files. If the parent directory does not exist, creation is attempted. `_iter{N}` will be appended to the diff --git a/src/aspire/reconstruction/mean.py b/src/aspire/reconstruction/mean.py index 58f9c566b6..d0cade9754 100644 --- a/src/aspire/reconstruction/mean.py +++ b/src/aspire/reconstruction/mean.py @@ -231,7 +231,7 @@ def cb(xk): # Do checkpoint at `checkpoint_iterations`, _do_checkpoint = ( - self.checkpoint_iterations + self.checkpoint_iterations is not None and (self.i % self.checkpoint_iterations) == 0 ) # or the last iteration when `maxiter` provided. @@ -258,7 +258,9 @@ def cb(xk): ) if info != 0: - raise RuntimeError("Unable to converge!") + logger.warning( + f"Conjugate gradient unable to converge after {info} iterations." + ) return x.reshape(self.r, self.basis.count) diff --git a/src/aspire/sinogram/__init__.py b/src/aspire/sinogram/__init__.py new file mode 100644 index 0000000000..98e489eedf --- /dev/null +++ b/src/aspire/sinogram/__init__.py @@ -0,0 +1 @@ +from .sinogram import Sinogram diff --git a/src/aspire/sinogram/sinogram.py b/src/aspire/sinogram/sinogram.py new file mode 100644 index 0000000000..7c7bb43662 --- /dev/null +++ b/src/aspire/sinogram/sinogram.py @@ -0,0 +1,155 @@ +import logging + +import numpy as np + +import aspire.image +from aspire.nufft import anufft +from aspire.numeric import fft, xp + +logger = logging.getLogger(__name__) + + +class Sinogram: + def __init__(self, data, dtype=None): + """ + Initialize a Sinogram Object. This is a stack of one or more line projections or sinograms. + + The stack can be multidimensional with 'self.n' equal to the product + of the stack dimensions. Singletons will be expanded into a stack + with one entry. + + :param data: Numpy array containing image data with shape + `(..., angles, radial points)`. + :param dtype: Optionally cast `data` to this dtype. + Defaults to `data.dtype`. + + :return: Sinogram instance holding `data`. + """ + if dtype is None: + self.dtype = data.dtype + else: + self.dtype = np.dtype(dtype) + + if data.ndim == 2: + data = data[np.newaxis, :, :] + if data.ndim < 3: + raise ValueError( + f"Invalid data shape: {data.shape}. Expected shape: (..., angles, radial_points), where '...' is the stack number." + ) + + self._data = data.astype(self.dtype, copy=False) + self.ndim = self._data.ndim + self.shape = self._data.shape + self.stack_shape = self._data.shape[:-2] + self.stack_n_dim = self._data.ndim - 2 + self.n = np.prod(self.stack_shape) + self.n_angles = self._data.shape[-2] + self.n_radial_points = self._data.shape[-1] + + # Numpy interop + # https://numpy.org/devdocs/user/basics.interoperability.html#the-array-interface-protocol + self.__array_interface__ = self._data.__array_interface__ + self.__array__ = self._data + + def _check_key_dims(self, key): + if isinstance(key, tuple) and (len(key) > self._data.ndim): + raise ValueError( + f"Sinogram stack_dim is {self.stack_n_dim}, slice length must be =< {self.n_dim}" + ) + + def __getitem__(self, key): + self._check_key_dims(key) + return self.__class__(self._data[key]) + + def __setitem__(self, key, value): + self._check_key_dims(key) + self._data[key] = value + + def stack_reshape(self, *args): + """ + Reshape the stack axis. + + :*args: Integer(s) or tuple describing the intended shape. + + :return: Sinogram instance + """ + + # If we're passed a tuple, use that + if len(args) == 1 and isinstance(args[0], tuple): + shape = args[0] + else: + # Otherwise use the variadic args + shape = args + + # Sanity check the size + if shape != (-1,) and np.prod(shape) != self.n: + raise ValueError( + f"Number of sinogram images {self.n} cannot be reshaped to {shape}." + ) + + return self.__class__(self._data.reshape(*shape, *self._data.shape[-2:])) + + def asnumpy(self): + """ + Return image data as a (, angles, radians) + read-only array view. + + :return: read-only ndarray view + """ + + view = self._data.view() + view.flags.writeable = False + return view + + def copy(self): + return self.__class__(self._data.copy()) + + def __str__(self): + return f"Sinogram(n_images = {self.n}, n_angles = {self.n_angles}, n_radial_points = {self.n_radial_points})" + + def __repr__(self): + msg = f"Sinogram: {self.n} images of dtype {self.dtype}, " + msg += f"arranged as a stack with shape {self.stack_shape}. " + msg += f"Each image has {self.n_angles} angles and {self.n_radial_points} radial points." + return msg + + def backproject(self, angles): + """ + Backprojection method for a single stack of lines. + + :param angles: np.ndarray + 1D array of angles in radians. Each entry in the array + corresponds to different angles which are used to + reconstruct the image. + :return: An Image object containing the original stack size + with a newly reconstructed numpy array of the images. + Expected return shape should be (..., n_radial_points, n_radial_points) + """ + if len(angles) != self.n_angles: + raise ValueError("Number of angles must match the number of projections.") + + original_stack_shape = self.stack_shape + sinogram = xp.asarray(self.stack_reshape(-1)._data) + L = self.n_radial_points + sinogram = fft.ifftshift(sinogram, axes=-1) + sinogram_ft = fft.rfft(sinogram, axis=-1) + sinogram_ft *= xp.pi # Fix scale to match + sinogram_ft[..., 0] /= 2 # Fix DC + angles = xp.asarray(angles) + + # grid generation with real points + y_idx = fft.rfftfreq(self.n_radial_points) * xp.pi * 2 + n_real_points = len(y_idx) + pts = xp.empty((2, len(angles), n_real_points), dtype=self.dtype) + pts[0] = y_idx[xp.newaxis, :] * xp.sin(angles)[:, xp.newaxis] + pts[1] = y_idx[xp.newaxis, :] * xp.cos(angles)[:, xp.newaxis] + + imgs = anufft( + sinogram_ft.reshape(self.n, -1), + pts.reshape(2, n_real_points * len(angles)), + sz=(L, L), + real=True, + ).reshape(self.n, L, L) + + imgs = imgs / (self.n_radial_points * len(angles)) + return aspire.image.Image(xp.asnumpy(imgs)).stack_reshape(original_stack_shape) diff --git a/src/aspire/source/coordinates.py b/src/aspire/source/coordinates.py index dca7aaf873..299422df70 100644 --- a/src/aspire/source/coordinates.py +++ b/src/aspire/source/coordinates.py @@ -490,7 +490,9 @@ def _images(self, indices): cropped = self._crop_micrograph(arr, next(coord)) im[i] = cropped # Finally, apply transforms to resulting Image - return self.generation_pipeline.forward(Image(im), indices) + return self.generation_pipeline.forward( + Image(im, pixel_size=self.pixel_size), indices + ) @staticmethod def _is_number(text): diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index 473585acb9..285b7163b5 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -2,7 +2,6 @@ import functools import logging import os.path -import types from abc import ABC, abstractmethod from collections import OrderedDict from collections.abc import Iterable @@ -150,7 +149,14 @@ class ImageSource(ABC): _mutable = True def __init__( - self, L, n, dtype="double", metadata=None, memory=None, symmetry_group=None + self, + L, + n, + dtype="double", + metadata=None, + memory=None, + symmetry_group=None, + pixel_size=None, ): """ A cryo-EM ImageSource object that supplies images along with other parameters for image manipulation. @@ -163,6 +169,7 @@ def __init__( The path of the base directory to use as a data store or None. If None is given, no caching is performed. :param symmetry_group: A SymmetryGroup instance or string indicating the underlying symmetry of the molecule. Defaults to the `IdentitySymmetryGroup`, which represents an asymmetric particle, if none provided. + :param pixel_size: Pixel size of the images in angstroms, default `None`. """ # Instantiate the accessor for the `images` property @@ -172,6 +179,9 @@ def __init__( self._n = None self.n = n self.dtype = np.dtype(dtype) + if pixel_size is not None: + pixel_size = float(pixel_size) + self.pixel_size = pixel_size # The private attribute '_cached_im' can be populated by calling this object's cache() method explicitly self._cached_im = None @@ -202,40 +212,6 @@ def __init__( logger.info(f"Creating {self.__class__.__name__} with {len(self)} images.") - def __deepcopy__(self, memo): - """ - A custom __deepcopy__ implementation to individually handle special cases. - Mostly copied over from https://stackoverflow.com/a/71125311 - """ - # Get a reference to the bound deepcopy method - deepcopy_method = self.__deepcopy__ - # Temporarily disable __deepcopy__ to avoid infinite recursion - self.__deepcopy__ = None - # Create a deepcopy cp using the normal procedure - cp = copy.deepcopy(self, memo) - - # -------------------------------------- - # Handle any special cases for cp here. - # -------------------------------------- - # This is the whole reason this method exists. If this section is empty, - # then this entire __deepcopy__ implementation can be removed. - - # The 'dtype' attribute is a numpy module level singleton obtained by np.dtype(..) call - # The 'finufft' library currently compares this to the result of a new np.dtype(..) call - # by reference, not by value (as it should). A deepcopy will make a copy of the singleton, - # and thus comparison by reference will fail. Till this bug in 'finufft' is removed, we assign - # self.dtype to dtype - cp.dtype = self.dtype - - # -------------------------------------- - - # Reattach the bound deepcopy method - self.__deepcopy__ = deepcopy_method - # Get the unbounded function corresponding to the bound deepcopy method and rebind to cp - cp.__deepcopy__ = types.MethodType(deepcopy_method.__func__, cp) - - return cp - @property def symmetry_group(self): """ @@ -736,7 +712,7 @@ def _apply_filters( f"_apply_filters() passed {type(im_orig)} instead of Image instance" ) # for now just convert it - im_orig = Image(im_orig) + im_orig = Image(im_orig, pixel_size=self.pixel_size) im = im_orig.copy() @@ -798,7 +774,7 @@ def downsample(self, L): self.L = L @_as_copy - def whiten(self, noise_estimate=None): + def whiten(self, noise_estimate=None, epsilon=None): """ Modify the `ImageSource` in-place by appending a whitening filter to the generation pipeline. @@ -810,6 +786,9 @@ def whiten(self, noise_estimate=None): passed a `NoiseEstimator` the `filter` attribute will be queried. Alternatively, the noise PSD may be passed directly as a `Filter` object. + :param epsilon: Threshold used to determine which frequencies to whiten + and which to set to zero. By default all PSD values in the `noise_estimate` + less than eps(self.dtype) are zeroed out in the whitening filter. :return: On return, the `ImageSource` object has been modified in place. """ @@ -827,8 +806,11 @@ def whiten(self, noise_estimate=None): " instead of `NoiseEstimator` or `Filter`." ) + if epsilon is None: + epsilon = np.finfo(self.dtype).eps + logger.info("Whitening source object") - whiten_filter = PowerFilter(noise_filter, power=-0.5) + whiten_filter = PowerFilter(noise_filter, power=-0.5, epsilon=epsilon) logger.info("Transforming all CTF Filters into Multiplicative Filters") self.unique_filters = [ @@ -1475,6 +1457,7 @@ def __init__(self, src, indices, memory=None): dtype=src.dtype, metadata=metadata, memory=memory, + pixel_size=src.pixel_size, ) # Create filter indices, these are required to pass unharmed through filter eval code @@ -1644,7 +1627,9 @@ class ArrayImageSource(ImageSource): if available, is consulted directly by the parent class, bypassing `_images`. """ - def __init__(self, im, metadata=None, angles=None, symmetry_group=None): + def __init__( + self, im, metadata=None, angles=None, symmetry_group=None, pixel_size=None + ): """ Initialize from an `Image` object. @@ -1653,12 +1638,13 @@ def __init__(self, im, metadata=None, angles=None, symmetry_group=None): :param metadata: A Dataframe of metadata information corresponding to this ImageSource's images :param angles: Optional n-by-3 array of rotation angles corresponding to `im`. :param symmetry_group: A SymmetryGroup instance or string indicating the underlying symmetry of the molecule. + :param pixel_size: Pixel size of the images in angstroms, default `None`. """ if not isinstance(im, Image): logger.info("Attempting to create an Image object from Numpy array.") try: - im = Image(im) + im = Image(im, pixel_size=pixel_size) except Exception as e: raise RuntimeError( "Creating Image object from Numpy array failed." @@ -1672,6 +1658,7 @@ def __init__(self, im, metadata=None, angles=None, symmetry_group=None): metadata=metadata, memory=None, symmetry_group=symmetry_group, + pixel_size=im.pixel_size, ) self._cached_im = im diff --git a/src/aspire/source/micrograph.py b/src/aspire/source/micrograph.py index 182133d982..2d654401b5 100644 --- a/src/aspire/source/micrograph.py +++ b/src/aspire/source/micrograph.py @@ -17,11 +17,14 @@ class MicrographSource(ABC): - def __init__(self, micrograph_count, micrograph_size, dtype): + def __init__(self, micrograph_count, micrograph_size, dtype, pixel_size=None): """ """ self.micrograph_count = int(micrograph_count) self.micrograph_size = int(micrograph_size) self.dtype = np.dtype(dtype) + if pixel_size is not None: + pixel_size = float(pixel_size) + self.pixel_size = pixel_size self._images_accessor = _ImageAccessor(self._images, self.micrograph_count) @@ -85,7 +88,7 @@ def show(self, *args, **kwargs): """ Helper function to display micrograph. See Image.show(). """ - Image(self.asnumpy()).show(*args, **kwargs) + Image(self.asnumpy(), pixel_size=self.pixel_size).show(*args, **kwargs) @property def images(self): @@ -107,7 +110,7 @@ def _images(self, indices): class ArrayMicrographSource(MicrographSource): - def __init__(self, micrographs, dtype=None): + def __init__(self, micrographs, dtype=None, pixel_size=None): """ Instantiate a `MicrographSource` with `micrographs`. @@ -119,6 +122,7 @@ def __init__(self, micrographs, dtype=None): Currently only `float32` and `float64` are supported. Note, due to limitations of common MRC implementations, saving is limited to single precision. + :param pixel_size: Pixel size of the images in angstroms, default `None`. """ # Check micrographs is an array @@ -140,6 +144,7 @@ def __init__(self, micrographs, dtype=None): micrograph_count=micrographs.shape[0], micrograph_size=micrographs.shape[-1], dtype=dtype or micrographs.dtype, + pixel_size=pixel_size, ) # We're already backed by an array, access it directly. @@ -152,11 +157,11 @@ def _images(self, indices): :param indices: A 1-D Numpy array of integer indices. :return: An array backed `MicrographSource` object representing the micrographs for `indices`. """ - return Image(self._data[indices]) + return Image(self._data[indices], pixel_size=self.pixel_size) class DiskMicrographSource(MicrographSource): - def __init__(self, micrographs_path, dtype=None): + def __init__(self, micrographs_path, dtype=None, pixel_size=None): """ Instantiate a `MicrographSource` with `micrographs_path`. @@ -190,11 +195,16 @@ def __init__(self, micrographs_path, dtype=None): # Load the first micrograph to infer shape/type # Size will be checked during on-the-fly loading of subsequent micrographs. micrograph0 = Image.load(self.micrograph_files[0]) + if micrograph0.pixel_size is not None and micrograph0.pixel_size != pixel_size: + raise ValueError( + f"Mismatched pixel size. {micrograph0.pixel_size} angstroms defined in {self.micrograph_files[0]}, but provided {pixel_size} angstroms." + ) super().__init__( micrograph_count=len(self.micrograph_files), micrograph_size=micrograph0.resolution, dtype=dtype or micrograph0.dtype, + pixel_size=pixel_size, ) # Prepare accessor to load files from disk on the fly. @@ -262,8 +272,16 @@ def _images(self, indices): ) # Assign to array, implicitly performs casting to dtype micrographs[i] = micrograph.asnumpy() + # Assert pixel_size + if ( + micrograph.pixel_size is not None + and micrograph.pixel_size != self.pixel_size + ): + raise ValueError( + f"Mismatched pixel size. {micrograph.pixel_size} angstroms defined in {self.micrograph_files[ind]}, but provided {self.pixel_size} angstroms." + ) - return Image(micrographs) + return Image(micrographs, pixel_size=self.pixel_size) class MicrographSimulation(MicrographSource): @@ -557,7 +575,7 @@ def _clean_images(self, indices): self.pad : self.micrograph_size + self.pad, self.pad : self.micrograph_size + self.pad, ] - return Image(clean_micrograph) + return Image(clean_micrograph, pixel_size=self.pixel_size) def get_micrograph_index(self, particle_index): """ diff --git a/src/aspire/source/relion.py b/src/aspire/source/relion.py index 99907cbf6a..bd6d660dd3 100644 --- a/src/aspire/source/relion.py +++ b/src/aspire/source/relion.py @@ -59,7 +59,6 @@ def __init__( self.filepath = filepath self.data_folder = data_folder - self.pixel_size = pixel_size self.B = B self.n_workers = n_workers self.max_rows = max_rows @@ -112,6 +111,7 @@ def __init__( metadata=metadata, symmetry_group=symmetry_group, memory=memory, + pixel_size=pixel_size, ) # CTF estimation parameters coming from Relion @@ -272,4 +272,6 @@ def load_single_mrcs(filepath, indices): logger.debug(f"Loading {len(indices)} images complete") # Finally, apply transforms to resulting Image - return self.generation_pipeline.forward(Image(im), indices) + return self.generation_pipeline.forward( + Image(im, pixel_size=self.pixel_size), indices + ) diff --git a/src/aspire/source/simulation.py b/src/aspire/source/simulation.py index e2ef10da12..331b86e442 100644 --- a/src/aspire/source/simulation.py +++ b/src/aspire/source/simulation.py @@ -50,6 +50,7 @@ def __init__( memory=None, noise_adder=None, symmetry_group=None, + pixel_size=None, ): """ A `Simulation` object that supplies images along with other parameters for image manipulation. @@ -79,6 +80,7 @@ def __init__( :param noise_adder: Optionally append instance of `NoiseAdder` to generation pipeline. :param symmetry_group: A SymmetryGroup instance or string indicating symmetry of the molecule. + :param pixel_size: Pixel size of the images in angstroms, default `None`. :return: A Simulation object. """ @@ -91,6 +93,7 @@ def __init__( self.vols = AsymmetricVolume( L=L or 8, C=C, + pixel_size=pixel_size, seed=self.seed, dtype=dtype or np.float32, ).generate() @@ -122,6 +125,7 @@ def __init__( dtype=self.vols.dtype, memory=memory, symmetry_group=symmetry_group, + pixel_size=self.vols.pixel_size, ) # If a user provides both `L` and `vols`, resolution should match. @@ -153,6 +157,7 @@ def __init__( if unique_filters is None: unique_filters = [] self.unique_filters = unique_filters + self._check_filter_pixel_size(unique_filters) # sim_filters must be a deep copy so that it is not changed # when unique_filters is changed self.sim_filters = copy.deepcopy(unique_filters) @@ -231,6 +236,29 @@ def _populate_ctf_metadata(self, filter_indices): filter_values, ) + def _check_filter_pixel_size(self, unique_filters): + """ + Private method to ensure user provided filters match `Simulation` pixel size. + + When `Simulation.pixel_size` is not `None`, any + `unique_filters` having a non-matching `pixel_size` attribute + will raise. + """ + + # Skip when Simulation pixel_size is not explicitly provided. + if self.pixel_size is None: + return + + for f in unique_filters: + f_pixel_size = getattr(f, "pixel_size", None) + if f_pixel_size is not None and not np.isclose( + f_pixel_size, self.pixel_size + ): + raise ValueError( + f"`Simulation.pixel_size` {self.pixel_size} does not match filter {f} pixel size {f_pixel_size}." + "Ensure provided `pixel_size` attributes match." + ) + @property def projections(self): """ @@ -245,7 +273,7 @@ def projections(self): def _projections(self, indices): """ - Accesses and returns projections as an `Image` instance. Called by self._projections_accessor + Accesses and returns projections as an `Image` instance. Called by self._projections_accessor. """ im = np.zeros( (len(indices), self._original_L, self._original_L), dtype=self.dtype @@ -260,7 +288,7 @@ def _projections(self, indices): im_k = self.vols[k - 1].project(rot_matrices=rot) im[idx_k, :, :] = im_k.asnumpy() - return Image(im) + return Image(im, pixel_size=self.pixel_size) @property def clean_images(self): diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index e909e2f394..b33098d3b4 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -2,15 +2,19 @@ General purpose math functions, mostly geometric in nature. """ +import logging import math import numpy as np from numpy.linalg import norm from scipy.linalg import svd +from aspire.numeric import xp from aspire.utils.random import Random from aspire.utils.rotation import Rotation +logger = logging.getLogger(__name__) + def cart2pol(x, y): """ @@ -298,6 +302,7 @@ def mean_aligned_angular_distance(rots_est, rots_gt, degree_tol=None): and the ground truth (in degrees). """ Q_mat, flag = register_rotations(rots_est, rots_gt) + logger.debug(f"Registration Q_mat: {Q_mat}\nflag: {flag}") regrot = get_aligned_rotations(rots_est, Q_mat, flag) mean_ang_dist = Rotation.mean_angular_distance(regrot, rots_gt) * 180 / np.pi @@ -321,7 +326,7 @@ def common_line_from_rots(r1, r2, ell): ut = np.dot(r2, r1.T) alpha_ij = np.arctan2(ut[2, 0], -ut[2, 1]) + np.pi - alpha_ji = np.arctan2(ut[0, 2], -ut[1, 2]) + np.pi + alpha_ji = np.arctan2(-ut[0, 2], ut[1, 2]) + np.pi ell_ij = alpha_ij * ell / (2 * np.pi) ell_ji = alpha_ji * ell / (2 * np.pi) @@ -368,54 +373,105 @@ def rots_to_clmatrix(rots, n_theta): def crop_pad_2d(im, size, fill_value=0): """ - :param im: A 2-dimensional numpy array + Crop/pads `im` according to `size`. + + Padding will use `fill_value`. + Return's host/GPU array based on `im`. + + :param im: A >=2-dimensional numpy array :param size: Integer size of cropped/padded output - :return: A numpy array of shape (size, size) + :return: Array of shape (..., size, size) """ - im_y, im_x = im.shape + im_y, im_x = im.shape[-2:] # shift terms start_x = math.floor(im_x / 2) - math.floor(size / 2) start_y = math.floor(im_y / 2) - math.floor(size / 2) # cropping if size <= min(im_y, im_x): - return im[start_y : start_y + size, start_x : start_x + size] + return im[..., start_y : start_y + size, start_x : start_x + size] # padding elif size >= max(im_y, im_x): - # ensure that we return in the same dtype as the input - to_return = fill_value * np.ones((size, size), dtype=im.dtype) + # Determine shape + shape = list(im.shape[:-2]) + shape.extend([size, size]) + + # Ensure that we return the same dtype as the input + _full = np.full # Default to numpy array + if isinstance(im, xp.ndarray): + # Use cupy when `im` _and_ xp are cupy ndarray + # Avoids having to handle when cupy is not installed + _full = xp.full + + to_return = _full(shape, fill_value, dtype=im.dtype) + # when padding, start_x and start_y are negative since size is larger # than im_x and im_y; the below line calculates where the original image # is placed in relation to the (now-larger) box size - to_return[-start_y : im_y - start_y, -start_x : im_x - start_x] = im + to_return[..., -start_y : im_y - start_y, -start_x : im_x - start_x] = im return to_return else: # target size is between mat_x and mat_y - raise ValueError("Cannot crop and pad an image at the same time.") + raise ValueError( + "Cannot crop and pad Image at the same time." + "If this is really what you intended," + " make two seperate calls for cropping and padding." + ) + + +def crop_pad_3d(vol, size, fill_value=0): + """ + Crop/pads `vol` according to `size`. + Padding will use `fill_value`. + Return's host/GPU array based on `vol`. -def crop_pad_3d(im, size, fill_value=0): - im_y, im_x, im_z = im.shape + :param vol: A >=3-dimensional numpy array + :param size: Integer size of cropped/padded output + :return: Array of shape (..., size, size, size) + """ + + vol_z, vol_y, vol_x = vol.shape[-3:] # shift terms - start_x = math.floor(im_x / 2) - math.floor(size / 2) - start_y = math.floor(im_y / 2) - math.floor(size / 2) - start_z = math.floor(im_z / 2) - math.floor(size / 2) + start_z = math.floor(vol_z / 2) - math.floor(size / 2) + start_y = math.floor(vol_y / 2) - math.floor(size / 2) + start_x = math.floor(vol_x / 2) - math.floor(size / 2) # cropping - if size <= min(im_y, im_x, im_z): - return im[ - start_y : start_y + size, start_x : start_x + size, start_z : start_z + size + if size <= min(vol_z, vol_y, vol_x): + return vol[ + ..., + start_z : start_z + size, + start_y : start_y + size, + start_x : start_x + size, ] # padding - elif size >= max(im_y, im_x, im_z): - to_return = fill_value * np.ones((size, size, size), dtype=im.dtype) + elif size >= max(vol_z, vol_y, vol_x): + # Determine shape + shape = list(vol.shape[:-3]) + shape.extend([size, size, size]) + + # Ensure that we return the same dtype as the input + _full = np.full # Default to numpy array + if isinstance(vol, xp.ndarray): + # Use cupy when `vol` _and_ xp are cupy ndarray + # Avoids having to handle when cupy is not installed + _full = xp.full + + to_return = _full(shape, fill_value, dtype=vol.dtype) + to_return[ - -start_y : im_y - start_y, - -start_x : im_x - start_x, - -start_z : im_z - start_z, - ] = im + ..., + -start_z : vol_z - start_z, + -start_y : vol_y - start_y, + -start_x : vol_x - start_x, + ] = vol return to_return else: - # target size is between min and max of (im_y, im_x, im_z) - raise ValueError("Cannot crop and pad a volume at the same time.") + # target size is between min and max of (vol_x, vol_y, vol_z) + raise ValueError( + "Cannot crop and pad Volume at the same time." + "If this is really what you intended," + " make two seperate calls for cropping and padding." + ) diff --git a/src/aspire/utils/matrix.py b/src/aspire/utils/matrix.py index 5e56d2e65e..71c709608b 100644 --- a/src/aspire/utils/matrix.py +++ b/src/aspire/utils/matrix.py @@ -434,11 +434,14 @@ def best_rank1_approximation(A): return (U @ S_rank1 @ V).reshape(og_shape) -def nearest_rotations(A): +def nearest_rotations(A, allow_reflection=False): """ Uses the SVD method to compute the set of nearest rotations to the set A of noisy rotations. + Note when `allow_reflection` is `True`, results may contain reflections. + :param A: A 2D array or a 3D array where the first axis is the stack axis. + :param allow_reflection: Optionally allow reflections (disables correction). :return: ndarray of rotations of equal size to A. """ og_shape = A.shape @@ -451,12 +454,16 @@ def nearest_rotations(A): f"Array must be of shape (3, 3) or (n, 3, 3). Found shape {A.shape}." ) - # For the singular value decomposition A = U @ S @ V, we compute the nearest rotation - # matrices R = U @ V. If det(U)*det(V) = -1, we negate the third singular value to ensure - # we have a rotation. + # For the singular value decomposition A = U @ S @ V, + # we compute the nearest rotation matrices R = U @ V. U, _, V = np.linalg.svd(A) - neg_det_idx = np.linalg.det(U) * np.linalg.det(V) < 0 - U[neg_det_idx] = U[neg_det_idx] @ np.diag((1, 1, -1)).astype(dtype, copy=False) + + if not allow_reflection: + # If det(U)*det(V) = -1, we negate the third singular value to + # ensure we have a rotation. + neg_det_idx = np.linalg.det(U) * np.linalg.det(V) < 0 + U[neg_det_idx] = U[neg_det_idx] @ np.diag((1, 1, -1)).astype(dtype, copy=False) + rots = U @ V return rots.reshape(og_shape) diff --git a/src/aspire/utils/misc.py b/src/aspire/utils/misc.py index f72efbaf14..d0c30b9f90 100644 --- a/src/aspire/utils/misc.py +++ b/src/aspire/utils/misc.py @@ -294,7 +294,8 @@ def fuzzy_mask(L, dtype, r0=None, risetime=None): if r0 is None: r0 = np.floor(0.45 * L[0]) if risetime is None: - risetime = np.floor(0.05 * L[0]) + # Guard against zero here for small L + risetime = max(np.floor(0.05 * L[0]), 1.0) dim = len(L) axes = ["x"] @@ -367,12 +368,12 @@ def J_conjugate(A): """ Conjugate the 3x3 matrix A by the diagonal matrix J=diag((-1, -1, 1)). - :param A: A 3x3 matrix. - :return: J*A*J + :param A: A 3x3 matrix, or nx3x3 matrix. + :return: J@A@J """ - J = np.array([[1, 1, -1], [1, 1, -1], [-1, -1, 1]], dtype=A.dtype) + JJop = np.array([[1, 1, -1], [1, 1, -1], [-1, -1, 1]], dtype=A.dtype) - return A * J + return A * JJop def cyclic_rotations(order, dtype=np.float64): diff --git a/src/aspire/utils/resolution_estimation.py b/src/aspire/utils/resolution_estimation.py index 14c142bdc5..af5a46f8a9 100644 --- a/src/aspire/utils/resolution_estimation.py +++ b/src/aspire/utils/resolution_estimation.py @@ -354,7 +354,8 @@ def plot(self, cutoff=None, save_to_file=False, labels=None): plt.ylabel("Correlation") plt.ylim([0, 1.1]) for i, line in enumerate(self.correlations): - _label = None + # Set default label for single correlation (required by plt.legend() below). + _label = "correlation" if len(self.correlations) > 1: _label = f"{i}" if labels is not None: diff --git a/src/aspire/utils/rotation.py b/src/aspire/utils/rotation.py index 08bec4ca3d..07a31df9df 100644 --- a/src/aspire/utils/rotation.py +++ b/src/aspire/utils/rotation.py @@ -408,6 +408,10 @@ def angle_dist(r1, r2, dtype=None): theta = (tr_r[non_zero_dist_ind] - 1) / 2 theta = np.maximum(np.minimum(theta, 1), -1) # Clamp theta in [-1,1] dist[non_zero_dist_ind] = np.arccos(theta, dtype=dtype) + + # Return scalar for single value. + if dist.size == 1: + dist = dist.flat[0] return dist @staticmethod diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index b6c100db36..0ed31b8b74 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -58,7 +58,7 @@ class Volume: Volume is an (N1 x ...) x L x L x L array, along with associated utility methods. """ - def __init__(self, data, dtype=None, symmetry_group=None): + def __init__(self, data, dtype=None, pixel_size=None, symmetry_group=None): """ A stack of one or more volumes. @@ -76,6 +76,10 @@ def __init__(self, data, dtype=None, symmetry_group=None): `(..., resolution, resolution, resolution)`. :param dtype: Optionally cast `data` to this dtype. Defaults to `data.dtype`. + :param pixel_size: Optional voxel_size in angstroms. + When provided will be saved with `map`/`mrc` metadata. + Default of `None` will not write to file, + but will be considered unit pixels (1) for FSC. :param symmetry_group: A SymmetryGroup instance or string indicating symmetry of the Volume. :return: A Volume instance holding `data`. @@ -107,6 +111,9 @@ def __init__(self, data, dtype=None, symmetry_group=None): self.n_vols = np.prod(self.stack_shape) self.resolution = self._data.shape[-1] self.size = self._data.size + self.pixel_size = None + if pixel_size is not None: + self.pixel_size = float(pixel_size) # Set symmetry_group. If None, default to 'C1'. self._set_symmetry_group(symmetry_group) @@ -140,7 +147,9 @@ def astype(self, dtype, copy=True): :return: Volume instance """ return self.__class__( - self.asnumpy().astype(dtype, copy=copy), symmetry_group=self.symmetry_group + self.asnumpy().astype(dtype, copy=copy), + pixel_size=self.pixel_size, + symmetry_group=self.symmetry_group, ) def _check_key_dims(self, key): @@ -151,7 +160,11 @@ def _check_key_dims(self, key): def __getitem__(self, key): self._check_key_dims(key) - return self.__class__(self._data[key], symmetry_group=self.symmetry_group) + return self.__class__( + self._data[key], + pixel_size=self.pixel_size, + symmetry_group=self.symmetry_group, + ) def __setitem__(self, key, value): self._check_key_dims(key) @@ -242,14 +255,19 @@ def stack_reshape(self, *args): return self.__class__( self._data.reshape(*shape, *self._data.shape[-3:]), + pixel_size=self.pixel_size, symmetry_group=self.symmetry_group, ) def __repr__(self): + px_msg = "." + if self.pixel_size is not None: + px_msg = f" with pixel_size={self.pixel_size} angstroms." + msg = ( f"{self.n_vols} {self.dtype} volumes arranged as a {self.stack_shape} stack" ) - msg += f" each of size {self.resolution}x{self.resolution}x{self.resolution}." + msg += f" each of size {self.resolution}x{self.resolution}x{self.resolution}{px_msg}" return msg def __len__(self): @@ -258,9 +276,15 @@ def __len__(self): def __add__(self, other): symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data + other.asnumpy(), symmetry_group=symmetry) + res = self.__class__( + self._data + other.asnumpy(), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) else: - res = self.__class__(self._data + other, symmetry_group=symmetry) + res = self.__class__( + self._data + other, pixel_size=self.pixel_size, symmetry_group=symmetry + ) return res @@ -270,21 +294,37 @@ def __radd__(self, otherL): def __sub__(self, other): symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data - other.asnumpy(), symmetry_group=symmetry) + res = self.__class__( + self._data - other.asnumpy(), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) else: - res = self.__class__(self._data - other, symmetry_group=symmetry) + res = self.__class__( + self._data - other, pixel_size=self.pixel_size, symmetry_group=symmetry + ) return res def __rsub__(self, otherL): - return self.__class__(otherL - self._data) + return self.__class__( + otherL - self._data, + pixel_size=self.pixel_size, + symmetry_group=self.symmetry_group, + ) def __mul__(self, other): symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data * other.asnumpy(), symmetry_group=symmetry) + res = self.__class__( + self._data * other.asnumpy(), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) else: - res = self.__class__(self._data * other, symmetry_group=symmetry) + res = self.__class__( + self._data * other, pixel_size=self.pixel_size, symmetry_group=symmetry + ) return res @@ -297,9 +337,15 @@ def __truediv__(self, other): """ symmetry = self._result_symmetry(other) if isinstance(other, Volume): - res = self.__class__(self._data / other.asnumpy(), symmetry_group=symmetry) + res = self.__class__( + self._data / other.asnumpy(), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) else: - res = self.__class__(self._data / other, symmetry_group=symmetry) + res = self.__class__( + self._data / other, pixel_size=self.pixel_size, symmetry_group=symmetry + ) return res @@ -307,9 +353,12 @@ def __rtruediv__(self, otherL): """ Right scalar division, follows numpy semantics. """ - return otherL * Volume(1.0 / self._data) + return otherL * Volume( + 1.0 / self._data, + pixel_size=self.pixel_size, + ) - def project(self, rot_matrices): + def project(self, rot_matrices, zero_nyquist=True): """ Using the stack of rot_matrices, project images of Volume. When projecting over a stack of volumes, a singleton Rotation or a Rotation with stack size @@ -318,6 +367,8 @@ def project(self, rot_matrices): and a Rotation stack, the i'th Volume will be projected using the i'th Rotation. :param rot_matrices: Stack of rotations. Rotation or ndarray instance. + :param zero_nyquist: Option to keep or remove Nyquist frequency for even resolution. + Defaults to zero_nyquist=True, removing the Nyquist frequency. :return: `Image` instance. """ # See Issue #727 @@ -342,13 +393,13 @@ def project(self, rot_matrices): if rot_matrices.ndim == 2: rot_matrices = np.expand_dims(rot_matrices, axis=0) - data = self._data + data = xp.asarray(self._data) n_rots = rot_matrices.shape[0] pts_rot = rotated_grids(self.resolution, rot_matrices) if n_rots == self.n_vols: # Apply rotations to Volumes element-wise. - im_f = np.empty( + im_f = xp.empty( (self.n_vols, self.resolution**2), dtype=complex_type(self.dtype) ) pts_rot = pts_rot.reshape((3, n_rots, self.resolution**2)) @@ -366,13 +417,13 @@ def project(self, rot_matrices): im_f = im_f.reshape(-1, self.resolution, self.resolution) - if self.resolution % 2 == 0: + # If resolution is even, optionally zero out the nyquist frequency. + if self.resolution % 2 == 0 and zero_nyquist is True: im_f[:, 0, :] = 0 im_f[:, :, 0] = 0 - im_f = xp.asnumpy(fft.centered_ifft2(xp.asarray(im_f))) - - return aspire.image.Image(np.real(im_f)) + im_f = fft.centered_ifft2(im_f) + return aspire.image.Image(xp.asnumpy(im_f.real), pixel_size=self.pixel_size) def to_vec(self): """Returns an N x resolution ** 3 array.""" @@ -416,7 +467,7 @@ def transpose(self): v = self._data.reshape(-1, *self._data.shape[-3:]) vt = np.transpose(v, (0, -1, -2, -3)) vt = vt.reshape(*original_stack_shape, *self._data.shape[-3:]) - return self.__class__(vt, symmetry_group=symmetry) + return self.__class__(vt, pixel_size=self.pixel_size, symmetry_group=symmetry) @property def T(self): @@ -459,35 +510,55 @@ def flip(self, axis=-3): f"Cannot flip axis {ax}: stack axis. Did you mean {ax-4}?" ) - return self.__class__(np.flip(self._data, axis), symmetry_group=symmetry) + return self.__class__( + np.flip(self._data, axis), + pixel_size=self.pixel_size, + symmetry_group=symmetry, + ) - def downsample(self, ds_res, mask=None): + def downsample(self, ds_res, mask=None, zero_nyquist=True): """ Downsample each volume to a desired resolution (only cubic supported). :param ds_res: Desired resolution. + :param zero_nyquist: Option to keep or remove Nyquist frequency for even resolution. + Defaults to zero_nyquist=True, removing the Nyquist frequency. :param mask: Optional NumPy array mask to multiply in Fourier space. """ - if mask is None: - mask = 1.0 original_stack_shape = self.stack_shape v = self.stack_reshape(-1) # take 3D Fourier transform of each volume in the stack - fx = fft.fftshift(fft.fftn(v._data, axes=(1, 2, 3))) + fx = fft.centered_fftn(xp.asarray(v._data)) + # crop each volume to the desired resolution in frequency space - crop_fx = ( - np.array([crop_pad_3d(fx[i, :, :, :], ds_res) for i in range(self.n_vols)]) - * mask - ) + fx = crop_pad_3d(fx, ds_res) + + # If downsample resolution is even, optionally zero out the nyquist frequency. + if ds_res % 2 == 0 and zero_nyquist is True: + fx[:, 0, :, :] = 0 + fx[:, :, 0, :] = 0 + fx[:, :, :, 0] = 0 + + # Optionally apply mask + if mask is not None: + fx = fx * xp.asarray(mask) + # inverse Fourier transform of each volume - out = fft.ifftn(fft.ifftshift(crop_fx), axes=(1, 2, 3)) * ( - ds_res**3 / self.resolution**3 - ) + out = fft.centered_ifftn(fx) + out = out.real * (ds_res**3 / self.resolution**3) + + # Optionally scale pixel size + ds_pixel_size = self.pixel_size + if ds_pixel_size is not None: + ds_pixel_size *= self.resolution / ds_res + # returns a new Volume object return self.__class__( - np.real(out), symmetry_group=self.symmetry_group + xp.asnumpy(out), + pixel_size=ds_pixel_size, + symmetry_group=self.symmetry_group, ).stack_reshape(original_stack_shape) def shift(self): @@ -549,7 +620,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): vol_f = vol_f.reshape(-1, self.resolution, self.resolution, self.resolution) - # If resolution is even, we zero out the nyquist frequency by default. + # If resolution is even, optionally zero out the nyquist frequency. if self.resolution % 2 == 0 and zero_nyquist is True: vol_f[:, 0, :, :] = 0 vol_f[:, :, 0, :] = 0 @@ -559,7 +630,7 @@ def rotate(self, rot_matrices, zero_nyquist=True): np.real(fft.centered_ifftn(xp.asarray(vol_f), axes=(-3, -2, -1))) ) - return self.__class__(vol, symmetry_group=symmetry) + return self.__class__(vol, pixel_size=self.pixel_size, symmetry_group=symmetry) def denoise(self): raise NotImplementedError @@ -580,6 +651,9 @@ def save(self, filename, overwrite=False): with mrcfile.new(filename, overwrite=overwrite) as mrc: mrc.set_data(self._data.astype(np.float32)) + # Note assigning voxel_size must come after `set_data` + if self.pixel_size is not None: + mrc.voxel_size = self.pixel_size if self.dtype != np.float32: logger.info(f"Volume with dtype {self.dtype} saved with dtype float32") @@ -598,20 +672,20 @@ def load(cls, filename, permissive=True, dtype=None, symmetry_group=None): :return: Volume instance. """ with mrcfile.open(filename, permissive=permissive) as mrc: - loaded_data = mrc.data - - # FINUFFT work around - if loaded_data.dtype == np.float32: - loaded_data = loaded_data.astype(np.float32) - elif loaded_data.dtype == np.float64: - loaded_data = loaded_data.astype(np.float64) + loaded_data = mrc.data.copy() # Allow mutation + pixel_size = Volume._vx_array_to_size(mrc.voxel_size) if loaded_data.dtype != dtype: logger.info(f"{filename} with dtype {loaded_data.dtype} loaded as {dtype}") - return cls(loaded_data, symmetry_group=symmetry_group, dtype=dtype) + return cls( + loaded_data, + pixel_size=pixel_size, + symmetry_group=symmetry_group, + dtype=dtype, + ) - def fsc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): + def fsc(self, other, cutoff=None, method="fft", plot=False): r""" Compute the Fourier shell correlation between two volumes. @@ -628,8 +702,6 @@ def fsc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): :param cutoff: Cutoff value, traditionally `.143`. Default `None` implies `cutoff=1` and excludes plotting cutoff line. - :param pixel_size: Pixel size in angstrom. Default `None` - implies unit in pixels, equivalent to pixel_size=1. :param method: Selects either 'fft' (on cartesian grid), or 'nufft' (on polar grid). Defaults to 'fft'. :param plot: Optionally plot to screen or file. @@ -649,7 +721,7 @@ def fsc(self, other, cutoff=None, pixel_size=None, method="fft", plot=False): fsc = FourierShellCorrelation( a=self.asnumpy(), b=other.asnumpy(), - pixel_size=pixel_size, + pixel_size=self.pixel_size, method=method, ) @@ -668,7 +740,7 @@ def empty_like(v): :param v: Volume instance :return: Volume instance """ - return Volume(np.empty(v.shape, dtype=v.dtype)) + return Volume(np.empty(v.shape, dtype=v.dtype), pixel_size=v.pixel_size) @staticmethod def zeros_like(v): @@ -678,7 +750,33 @@ def zeros_like(v): :param v: Volume instance :return: Volume instance """ - return Volume(np.zeros(v.shape, dtype=v.dtype)) + return Volume(np.zeros(v.shape, dtype=v.dtype), pixel_size=v.pixel_size) + + @staticmethod + def _vx_array_to_size(vx): + """ + Utility to convert from several possible `mrcfile.voxel_size` + representations to a single (float) value or None. + """ + + # Convert from recarray to single values, + # checks uniformity. + if isinstance(vx, np.recarray): + if vx.x != vx.y or vx.x != vx.z: + raise ValueError(f"Voxel sizes are not uniform: {vx}") + vx = vx.x + + # Convert `0` to `None` + if ( + isinstance(vx, int) or isinstance(vx, float) or isinstance(vx, np.ndarray) + ) and vx == 0: + vx = None + + # Consistently return a `float` when not None + if vx is not None: + vx = float(vx) + + return vx class CartesianVolume(Volume): diff --git a/src/aspire/volume/volume_synthesis.py b/src/aspire/volume/volume_synthesis.py index b9514df5ea..43f794bfaf 100644 --- a/src/aspire/volume/volume_synthesis.py +++ b/src/aspire/volume/volume_synthesis.py @@ -16,11 +16,12 @@ class SyntheticVolumeBase(abc.ABC): - def __init__(self, L, C, seed=None, dtype=np.float64): + def __init__(self, L, C, pixel_size=None, seed=None, dtype=np.float64): self.L = L self.C = C self.seed = seed self.dtype = dtype + self.pixel_size = pixel_size @abc.abstractmethod def generate(self): @@ -39,18 +40,24 @@ class GaussianBlobsVolume(SyntheticVolumeBase): A base class for all volumes which are generated with randomized 3D Gaussians. """ - def __init__(self, L, C, K=16, alpha=1, seed=None, dtype=np.float64): + def __init__( + self, L, C, K=16, alpha=1, pixel_size=None, seed=None, dtype=np.float64 + ): """ :param L: Resolution of the Volume(s) in pixels. :param C: Number of Volumes to generate. :param K: Number of Gaussian blobs used to construct the Volume(s). :param alpha: Scaling factor for variance of Gaussian blobs. Default=1. + :param pixel_size: Optional voxel_size in angstroms. + When provided will be saved with `map`/`mrc` metadata. + Default of `None` will not write to file, + but will be considered unit pixels (1) for FSC. :param seed: Random seed for generating random Gaussian blobs. :param dtype: dtype for Volume(s) """ self.K = int(K) self.alpha = float(alpha) - super().__init__(L=L, C=C, seed=seed, dtype=dtype) + super().__init__(L=L, C=C, pixel_size=pixel_size, seed=seed, dtype=dtype) self._set_symmetry_group() @abc.abstractproperty @@ -75,7 +82,11 @@ def generate(self): """ vol = self._gaussian_blob_vols() bump_mask = bump_3d(self.L, spread=5, dtype=self.dtype) - return Volume(bump_mask * vol, symmetry_group=self.symmetry_group) + return Volume( + bump_mask * vol, + symmetry_group=self.symmetry_group, + pixel_size=self.pixel_size, + ) def _gaussian_blob_vols(self): """ @@ -168,18 +179,26 @@ class CnSymmetricVolume(GaussianBlobsVolume): A Volume object with cyclically symmetric volumes constructed of random 3D Gaussian blobs. """ - def __init__(self, L, C, order, K=16, alpha=1, seed=None, dtype=np.float64): + def __init__( + self, L, C, order, K=16, alpha=1, pixel_size=None, seed=None, dtype=np.float64 + ): """ :param L: Resolution of the Volume(s) in pixels. :param C: Number of Volumes to generate. :param order: An integer representing the cyclic order of the Volume(s). :param K: Number of Gaussian blobs used to construct the Volume(s). + :param pixel_size: Optional voxel_size in angstroms. + When provided will be saved with `map`/`mrc` metadata. + Default of `None` will not write to file, + but will be considered unit pixels (1) for FSC. :param seed: Random seed for generating random Gaussian blobs. :param dtype: dtype for Volume(s) """ self.order = int(order) self._check_order() - super().__init__(L=L, C=C, K=K, alpha=alpha, seed=seed, dtype=dtype) + super().__init__( + L=L, C=C, K=K, alpha=alpha, pixel_size=pixel_size, seed=seed, dtype=dtype + ) def _check_order(self): if self.order < 2: @@ -239,8 +258,10 @@ class AsymmetricVolume(CnSymmetricVolume): An asymmetric Volume constructed of random 3D Gaussian blobs with compact support in the unit sphere. """ - def __init__(self, L, C, K=64, seed=None, dtype=np.float64): - super().__init__(L=L, C=C, K=K, order=1, seed=seed, dtype=dtype) + def __init__(self, L, C, K=64, pixel_size=None, seed=None, dtype=np.float64): + super().__init__( + L=L, C=C, K=K, order=1, pixel_size=pixel_size, seed=seed, dtype=dtype + ) def _check_order(self): if self.order != 1: @@ -260,8 +281,8 @@ class LegacyVolume(AsymmetricVolume): An asymmetric Volume object used for testing of legacy code. """ - def __init__(self, L, C=2, K=16, seed=0, dtype=np.float64): - super().__init__(L=L, C=C, K=K, seed=seed, dtype=dtype) + def __init__(self, L, C=2, K=16, pixel_size=None, seed=0, dtype=np.float64): + super().__init__(L=L, C=C, K=K, pixel_size=pixel_size, seed=seed, dtype=dtype) def generate(self): """ @@ -272,4 +293,4 @@ def generate(self): # Swap axes to retain Legacy xyz-indexing. vols = np.swapaxes(vols, 1, 3) - return Volume(vols) + return Volume(vols, pixel_size=self.pixel_size) diff --git a/tests/saved_test_data/rln_proj_64_centered.mrcs b/tests/saved_test_data/rln_proj_64_centered.mrcs new file mode 100644 index 0000000000..7dc22c9caf Binary files /dev/null and b/tests/saved_test_data/rln_proj_64_centered.mrcs differ diff --git a/tests/saved_test_data/rln_proj_64_centered.star b/tests/saved_test_data/rln_proj_64_centered.star new file mode 100644 index 0000000000..71a5cf8bd1 --- /dev/null +++ b/tests/saved_test_data/rln_proj_64_centered.star @@ -0,0 +1,32 @@ + +# version 30001 + +data_optics + +loop_ +_rlnOpticsGroup #1 +_rlnOpticsGroupName #2 +_rlnVoltage #3 +_rlnSphericalAberration #4 +_rlnImagePixelSize #5 +_rlnImageSize #6 +_rlnImageDimensionality #7 + 1 optics1 300.000000 2.700000 1.000000 64 2 + + +# version 30001 + +data_particles + +loop_ +_rlnAngleRot #1 +_rlnAngleTilt #2 +_rlnAnglePsi #3 +_rlnOriginXAngst #4 +_rlnOriginYAngst #5 +_rlnOpticsGroup #6 +_rlnImageName #7 + 235.820138 113.086030 50.468981 0.000000 0.000000 1 000001@rln_proj_64_centered.mrcs + 86.698555 31.958115 139.545228 0.000000 0.000000 1 000002@rln_proj_64_centered.mrcs + 48.456166 71.176316 185.304830 0.000000 0.000000 1 000003@rln_proj_64_centered.mrcs + 215.714386 105.017323 154.043384 0.000000 0.000000 1 000004@rln_proj_64_centered.mrcs diff --git a/tests/saved_test_data/rln_proj_64_shifted.mrcs b/tests/saved_test_data/rln_proj_64_shifted.mrcs new file mode 100644 index 0000000000..864c937bd2 Binary files /dev/null and b/tests/saved_test_data/rln_proj_64_shifted.mrcs differ diff --git a/tests/saved_test_data/rln_proj_64_shifted.star b/tests/saved_test_data/rln_proj_64_shifted.star new file mode 100644 index 0000000000..4be5add306 --- /dev/null +++ b/tests/saved_test_data/rln_proj_64_shifted.star @@ -0,0 +1,32 @@ + +# version 30001 + +data_optics + +loop_ +_rlnOpticsGroup #1 +_rlnOpticsGroupName #2 +_rlnVoltage #3 +_rlnSphericalAberration #4 +_rlnImagePixelSize #5 +_rlnImageSize #6 +_rlnImageDimensionality #7 + 1 optics1 300.000000 2.700000 1.000000 64 2 + + +# version 30001 + +data_particles + +loop_ +_rlnAngleRot #1 +_rlnAngleTilt #2 +_rlnAnglePsi #3 +_rlnOriginX #4 +_rlnOriginY #5 +_rlnOpticsGroup #6 +_rlnImageName #7 + 235.820138 113.086030 50.468981 6.000000 10.000000 1 000001@rln_proj_64_shifted.mrcs + 86.698555 31.958115 139.545228 10.000000 -5.000000 1 000002@rln_proj_64_shifted.mrcs + 48.456166 71.176316 185.304830 -8.000000 11.000000 1 000003@rln_proj_64_shifted.mrcs + 215.714386 105.017323 154.043384 -13.000000 -3.000000 1 000004@rln_proj_64_shifted.mrcs diff --git a/tests/saved_test_data/rln_proj_65_centered.mrcs b/tests/saved_test_data/rln_proj_65_centered.mrcs new file mode 100644 index 0000000000..57fc725464 Binary files /dev/null and b/tests/saved_test_data/rln_proj_65_centered.mrcs differ diff --git a/tests/saved_test_data/rln_proj_65_centered.star b/tests/saved_test_data/rln_proj_65_centered.star new file mode 100644 index 0000000000..1e105ca1dc --- /dev/null +++ b/tests/saved_test_data/rln_proj_65_centered.star @@ -0,0 +1,35 @@ + +# version 30001 + +data_optics + +loop_ +_rlnOpticsGroup #1 +_rlnOpticsGroupName #2 +_rlnVoltage #3 +_rlnSphericalAberration #4 +_rlnImagePixelSize #5 +_rlnImageSize #6 +_rlnImageDimensionality #7 + 1 optics1 300.000000 2.700000 1.000000 65 2 + + +# version 30001 + +data_particles + +loop_ +_rlnAngleRot #1 +_rlnAngleTilt #2 +_rlnAnglePsi #3 +_rlnOriginXAngst #4 +_rlnOriginYAngst #5 +_rlnOpticsGroup #6 +_rlnImageName #7 +_rlnOriginX #8 +_rlnOriginY #9 + 235.820138 113.086030 50.468981 0.000000 0.000000 1 000001@rln_proj_65_centered.mrcs 0.00000 0.000000 + 86.698555 31.958115 139.545228 0.000000 0.000000 1 000002@rln_proj_65_centered.mrcs 0.00000 0.000000 + 48.456166 71.176316 185.304830 0.000000 0.000000 1 000003@rln_proj_65_centered.mrcs 0.00000 0.000000 + 215.714386 105.017323 154.043384 0.000000 0.000000 1 000004@rln_proj_65_centered.mrcs 0.00000 0.000000 + diff --git a/tests/saved_test_data/rln_proj_65_shifted.mrcs b/tests/saved_test_data/rln_proj_65_shifted.mrcs new file mode 100644 index 0000000000..55ba195d48 Binary files /dev/null and b/tests/saved_test_data/rln_proj_65_shifted.mrcs differ diff --git a/tests/saved_test_data/rln_proj_65_shifted.star b/tests/saved_test_data/rln_proj_65_shifted.star new file mode 100644 index 0000000000..f36eb664cf --- /dev/null +++ b/tests/saved_test_data/rln_proj_65_shifted.star @@ -0,0 +1,32 @@ + +# version 30001 + +data_optics + +loop_ +_rlnOpticsGroup #1 +_rlnOpticsGroupName #2 +_rlnVoltage #3 +_rlnSphericalAberration #4 +_rlnImagePixelSize #5 +_rlnImageSize #6 +_rlnImageDimensionality #7 + 1 optics1 300.000000 2.700000 1.000000 65 2 + + +# version 30001 + +data_particles + +loop_ +_rlnAngleRot #1 +_rlnAngleTilt #2 +_rlnAnglePsi #3 +_rlnOriginX #4 +_rlnOriginY #5 +_rlnOpticsGroup #6 +_rlnImageName #7 + 235.820138 113.086030 50.468981 6.000000 10.000000 1 000001@rln_proj_65_shifted.mrcs + 86.698555 31.958115 139.545228 10.000000 -5.000000 1 000002@rln_proj_65_shifted.mrcs + 48.456166 71.176316 185.304830 -8.000000 11.000000 1 000003@rln_proj_65_shifted.mrcs + 215.714386 105.017323 154.043384 -13.000000 -3.000000 1 000004@rln_proj_65_shifted.mrcs diff --git a/tests/test_FFBbasis2D.py b/tests/test_FFBbasis2D.py index 8acf7201d1..c3ee42dd75 100644 --- a/tests/test_FFBbasis2D.py +++ b/tests/test_FFBbasis2D.py @@ -6,6 +6,7 @@ from scipy.special import jv from aspire.basis import Coef, FFBBasis2D +from aspire.nufft import all_backends from aspire.source import Simulation from aspire.utils.misc import grid_2d from aspire.volume import Volume @@ -126,6 +127,9 @@ def testShift(self, basis): params = [pytest.param(512, np.float32, marks=pytest.mark.expensive)] +@pytest.mark.skipif( + all_backends()[0] == "cufinufft", reason="Not enough memory to run via GPU" +) @pytest.mark.parametrize( "L, dtype", params, @@ -136,6 +140,7 @@ def testHighResFFBBasis2D(L, dtype): sim = Simulation( n=1, L=L, + C=1, dtype=dtype, amplitudes=1, offsets=0, @@ -149,4 +154,6 @@ def testHighResFFBBasis2D(L, dtype): # Mask to compare inside disk of radius 1. mask = grid_2d(L, normalized=True)["r"] < 1 - assert np.allclose(im_ffb.asnumpy()[0][mask], im.asnumpy()[0][mask], atol=1e-4) + np.testing.assert_allclose( + im_ffb.asnumpy()[0][mask], im.asnumpy()[0][mask], rtol=1e-05, atol=1e-4 + ) diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py index 7d6b3f5c47..94ea9c4316 100644 --- a/tests/test_FLEbasis2D.py +++ b/tests/test_FLEbasis2D.py @@ -1,4 +1,5 @@ import os +import platform import sys import numpy as np @@ -70,8 +71,12 @@ def relerr(base, approx): @pytest.mark.parametrize("basis", test_bases, ids=show_fle_params) class TestFLEBasis2D(UniversalBasisMixin): - # Loosen the tolerance for `cufinufft` to be within 15% - test_eps = 1.15 if backend_available("cufinufft") else 1.0 + # Loosen the tolerance for `cufinufft` and `osx_arm` + test_eps = 1.0 + if backend_available("cufinufft"): + test_eps = 1.15 + elif platform.system() == "Darwin": + test_eps = 1.30 # check closeness guarantees for fast vs dense matrix method def testFastVDense_T(self, basis): @@ -142,7 +147,7 @@ def testMatchFBEvaluate(basis): fb_images = fb_basis.evaluate(coefs) fle_images = basis.evaluate(coefs) - assert np.allclose(fb_images._data, fle_images._data, atol=1e-4) + np.testing.assert_allclose(fb_images._data, fle_images._data, atol=1e-4) @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) @@ -159,8 +164,8 @@ def testMatchFBDenseEvaluate(basis): fle_images = Image(fle_out.T.reshape(-1, basis.nres, basis.nres)).asnumpy() # Matrix column reording in match_fb mode flips signs of some of the basis functions - assert np.allclose(np.abs(fb_images), np.abs(fle_images), atol=1e-3) - assert np.allclose(fb_images, fle_images, atol=1e-3) + np.testing.assert_allclose(np.abs(fb_images), np.abs(fle_images), atol=1e-3) + np.testing.assert_allclose(fb_images, fle_images, atol=1e-3) @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) @@ -177,7 +182,7 @@ def testMatchFBEvaluate_t(basis): fb_coefs = fb_basis.evaluate_t(images) fle_coefs = basis.evaluate_t(images) - assert np.allclose(fb_coefs, fle_coefs, atol=1e-4) + np.testing.assert_allclose(fb_coefs, fle_coefs, atol=1e-4) @pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) @@ -197,7 +202,7 @@ def testMatchFBDenseEvaluate_t(basis): fle_coefs = basis._create_dense_matrix().T @ vec.T # Matrix column reording in match_fb mode flips signs of some of the basis coefficients - assert np.allclose(np.abs(fb_coefs), np.abs(fle_coefs), atol=1e-4) + np.testing.assert_allclose(np.abs(fb_coefs), np.abs(fle_coefs), atol=1e-4) def testLowPass(): @@ -230,22 +235,23 @@ def testLowPass(): def testRadialConvolution(): # test ability to accurately convolve with a radial # (e.g. CTF) function via FLE coefficients - L = 32 - basis = FLEBasis2D(L, match_fb=False) + # load test radial function x = np.load(os.path.join(DATA_DIR, "fle_radial_fn_32x32.npy")).reshape(1, 32, 32) x = x / np.max(np.abs(x.flatten())) # get sample images ims = create_images(L, 10) + # convolve using coefficients + basis = FLEBasis2D(L, match_fb=False, dtype=ims.dtype) coefs = basis.evaluate_t(ims) coefs_convolved = basis.radial_convolve(coefs, x) imgs_convolved_fle = basis.evaluate(coefs_convolved).asnumpy() # convolve using FFT - x = basis.evaluate(basis.evaluate_t(x)).asnumpy() + x = basis.evaluate(basis.evaluate_t(Image(x))).asnumpy() ims = basis.evaluate(coefs).asnumpy() imgs_convolved_slow = np.zeros((10, L, L)) @@ -265,4 +271,4 @@ def testRadialConvolution(): convolution_fft_pad[L // 2 : L // 2 + L, L // 2 : L // 2 + L] ) - assert np.allclose(imgs_convolved_fle, imgs_convolved_slow, atol=1e-5) + np.testing.assert_allclose(imgs_convolved_fle, imgs_convolved_slow, atol=1e-5) diff --git a/tests/test_anisotropic_noise.py b/tests/test_anisotropic_noise.py index 12b52064ea..caaedc4aff 100644 --- a/tests/test_anisotropic_noise.py +++ b/tests/test_anisotropic_noise.py @@ -20,11 +20,17 @@ def setUp(self): n=1024, vols=self.vol, unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + # Set legacy pixel size + RadialCTFFilter(pixel_size=10, defocus=d) + for d in np.linspace(1.5e4, 2.5e4, 7) ], dtype=self.dtype, ) + # Keep hardcoded tests passing after fixing swapped offsets. + # See github issue #1146. + self.sim = self.sim.update(offsets=self.sim.offsets[:, [1, 0]]) + def tearDown(self): pass diff --git a/tests/test_basis_utils.py b/tests/test_basis_utils.py index 2d6a3efdb2..2e416a5dee 100644 --- a/tests/test_basis_utils.py +++ b/tests/test_basis_utils.py @@ -1,6 +1,7 @@ from unittest import TestCase import numpy as np +from scipy.special import sph_harm as sp_sph_harm from aspire.basis.basis_utils import ( all_besselj_zeros, @@ -9,10 +10,61 @@ norm_assoc_legendre, real_sph_harmonic, sph_bessel, + sph_harm, unique_coords_nd, ) +def test_sph_harm_low_order(): + """ + Test the `sph_harm` implementation matches `scipy` at lower orders. + """ + m = 3 + j = 5 + x = np.linspace(0, np.pi, 42) + y = np.linspace(0, 2 * np.pi, 42) + + ref = sp_sph_harm(m, j, y, x) # Note calling convention is different + np.testing.assert_allclose(sph_harm(j, m, x, y), ref) + + # negative m + m *= -1 + ref = sp_sph_harm(m, j, y, x) # Note calling convention is different + np.testing.assert_allclose(sph_harm(j, m, x, y), ref) + + +def test_sph_harm_high_order(): + """ + Test we remain finite at higher orders where `scipy.special.sph_harm` overflows. + """ + m = 87 + j = 87 + x = 0.12345 + y = 0.56789 + + # If scipy fixed their implementation for higher orders in the future, + # this check should fail and we can reconsider that package. + ref = sp_sph_harm(m, j, y, x) # Note calling convention is different + assert not np.isfinite(ref) + + # Can manually check against pyshtools, + # but we are avoiding that package dependency. + # Leaving this here intentionally for future developers. + # y = spharm_lm( + # j, + # abs_m, + # theta, + # phi, + # kind="complex", + # degrees=False, + # csphase=-1, + # normalization="ortho", + # ) + + # Check we are finite. + assert np.isfinite(sph_harm(j, m, x, y)) + + class BesselTestCase(TestCase): def setUp(self): pass diff --git a/tests/test_class_src.py b/tests/test_class_src.py index 0c169621cd..e395aecc8e 100644 --- a/tests/test_class_src.py +++ b/tests/test_class_src.py @@ -128,7 +128,14 @@ def class_sim_fixture(dtype, img_size): # Note using a single volume via C=1 is critical to matching # alignment without the complexity of remapping via states etc. src = Simulation( - L=img_size, n=n, vols=v, offsets=0, amplitudes=1, C=1, angles=true_rots.angles + L=img_size, + n=n, + vols=v, + offsets=0, + amplitudes=1, + C=1, + angles=true_rots.angles, + symmetry_group="C4", # For testing symmetry_group pass-through. ) # Prefetch all the images src = src.cache() @@ -193,6 +200,9 @@ class averages. k = len(src2.class_indices) np.testing.assert_equal(src2.class_indices, test_src.class_indices[::3][:k]) + # Check symmetry_group pass-through. + assert test_src.symmetry_group == class_sim_fixture.symmetry_group + # Test the _HeapItem helper class def test_heap_helper(): diff --git a/tests/test_coef.py b/tests/test_coef.py index ab546d9bac..3ace0ddec5 100644 --- a/tests/test_coef.py +++ b/tests/test_coef.py @@ -48,7 +48,7 @@ def dtype(request): return request.param -@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +@pytest.fixture(params=DTYPES, ids=lambda x: f"basis_dtype={x}", scope="module") def basis_dtype(request): """ Dtypes for basis @@ -416,11 +416,16 @@ def test_shifts(coef_fixture, basis, rots): shifts = np.column_stack((rots, rots[::-1])) # Compare + min_dtype = ( + np.float32 + if (basis.dtype == np.float32 or coef_fixture.dtype == np.float32) + else np.float64 + ) np.testing.assert_allclose( coef_fixture.shift(shifts), basis.shift(coef_fixture, shifts), rtol=1e-05, - atol=utest_tolerance(basis.dtype), + atol=utest_tolerance(min_dtype), ) diff --git a/tests/test_commonline_sync3n.py b/tests/test_commonline_sync3n.py new file mode 100644 index 0000000000..600f883c2d --- /dev/null +++ b/tests/test_commonline_sync3n.py @@ -0,0 +1,138 @@ +import copy +import os + +import numpy as np +import pytest + +from aspire.abinitio import CLSync3N +from aspire.source import Simulation +from aspire.utils import mean_aligned_angular_distance, rots_to_clmatrix +from aspire.volume import AsymmetricVolume + +DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") + +RESOLUTION = [ + 40, + pytest.param(41, marks=pytest.mark.expensive), +] + +OFFSETS = [ + 0, + pytest.param(None, marks=pytest.mark.expensive), +] + +DTYPES = [ + np.float32, + pytest.param(np.float64, marks=pytest.mark.expensive), +] + + +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") +def resolution(request): + return request.param + + +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}", scope="module") +def offsets(request): + return request.param + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def source_orientation_objs(resolution, offsets, dtype): + src = Simulation( + n=100, + L=resolution, + vols=AsymmetricVolume( + L=resolution, C=1, K=100, seed=123, dtype=dtype + ).generate(), + offsets=offsets, + amplitudes=1, + seed=456, + ).cache() + + # Search for common lines over less shifts for 0 offsets. + max_shift = 1 / resolution + shift_step = 1 + if src.offsets.all() != 0: + max_shift = 0.20 + shift_step = 0.25 # Reduce shift steps for non-integer offsets of Simulation. + + orient_est = CLSync3N(src, max_shift=max_shift, shift_step=shift_step, seed=789) + + # Estimate rotations once for all tests. + orient_est.estimate_rotations() + + return src, orient_est + + +def test_build_clmatrix(source_orientation_objs): + src, orient_est = source_orientation_objs + + gt_clmatrix = rots_to_clmatrix(src.rotations, orient_est.n_theta) + + angle_diffs = abs(orient_est.clmatrix - gt_clmatrix) * 360 / orient_est.n_theta + + # Count number of estimates within 5 degrees of ground truth. + within_5 = np.sum((angle_diffs - 360) % 360 < 5) + + # Check that at least 98% of estimates are within 5 degrees. + tol = 0.98 + if src.offsets.all() != 0: + # Set tolerance to 75% when using nonzero offsets. + tol = 0.75 + assert within_5 / angle_diffs.size > tol + + +def test_estimate_shifts_with_gt_rots(source_orientation_objs): + src, orient_est = source_orientation_objs + + # Assign ground truth rotations. + # Deep copy to prevent altering for other tests. + orient_est = copy.deepcopy(orient_est) + orient_est.rotations = src.rotations + + # Estimate shifts using ground truth rotations. + est_shifts = orient_est.estimate_shifts() + + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.8 pix) to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_dist, 0.8) + else: + np.testing.assert_allclose(mean_dist, 0) + + +def test_estimate_shifts_with_est_rots(source_orientation_objs): + src, orient_est = source_orientation_objs + # Estimate shifts using estimated rotations. + est_shifts = orient_est.estimate_shifts() + + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.8 pix) to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_dist, 0.8) + else: + np.testing.assert_allclose(mean_dist, 0) + + +def test_estimate_rotations(source_orientation_objs): + src, orient_est = source_orientation_objs + + # Register estimates to ground truth rotations and compute the + # mean angular distance between them (in degrees). + # Assert that mean angular distance is less than 1 degree (4 with offsets). + tol = 1 + if src.offsets.all() != 0: + tol = 4 + mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=tol) diff --git a/tests/test_commonline_sync3n_cupy.py b/tests/test_commonline_sync3n_cupy.py new file mode 100644 index 0000000000..9bea14c21f --- /dev/null +++ b/tests/test_commonline_sync3n_cupy.py @@ -0,0 +1,233 @@ +import numpy as np +import pytest + +from aspire.abinitio.commonline_sync3n import CLSync3N +from aspire.source import Simulation + +# If cupy is not available, skip this entire module +pytest.importorskip("cupy") + + +N = 32 # Number of images +n_pairs = N * (N - 1) // 2 +DTYPES = [np.float32, np.float64] + + +@pytest.fixture(scope="module", params=DTYPES, ids=lambda x: f"dtype={x}") +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def src_fixture(dtype): + src = Simulation(n=N, L=32, C=1, dtype=dtype) + src = src.cache() + return src + + +@pytest.fixture(scope="module") +def cl3n_fixture(src_fixture): + cl = CLSync3N(src_fixture) + return cl + + +@pytest.fixture(scope="module") +def rijs_fixture(dtype): + Rijs = np.arange(n_pairs * 3 * 3, dtype=dtype).reshape(n_pairs, 3, 3) + return Rijs + + +def test_pairs_prob_host_vs_cupy(cl3n_fixture, rijs_fixture): + """ + Compares pairs_probabilities between host and cupy implementations. + """ + + P2, A, a, B, b, x0 = 1, 2, 3, 4, 5, 6 + + # DTYPE is critical here (manually calling private method + params = np.array([P2, A, a, B, b, x0], dtype=np.float64) + + # Execute CUPY + indscp, arbcp = cl3n_fixture._pairs_probabilities_cupy(rijs_fixture, *params) + + # Execute host + indsh, arbh = cl3n_fixture._pairs_probabilities_host(rijs_fixture, *params) + + # Compare host to cupy calls + rtol = 1e-07 # np testing default + if rijs_fixture.dtype != np.float64: + rtol = 2e-5 + np.testing.assert_allclose(indsh, indscp, rtol=rtol) + np.testing.assert_allclose(arbh, arbcp, rtol=rtol) + + +def test_triangle_scores_host_vs_cupy(cl3n_fixture, rijs_fixture): + """ + Compares triangle_scores between host and cupy implementations. + """ + + # Execute CUPY + hist_cp = cl3n_fixture._triangle_scores_inner_cupy(rijs_fixture) + + # Execute host + hist_h = cl3n_fixture._triangle_scores_inner_host(rijs_fixture) + + # Compare host to cupy calls + np.testing.assert_allclose(hist_cp, hist_h) + + +def test_stv_host_vs_cupy(cl3n_fixture, rijs_fixture): + """ + Compares signs_times_v between host and cupy implementations. + + Default J_weighting=False + """ + # dummy data vector + vec = np.ones(n_pairs, dtype=rijs_fixture.dtype) + + # J_weighting=False + assert cl3n_fixture.J_weighting is False + + # Execute CUPY + new_vec_cp = cl3n_fixture._signs_times_v_cupy(rijs_fixture, vec) + + # Execute host + new_vec_h = cl3n_fixture._signs_times_v_host(rijs_fixture, vec) + + # Compare host to cupy calls + np.testing.assert_allclose(new_vec_cp, new_vec_h) + + +def test_stvJwt_host_vs_cupy(cl3n_fixture, rijs_fixture): + """ + Compares signs_times_v between host and cupy implementations. + + Force J_weighting=True + """ + # dummy data vector + vec = np.ones(n_pairs, dtype=rijs_fixture.dtype) + + # J_weighting=True + cl3n_fixture.J_weighting = True + + # Execute CUPY + new_vec_cp = cl3n_fixture._signs_times_v_cupy(rijs_fixture, vec) + + # Execute host + new_vec_h = cl3n_fixture._signs_times_v_host(rijs_fixture, vec) + + # Compare host to cupy calls + rtol = 1e-7 # np testing default + if vec.dtype != np.float64: + rtol = 3e-07 + np.testing.assert_allclose(new_vec_cp, new_vec_h, rtol=rtol) + + +# The following fixture and tests compare against the legacy MATLAB implementation + + +@pytest.fixture +def matlab_ref_fixture(): + """ + Setup ASPIRE-Python objects using dummy data that is easily + constructed in MATLAB. + """ + DTYPE = np.float64 # MATLAB code is doubles only + n = 5 + n_pairs = n * (n - 1) // 2 + + # Dummy input vector. + Rijs = np.transpose( + np.arange(1, n_pairs * 3 * 3 + 1, dtype=DTYPE).reshape(n_pairs, 3, 3), (0, 2, 1) + ) + # Equivalent MATLAB + # n=5; np=n*(n-1)/2; rijs= reshape([1:np*3*3],[3,3,np]) + + # Create CL object for testing function calls + src = Simulation(L=8, n=n, C=1, dtype=DTYPE) + cl3n = CLSync3N(src, seed=314, S_weighting=False, J_weighting=False) + + return Rijs, cl3n + + +def test_triangles_scores(matlab_ref_fixture): + """ + Compares output of identical dummy data between this + implementation and legacy MATLAB triangles_scores_mex. + """ + Rijs, cl3n = matlab_ref_fixture + + hist = cl3n._triangle_scores_inner(Rijs) + + # Default is 100 histogram intervals, + # so the histogram reference is compressed. + ref_hist = np.zeros(cl3n.hist_intervals) + # Nonzeros, [[indices, ...], [values, ...]] + ref_compressed = np.array( + [[0, 10, 11, 12, 70, 71, 72, 76, 81, 89], [14, 2, 2, 2, 1, 1, 2, 1, 2, 3]] + ) + # Pack the reference histogram + np.put(ref_hist, *ref_compressed) + + np.testing.assert_allclose(hist, ref_hist) + + +def test_pairs_prob_mex(matlab_ref_fixture): + """ + Compares output of identical dummy data between this + implementation and legacy MATLAB pairs_probabilities_mex. + """ + Rijs, cl3n = matlab_ref_fixture + + params = np.arange(1, 7) + + ln_f_ind, ln_f_arb = cl3n._pairs_probabilities_host(Rijs, *params) + + ref_ln_f_ind = [ + -24.1817, + -5.6554, + 4.9117, + 12.7047, + -12.9374, + -5.5158, + 1.5289, + -9.0406, + -2.2067, + -7.3968, + ] + + ref_ln_f_arb = [ + -17.1264, + -6.7218, + -0.8876, + 3.3437, + -10.7251, + -6.7051, + -2.9029, + -8.5061, + -4.8288, + -7.5608, + ] + + np.testing.assert_allclose(ln_f_arb, ref_ln_f_arb, atol=5e-5) + + np.testing.assert_allclose(ln_f_ind, ref_ln_f_ind, atol=5e-5) + + +def test_signs_times_v_mex(matlab_ref_fixture): + """ + Compares output of identical dummy data between this + implementation and legacy MATLAB signs_times_v. + """ + Rijs, cl3n = matlab_ref_fixture + + # Dummy input vector + vec = np.ones(len(Rijs), dtype=Rijs.dtype) + # Equivalent matlab + # vec=ones([1,np]); + + new_vec = cl3n._signs_times_v(Rijs, vec) + + ref_vec = [0, -2, -2, 0, -6, -4, -2, -2, -2, 0] + + np.testing.assert_allclose(new_vec, ref_vec) diff --git a/tests/test_covar2d.py b/tests/test_covar2d.py index 05e0eda509..6ec5bf0b14 100644 --- a/tests/test_covar2d.py +++ b/tests/test_covar2d.py @@ -56,10 +56,10 @@ def img_size(request): def volume(dtype, img_size): # Get a volume v = Volume( - np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype(dtype) + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy")).astype(dtype) ) # 1e3 is hardcoded to match legacy test files. - return v.downsample(img_size) * 1.0e3 + return v * 1.0e3 @pytest.fixture(params=BASIS, ids=lambda x: f"basis={x}") diff --git a/tests/test_covar2d_denoiser.py b/tests/test_covar2d_denoiser.py index a403a72109..ea5410fb34 100644 --- a/tests/test_covar2d_denoiser.py +++ b/tests/test_covar2d_denoiser.py @@ -6,6 +6,7 @@ from aspire.noise import WhiteNoiseAdder from aspire.operators import IdentityFilter, RadialCTFFilter from aspire.source import Simulation +from aspire.utils import utest_tolerance # TODO, parameterize these further. dtype = np.float32 @@ -17,12 +18,28 @@ RadialCTFFilter(5, 200, defocus=d, Cs=2.0, alpha=0.1) for d in np.linspace(1.5e4, 2.5e4, 7) ] + +# For (F)PSWFBasis2D we get off-block entries which are truncated +# when converting to block-diagonal. We filter these warnings. BASIS = [ pytest.param(FBBasis2D, marks=pytest.mark.expensive), FFBBasis2D, FLEBasis2D, - pytest.param(PSWFBasis2D, marks=pytest.mark.expensive), - FPSWFBasis2D, + pytest.param( + PSWFBasis2D, + marks=[ + pytest.mark.expensive, + pytest.mark.filterwarnings( + "ignore:BlkDiagMatrix.from_dense truncating values*" + ), + ], + ), + pytest.param( + FPSWFBasis2D, + marks=pytest.mark.filterwarnings( + "ignore:BlkDiagMatrix.from_dense truncating values*" + ), + ), ] @@ -89,7 +106,9 @@ def test_batched_rotcov2d_MSE(sim, basis): # Additionally test the `DenoisedSource` and lazy-eval-cache # of the cov2d estimator. src = DenoisedSource(sim, denoiser) - np.testing.assert_allclose(imgs_denoised, src.images[:], rtol=1e-05, atol=1e-08) + np.testing.assert_allclose( + imgs_denoised, src.images[:], rtol=1e-05, atol=utest_tolerance(src.dtype) + ) def test_source_mismatch(sim, basis): diff --git a/tests/test_diag_matrix.py b/tests/test_diag_matrix.py index ecce899105..05805912c9 100644 --- a/tests/test_diag_matrix.py +++ b/tests/test_diag_matrix.py @@ -77,7 +77,7 @@ def test_repr(): Test accessing the `repr` does not crash. """ - d = DiagMatrix(np.empty((10, 8))) + d = DiagMatrix(np.ones((10, 8))) assert repr(d).startswith("DiagMatrix(") @@ -86,7 +86,7 @@ def test_str(): Test accessing the `str` does not crash. """ - d = DiagMatrix(np.empty((10, 8))) + d = DiagMatrix(np.ones((10, 8))) assert str(d).startswith("DiagMatrix(") @@ -104,13 +104,13 @@ def test_len(): """ Test the `len`. """ - d = DiagMatrix(np.empty((10, 8))) + d = DiagMatrix(np.ones((10, 8))) assert d.size == 10 assert d.count == 8 assert len(d) == 10 - d = DiagMatrix(np.empty((2, 5, 8))) + d = DiagMatrix(np.ones((2, 5, 8))) assert d.size == 10 assert d.count == 8 @@ -121,8 +121,8 @@ def test_size_mismatch(): """ Test we raise operating on `DiagMatrix` having different counts. """ - d1 = DiagMatrix(np.empty((10, 8))) - d2 = DiagMatrix(np.empty((10, 7))) + d1 = DiagMatrix(np.ones((10, 8))) + d2 = DiagMatrix(np.ones((10, 7))) with pytest.raises(RuntimeError, match=r".*not same dimension.*"): _ = d1 + d2 @@ -132,8 +132,8 @@ def test_dtype_mismatch(): """ Test we raise operating on `DiagMatrix` having different dtypes. """ - d1 = DiagMatrix(np.empty((10, 8)), dtype=np.float32) - d2 = DiagMatrix(np.empty((10, 8)), dtype=np.float64) + d1 = DiagMatrix(np.ones((10, 8)), dtype=np.float32) + d2 = DiagMatrix(np.ones((10, 8)), dtype=np.float64) with pytest.raises(RuntimeError, match=r".*received different types.*"): _ = d1 + d2 @@ -144,7 +144,7 @@ def test_dtype_passthrough(): Test that the datatype is inferred correctly. """ for dtype in (int, np.float32, np.float64, np.complex64, np.complex128): - d_np = np.empty(42, dtype=dtype) + d_np = np.ones(42, dtype=dtype) d = DiagMatrix(d_np) assert d.dtype == dtype @@ -154,7 +154,7 @@ def test_dtype_cast(): Test that a datatype is cast when overridden. """ for dtype in (int, np.float32, np.float64, np.complex64, np.complex128): - d_np = np.empty(42, dtype=np.float16) + d_np = np.ones(42, dtype=np.float16) d = DiagMatrix(d_np, dtype) assert d.dtype == dtype @@ -444,7 +444,7 @@ def test_diag_badtype_matmul(): """ Test matrix multiply of `DiagMatrix` with incompatible type raises. """ - d1 = DiagMatrix(np.empty(8)) + d1 = DiagMatrix(np.ones(8)) # matmul with pytest.raises(RuntimeError, match=r".*not implemented for.*"): @@ -576,7 +576,7 @@ def test_bad_as_blk_diag(matrix_size, blk_diag): """ with pytest.raises(RuntimeError, match=r".*only implemented for singletons.*"): # Construct via Numpy. - d_np = np.empty((2, matrix_size), dtype=blk_diag.dtype) + d_np = np.ones((2, matrix_size), dtype=blk_diag.dtype) # Create DiagMatrix then convert to BlkDiagMatrix d = DiagMatrix(d_np) @@ -654,7 +654,7 @@ def test_diag_blk_mul(): """ Test mixing `BlkDiagMatrix` with `DiagMatrix` element-wise multiplication raises. """ - d = DiagMatrix(np.empty(8)) + d = DiagMatrix(np.ones(8)) partition = [(4, 4), (4, 4)] b = BlkDiagMatrix.ones(partition, dtype=d.dtype) @@ -672,7 +672,7 @@ def test_non_square_as_blk_diag(): """ Test non square partition blocks raise an error in as_blk_diag. """ - d = DiagMatrix(np.empty(8)) + d = DiagMatrix(np.ones(8)) partition = [(4, 5), (4, 3)] with pytest.raises(RuntimeError, match=r".*not square.*"): @@ -683,8 +683,8 @@ def test_bad_broadcast(): """ Test incompatible stack shapes raise appropriate error. """ - d1 = DiagMatrix(np.empty((2, 3, 8))) - d2 = DiagMatrix(np.empty((2, 2, 8))) + d1 = DiagMatrix(np.ones((2, 3, 8))) + d2 = DiagMatrix(np.ones((2, 2, 8))) with pytest.raises(ValueError, match=r".*incompatible shapes.*"): _ = d1 + d2 diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 13fad279df..4c990212ce 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from aspire.downloader import emdb_2660 from aspire.image import Image from aspire.source import Simulation from aspire.utils import utest_tolerance @@ -89,6 +90,9 @@ def test_downsample_2d_case(L, L_ds): assert (N, L_ds, L_ds) == imgs_ds.shape # check center points for all images assert checkCenterPoint(imgs_org, imgs_ds) + # Confirm default `pixel_size` + assert imgs_org.pixel_size is None + assert imgs_ds.pixel_size is None @pytest.mark.parametrize("L", [65, 66]) @@ -102,8 +106,80 @@ def test_downsample_3d_case(L, L_ds): assert checkCenterPoint(vols_org, vols_ds) # check signal energy is conserved assert checkSignalEnergy(vols_org, vols_ds) + # Confirm default `pixel_size` + assert vols_org.pixel_size is None + assert vols_ds.pixel_size is None def test_integer_offsets(): sim = Simulation(offsets=0) _ = sim.downsample(3) + + +# Test that vol.downsample.project == vol.project.downsample. +DTYPES = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] +RES = [65, 66] +RES_DS = [32, 33] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(params=RES, ids=lambda x: f"resolution={x}", scope="module") +def res(request): + return request.param + + +@pytest.fixture(params=RES_DS, ids=lambda x: f"resolution_ds={x}", scope="module") +def res_ds(request): + return request.param + + +@pytest.fixture(scope="module") +def emdb_vol(): + return emdb_2660() + + +@pytest.fixture(scope="module") +def volume(emdb_vol, res, dtype): + vol = emdb_vol.astype(dtype, copy=False) + vol = vol.downsample(res) + return vol + + +def test_downsample_project(volume, res_ds): + """ + Test that vol.downsample.project == vol.project.downsample. + """ + rot = np.eye(3, dtype=volume.dtype) # project along z-axis + im_ds_proj = volume.downsample(res_ds).project(rot) + im_proj_ds = volume.project(rot).downsample(res_ds) + + tol = 1e-07 + if volume.dtype == np.float64: + tol = 1e-09 + np.testing.assert_allclose(im_ds_proj, im_proj_ds, atol=tol) + + +def test_pixel_size(): + """ + Test downsampling is rescaling the `pixel_size` attribute. + """ + # Image sizes in pixels + L = 8 # original + dsL = 5 # downsampled + + # Construct a small test Image + img = Image(np.random.random((1, L, L)).astype(DTYPE, copy=False), pixel_size=1.23) + + # Downsample the image + result = img.downsample(dsL) + + # Confirm the pixel size is scaled + np.testing.assert_approx_equal( + result.pixel_size, + img.pixel_size * L / dsL, + err_msg="Incorrect pixel size.", + ) diff --git a/tests/test_filters.py b/tests/test_filters.py index 35d7955a9e..b0b23bb74f 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -118,7 +118,8 @@ def testRadialCTFFilter(self): self.assertEqual(result.shape, (256,)) def testRadialCTFFilterGrid(self): - filter = RadialCTFFilter(defocus=2.5e4) + # Set legacy pixel size + filter = RadialCTFFilter(pixel_size=10, defocus=2.5e4) result = filter.evaluate_grid(8, dtype=self.dtype) self.assertEqual(result.shape, (8, 8)) @@ -218,7 +219,10 @@ def testRadialCTFFilterGrid(self): ) def testRadialCTFFilterMultiplierGrid(self): - filter = RadialCTFFilter(defocus=2.5e4) * RadialCTFFilter(defocus=2.5e4) + # Set legacy pixel size + filter = RadialCTFFilter(pixel_size=10, defocus=2.5e4) * RadialCTFFilter( + pixel_size=10, defocus=2.5e4 + ) result = filter.evaluate_grid(8, dtype=self.dtype) self.assertEqual(result.shape, (8, 8)) @@ -332,20 +336,36 @@ def testFilterSigns(self): self.assertTrue(np.allclose(sign_filter.evaluate(self.omega), signs)) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_power_filter_safeguard(dtype, caplog): +DTYPES = [np.float32, np.float64] +EPS = [None, 0.01] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(params=EPS, ids=lambda x: f"epsilon={x}", scope="module") +def epsilon(request): + return request.param + + +def test_power_filter_safeguard(dtype, epsilon, caplog): L = 25 arr = np.ones((L, L), dtype=dtype) # Set a few values below machine epsilon. num_eps = 3 - eps = np.finfo(dtype).eps + eps = epsilon + if eps is None: + eps = np.finfo(dtype).eps arr[L // 2, L // 2 : L // 2 + num_eps] = eps / 2 # For negative powers, values below machine eps will be set to zero. filt = PowerFilter( filter=ArrayFilter(arr), power=-0.5, + epsilon=epsilon, ) caplog.clear() @@ -361,3 +381,20 @@ def test_power_filter_safeguard(dtype, caplog): # Check caplog for warning. msg = f"setting {num_eps} extremal filter value(s) to zero." assert msg in caplog.text + + +def test_array_filter_dtype_passthrough(dtype): + """ + We upcast to use scipy's fast interpolator. We do not recast + on exit, so this is an expected fail for singles. + """ + if dtype == np.float32: + pytest.xfail(reason="ArrayFilter currently upcasts singles.") + + L = 8 + arr = np.ones((L, L), dtype=dtype) + + filt = ArrayFilter(arr) + filt_vals = filt.evaluate_grid(L, dtype=dtype) + + assert filt_vals.dtype == dtype diff --git a/tests/test_fourier_correlation.py b/tests/test_fourier_correlation.py index 282089ce12..79240572f8 100644 --- a/tests/test_fourier_correlation.py +++ b/tests/test_fourier_correlation.py @@ -115,7 +115,7 @@ def volume_fixture(img_size, dtype): def test_frc_id(image_fixture, method): img, _, _ = image_fixture - frc_resolution, frc = img.frc(img, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img.frc(img, cutoff=0.143, method=method) assert np.isclose(frc_resolution[0], 2, rtol=0.02) assert np.allclose(frc, 1, rtol=0.01) @@ -123,14 +123,14 @@ def test_frc_id(image_fixture, method): def test_frc_trunc(image_fixture, method): img_a, img_b, _ = image_fixture assert img_a.dtype == img_b.dtype - frc_resolution, frc = img_a.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img_a.frc(img_b, cutoff=0.143, method=method) assert frc_resolution[0] > 3.0 def test_frc_noise(image_fixture, method): img_a, _, img_n = image_fixture - frc_resolution, frc = img_a.frc(img_n, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img_a.frc(img_n, cutoff=0.143, method=method) assert frc_resolution[0] > 3.5 @@ -142,13 +142,13 @@ def test_frc_img_plot(image_fixture): # Plot to screen with matplotlib_no_gui(): - _ = img_a.frc(img_n, pixel_size=1, cutoff=0.143, plot=True) + _ = img_a.frc(img_n, cutoff=0.143, plot=True) # Plot to file # Also tests `cutoff=None` with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "img_frc_curve.png") - img_a.frc(img_n, pixel_size=1, cutoff=None, plot=file_path) + img_a.frc(img_n, cutoff=None, plot=file_path) assert os.path.exists(file_path) @@ -160,9 +160,7 @@ def test_frc_plot(image_fixture, method): """ img_a, img_b, _ = image_fixture - frc = FourierRingCorrelation( - img_a.asnumpy(), img_b.asnumpy(), pixel_size=1, method=method - ) + frc = FourierRingCorrelation(img_a.asnumpy(), img_b.asnumpy(), method=method) with matplotlib_no_gui(): frc.plot(cutoff=0.5) @@ -178,7 +176,7 @@ def test_frc_plot(image_fixture, method): def test_fsc_id(volume_fixture, method): vol, _ = volume_fixture - fsc_resolution, fsc = vol.fsc(vol, pixel_size=1, cutoff=0.143, method=method) + fsc_resolution, fsc = vol.fsc(vol, cutoff=0.143, method=method) assert np.isclose(fsc_resolution[0], 2, rtol=0.02) assert np.allclose(fsc, 1, rtol=0.01) @@ -186,11 +184,11 @@ def test_fsc_id(volume_fixture, method): def test_fsc_trunc(volume_fixture, method): vol_a, vol_b = volume_fixture - fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.143, method=method) + fsc_resolution, fsc = vol_a.fsc(vol_b, cutoff=0.143, method=method) assert fsc_resolution[0] > 3.0 # The follow should correspond to the test_fsc_plot below. - fsc_resolution, fsc = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, method=method) + fsc_resolution, fsc = vol_a.fsc(vol_b, cutoff=0.5, method=method) assert fsc_resolution[0] > 3.9 @@ -202,13 +200,13 @@ def test_fsc_vol_plot(volume_fixture): # Plot to screen with matplotlib_no_gui(): - _ = vol_a.fsc(vol_b, pixel_size=1, cutoff=0.5, plot=True) + _ = vol_a.fsc(vol_b, cutoff=0.5, plot=True) # Plot to file # Also tests `cutoff=None` with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "vol_fsc_curve.png") - vol_a.fsc(vol_b, pixel_size=1, cutoff=None, plot=file_path) + vol_a.fsc(vol_b, cutoff=None, plot=file_path) assert os.path.exists(file_path) @@ -218,9 +216,7 @@ def test_fsc_plot(volume_fixture, method): """ vol_a, vol_b = volume_fixture - fsc = FourierShellCorrelation( - vol_a.asnumpy(), vol_b.asnumpy(), pixel_size=1, method=method - ) + fsc = FourierShellCorrelation(vol_a.asnumpy(), vol_b.asnumpy(), method=method) with matplotlib_no_gui(): fsc.plot(cutoff=0.5) @@ -306,7 +302,7 @@ def test_img_type_mismatch(): b = a.asnumpy() with pytest.raises(TypeError, match=r"`other` image must be an `Image` instance"): - _ = a.frc(b, pixel_size=1, cutoff=0.143) + _ = a.frc(b, cutoff=0.143) def test_vol_type_mismatch(): @@ -314,7 +310,7 @@ def test_vol_type_mismatch(): b = a.asnumpy() with pytest.raises(TypeError, match=r"`other` volume must be an `Volume` instance"): - _ = a.fsc(b, pixel_size=1, cutoff=0.143) + _ = a.fsc(b, cutoff=0.143) # Broadcasting @@ -329,7 +325,7 @@ def test_frc_id_bcast(image_fixture, method): k = 3 img_b = Image(np.tile(img, (3, 1, 1))) - frc_resolution, frc = img.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img.frc(img_b, cutoff=0.143, method=method) assert np.allclose( frc_resolution, [ @@ -344,7 +340,7 @@ def test_frc_id_bcast(image_fixture, method): # (1) x (1,3) img_b = img_b.stack_reshape(1, 3) - frc_resolution, frc = img.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img.frc(img_b, cutoff=0.143, method=method) assert np.allclose( frc_resolution, [ @@ -359,7 +355,7 @@ def test_frc_id_bcast(image_fixture, method): # (1) x (3,1) img_b = img_b.stack_reshape(3, 1) - frc_resolution, frc = img.frc(img_b, pixel_size=1, cutoff=0.143, method=method) + frc_resolution, frc = img.frc(img_b, cutoff=0.143, method=method) assert np.allclose( frc_resolution, [ @@ -378,7 +374,7 @@ def test_fsc_id_bcast(volume_fixture, method): k = 3 vol_b = Volume(np.tile(vol.asnumpy(), (3, 1, 1, 1))) - fsc_resolution, fsc = vol.fsc(vol_b, pixel_size=1, cutoff=0.143, method=method) + fsc_resolution, fsc = vol.fsc(vol_b, cutoff=0.143, method=method) assert np.allclose( fsc_resolution, [ @@ -400,12 +396,12 @@ def test_frc_img_plot_bcast(image_fixture): # Plot to screen, one:many with matplotlib_no_gui(): - _ = img_a.frc(img_b, pixel_size=1, cutoff=0.143, plot=True) + _ = img_a.frc(img_b, cutoff=0.143, plot=True) # Plot to file, many elementwise with tempfile.TemporaryDirectory() as tmp_input_dir: file_path = os.path.join(tmp_input_dir, "img_frc_curve.png") - img_b.frc(img_b, pixel_size=1, cutoff=0.143, plot=file_path) + img_b.frc(img_b, cutoff=0.143, plot=file_path) assert os.path.exists(file_path) diff --git a/tests/test_image.py b/tests/test_image.py index d9a062bbb7..89fbde4a84 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -24,8 +24,22 @@ n = 3 mdim = 2 +PARITY = [0, 1] +DTYPES = [np.float32, np.float64] -def get_images(parity=0, dtype=np.float32): + +@pytest.fixture(params=PARITY, ids=lambda x: f"parity={x}", scope="module") +def parity(request): + return request.param + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(scope="module") +def get_images(parity, dtype): size = 768 - parity # numpy array for top-level functions that directly expect it im_np = face(gray=True).astype(dtype)[np.newaxis, :size, :size] @@ -33,36 +47,40 @@ def get_images(parity=0, dtype=np.float32): im_np /= denom # Normalize test image data to 0,1 # Independent Image object for testing Image methods - im = Image(im_np.copy()) + im = Image(im_np.copy(), pixel_size=1.23) return im_np, im -def get_stacks(parity=0, dtype=np.float32): - im_np, im = get_images(parity, dtype) +@pytest.fixture(scope="module") +def get_stacks(get_images, dtype): + im_np, im = get_images # Construct a simple stack of Images - ims_np = np.empty((n, *im_np.shape[1:]), dtype=dtype) + ims_np = np.empty((n, *im_np.shape[1:]), dtype=im_np.dtype) for i in range(n): ims_np[i] = im_np * (i + 1) / float(n) # Independent Image stack object for testing Image methods - ims = Image(ims_np) + ims = Image(ims_np.copy()) return ims_np, ims -def get_mdim_images(parity=0, dtype=np.float32): - ims_np, im = get_stacks(parity, dtype) +# Note that `get_mdim_images` is mutated by some tests, +# force per function scope. +@pytest.fixture(scope="function") +def get_mdim_images(get_stacks): + ims_np, im = get_stacks # Multi dimensional stack Image object mdim = 2 mdim_ims_np = np.concatenate([ims_np] * mdim).reshape(mdim, *ims_np.shape) # Independent multidimensional Image stack object for testing Image methods - mdim_ims = Image(mdim_ims_np) + mdim_ims = Image(mdim_ims_np.copy()) return mdim_ims_np, mdim_ims -def testRepr(): - _, mdim_ims = get_mdim_images() +def testRepr(get_mdim_images): + _, mdim_ims = get_mdim_images r = repr(mdim_ims) logger.info(f"Image repr:\n{r}") @@ -73,9 +91,8 @@ def testNonSquare(): _ = Image(np.empty((4, 5))) -@pytest.mark.parametrize("parity,dtype", params) -def testImShift(parity, dtype): - im_np, im = get_images(parity, dtype) +def testImShift(get_images, dtype): + im_np, im = get_images # Note that the _im_translate method can handle float input shifts, as it # computes the shifts in Fourier space, rather than performing a roll # However, NumPy's roll() only accepts integer inputs @@ -87,19 +104,22 @@ def testImShift(parity, dtype): im1 = im._im_translate(shifts) # test that float input returns the same thing im2 = im.shift(shifts.astype(dtype)) - # ground truth numpy roll - im3 = np.roll(im_np[0, :, :], -shifts, axis=(0, 1)) + # ground truth numpy roll. + # Note: NumPy axes 0 and 1 correspond to the row and column of an array, + # respectively, which corresponds to the y-axis and x-axis when that array + # represents an image. Since our shifts are (x-shifts, y-shifts), the axis + # parameter for np.roll() must be set to (1, 0) to accomodate. + im3 = np.roll(im_np[0, :, :], -shifts, axis=(1, 0)) atol = utest_tolerance(dtype) - assert np.allclose(im0.asnumpy(), im1.asnumpy(), atol=atol) - assert np.allclose(im1.asnumpy(), im2.asnumpy(), atol=atol) - assert np.allclose(im0.asnumpy()[0, :, :], im3, atol=atol) + np.testing.assert_allclose(im0.asnumpy(), im1.asnumpy(), atol=atol) + np.testing.assert_allclose(im1.asnumpy(), im2.asnumpy(), atol=atol) + np.testing.assert_allclose(im0.asnumpy()[0, :, :], im3, atol=atol) -@pytest.mark.parametrize("parity,dtype", params) -def testImShiftStack(parity, dtype): - ims_np, ims = get_stacks(parity, dtype) +def testImShiftStack(get_stacks, dtype): + ims_np, ims = get_stacks # test stack of shifts (same number as Image.num_img) # mix of odd and even shifts = np.array([[100, 200], [203, 150], [55, 307]]) @@ -111,19 +131,23 @@ def testImShiftStack(parity, dtype): # test that float input returns the same thing im2 = ims.shift(shifts.astype(dtype)) # ground truth numpy roll + # Note: NumPy axes 0 and 1 correspond to the row and column of an array, + # respectively, which corresponds to the y-axis and x-axis when that array + # represents an image. Since our shifts are (x-shifts, y-shifts), the axis + # parameter for np.roll() must be set to (1, 0) to accomodate. im3 = np.array( - [np.roll(ims_np[i, :, :], -shifts[i], axis=(0, 1)) for i in range(n)] + [np.roll(ims_np[i, :, :], -shifts[i], axis=(1, 0)) for i in range(n)] ) atol = utest_tolerance(dtype) - assert np.allclose(im0.asnumpy(), im1.asnumpy(), atol=atol) - assert np.allclose(im1.asnumpy(), im2.asnumpy(), atol=atol) - assert np.allclose(im0.asnumpy(), im3, atol=atol) + np.testing.assert_allclose(im0.asnumpy(), im1.asnumpy(), atol=atol) + np.testing.assert_allclose(im1.asnumpy(), im2.asnumpy(), atol=atol) + np.testing.assert_allclose(im0.asnumpy(), im3, atol=atol) -def testImageShiftErrors(): - _, im = get_images(0, np.float32) +def testImageShiftErrors(get_images): + _, im = get_images # test bad shift shape with pytest.raises(ValueError, match="Input shifts must be of shape"): _ = im.shift(np.array([100, 100, 100])) @@ -132,18 +156,16 @@ def testImageShiftErrors(): _ = im.shift(np.array([[100, 200], [100, 200]])) -@pytest.mark.parametrize("parity,dtype", params) -def testImageSqrt(parity, dtype): - im_np, im = get_images(parity, dtype) - ims_np, ims = get_stacks(parity, dtype) +def testImageSqrt(get_images, get_stacks): + im_np, im = get_images + ims_np, ims = get_stacks assert np.allclose(im.sqrt().asnumpy(), np.sqrt(im_np)) assert np.allclose(ims.sqrt().asnumpy(), np.sqrt(ims_np)) -@pytest.mark.parametrize("parity,dtype", params) -def testImageTranspose(parity, dtype): - im_np, im = get_images(parity, dtype) - ims_np, ims = get_stacks(parity, dtype) +def testImageTranspose(get_images, get_stacks): + im_np, im = get_images + ims_np, ims = get_stacks # test method and abbreviation assert np.allclose(im.T.asnumpy(), np.transpose(im_np, (0, 2, 1))) assert np.allclose(im.transpose().asnumpy(), np.transpose(im_np, (0, 2, 1))) @@ -154,10 +176,9 @@ def testImageTranspose(parity, dtype): assert np.allclose(ims.transpose()[i], ims_np[i].T) -@pytest.mark.parametrize("parity,dtype", params) -def testImageFlip(parity, dtype): - im_np, im = get_images(parity, dtype) - ims_np, ims = get_stacks(parity, dtype) +def testImageFlip(get_images, get_stacks): + im_np, im = get_images + ims_np, ims = get_stacks for axis in powerset(range(1, 3)): if not axis: # test default @@ -179,31 +200,31 @@ def testImageFlip(parity, dtype): _ = im.flip(axis) -def testShape(): - ims_np, ims = get_stacks() +def testShape(get_stacks): + ims_np, ims = get_stacks assert ims.shape == ims_np.shape assert ims.stack_shape == ims_np.shape[:-2] assert ims.stack_ndim == 1 -def testMultiDimShape(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimShape(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images assert mdim_ims.shape == mdim_ims_np.shape assert mdim_ims.stack_shape == mdim_ims_np.shape[:-2] assert mdim_ims.stack_ndim == mdim assert mdim_ims.n_images == mdim * ims.n_images -def testBadKey(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testBadKey(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images with pytest.raises(ValueError, match="slice length must be"): _ = mdim_ims[tuple(range(mdim_ims.ndim + 1))] -def testMultiDimGets(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimGets(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images for X in mdim_ims: assert np.allclose(ims_np, X) @@ -211,9 +232,9 @@ def testMultiDimGets(): assert np.allclose(mdim_ims[:, 1:], ims[1:]) -def testMultiDimSets(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimSets(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images mdim_ims[0, 1] = 123 # Check the values changed assert np.allclose(mdim_ims[0, 1], 123) @@ -223,9 +244,9 @@ def testMultiDimSets(): assert np.allclose(mdim_ims[1, :], ims_np) -def testMultiDimSetsSlice(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimSetsSlice(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images # Test setting a slice mdim_ims[0, 1:] = 456 # Check the values changed @@ -235,9 +256,9 @@ def testMultiDimSetsSlice(): assert np.allclose(mdim_ims[1, :], ims_np) -def testMultiDimReshape(): +def testMultiDimReshape(get_mdim_images): # Try mdim reshape - mdim_ims_np, mdim_ims = get_mdim_images() + mdim_ims_np, mdim_ims = get_mdim_images X = mdim_ims.stack_reshape(*mdim_ims.stack_shape[::-1]) assert X.stack_shape == mdim_ims.stack_shape[::-1] # Compare with direct np.reshape of axes of ndarray @@ -245,22 +266,22 @@ def testMultiDimReshape(): assert np.allclose(X.asnumpy(), mdim_ims_np.reshape(shape)) -def testMultiDimFlattens(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimFlattens(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images # Try flattening X = mdim_ims.stack_reshape(mdim_ims.n_images) assert X.stack_shape, (mdim_ims.n_images,) -def testMultiDimFlattensTrick(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimFlattensTrick(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images # Try flattening with -1 X = mdim_ims.stack_reshape(-1) assert X.stack_shape == (mdim_ims.n_images,) -def testMultiDimReshapeTuples(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimReshapeTuples(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images # Try flattening with (-1,) X = mdim_ims.stack_reshape((-1,)) assert X.stack_shape, (mdim_ims.n_images,) @@ -270,8 +291,8 @@ def testMultiDimReshapeTuples(): assert X.stack_shape == mdim_ims.stack_shape[::-1] -def testMultiDimBadReshape(): - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimBadReshape(get_mdim_images): + mdim_ims_np, mdim_ims = get_mdim_images # Incorrect flat shape with pytest.raises(ValueError, match="Number of images"): _ = mdim_ims.stack_reshape(8675309) @@ -281,11 +302,11 @@ def testMultiDimBadReshape(): _ = mdim_ims.stack_reshape(42, 8675309) -def testMultiDimBroadcast(): - ims_np, ims = get_stacks() - mdim_ims_np, mdim_ims = get_mdim_images() +def testMultiDimBroadcast(get_stacks, get_mdim_images): + ims_np, ims = get_stacks + mdim_ims_np, mdim_ims = get_mdim_images X = mdim_ims + ims - assert np.allclose(X[0], 2 * ims.asnumpy()) + np.testing.assert_allclose(X[0], 2 * ims.asnumpy()) @matplotlib_dry_run @@ -297,12 +318,12 @@ def testShow(): im.show() -def test_backproject_symmetry_group(): +def test_backproject_symmetry_group(dtype): """ Test backproject SymmetryGroup pass through and error message. """ ary = np.random.random((5, 8, 8)) - im = Image(ary) + im = Image(ary, dtype=dtype) rots = Rotation.generate_random_rotations(5).matrices # Attempt backproject with bad symmetry group. @@ -315,9 +336,7 @@ def test_backproject_symmetry_group(): assert isinstance(vol.symmetry_group, CnSymmetryGroup) # Symmetry from instance. - vol = im.backproject( - rots, symmetry_group=CnSymmetryGroup(order=3, dtype=np.float32) - ) + vol = im.backproject(rots, symmetry_group=CnSymmetryGroup(order=3, dtype=dtype)) assert isinstance(vol.symmetry_group, CnSymmetryGroup) @@ -334,34 +353,34 @@ def test_asnumpy_readonly(): vw[0, 0, 0] = 123 -@pytest.mark.xfail(reason="Ray logging issue ray#37711", strict=False) def test_corrupt_mrc_load(caplog): """ Test that corrupt mrc files are logged as expected. """ - caplog.set_level(logging.WARNING) - # Create a tmp dir for this test output with tempfile.TemporaryDirectory() as tmpdir_name: # tmp filename mrc_path = os.path.join(tmpdir_name, "bad.mrc") # Create and save image - Image(np.empty((1, 8, 8), dtype=np.float32)).save(mrc_path) + Image(np.ones((1, 8, 8), dtype=np.float32)).save(mrc_path) # Open mrc file and soft corrupt it with mrcfile.open(mrc_path, "r+") as fh: fh.header.map = -1 # Check that we get a WARNING - _ = Image.load(mrc_path) + with caplog.at_level(logging.WARNING): + _ = Image.load(mrc_path) - # Check the message prefix - assert f"Image.load of {mrc_path} reporting 1 corruptions" in caplog.text + # Check the message prefix + assert f"Image.load of {mrc_path} reporting 1 corruptions" in caplog.text - # Check the message contains the file path - assert mrc_path in caplog.text + # Check the message contains the file path + assert mrc_path in caplog.text + + caplog.clear() def test_load_bad_ext(): @@ -372,7 +391,7 @@ def test_load_bad_ext(): _ = Image.load("bad.ext") -def test_load_mrc(): +def test_load_mrc(dtype): """ Test `Image.load` round-trip. """ @@ -381,27 +400,19 @@ def test_load_mrc(): filepath = os.path.join(DATA_DIR, "sample.mrc") # Load data from file - im = Image.load(filepath) - im_64 = Image.load(filepath, dtype=np.float64) + im = Image.load(filepath, dtype=dtype) with tempfile.TemporaryDirectory() as tmpdir_name: # tmp filename test_filepath = os.path.join(tmpdir_name, "test.mrc") - test_filepath_64 = os.path.join(tmpdir_name, "test_64.mrc") im.save(test_filepath) - im_64.save(test_filepath_64) - im2 = Image.load(test_filepath) - im2_64 = Image.load(test_filepath_64, dtype=np.float64) + im2 = Image.load(test_filepath, dtype) # Check the single precision round-trip assert np.array_equal(im, im2) - assert im2.dtype == np.float32 - - # check the double precision round-trip - assert np.array_equal(im_64, im2_64) - assert im2_64.dtype == np.float64 + assert im2.dtype == dtype def test_load_tiff(): @@ -427,3 +438,30 @@ def test_load_tiff(): # Check contents assert np.array_equal(im, im2) + + +def test_save_load_pixel_size(get_images, dtype): + """ + Test saving and loading an MRC with pixel size attribute + """ + + im_np, im = get_images + + with tempfile.TemporaryDirectory() as tmpdir_name: + # tmp filename + test_filepath = os.path.join(tmpdir_name, "test.mrc") + + # Save image to file + im.save(test_filepath) + + # Load image from file + im2 = Image.load(test_filepath, dtype) + + # Check we've loaded the image data + np.testing.assert_allclose(im2, im) + # Check we've loaded the image dtype + assert im2.dtype == im.dtype, "Image dtype mismatched on save-load" + # Check we've loaded the pixel size + np.testing.assert_almost_equal( + im2.pixel_size, im.pixel_size, err_msg="Image pixel_size incorrect save-load" + ) diff --git a/tests/test_indexed_source.py b/tests/test_indexed_source.py index 30a23ee16b..3092ed16f4 100644 --- a/tests/test_indexed_source.py +++ b/tests/test_indexed_source.py @@ -13,7 +13,7 @@ def sim_fixture(): """ Generate a very small simulation and slice it. """ - sim = Simulation(L=8, n=10, C=1) + sim = Simulation(L=8, n=10, C=1, symmetry_group="D3") sim2 = sim[0::2] # Slice the evens return sim, sim2 @@ -22,15 +22,22 @@ def test_remapping(sim_fixture): sim, sim2 = sim_fixture # Check images are served correctly, using internal index. - assert np.allclose(sim.images[sim2.index_map].asnumpy(), sim2.images[:].asnumpy()) + np.testing.assert_allclose( + sim.images[sim2.index_map].asnumpy(), sim2.images[:].asnumpy(), atol=1e-6 + ) # Check images are served correctly, using known index (evens). index = list(range(0, sim.n, 2)) - assert np.allclose(sim.images[index].asnumpy(), sim2.images[:].asnumpy()) + np.testing.assert_allclose( + sim.images[index].asnumpy(), sim2.images[:].asnumpy(), atol=1e-6 + ) # Check meta is served correctly. assert np.all(sim.get_metadata(indices=sim2.index_map) == sim2.get_metadata()) + # Check symmetry_group pass-through. + assert sim.symmetry_group == sim2.symmetry_group + def test_repr(sim_fixture): sim, sim2 = sim_fixture diff --git a/tests/test_mean_estimator.py b/tests/test_mean_estimator.py index 2650839272..e6b2a2f837 100644 --- a/tests/test_mean_estimator.py +++ b/tests/test_mean_estimator.py @@ -159,10 +159,7 @@ def test_checkpoint(sim, basis, estimator): maxiter=test_iter + 1, checkpoint_prefix=prefix, ) - - # Assert we raise when reading `maxiter`. - with raises(RuntimeError, match="Unable to converge!"): - _ = _estimator.estimate() + _ = _estimator.estimate() # Load the checkpoint coefficients while tmp_input_dir exists. x_chk = np.load(f"{prefix}_iter{test_iter:04d}.npy") diff --git a/tests/test_mean_estimator_boosting.py b/tests/test_mean_estimator_boosting.py index 9251dee09e..6eac159115 100644 --- a/tests/test_mean_estimator_boosting.py +++ b/tests/test_mean_estimator_boosting.py @@ -122,7 +122,7 @@ def weighted_source(weighted_volume): def test_fsc(source, estimated_volume): """Compare estimated volume to source volume with FSC.""" # Fourier Shell Correlation - fsc_resolution, fsc = source.vols.fsc(estimated_volume, pixel_size=1, cutoff=0.5) + fsc_resolution, fsc = source.vols.fsc(estimated_volume, cutoff=0.5) # Check that resolution is less than 2.1 pixels. np.testing.assert_array_less(fsc_resolution, 2.1) diff --git a/tests/test_micrograph_source.py b/tests/test_micrograph_source.py index 06c5874a07..d4793cf61f 100644 --- a/tests/test_micrograph_source.py +++ b/tests/test_micrograph_source.py @@ -285,7 +285,7 @@ def test_rectangular_micrograph_source_files(): """ # Test inconsistent mrc files - imgs = [np.empty((7, 7)), np.empty((8, 8))] + imgs = [np.zeros((7, 7)), np.zeros((8, 8))] with tempfile.TemporaryDirectory() as tmp_output_dir: # Save the files for i, img in enumerate(imgs): diff --git a/tests/test_numeric_sparse.py b/tests/test_numeric_sparse.py new file mode 100644 index 0000000000..e58aa02e6a --- /dev/null +++ b/tests/test_numeric_sparse.py @@ -0,0 +1,58 @@ +""" +Tests basic numpy/cupy functionality of sparse numeric wrappers. +""" + +import numpy as np +import pytest + +from aspire.numeric import numeric_object, sparse_object + +# If cupy is not available, skip this entire test module +pytest.importorskip("cupy") + +NUMERICS = ["numpy", "cupy"] + + +@pytest.fixture(params=NUMERICS, ids=lambda x: f"{x}", scope="module") +def backends(request): + xp = numeric_object(request.param) + sparse = sparse_object(request.param) + return xp, sparse + + +def test_csr_matrix(backends): + """ + Create csr_matrix and multiply with an `xp` array. + """ + xp, sparse = backends + + m, n = 10, 10 + jdx = xp.arange(m) + idx = xp.arange(n) + vals = xp.random.random(10) + + # Compute dense matmul + _A = np.diag(xp.asnumpy(vals)) + _B = np.random.random((n, 20)) + _C = _A @ _B + + # Compute matmul using sparse csr + A = sparse.csr_matrix((vals, (jdx, idx)), shape=(m, n), dtype=np.float64) + B = xp.array(_B) + C = A @ B + + # Compare + np.testing.assert_allclose(_C, xp.asnumpy(C)) + + +def test_eigsh(backends): + """ + Invoke sparse eigsh call with `xp` arrays. + """ + xp, sparse = backends + + n = 123 + A = xp.diag(xp.arange(1, n + 1, dtype=np.float64)) + + lamb, _ = sparse.linalg.eigsh(A, k=1) + np.testing.assert_allclose(xp.asnumpy(lamb), n) diff --git a/tests/test_orient_d2.py b/tests/test_orient_d2.py new file mode 100644 index 0000000000..ca972bcf7c --- /dev/null +++ b/tests/test_orient_d2.py @@ -0,0 +1,478 @@ +import numpy as np +import pytest + +from aspire.abinitio import CLSymmetryD2 +from aspire.source import Simulation +from aspire.utils import ( + J_conjugate, + Random, + Rotation, + all_pairs, + mean_aligned_angular_distance, + utest_tolerance, +) +from aspire.volume import DnSymmetricVolume, DnSymmetryGroup + +############## +# Parameters # +############## + +DTYPE = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] +RESOLUTION = [48, 49] +N_IMG = [10] +OFFSETS = [0, pytest.param(None, marks=pytest.mark.expensive)] + +# Since these tests are optimized for runtime, detuned parameters cause +# the algorithm to be fickle, especially for small problem sizes. +# In particular, the parameters `grid_res`, inplane_res`, and `eq_min_dist` +# which control the number of candidate rotations used in the D2 algorithm +# will produce bad estimates if the candidates do not align closely with the +# ground truth rotations. +# This seed is chosen so the tests pass CI on github's envs as well +# as our self-hosted runner. +SEED = 3 + + +@pytest.fixture(params=DTYPE, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + return request.param + + +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") +def resolution(request): + return request.param + + +@pytest.fixture(params=N_IMG, ids=lambda x: f"n images={x}", scope="module") +def n_img(request): + return request.param + + +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}", scope="module") +def offsets(request): + return request.param + + +############ +# Fixtures # +############ + + +@pytest.fixture(scope="module") +def source(n_img, resolution, dtype, offsets): + vol = DnSymmetricVolume( + L=resolution, order=2, C=1, K=100, dtype=dtype, seed=SEED + ).generate() + + src = Simulation( + n=n_img, + L=resolution, + vols=vol, + offsets=offsets, + amplitudes=1, + seed=SEED, + ) + src = src.cache() # Precompute image stack + + return src + + +@pytest.fixture(scope="module") +def orient_est(source): + return build_cl_from_source(source) + + +######### +# Tests # +######### + + +def test_estimate_rotations(orient_est): + """ + This test runs through the complete D2 algorithm and compares the + estimated rotations to the ground truth rotations. In particular, + we check that the estimates are close to the ground truth up to + a local rotation by a D2 symmetry group member, a global J-conjugation, + and a globally aligning rotation. + """ + # Estimate rotations. + orient_est.estimate_rotations() + rots_est = orient_est.rotations + + # Ground truth rotations. + rots_gt = orient_est.src.rotations + + # g-sync ground truth rotations. + rots_gt_sync = g_sync_d2(rots_est, rots_gt) + + # Register estimates to ground truth rotations and check that the mean angular + # distance between them is less than 5 degrees. + mean_aligned_angular_distance(rots_est, rots_gt_sync, degree_tol=5) + + # Check dtype pass-through. + assert rots_est.dtype == orient_est.dtype + + +def test_scl_scores(orient_est): + """ + This test uses a Simulation generated with rotations taken directly + from the D2 algorithm `sphere_grid` of candidate rotations. It is + these candidates which should produce maximum correlation scores since + they match perfectly the Simulation rotations. + """ + # Generate lookup data and extract rotations from the candidate `sphere_grid`. + # In this case, we take first 10 candidates from a non-equator viewing direction. + orient_est._generate_lookup_data() + cand_rots = orient_est.inplane_rotated_grid1 + non_eq_idx = int(np.argwhere(orient_est.eq_class1 == 0)[0][0]) + rots = cand_rots[non_eq_idx, :10] + angles = Rotation(rots).angles + + # Create a Simulation using those first 10 candidate rotations. + src = Simulation( + n=orient_est.src.n, + L=orient_est.src.L, + vols=orient_est.src.vols, + angles=angles, + offsets=orient_est.src.offsets, + amplitudes=1, + seed=SEED, + ) + + # Initialize CL instance with new source. + cl = build_cl_from_source(src) + + # Generate lookup data. + cl._compute_shifted_pf() + cl._generate_lookup_data() + cl._generate_scl_lookup_data() + + # Compute self-commonline scores. + cl._compute_scl_scores() + + # cl.scls_scores is shape (n_img, n_cand_rots). Since we used the first + # 10 candidate rotations of the first non-equator viewing direction as our + # Simulation rotations, the maximum correlation for image i should occur at + # candidate rotation index (non_eq_idx * cl.n_inplane_rots + i). + max_corr_idx = np.argmax(cl.scls_scores, axis=1) + gt_idx = non_eq_idx * cl.n_inplane_rots + np.arange(10) + + # Check that self-commonline indices match ground truth. + n_match = np.sum(max_corr_idx == gt_idx) + match_tol = 0.99 # match at least 99%. + if not (src.offsets == 0.0).all(): + match_tol = 0.89 # match at least 89% with offsets. + np.testing.assert_array_less(match_tol, n_match / src.n) + + # Check dtype pass-through. + assert cl.scls_scores.dtype == orient_est.dtype + + +def test_global_J_sync(orient_est): + """ + For this test we build a set of relative rotations, Rijs, of shape + (npairs, order(D2), 3, 3) and randomly J_conjugate them. We then test + that the J-configuration is correctly detected and that J-synchronization + is correct up to conjugation of the entire set. + """ + # Grab set of rotations and generate a set of relative rotations, Rijs. + rots = orient_est.src.rotations + Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + Rij = rots[i].T @ orient_est.gs @ rots[j] + np.random.shuffle(Rij) # Mix up the ordering of Rijs + Rijs[p] = Rij + + # J-conjugate a random set of Rijs. + Rijs_conj = Rijs.copy() + inds = np.random.choice( + orient_est.n_pairs, size=orient_est.n_pairs // 2, replace=False + ) + Rijs_conj[inds] = J_conjugate(Rijs[inds]) + + # Create J-configuration conditions for the triplet Rij, Rjk, Rik. + J_conds = { + (False, False, False): 0, + (True, True, True): 0, + (True, False, False): 1, + (False, True, True): 1, + (False, True, False): 2, + (True, False, True): 2, + (False, False, True): 3, + (True, True, False): 3, + } + + # Construct ground truth J-configuration list based on `inds` of Rijs + # that have been conjugated above. + J_list_gt = np.zeros(len(orient_est.triplets), dtype=int) + idx = 0 + for i, j, k in orient_est.triplets: + ij = orient_est.pairs_to_linear[i, j] + jk = orient_est.pairs_to_linear[j, k] + ik = orient_est.pairs_to_linear[i, k] + + J_conf = (ij in inds, jk in inds, ik in inds) + J_list_gt[idx] = J_conds[J_conf] + idx += 1 + + # Perform J-configuration and compare to ground truth. + J_list = orient_est._J_configuration(Rijs_conj) + np.testing.assert_equal(J_list, J_list_gt) + + # Perform global J-synchronization and check that + # Rijs_sync is equal to either Rijs or J_conjugate(Rijs). + Rijs_sync = orient_est._global_J_sync(Rijs_conj) + need_to_conj_Rijs = not np.allclose(Rijs_sync[inds][0], Rijs[inds][0]) + if need_to_conj_Rijs: + np.testing.assert_allclose(Rijs_sync, J_conjugate(Rijs)) + else: + np.testing.assert_allclose(Rijs_sync, Rijs) + + # Check dtype pass-through. + assert Rijs_sync.dtype == orient_est.dtype + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_global_J_sync_single_triplet(dtype): + """ + This exercises the J-synchronization algorithm using the smallest + possible problem size, a single triplets of relative rotations Rijs. + """ + # Generate 3 image source and orientation object. + src = Simulation(n=3, L=10, dtype=dtype, seed=SEED) + orient_est = build_cl_from_source(src) + + # Grab set of rotations and generate a set of relative rotations, Rijs. + rots = orient_est.src.rotations + Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + Rij = rots[i].T @ orient_est.gs @ rots[j] + np.random.shuffle(Rij) # Mix up the ordering of Rijs + Rijs[p] = Rij + + # J-conjugate a random Rij. + Rijs_conj = Rijs.copy() + inds = np.random.choice(orient_est.n_pairs, size=1, replace=False) + Rijs_conj[inds] = J_conjugate(Rijs[inds]) + + # Perform global J-synchronization and check that + # Rijs_sync is equal to either Rijs or J_conjugate(Rijs). + Rijs_sync = orient_est._global_J_sync(Rijs_conj) + need_to_conj_Rijs = not np.allclose(Rijs_sync[inds][0], Rijs[inds][0]) + if need_to_conj_Rijs: + np.testing.assert_allclose(Rijs_sync, J_conjugate(Rijs)) + else: + np.testing.assert_allclose(Rijs_sync, Rijs) + + +def test_sync_colors(orient_est): + """ + A set of estimated relative rotations, Rijs, have the shape (n_pairs, 4, 3, 3), + where each 4-tuple Rij is given by Rij = Ri.T @ g_m @ Rj, for m in [0, 1, 2, 3], + where each g_m is an element of the D2 symmetry group. The ordering of the symmetry + group elements, g_m, is unknown and independent between Rijs. The `_sync_colors` + algorithm forms the set of vijs of shape (n_pairs, 3, 3, 3), where each vij, given + by vij = (Rij[0] + Rij[m]) / 2 with m = 1, 2, 3, is some permutation of the outer + products of the k'th rows of the rotation matrices Ri and Rj, for k = 0, 1, 2. + + The 'sync_colors` algorithm uses a colored graph to partition the set of vijs + based on k'th row outer products and returns those outer products along with + a color mapping encoding a permutation for each vij. + + In this test we form a set of Rijs with randomly ordered symmetry group elements + and extract the ground truth color permutations based on that ordering. We then + construct a set of ground truth vijs adjusted by the ground truth color permuations. + We then compare estimated vijs and color permutations to ground truth. + """ + # Grab set of rotations and generate a set of relative rotations, Rijs. + rots = orient_est.src.rotations + Rijs = np.zeros((orient_est.n_pairs, 4, 3, 3), dtype=orient_est.dtype) + gt_colors = np.zeros((orient_est.n_pairs, 3), dtype=int) + + with Random(123): + for p, (i, j) in enumerate(orient_est.pairs): + gs = orient_est.gs + if p > 0: + np.random.shuffle(gs) # Mix up the ordering of all but 1st Rijs. + + # Compute the rotation row permutation created by the ordering of gs. + # See Proposition 5.1 in the related publication for details. + for m in range(3): + gt_colors[p, m] = np.argmax( + np.sum(abs(0.5 * (gs[0] + gs[m + 1])), axis=0) + ) + + # Compute Rijs with shuffled gs. + Rij = rots[i].T @ gs @ rots[j] + Rijs[p] = Rij + + # Compute ground truth m'th row outer products. + vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + for m in range(3): + row = gt_colors[p, m] + vijs[p, m] = np.outer(rots[i][row], rots[j][row]) + + # Perform color synchronization. + # `est_vijs` is shape (n_pairs, 3, 3, 3) where est_vijs[ij, m] corresponds + # to the outer product vij_m = rots[i, m].T @ rots[j, m] where m is the m'th row + # of the rotations matrices Ri and Rj. `est_colors` partitions the set of `est_vijs` + # such that the indices of `est_colors` corresponds to the row index m. + est_colors, est_vijs = orient_est._sync_colors(Rijs) + + # Reshape `est_colors` to shape (n_pairs, 3) and use to index est_vijs into the + # correctly order 3rd row outer products vijs. + est_colors = est_colors.reshape(orient_est.n_pairs, 3) + + # `est_colors` is an arbitrary permutation (but globally consistent), and we know + # that est_colors[0] should correspond to the ordering [0, 1, 2] due to the construction + # of Rijs[0] using the symmetric rotations g0, g1, g2, g3 in non-permuted order. + # So we sort the columns such that est_colors[0] = [0,1,2]. + + # Create a mapping array + perm = est_colors[0] + mapping = np.zeros_like(perm) + mapping[perm] = np.arange(3) + + # Apply this mapping to all rows of the est_colors array + est_colors_mapped = mapping[est_colors] + + # Check that remapped color permutations match ground truth. + np.testing.assert_allclose(est_colors_mapped, gt_colors) + + # est_vijs_synced should match the ground truth vijs up to the sign of each row. + # So we multiply by the sign of the first column of the last two axes to sync signs. + vijs = vijs * np.sign(vijs[..., 0])[..., None] + est_vijs = est_vijs * np.sign(est_vijs[..., 0])[..., None] + np.testing.assert_allclose(vijs, est_vijs, atol=utest_tolerance(orient_est.dtype)) + + # Check dtype pass-through. + assert est_vijs.dtype == orient_est.dtype + + +def test_sync_signs(orient_est): + """ + Sign synchronization consumes a set of m'th row outer products along with + a color synchronizing vector and returns a set of rotation matrices + that are the result of synchronizing the signs of the rows of the outer + products and factoring the outer products to form the rows of the rotations. + + In this test we provide a color-synchronized set of m'th row outer products + with a corresponding color vector and test that the output rotations + equivalent to the ground truth rotations up to a global alignment. + """ + rots = orient_est.src.rotations + + # Compute ground truth m'th row outer products. + vijs = np.zeros((orient_est.n_pairs, 3, 3, 3), dtype=orient_est.dtype) + for p, (i, j) in enumerate(orient_est.pairs): + for m in range(3): + vijs[p, m] = np.outer(rots[i][m], rots[j][m]) + + # We will pass in m'th row outer products that are color synchronized, + # ie. colors = [0, 1, 2, 0, 1, 2, ...] + perm = np.array([0, 1, 2]) + colors = np.tile(perm, orient_est.n_pairs) + + # Estimate rotations and check against ground truth. + rots_est = orient_est._sync_signs(vijs, colors) + mean_aligned_angular_distance(rots, rots_est, degree_tol=1e-5) + + # Check dtype pass-through. + assert rots_est.dtype == orient_est.dtype + + +#################### +# Helper Functions # +#################### + + +def g_sync_d2(rots, rots_gt): + """ + Every estimated rotation might be a version of the ground truth rotation + rotated by g^{s_i}, where s_i = 0, 1, ..., order. This method synchronizes the + ground truth rotations so that only a single global rotation need be applied + to all estimates for error analysis. + + :param rots: Estimated rotation matrices + :param rots_gt: Ground truth rotation matrices. + + :return: g-synchronized ground truth rotations. + """ + assert len(rots) == len( + rots_gt + ), "Number of estimates not equal to number of references." + n_img = len(rots) + dtype = rots.dtype + + rots_symm = DnSymmetryGroup(2, dtype).matrices + order = len(rots_symm) + + A_g = np.zeros((n_img, n_img), dtype=complex) + + pairs = all_pairs(n_img) + + for i, j in pairs: + Ri = rots[i] + Rj = rots[j] + Rij = Ri.T @ Rj + + Ri_gt = rots_gt[i] + Rj_gt = rots_gt[j] + + diffs = np.zeros(order) + for s, g_s in enumerate(rots_symm): + Rij_gt = Ri_gt.T @ g_s @ Rj_gt + diffs[s] = min( + [ + np.linalg.norm(Rij - Rij_gt), + np.linalg.norm(Rij - J_conjugate(Rij_gt)), + ] + ) + + idx = np.argmin(diffs) + + A_g[i, j] = np.exp(-1j * 2 * np.pi / order * idx) + + # A_g(k,l) is exp(-j(-theta_k+theta_l)) + # Diagonal elements correspond to exp(-i*0) so put 1. + # This is important only for verification purposes that spectrum is (K,0,0,0...,0). + A_g += np.conj(A_g).T + np.eye(n_img) + + _, eig_vecs = np.linalg.eigh(A_g) + leading_eig_vec = eig_vecs[:, -1] + + angles = np.exp(1j * 2 * np.pi / order * np.arange(order)) + rots_gt_sync = np.zeros((n_img, 3, 3), dtype=dtype) + + for i, rot_gt in enumerate(rots_gt): + # Since the closest ccw or cw rotation are just as good, + # we take the absolute value of the angle differences. + angle_dists = np.abs(np.angle(leading_eig_vec[i] / angles)) + power_g_Ri = np.argmin(angle_dists) + rots_gt_sync[i] = rots_symm[power_g_Ri] @ rot_gt + + return rots_gt_sync + + +def build_cl_from_source(source): + # Search for common lines over less shifts for 0 offsets. + max_shift = 0 + shift_step = 1 + if source.offsets.all() != 0: + max_shift = 0.2 + shift_step = 0.02 # Reduce shift steps for non-integer offsets of Simulation. + + orient_est = CLSymmetryD2( + source, + max_shift=max_shift, + shift_step=shift_step, + n_theta=180, + n_rad=source.L, + grid_res=350, # Tuned for speed + inplane_res=12, # Tuned for speed + eq_min_dist=10, # Tuned for speed + epsilon=0.001, + seed=SEED, + ) + return orient_est diff --git a/tests/test_orient_sdp.py b/tests/test_orient_sdp.py index a161d2fdd7..22658ee06a 100644 --- a/tests/test_orient_sdp.py +++ b/tests/test_orient_sdp.py @@ -77,7 +77,7 @@ def test_estimate_rotations(src_orient_est_fixture): src, orient_est = src_orient_est_fixture if backend_available("cufinufft") and src.dtype == np.float32: - pytest.skip("CI on gpu fails for singles.") + pytest.skip("CI on GPU fails for singles.") orient_est.estimate_rotations() diff --git a/tests/test_orient_sync_voting.py b/tests/test_orient_sync_voting.py index 31d6b20e94..4bea8df6c2 100644 --- a/tests/test_orient_sync_voting.py +++ b/tests/test_orient_sync_voting.py @@ -1,3 +1,4 @@ +import copy import os import os.path import tempfile @@ -32,27 +33,27 @@ ] -@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}") +@pytest.fixture(params=RESOLUTION, ids=lambda x: f"resolution={x}", scope="module") def resolution(request): return request.param -@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}") +@pytest.fixture(params=OFFSETS, ids=lambda x: f"offsets={x}", scope="module") def offsets(request): return request.param -@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}") +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") def dtype(request): return request.param -@pytest.fixture +@pytest.fixture(scope="module") def source_orientation_objs(resolution, offsets, dtype): src = Simulation( n=50, L=resolution, - vols=AsymmetricVolume(L=resolution, C=1, K=100, seed=0).generate(), + vols=AsymmetricVolume(L=resolution, C=1, K=100, seed=0, dtype=dtype).generate(), offsets=offsets, amplitudes=1, seed=0, @@ -68,6 +69,9 @@ def source_orientation_objs(resolution, offsets, dtype): src, max_shift=max_shift, shift_step=shift_step, mask=False ) + # Estimate rotations once for all tests. + orient_est.estimate_rotations() + return src, orient_est @@ -96,23 +100,49 @@ def test_build_clmatrix(source_orientation_objs): def test_estimate_rotations(source_orientation_objs): src, orient_est = source_orientation_objs - orient_est.estimate_rotations() - # Register estimates to ground truth rotations and compute the # mean angular distance between them (in degrees). # Assert that mean angular distance is less than 1 degree. mean_aligned_angular_distance(orient_est.rotations, src.rotations, degree_tol=1) -def test_estimate_shifts(source_orientation_objs): +def test_estimate_shifts_with_gt_rots(source_orientation_objs): src, orient_est = source_orientation_objs + + # Assign ground truth rotations. + # Deep copy to prevent altering for other tests. + orient_est = copy.deepcopy(orient_est) + orient_est.rotations = src.rotations + + # Estimate shifts using ground truth rotations. + est_shifts = orient_est.estimate_shifts() + + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets if src.offsets.all() != 0: - pytest.xfail("Currently failing under non-zero offsets.") + np.testing.assert_array_less(mean_dist, 0.5) + else: + np.testing.assert_allclose(mean_dist, 0) + +def test_estimate_shifts_with_est_rots(source_orientation_objs): + src, orient_est = source_orientation_objs + + # Estimate shifts using estimated rotations. est_shifts = orient_est.estimate_shifts() - # Assert that estimated shifts are close to src.offsets - assert np.allclose(est_shifts, src.offsets) + # Calculate the mean 2D distance between estimates and ground truth. + error = src.offsets - est_shifts + mean_dist = np.hypot(error[:, 0], error[:, 1]).mean() + + # Assert that on average estimated shifts are close (within 0.5 pix) to src.offsets + if src.offsets.all() != 0: + np.testing.assert_array_less(mean_dist, 0.5) + else: + np.testing.assert_allclose(mean_dist, 0) def test_estimate_rotations_fuzzy_mask(): diff --git a/tests/test_preprocess_pipeline.py b/tests/test_preprocess_pipeline.py index fb7d2427ec..0f67ae1b67 100644 --- a/tests/test_preprocess_pipeline.py +++ b/tests/test_preprocess_pipeline.py @@ -101,7 +101,7 @@ def testWhiten(dtype): corr_coef = np.corrcoef(imgs_wt[:, L - 1, L - 1], imgs_wt[:, L - 2, L - 1]) # correlation matrix should be close to identity - assert np.allclose(np.eye(2), corr_coef, atol=1e-1) + np.testing.assert_allclose(np.eye(2), corr_coef, atol=1e-1) # dtype of returned images should be the same assert dtype == imgs_wt.dtype @@ -123,7 +123,36 @@ def testWhiten2(dtype): corr_coef = np.corrcoef(imgs_wt[:, L - 1, L - 1], imgs_wt[:, L - 2, L - 1]) # Correlation matrix should be close to identity - assert np.allclose(np.eye(2), corr_coef, atol=2e-1) + np.testing.assert_allclose(np.eye(2), corr_coef, atol=2e-1) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_whiten_safeguard(dtype): + """Test that whitening safeguard works as expected.""" + L = 25 + epsilon = 0.02 + sim = get_sim_object(L, dtype) + noise_estimator = AnisotropicNoiseEstimator(sim) + sim = sim.whiten(noise_estimator.filter, epsilon=epsilon) + + # Get whitening_filter from generation pipeline. + whiten_filt = sim.generation_pipeline.xforms[0].filter.evaluate_grid(sim.L) + + # Generate whitening_filter without safeguard directly from noise_estimator. + filt_vals = noise_estimator.filter.xfer_fn_array + whiten_filt_unsafe = filt_vals**-0.5 + + # Get indices where safeguard should be applied + # and assert that they are not empty. + ind = np.where(filt_vals < epsilon) + np.testing.assert_array_less(0, len(ind[0])) + + # Check that whiten_filt and whiten_filt_unsafe agree up to safeguard indices. + disagree = np.where(whiten_filt != whiten_filt_unsafe) + np.testing.assert_array_equal(ind, disagree) + + # Check that whiten_filt is zero at safeguard indices. + np.testing.assert_allclose(whiten_filt[ind], 0.0) @pytest.mark.parametrize("L, dtype", params) @@ -138,7 +167,9 @@ def testInvertContrast(L, dtype): imgs2_rc = sim2.images[:num_images] # all images should be the same after inverting contrast - assert np.allclose(imgs1_rc.asnumpy(), imgs2_rc.asnumpy()) + np.testing.assert_allclose( + imgs1_rc.asnumpy(), imgs2_rc.asnumpy(), rtol=1e-05, atol=1e-06 + ) # dtype of returned images should be the same assert dtype == imgs1_rc.dtype assert dtype == imgs2_rc.dtype diff --git a/tests/test_relion_interop.py b/tests/test_relion_interop.py index a1a2796675..a79176b8be 100644 --- a/tests/test_relion_interop.py +++ b/tests/test_relion_interop.py @@ -4,29 +4,43 @@ import pytest from aspire.source import RelionSource, Simulation +from aspire.utils import utest_tolerance from aspire.volume import Volume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") -STARFILE = ["rln_proj_65.star", "rln_proj_64.star"] +STARFILE_ODD = [ + "rln_proj_65_centered.star", + "rln_proj_65_shifted.star", +] +STARFILE_EVEN = [ + "rln_proj_64_centered.star", + "rln_proj_64_shifted.star", +] -@pytest.fixture(params=STARFILE, scope="module") + +@pytest.fixture(params=STARFILE_ODD + STARFILE_EVEN, scope="module") def sources(request): + """ + Initialize RelionSource from starfile and generate corresponding ASPIRE + Simulation source. + """ starfile = os.path.join(DATA_DIR, request.param) rln_src = RelionSource(starfile) # Generate Volume used for Relion projections. - # Note, `downsample` is a no-op for resolution 65. vol_path = os.path.join(DATA_DIR, "clean70SRibosome_vol.npy") - vol = Volume(np.load(vol_path), dtype=rln_src.dtype).downsample(rln_src.L) + vol = Volume(np.load(vol_path, allow_pickle=True), dtype=rln_src.dtype) + if rln_src.L == 64: + vol = vol.downsample(64) # Create Simulation source using Volume and angles from Relion projections. # Note, for odd resolution Relion projections are shifted by 1 pixel in x and y. - offsets = 0 + offsets = rln_src.offsets if rln_src.L % 2 == 1: - offsets = -np.ones((rln_src.n, 2), dtype=rln_src.dtype) + offsets -= 1 sim_src = Simulation( n=rln_src.n, @@ -39,6 +53,21 @@ def sources(request): return rln_src, sim_src +@pytest.fixture(params=[STARFILE_ODD, STARFILE_EVEN], scope="module") +def rln_sources(request): + """ + Initialize centered and shifted RelionSource's generated using the + same viewing angles. + """ + starfile_centered = os.path.join(DATA_DIR, request.param[0]) + starfile_shifted = os.path.join(DATA_DIR, request.param[1]) + + rln_src_centered = RelionSource(starfile_centered) + rln_src_shifted = RelionSource(starfile_shifted) + + return rln_src_centered, rln_src_shifted + + def test_projections_relative_error(sources): """Check the relative error between Relion and ASPIRE projection images.""" rln_src, sim_src = sources @@ -51,11 +80,11 @@ def test_projections_relative_error(sources): rln_np = (rln_np - np.mean(rln_np)) / np.std(rln_np) sim_np = (sim_np - np.mean(sim_np)) / np.std(sim_np) - # Check that relative error is less than 3%. + # Check that relative error is less than 4%. error = np.linalg.norm(rln_np - sim_np, axis=(1, 2)) / np.linalg.norm( rln_np, axis=(1, 2) ) - np.testing.assert_array_less(error, 0.03) + np.testing.assert_array_less(error, 0.04) def test_projections_frc(sources): @@ -67,4 +96,18 @@ def test_projections_frc(sources): # Check that estimated resolution is high (< 2.5 pixels) and correlation is close to 1. np.testing.assert_array_less(res, 2.5) - np.testing.assert_array_less(1 - corr[:, -2], 0.02) + np.testing.assert_array_less(1 - corr[:, -2], 0.025) + + +def test_relion_source_centering(rln_sources): + """Test that centering by using provided Relion shifts works.""" + rln_src_centered, rln_src_shifted = rln_sources + ims_centered = rln_src_centered.images[:] + ims_shifted = rln_src_shifted.images[:] + + offsets = rln_src_shifted.offsets + np.testing.assert_allclose( + ims_centered.asnumpy(), + ims_shifted.shift(-offsets).asnumpy(), + atol=utest_tolerance(rln_src_centered.dtype), + ) diff --git a/tests/test_rotation.py b/tests/test_rotation.py index 9e0dba4ec6..e02e650bd5 100644 --- a/tests/test_rotation.py +++ b/tests/test_rotation.py @@ -173,6 +173,9 @@ def test_angle_dist(dtype): with pytest.raises(ValueError, match=r"r1 and r2 are not broadcastable*"): _ = Rotation.angle_dist(rots[:3], rots[:5]) + # Test that single value returns as 0-dim. + assert Rotation.angle_dist(rots[0], rots[1], dtype).ndim == 0 + def test_mean_angular_distance(dtype): rots_z = Rotation.about_axis("z", [0, np.pi / 4, np.pi / 2], dtype=dtype).matrices diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 659ce95603..6780595856 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -21,15 +21,19 @@ class SingleSimTestCase(TestCase): """Test we can construct a length 1 Sim.""" def setUp(self): - self.sim = Simulation( - n=1, - L=8, - ) + self._pixel_size = 1.23 # Test value + + self.sim = Simulation(n=1, L=8, pixel_size=self._pixel_size) def testImage(self): """Test we can get an Image from a length 1 Sim.""" _ = self.sim.images[0] + def testPixelSize(self): + """Test pixel_size is passing through Simulation.""" + self.assertTrue(self.sim.pixel_size == self._pixel_size) + self.assertTrue(self.sim.pixel_size == self.sim.vols.pixel_size) + @matplotlib_dry_run def testImageShow(self): self.sim.images[:].show() @@ -106,9 +110,12 @@ def setUp(self): self.n = 1024 self.L = 8 self.dtype = np.float32 + # Set legacy pixel_size + self._pixel_size = 10 self.vols = LegacyVolume( L=self.L, + pixel_size=self._pixel_size, dtype=self.dtype, ).generate() @@ -117,44 +124,45 @@ def setUp(self): L=self.L, vols=self.vols, unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + RadialCTFFilter(pixel_size=self._pixel_size, defocus=d) + for d in np.linspace(1.5e4, 2.5e4, 7) ], noise_adder=WhiteNoiseAdder(var=1), dtype=self.dtype, ) + # Keep hardcoded tests passing after fixing swapped offsets. + # See github issue #1146. + self.sim = self.sim.update(offsets=self.sim.offsets[:, [1, 0]]) + def tearDown(self): pass def testGaussianBlob(self): blobs = self.sim.vols.asnumpy() ref = np.load(os.path.join(DATA_DIR, "sim_blobs.npy")) - self.assertTrue(np.allclose(blobs, ref)) + np.testing.assert_allclose(blobs, ref, rtol=1e-05, atol=1e-08) def testSimulationRots(self): - self.assertTrue( - np.allclose( - self.sim.rots_zyx_to_legacy_aspire(self.sim.rotations[0, :, :]), - np.array( - [ - [0.91675498, 0.2587233, 0.30433956], - [0.39941773, -0.58404652, -0.70665065], - [-0.00507853, 0.76938412, -0.63876622], - ] - ), - atol=utest_tolerance(self.dtype), - ) + np.testing.assert_allclose( + self.sim.rots_zyx_to_legacy_aspire(self.sim.rotations[0, :, :]), + np.array( + [ + [0.91675498, 0.2587233, 0.30433956], + [0.39941773, -0.58404652, -0.70665065], + [-0.00507853, 0.76938412, -0.63876622], + ] + ), + atol=utest_tolerance(self.dtype), ) def testSimulationImages(self): images = self.sim.clean_images[:512].asnumpy() - self.assertTrue( - np.allclose( - images, - np.load(os.path.join(DATA_DIR, "sim_clean_images.npy")), - rtol=1e-2, - atol=utest_tolerance(self.sim.dtype), - ) + np.testing.assert_allclose( + images, + np.load(os.path.join(DATA_DIR, "sim_clean_images.npy")), + rtol=1e-2, + atol=utest_tolerance(self.sim.dtype), ) def testSimulationCached(self): @@ -162,39 +170,38 @@ def testSimulationCached(self): n=self.n, L=self.L, vols=self.vols, + offsets=self.sim.offsets, unique_filters=[ - RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7) + # Set legacy pixel size + RadialCTFFilter(pixel_size=self._pixel_size, defocus=d) + for d in np.linspace(1.5e4, 2.5e4, 7) ], noise_adder=WhiteNoiseAdder(var=1), dtype=self.dtype, ) sim_cached = sim_cached.cache() - self.assertTrue( - np.array_equal(sim_cached.images[:].asnumpy(), self.sim.images[:].asnumpy()) + np.testing.assert_allclose( + sim_cached.images[:].asnumpy(), self.sim.images[:].asnumpy(), atol=1e-6 ) def testSimulationImagesNoisy(self): images = self.sim.images[:512].asnumpy() - self.assertTrue( - np.allclose( - images, - np.load(os.path.join(DATA_DIR, "sim_images_with_noise.npy")), - rtol=1e-2, - atol=utest_tolerance(self.sim.dtype), - ) + np.testing.assert_allclose( + images, + np.load(os.path.join(DATA_DIR, "sim_images_with_noise.npy")), + rtol=1e-2, + atol=utest_tolerance(self.sim.dtype), ) def testSimulationImagesDownsample(self): # The simulation already generates images of size 8 x 8; Downsampling to resolution 8 should thus have no effect self.sim = self.sim.downsample(8) images = self.sim.clean_images[:512].asnumpy() - self.assertTrue( - np.allclose( - images, - np.load(os.path.join(DATA_DIR, "sim_clean_images.npy")), - rtol=1e-2, - atol=utest_tolerance(self.sim.dtype), - ) + np.testing.assert_allclose( + images, + np.load(os.path.join(DATA_DIR, "sim_clean_images.npy")), + rtol=1e-2, + atol=utest_tolerance(self.sim.dtype), ) def testSimulationImagesShape(self): @@ -210,192 +217,192 @@ def testSimulationImagesDownsampleShape(self): def testSimulationEigen(self): eigs_true, lambdas_true = self.sim.eigs() - self.assertTrue( - np.allclose( - eigs_true.asnumpy()[0, :, :, 2], - np.array( - [ - [ - -1.67666201e-07, - -7.95741380e-06, - -1.49160041e-04, - -1.10151654e-03, - -3.11287888e-03, - -3.09157884e-03, - -9.91418026e-04, - -1.31673165e-04, - ], - [ - -1.15402077e-06, - -2.49849709e-05, - -3.51658906e-04, - -2.21575261e-03, - -7.83315487e-03, - -9.44795180e-03, - -4.07636259e-03, - -9.02186439e-04, - ], - [ - -1.88737249e-05, - -1.91418396e-04, - -1.09021540e-03, - -1.02020288e-03, - 1.39411855e-02, - 8.58035963e-03, - -5.54619730e-03, - -3.86377703e-03, - ], - [ - -1.21280536e-04, - -9.51461843e-04, - -3.22565017e-03, - -1.05731178e-03, - 2.61375736e-02, - 3.11595201e-02, - 6.40814053e-03, - -2.31698658e-02, - ], - [ - -2.44067283e-04, - -1.40560151e-03, - -6.73082832e-05, - 1.44160679e-02, - 2.99893934e-02, - 5.92632964e-02, - 7.75623545e-02, - 3.06570008e-02, - ], - [ - -1.53507499e-04, - -7.21709803e-04, - 8.54929152e-04, - -1.27235036e-02, - -5.34382043e-03, - 2.18879692e-02, - 6.22706190e-02, - 4.51998860e-02, - ], - [ - -3.00595184e-05, - -1.43038429e-04, - -2.15870258e-03, - -9.99002904e-02, - -7.79077187e-02, - -1.53395887e-02, - 1.88777559e-02, - 1.68759506e-02, - ], - [ - 3.22692649e-05, - 4.07977635e-03, - 1.63959339e-02, - -8.68835449e-02, - -7.86240026e-02, - -1.75694861e-02, - 3.24984640e-03, - 1.95389288e-03, - ], - ] - ), - ) - ) - - def testSimulationMean(self): - mean_vol = self.sim.mean_true() - self.assertTrue( - np.allclose( + np.testing.assert_allclose( + eigs_true.asnumpy()[0, :, :, 2], + np.array( [ [ - 0.00000930, - 0.00033866, - 0.00490734, - 0.01998369, - 0.03874487, - 0.04617764, - 0.02970645, - 0.00967604, + -1.67666201e-07, + -7.95741380e-06, + -1.49160041e-04, + -1.10151654e-03, + -3.11287888e-03, + -3.09157884e-03, + -9.91418026e-04, + -1.31673165e-04, ], [ - 0.00003904, - 0.00247391, - 0.03818476, - 0.12325402, - 0.22278425, - 0.25246665, - 0.14093882, - 0.03683474, + -1.15402077e-06, + -2.49849709e-05, + -3.51658906e-04, + -2.21575261e-03, + -7.83315487e-03, + -9.44795180e-03, + -4.07636259e-03, + -9.02186439e-04, ], [ - 0.00014177, - 0.01191146, - 0.14421064, - 0.38428235, - 0.78645319, - 0.86522675, - 0.44862473, - 0.16382280, + -1.88737249e-05, + -1.91418396e-04, + -1.09021540e-03, + -1.02020288e-03, + 1.39411855e-02, + 8.58035963e-03, + -5.54619730e-03, + -3.86377703e-03, ], [ - 0.00066036, - 0.03137806, - 0.29226971, - 0.97105378, - 2.39410496, - 2.17099857, - 1.23595858, - 0.49233940, + -1.21280536e-04, + -9.51461843e-04, + -3.22565017e-03, + -1.05731178e-03, + 2.61375736e-02, + 3.11595201e-02, + 6.40814053e-03, + -2.31698658e-02, ], [ - 0.00271748, - 0.05491289, - 0.49955708, - 2.05356097, - 3.70941424, - 3.01578689, - 1.51441932, - 0.52054572, + -2.44067283e-04, + -1.40560151e-03, + -6.73082832e-05, + 1.44160679e-02, + 2.99893934e-02, + 5.92632964e-02, + 7.75623545e-02, + 3.06570008e-02, ], [ - 0.00584845, - 0.06962635, - 0.50568032, - 1.99643707, - 3.77415895, - 2.76039767, - 1.04602003, - 0.20633197, + -1.53507499e-04, + -7.21709803e-04, + 8.54929152e-04, + -1.27235036e-02, + -5.34382043e-03, + 2.18879692e-02, + 6.22706190e-02, + 4.51998860e-02, ], [ - 0.00539583, - 0.06068972, - 0.47008955, - 1.17128026, - 1.82821035, - 1.18743944, - 0.30667788, - 0.04851476, + -3.00595184e-05, + -1.43038429e-04, + -2.15870258e-03, + -9.99002904e-02, + -7.79077187e-02, + -1.53395887e-02, + 1.88777559e-02, + 1.68759506e-02, ], [ - 0.00246362, - 0.04867788, - 0.65284950, - 0.65238875, - 0.65745538, - 0.37955678, - 0.08053055, - 0.01210055, + 3.22692649e-05, + 4.07977635e-03, + 1.63959339e-02, + -8.68835449e-02, + -7.86240026e-02, + -1.75694861e-02, + 3.24984640e-03, + 1.95389288e-03, ], + ] + ), + rtol=1e-05, + atol=1e-08, + ) + + def testSimulationMean(self): + mean_vol = self.sim.mean_true() + np.testing.assert_allclose( + [ + [ + 0.00000930, + 0.00033866, + 0.00490734, + 0.01998369, + 0.03874487, + 0.04617764, + 0.02970645, + 0.00967604, ], - mean_vol.asnumpy()[0, :, :, 4], - ) + [ + 0.00003904, + 0.00247391, + 0.03818476, + 0.12325402, + 0.22278425, + 0.25246665, + 0.14093882, + 0.03683474, + ], + [ + 0.00014177, + 0.01191146, + 0.14421064, + 0.38428235, + 0.78645319, + 0.86522675, + 0.44862473, + 0.16382280, + ], + [ + 0.00066036, + 0.03137806, + 0.29226971, + 0.97105378, + 2.39410496, + 2.17099857, + 1.23595858, + 0.49233940, + ], + [ + 0.00271748, + 0.05491289, + 0.49955708, + 2.05356097, + 3.70941424, + 3.01578689, + 1.51441932, + 0.52054572, + ], + [ + 0.00584845, + 0.06962635, + 0.50568032, + 1.99643707, + 3.77415895, + 2.76039767, + 1.04602003, + 0.20633197, + ], + [ + 0.00539583, + 0.06068972, + 0.47008955, + 1.17128026, + 1.82821035, + 1.18743944, + 0.30667788, + 0.04851476, + ], + [ + 0.00246362, + 0.04867788, + 0.65284950, + 0.65238875, + 0.65745538, + 0.37955678, + 0.08053055, + 0.01210055, + ], + ], + mean_vol.asnumpy()[0, :, :, 4], + rtol=1e-05, + atol=1e-08, ) def testSimulationVolCoords(self): coords, norms, inners = self.sim.vol_coords() - self.assertTrue(np.allclose([4.72837704, -4.72837709], coords, atol=1e-4)) - self.assertTrue(np.allclose([8.20515764e-07, 1.17550184e-06], norms, atol=1e-4)) - self.assertTrue( - np.allclose([3.78030562e-06, -4.20475816e-06], inners, atol=1e-4) + np.testing.assert_allclose([4.72837704, -4.72837709], coords, atol=1e-4) + np.testing.assert_allclose([8.20515764e-07, 1.17550184e-06], norms, atol=1e-4) + np.testing.assert_allclose( + [[3.78030562e-06, -4.20475816e-06]], inners, atol=1e-4 ) def testSimulationCovar(self): @@ -483,23 +490,23 @@ def testSimulationCovar(self): ], ] - self.assertTrue(np.allclose(result, covar[:, :, 4, 4, 4, 4], atol=1e-4)) + np.testing.assert_allclose(result, covar[:, :, 4, 4, 4, 4], atol=1e-4) def testSimulationEvalMean(self): mean_est = Volume(np.load(os.path.join(DATA_DIR, "mean_8_8_8.npy"))) result = self.sim.eval_mean(mean_est) - self.assertTrue(np.allclose(result["err"], 2.664116055950763, atol=1e-4)) - self.assertTrue(np.allclose(result["rel_err"], 0.1765943704851626, atol=1e-4)) - self.assertTrue(np.allclose(result["corr"], 0.9849211540734224, atol=1e-4)) + np.testing.assert_allclose(result["err"], 2.664116055950763, atol=1e-4) + np.testing.assert_allclose(result["rel_err"], 0.1765943704851626, atol=1e-4) + np.testing.assert_allclose(result["corr"], 0.9849211540734224, atol=1e-4) def testSimulationEvalCovar(self): covar_est = np.load(os.path.join(DATA_DIR, "covar_8_8_8_8_8_8.npy")) result = self.sim.eval_covar(covar_est) - self.assertTrue(np.allclose(result["err"], 13.322721549011165, atol=1e-4)) - self.assertTrue(np.allclose(result["rel_err"], 0.5958936073938558, atol=1e-4)) - self.assertTrue(np.allclose(result["corr"], 0.8405347287741631, atol=1e-4)) + np.testing.assert_allclose(result["err"], 13.322721549011165, atol=1e-4) + np.testing.assert_allclose(result["rel_err"], 0.5958936073938558, atol=1e-4) + np.testing.assert_allclose(result["corr"], 0.8405347287741631, atol=1e-4) def testSimulationEvalCoords(self): mean_est = Volume(np.load(os.path.join(DATA_DIR, "mean_8_8_8.npy"))) @@ -513,58 +520,54 @@ def testSimulationEvalCoords(self): result = self.sim.eval_coords(mean_est, eigs_est, clustered_coords_est) - self.assertTrue( - np.allclose( - result["err"][0, :10], - [ - 1.58382394, - 1.58382394, - 3.72076112, - 1.58382394, - 1.58382394, - 3.72076112, - 3.72076112, - 1.58382394, - 1.58382394, - 1.58382394, - ], - ) + np.testing.assert_allclose( + result["err"][0, :10], + [ + 1.58382394, + 1.58382394, + 3.72076112, + 1.58382394, + 1.58382394, + 3.72076112, + 3.72076112, + 1.58382394, + 1.58382394, + 1.58382394, + ], ) - self.assertTrue( - np.allclose( - result["rel_err"][0, :10], - [ - 0.11048937, - 0.11048937, - 0.21684697, - 0.11048937, - 0.11048937, - 0.21684697, - 0.21684697, - 0.11048937, - 0.11048937, - 0.11048937, - ], - ) + np.testing.assert_allclose( + result["rel_err"][0, :10], + [ + 0.11048937, + 0.11048937, + 0.21684697, + 0.11048937, + 0.11048937, + 0.21684697, + 0.21684697, + 0.11048937, + 0.11048937, + 0.11048937, + ], ) - self.assertTrue( - np.allclose( - result["corr"][0, :10], - [ - 0.99390133, - 0.99390133, - 0.97658719, - 0.99390133, - 0.99390133, - 0.97658719, - 0.97658719, - 0.99390133, - 0.99390133, - 0.99390133, - ], - ) + np.testing.assert_allclose( + result["corr"][0, :10], + [ + 0.99390133, + 0.99390133, + 0.97658719, + 0.99390133, + 0.99390133, + 0.97658719, + 0.97658719, + 0.99390133, + 0.99390133, + 0.99390133, + ], + rtol=1e-05, + atol=1e-08, ) def testSimulationSaveFile(self): @@ -589,7 +592,9 @@ def testSimulationSaveFile(self): relion_src = RelionSource(star_filepath, tmpdir, max_rows=1024) imgs_sav = relion_src.images[:1024] # Compare original images with saved images - self.assertTrue(np.allclose(imgs_org.asnumpy(), imgs_sav.asnumpy())) + np.testing.assert_allclose( + imgs_org.asnumpy(), imgs_sav.asnumpy(), atol=1e-6 + ) # Save images into multiple MRCS files based on batch size batch_size = 512 info = self.sim.save(star_filepath, batch_size=batch_size, overwrite=False) @@ -608,7 +613,9 @@ def testSimulationSaveFile(self): relion_src = RelionSource(star_filepath, tmpdir, max_rows=1024) imgs_sav = relion_src.images[:1024] # Compare original images with saved images - self.assertTrue(np.allclose(imgs_org.asnumpy(), imgs_sav.asnumpy())) + np.testing.assert_allclose( + imgs_org.asnumpy(), imgs_sav.asnumpy(), atol=1e-6 + ) def test_default_symmetry_group(): @@ -651,6 +658,20 @@ def test_cached_image_accessors(): cached_src = src.cache() # Compare the cached vs dynamic image sets. - np.testing.assert_allclose(cached_src.projections[:], src.projections[:]) - np.testing.assert_allclose(cached_src.images[:], src.images[:]) - np.testing.assert_allclose(cached_src.clean_images[:], src.clean_images[:]) + np.testing.assert_allclose(cached_src.projections[:], src.projections[:], atol=1e-6) + np.testing.assert_allclose(cached_src.images[:], src.images[:], atol=1e-6) + np.testing.assert_allclose( + cached_src.clean_images[:], src.clean_images[:], atol=1e-6 + ) + + +def test_mismatched_pixel_size(): + """ + Confirm raises error when explicit Simulation and CTFFilter pixel sizes mismatch. + """ + # Create a CTF with a pixel_size + filts = [RadialCTFFilter(pixel_size=5)] + + # Try to create a Simulation with a different pixel_size + with raises(ValueError, match=r"pixel_size.*does not match filter.*"): + _ = Simulation(L=8, n=1, C=1, pixel_size=10, unique_filters=filts) diff --git a/tests/test_sinogram.py b/tests/test_sinogram.py new file mode 100644 index 0000000000..dc173449b9 --- /dev/null +++ b/tests/test_sinogram.py @@ -0,0 +1,272 @@ +import numpy as np +import pytest +from skimage import data +from skimage.transform import iradon, radon + +from aspire.image import Image +from aspire.utils import grid_2d + +# Relative tolerance comparing sinogram projections to scikit +# The same tolerance will be used in all scikit forward and backward comparisons +SK_TOL_FORWARDPROJECT = 0.005 + +SK_TOL_BACKPROJECT = 0.0025 + +IMG_SIZES = [ + 511, + 512, +] + +DTYPES = [ + np.float32, + np.float64, +] + +ANGLES = [ + 1, + 50, + pytest.param(90, marks=pytest.mark.expensive), + pytest.param(117, marks=pytest.mark.expensive), + pytest.param(180, marks=pytest.mark.expensive), + pytest.param(360, marks=pytest.mark.expensive), +] + + +@pytest.fixture(params=DTYPES, ids=lambda x: f"dtype={x}", scope="module") +def dtype(request): + """ + Dtypes for image. + """ + return request.param + + +@pytest.fixture(params=IMG_SIZES, ids=lambda x: f"px={x}", scope="module") +def img_size(request): + """ + Image size. + """ + return request.param + + +@pytest.fixture(params=ANGLES, ids=lambda x: f"n_angles={x}", scope="module") +def num_ang(request): + """ + Number of angles in radon transform. + """ + return request.param + + +@pytest.fixture +def masked_image(dtype, img_size): + """ + Creates a masked image fixture using camera data from Scikit-Image. + """ + g = grid_2d(img_size, normalized=True, shifted=True) + mask = g["r"] < 0.99 + + image = data.camera().astype(dtype) + image = image[:img_size, :img_size] + return Image(image * mask) + + +# Image.project and compare results to skimage.radon +def test_project_single(masked_image, num_ang): + """ + Test Image.project on a single stack of images. Compares project method output with skimage project. + """ + ny = masked_image.resolution + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180 * np.pi + s = masked_image.project(rads) + assert s.shape == (1, len(angles), ny) + + # sci-kit image `radon` reference + # + # Note, Image.project's angles are wrt projection line (ie + # grid), while sk's radon are wrt the image. To correspond the + # rotations are inverted. This was the convention prefered by + # the original author of this method. + # + # Note, transpose sk output to match (angles, points) + # Note, `radon` does not admit read only views, so the slice is copied. + reference_sinogram = radon(masked_image.asnumpy()[0].copy(), theta=angles[::-1]).T + assert reference_sinogram.shape == (len(angles), ny), "Incorrect Shape" + + # compare project method on ski-image reference + nrms = np.sqrt( + np.mean((s[0].asnumpy() - reference_sinogram) ** 2, axis=-1) + ) / np.linalg.norm(reference_sinogram, axis=-1) + + np.testing.assert_array_less( + nrms, SK_TOL_FORWARDPROJECT, err_msg="Error in image projections." + ) + + +def test_project_multidim(num_ang): + """ + Test Image.project on stacks of images. Extension of test_image_project but for multi-dimensional stacks. + """ + + L = 512 # pixels + n = 3 + m = 2 + + # Generate a mask + g = grid_2d(L, normalized=True, shifted=True) + mask = g["r"] < 0.99 + + # Generate images + imgs = Image(np.random.random((m, n, L, L))) * mask + + # Generate line project angles + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180.0 * np.pi + s = imgs.project(rads) + + # Compare + reference_sinograms = np.empty((m, n, num_ang, L)) + for i in range(m): + for j in range(n): + img = imgs[i, j] + # Compute the singleton case, and compare with stack. + single_sinogram = img.project(rads) + + # These should be allclose up to determinism in the FFT and NUFFT. + np.testing.assert_allclose(s[i, j : j + 1], single_sinogram) + + # Next individually compute sk's radon transform for each image. + # Note, `radon` does not admit read only views, so the slice is copied. + reference_sinograms[i, j] = radon( + img.asnumpy()[0].copy(), theta=angles[::-1] + ).T + + _nrms = np.sqrt(np.mean((s - reference_sinograms) ** 2, axis=-1)) / np.linalg.norm( + reference_sinograms, axis=-1 + ) + np.testing.assert_array_less( + _nrms, SK_TOL_FORWARDPROJECT, err_msg="Error in image projections." + ) + + +def test_backproject_single(masked_image, num_ang): + """ + Test Sinogram.backproject on a single stack of line projections (sinograms). + + This test compares the reconstructed image from the `backproject` method to + the skimage method `iradon.` + """ + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180 * np.pi + sinogram = masked_image.project(rads) + sinogram_np = sinogram.asnumpy() + back_project = sinogram.backproject(rads) + + assert masked_image.shape == back_project.shape, "The shape must be the same." + + # generate circular mask w/ radius 1 to reconstructed image + # aim to remove discrepencies for the edges of the image + g = grid_2d(back_project.resolution, normalized=True, shifted=True) + mask = g["r"] < 0.99 + our_back_project = back_project.asnumpy()[0] * mask + + # generating sci-kit image backproject method w/ no filter + sk_image_iradon = iradon(sinogram_np[0].T, theta=-angles, filter_name=None) * mask + + # we apply a normalized root mean square error on the images to find relative error to range of ref. image + nrmse = np.sqrt(np.mean((our_back_project - sk_image_iradon) ** 2)) / ( + np.max(sk_image_iradon) - np.min(sk_image_iradon) + ) + np.testing.assert_array_less( + nrmse, + SK_TOL_BACKPROJECT, + err_msg=f"NRMSE is too high: {nrmse}, expected less than {SK_TOL_BACKPROJECT}", + ) + + +def test_backproject_multidim(num_ang): + """ + Test Sinogram.backproject on a stack of line projections. + + Extension of the `backproject_single` test but checks for multi-dimensional stacks. + """ + L = 512 # pixels + n = 3 + m = 2 + + g = grid_2d(L, normalized=True, shifted=True) + mask = g["r"] < 0.99 + + # Generate images + imgs = Image(np.random.random((m, n, L, L))) * mask + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180 * np.pi + + # apply a forward project on the image, then backwards + ours_forward = imgs.project(rads) + ours_backward = ours_forward.backproject(rads) + + # Compare + reference_back_projects = np.empty((m, n, L, L)) + for i in range(m): + for j in range(n): + img = imgs[i, j] + # Compute the singleton case, and compare with stack. + single_sinogram = img.project(rads) + back_project = single_sinogram.backproject(rads) + + # These should be allclose up to determinism. + np.testing.assert_allclose(ours_backward[i, j : j + 1], back_project[0]) + + # Next individually compute sk's iradon transform for each image. + reference_back_projects[i, j] = ( + iradon( + single_sinogram.asnumpy()[0].T, theta=-1 * angles, filter_name=None + ) + * mask + ) + + # apply a mask, then find the NRMSE on the collection of images + # similar tolerance level to single project test + nrmse = np.sqrt( + np.mean( + (ours_backward.asnumpy() * mask - reference_back_projects), axis=(-2, -1) + ) + ** 2 + ) / ( + np.max(reference_back_projects, axis=(-2, -1)) + - np.min(reference_back_projects, axis=(-2, -1)) + ) + + np.testing.assert_array_less( + nrmse, SK_TOL_BACKPROJECT, err_msg="Error with the reconstructed images." + ) + + +# testing the str method +def test_sinogram_str_method(masked_image, num_ang): + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180 * np.pi + sinogram = masked_image.project(rads) + n_images = sinogram.n + n_angles = sinogram.n_angles + n_radial_points = sinogram.n_radial_points + expected_str = f"Sinogram(n_images = {n_images}, n_angles = {n_angles}, n_radial_points = {n_radial_points})" + assert str(sinogram) == expected_str + + +# testing the repr method +def test_sinogram_repr_method(masked_image, num_ang): + angles = np.linspace(0, 360, num_ang, endpoint=False) + rads = angles / 180 * np.pi + sinogram = masked_image.project(rads) + n_images = sinogram.n + dtype = sinogram.dtype + stack_shape = sinogram.stack_shape + n_angles = sinogram.n_angles + n_radial_points = sinogram.n_radial_points + expected_repr = ( + f"Sinogram: {n_images} images of dtype {dtype}, " + f"arranged as a stack with shape {stack_shape}. " + f"Each image has {n_angles} angles and {n_radial_points} radial points." + ) + assert repr(sinogram) == expected_repr diff --git a/tests/test_synthetic_volume.py b/tests/test_synthetic_volume.py index ddcdcbcab5..fec7591764 100644 --- a/tests/test_synthetic_volume.py +++ b/tests/test_synthetic_volume.py @@ -20,6 +20,9 @@ # dtype fixture to pass into volume fixture. DTYPES = [np.float32, pytest.param(np.float64, marks=pytest.mark.expensive)] +# Pixel sized used to test assignment +PXSZ = 3.0 + @pytest.fixture(params=DTYPES) def dtype_fixture(request): @@ -85,6 +88,10 @@ def vol_fixture(request, dtype_fixture): if len(params) > 2: vol_kwargs["order"] = params[2] + # Assign some volumes a pixel_size, leave others as default. + if res % 2: + vol_kwargs["pixel_size"] = PXSZ + return vol_class(**vol_kwargs) @@ -96,8 +103,15 @@ def test_volume_repr(vol_fixture): def test_volume_generate(vol_fixture): - """Test that a volume is generated""" - _ = vol_fixture.generate() + """ + Test that a volume is generated + and stores pixel_size when provided. + """ + v = vol_fixture.generate() + + # In vol_fixture, we assign pixel_size to volumes having odd voxel sizes. + if vol_fixture.L % 2: + np.testing.assert_approx_equal(v.pixel_size, PXSZ) def test_simulation_init(vol_fixture): diff --git a/tests/test_utils.py b/tests/test_utils.py index ffad5bc9f6..040f427758 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -397,8 +397,16 @@ def matplotlib_no_gui(): with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"Matplotlib is currently using agg.*") + # Ignore the specific UserWarning about non-interactive FigureCanvasAgg + warnings.filterwarnings( + "ignore", r"FigureCanvasAgg is non-interactive, and thus cannot be shown" + ) + yield + # Explicitly close all figures before making backend changes. + matplotlib.pyplot.close("all") + # Restore backend matplotlib.use(backend) diff --git a/tests/test_volume.py b/tests/test_volume.py index ea52d1d67f..1f55645e10 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -30,6 +30,7 @@ def res_id(params): RES = [42, 43] +TEST_PX_SZ = 4.56 @pytest.fixture(params=RES, ids=res_id, scope="module") @@ -75,7 +76,7 @@ def vols_1(data_1): @pytest.fixture def vols_2(data_2): - return Volume(data_2) + return Volume(data_2, pixel_size=TEST_PX_SZ) @pytest.fixture @@ -291,6 +292,39 @@ def test_save_load(vols_1): assert np.allclose(vols_1, vols_loaded_single) assert isinstance(vols_loaded_double, Volume) assert np.allclose(vols_1, vols_loaded_double) + assert vols_loaded_single.pixel_size is None, "Pixel size should be None" + assert vols_loaded_double.pixel_size is None, "Pixel size should be None" + + +def test_volume_pixel_size(vols_2): + """ + Test volume is storing pixel_size attribute. + """ + assert np.isclose(TEST_PX_SZ, vols_2.pixel_size), "Incorrect Volume pixel_size" + + +def test_save_load_pixel_size(vols_2): + # Create a tmpdir in a context. It will be cleaned up on exit. + with tempfile.TemporaryDirectory() as tmpdir: + # Save the Volume object into an MRC files + mrcs_filepath = os.path.join(tmpdir, "test.mrc") + vols_2.save(mrcs_filepath) + + # Load saved MRC file as a Volume of dtypes single and double. + vols_loaded_single = Volume.load(mrcs_filepath, dtype=np.float32) + vols_loaded_double = Volume.load(mrcs_filepath, dtype=np.float64) + + # Confirm the pixel size is loaded + np.testing.assert_approx_equal( + vols_loaded_single.pixel_size, + vols_2.pixel_size, + err_msg="Incorrect pixel size in singles.", + ) + np.testing.assert_approx_equal( + vols_loaded_double.pixel_size, + vols_2.pixel_size, + err_msg="Incorrect pixel size in doubles.", + ) def test_project(vols_hot_cold): @@ -320,7 +354,7 @@ def test_project(vols_hot_cold): # Generate projection images. projections = vols.project(rots) - # Check that new hot/cold spots are within 1 pixel of expectecd locations. + # Check that new hot/cold spots are within 1 pixel of expected locations. for i in range(vols.n_vols): p = projections.asnumpy()[i] new_hot_loc = np.unravel_index(np.argmax(p), (L, L)) @@ -545,21 +579,30 @@ def test_flip(vols_1, data_1): def test_downsample(res): - vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy"))) - result = vols.downsample(8) - res = vols.resolution + vols = Volume( + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")), pixel_size=1.23 + ) + result = vols.downsample(res) + og_res = vols.resolution ds_res = result.resolution + # Confirm the pixel size is scaled + np.testing.assert_approx_equal( + result.pixel_size, + vols.pixel_size * og_res / ds_res, + err_msg="Incorrect pixel size.", + ) + # check signal energy - assert np.allclose( - anorm(vols.asnumpy(), axes=(1, 2, 3)) / res, + np.testing.assert_allclose( + anorm(vols.asnumpy(), axes=(1, 2, 3)) / og_res, anorm(result.asnumpy(), axes=(1, 2, 3)) / ds_res, atol=1e-3, ) # check gridpoints - assert np.allclose( - vols.asnumpy()[:, res // 2, res // 2, res // 2], + np.testing.assert_allclose( + vols.asnumpy()[:, og_res // 2, og_res // 2, og_res // 2], result.asnumpy()[:, ds_res // 2, ds_res // 2, ds_res // 2], atol=1e-4, ) @@ -803,12 +846,6 @@ def test_aglebraic_ops_symmetry_warnings(symmetric_vols): # Should have 4 warnings on record. assert len(record) == 4 - # Check that warning occurs only once per line. - with warnings.catch_warnings(record=True) as record: - for _ in range(5): - vol_c3 + vol_c4 - assert len(record) == 1 - def test_volume_load_with_symmetry(): # Check we can load a Volume with symmetry_group. diff --git a/tests/test_weighted_mean_estimator.py b/tests/test_weighted_mean_estimator.py index 5d622c3865..eabcd0574f 100644 --- a/tests/test_weighted_mean_estimator.py +++ b/tests/test_weighted_mean_estimator.py @@ -155,9 +155,7 @@ def test_checkpoint(sim, basis, estimator, weights): checkpoint_prefix=prefix, ) - # Assert we raise when reading `maxiter`. - with raises(RuntimeError, match="Unable to converge!"): - _ = _estimator.estimate() + _ = _estimator.estimate() # Load the checkpoint coefficients while tmp_input_dir exists. x_chk = np.load(f"{prefix}_iter{test_iter:04d}.npy")