diff --git a/gallery/tutorials/lecture_feature_demo.py b/gallery/tutorials/lecture_feature_demo.py index eb8f4a48d4..f3b156e79a 100644 --- a/gallery/tutorials/lecture_feature_demo.py +++ b/gallery/tutorials/lecture_feature_demo.py @@ -42,8 +42,8 @@ from aspire.noise import AnisotropicNoiseEstimator, WhiteNoiseEstimator from aspire.operators import FunctionFilter, RadialCTFFilter, ScalarFilter from aspire.source import RelionSource, Simulation -from aspire.utils import Rotation -from aspire.utils.coor_trans import ( +from aspire.utils import ( + Rotation, get_aligned_rotations, get_rots_mse, register_rotations, diff --git a/gallery/tutorials/orient3d_simulation.py b/gallery/tutorials/orient3d_simulation.py index ce23ee5f87..c3719f0a0d 100644 --- a/gallery/tutorials/orient3d_simulation.py +++ b/gallery/tutorials/orient3d_simulation.py @@ -15,11 +15,7 @@ from aspire.abinitio import CLSyncVoting from aspire.operators import RadialCTFFilter from aspire.source.simulation import Simulation -from aspire.utils.coor_trans import ( - get_aligned_rotations, - get_rots_mse, - register_rotations, -) +from aspire.utils import get_aligned_rotations, get_rots_mse, register_rotations from aspire.volume import Volume logger = logging.getLogger(__name__) diff --git a/src/aspire/abinitio/commonline_base.py b/src/aspire/abinitio/commonline_base.py index f46302c577..beb2bbd1d3 100644 --- a/src/aspire/abinitio/commonline_base.py +++ b/src/aspire/abinitio/commonline_base.py @@ -6,7 +6,7 @@ from aspire.abinitio.orientation_src import OrientEstSource from aspire.basis import PolarBasis2D -from aspire.utils.coor_trans import common_line_from_rots +from aspire.utils import common_line_from_rots from aspire.utils.random import choice logger = logging.getLogger(__name__) diff --git a/src/aspire/basis/basis_utils.py b/src/aspire/basis/basis_utils.py index ce05f5c4c7..8219e350b2 100644 --- a/src/aspire/basis/basis_utils.py +++ b/src/aspire/basis/basis_utils.py @@ -10,7 +10,7 @@ from numpy.polynomial.legendre import leggauss from scipy.special import jn, jv, sph_harm -from aspire.utils.coor_trans import grid_2d, grid_3d +from aspire.utils import grid_2d, grid_3d logger = logging.getLogger(__name__) diff --git a/src/aspire/ctf/ctf_estimator.py b/src/aspire/ctf/ctf_estimator.py index 46e2ff9682..d7b933b9f9 100644 --- a/src/aspire/ctf/ctf_estimator.py +++ b/src/aspire/ctf/ctf_estimator.py @@ -21,8 +21,7 @@ from aspire.numeric import fft from aspire.operators import voltage_to_wavelength from aspire.storage import StarFile -from aspire.utils import abs2, complex_type -from aspire.utils.coor_trans import grid_1d, grid_2d +from aspire.utils import abs2, complex_type, grid_1d, grid_2d logger = logging.getLogger(__name__) diff --git a/src/aspire/denoising/adaptive_support.py b/src/aspire/denoising/adaptive_support.py index c95afdc9fa..fa9d9bc608 100644 --- a/src/aspire/denoising/adaptive_support.py +++ b/src/aspire/denoising/adaptive_support.py @@ -5,7 +5,7 @@ from aspire.noise import WhiteNoiseEstimator from aspire.numeric import fft from aspire.source import ImageSource -from aspire.utils.coor_trans import grid_2d +from aspire.utils import grid_2d logger = logging.getLogger(__name__) diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 2a175a8f7c..2f8e2a8a6d 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -9,7 +9,7 @@ import aspire.volume from aspire.nufft import anufft from aspire.numeric import fft, xp -from aspire.utils.coor_trans import grid_2d +from aspire.utils import grid_2d from aspire.utils.matrix import anorm logger = logging.getLogger(__name__) diff --git a/src/aspire/noise/noise.py b/src/aspire/noise/noise.py index b99a43f665..af04a00c0f 100644 --- a/src/aspire/noise/noise.py +++ b/src/aspire/noise/noise.py @@ -4,7 +4,7 @@ from aspire.numeric import fft, xp from aspire.operators import ArrayFilter, ScalarFilter -from aspire.utils.coor_trans import grid_2d +from aspire.utils import grid_2d logger = logging.getLogger(__name__) diff --git a/src/aspire/operators/filters.py b/src/aspire/operators/filters.py index 2397bd1404..8bf95bd935 100644 --- a/src/aspire/operators/filters.py +++ b/src/aspire/operators/filters.py @@ -5,7 +5,7 @@ import numpy as np from scipy.interpolate import RegularGridInterpolator -from aspire.utils.coor_trans import grid_2d +from aspire.utils import grid_2d from aspire.utils.filter_to_fb_mat import filter_to_fb_mat logger = logging.getLogger(__name__) diff --git a/src/aspire/source/image.py b/src/aspire/source/image.py index bb96e251fb..8967d5dfc3 100644 --- a/src/aspire/source/image.py +++ b/src/aspire/source/image.py @@ -22,8 +22,7 @@ PowerFilter, ) from aspire.storage import MrcStats, StarFile -from aspire.utils import Rotation -from aspire.utils.coor_trans import grid_2d +from aspire.utils import Rotation, grid_2d logger = logging.getLogger(__name__) diff --git a/src/aspire/source/simulation.py b/src/aspire/source/simulation.py index f76e07bec0..a9453ab632 100644 --- a/src/aspire/source/simulation.py +++ b/src/aspire/source/simulation.py @@ -8,8 +8,14 @@ from aspire.image.xform import NoiseAdder from aspire.operators import ZeroFilter from aspire.source import ImageSource -from aspire.utils import acorr, ainner, anorm, make_symmat, vecmat_to_volmat -from aspire.utils.coor_trans import uniform_random_angles +from aspire.utils import ( + acorr, + ainner, + anorm, + make_symmat, + uniform_random_angles, + vecmat_to_volmat, +) from aspire.utils.random import rand, randi, randn from aspire.volume import Volume, gaussian_blob_vols diff --git a/src/aspire/utils/__init__.py b/src/aspire/utils/__init__.py index 0517ffce08..8f88da3f0f 100644 --- a/src/aspire/utils/__init__.py +++ b/src/aspire/utils/__init__.py @@ -1,4 +1,23 @@ -from .misc import abs2, get_full_version, powerset, sha256sum # isort:skip +from .coor_trans import ( # isort:skip + common_line_from_rots, + crop_pad_2d, + get_aligned_rotations, + get_rots_mse, + grid_1d, + grid_2d, + grid_3d, + register_rotations, + uniform_random_angles, +) +from .misc import ( # isort:skip + abs2, + circ, + gaussian_2d, + get_full_version, + inverse_r, + powerset, + sha256sum, +) from .matrix import ( acorr, ainner, @@ -23,6 +42,5 @@ vol_to_vec, volmat_to_vecmat, ) -from .misc import circ, gaussian_2d, inverse_r from .rotation import Rotation from .types import complex_type, real_type, utest_tolerance diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index d03f4345ff..a8ad8f2d40 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -2,6 +2,8 @@ General purpose math functions, mostly geometric in nature. """ +import math + import numpy as np from numpy.linalg import norm from scipy.linalg import svd @@ -302,3 +304,32 @@ def common_line_from_rots(r1, r2, ell): ell_ji = int(np.mod(np.round(ell_ji), ell)) return ell_ij, ell_ji + + +def crop_pad_2d(im, size, fill_value=0): + """ + :param im: A 2-dimensional numpy array + :param size: Integer size of cropped/padded output + :return: A numpy array of shape (size, size) + """ + + im_y, im_x = im.shape + # shift terms + start_x = math.floor(im_x / 2) - math.floor(size / 2) + start_y = math.floor(im_y / 2) - math.floor(size / 2) + + # cropping + if size <= min(im_y, im_x): + return im[start_y : start_y + size, start_x : start_x + size] + # padding + elif size >= max(im_y, im_x): + # ensure that we return in the same dtype as the input + to_return = fill_value * np.ones((size, size), dtype=im.dtype) + # when padding, start_x and start_y are negative since size is larger + # than im_x and im_y; the below line calculates where the original image + # is placed in relation to the (now-larger) box size + to_return[-start_y : im_y - start_y, -start_x : im_x - start_x] = im + return to_return + else: + # target size is between mat_x and mat_y + raise ValueError("Cannot crop and pad an image at the same time.") diff --git a/src/aspire/utils/misc.py b/src/aspire/utils/misc.py index 85a19b4ab1..c1608d7408 100644 --- a/src/aspire/utils/misc.py +++ b/src/aspire/utils/misc.py @@ -9,7 +9,7 @@ import numpy as np -from aspire.utils.coor_trans import grid_1d, grid_2d, grid_3d +from aspire.utils import grid_1d, grid_2d, grid_3d logger = logging.getLogger(__name__) diff --git a/src/aspire/volume/__init__.py b/src/aspire/volume/__init__.py index a0288fa75a..16128246a2 100644 --- a/src/aspire/volume/__init__.py +++ b/src/aspire/volume/__init__.py @@ -7,8 +7,7 @@ import aspire.image from aspire.nufft import nufft from aspire.numeric import fft, xp -from aspire.utils import Rotation, mat_to_vec, vec_to_mat -from aspire.utils.coor_trans import grid_2d, grid_3d +from aspire.utils import Rotation, grid_2d, grid_3d, mat_to_vec, vec_to_mat from aspire.utils.matlab_compat import m_reshape from aspire.utils.random import Random, randn from aspire.utils.types import complex_type diff --git a/tests/test_coor_trans.py b/tests/test_coor_trans.py index 167e58bf74..56603bb702 100644 --- a/tests/test_coor_trans.py +++ b/tests/test_coor_trans.py @@ -3,8 +3,9 @@ import numpy as np -from aspire.utils import Rotation -from aspire.utils.coor_trans import ( +from aspire.utils import ( + Rotation, + crop_pad_2d, get_aligned_rotations, grid_2d, grid_3d, @@ -77,3 +78,184 @@ def testRegisterRots(self): q_mat_est, flag_est = register_rotations(rots_ref, regrots_ref) self.assertTrue(np.allclose(flag_est, flag) and np.allclose(q_mat_est, q_mat)) + + def testSquareCrop2D(self): + # Test even/odd cases based on the convention that the center of a sequence of length n + # is (n+1)/2 if n is odd and n/2 + 1 if even. + # Cropping is done to keep the center of the sequence the same value before and after. + # Therefore the following apply: + # Cropping even to odd will result in the 0-index (beginning) + # of the sequence being chopped off (x marks the center, ~ marks deleted data): + # ---x-- => ~--x-- + # Cropping odd to even will result in the -1-index (end) + # of the sequence being chopped off: + # ---x--- => ---x--~ + + # even to even + a = np.diag(np.arange(8)) + test_a = np.diag(np.arange(1, 7)) + self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 6))) + + # even to odd + # the extra row/column cut off are the top and left + # due to the centering convention + a = np.diag(np.arange(8)) + test_a = np.diag(np.arange(1, 8)) + self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 7))) + + # odd to odd + a = np.diag(np.arange(9)) + test_a = np.diag(np.arange(1, 8)) + self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 7))) + + # odd to even + # the extra row/column cut off are the bottom and right + # due to the centering convention + a = np.diag(np.arange(9)) + test_a = np.diag(np.arange(8)) + self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 8))) + + def testSquarePad2D(self): + # Test even/odd cases based on the convention that the center of a sequence of length n + # is (n+1)/2 if n is odd and n/2 + 1 if even. + # Padding is done to keep the center of the sequence the same value before and after. + # Therefore the following apply: + # Padding from even to odd results in the spare padding being added to the -1-index (end) + # of the sequence (x represents the center, + represents padding): + # ---x-- => ---x--+ + # Padding from odd to even results in the spare padding being added to the 0-index (beginning) + # of the sequence: + # --x-- => +--x-- + + # even to even + a = np.diag(np.arange(1, 9)) + test_a = np.diag([0, 1, 2, 3, 4, 5, 6, 7, 8, 0]) + self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 10))) + + # even to odd + # the extra padding is to the bottom and right + # due to the centering convention + a = np.diag(np.arange(1, 9)) + test_a = np.diag([0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0]) + self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 11))) + + # odd to odd + a = np.diag(np.arange(1, 10)) + test_a = np.diag([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0]) + self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 11))) + + # odd to even + # the extra padding is to the top and left + # due to the centering convention + a = np.diag(np.arange(1, 10)) + test_a = np.diag([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + self.assertTrue(np.array_equal(test_a, crop_pad_2d(a, 10))) + + def testRectCrop2D(self): + # Additional sanity checks for rectangular cropping case + + # 12x10 -> 10x10 + a = np.diag(np.arange(1, 11)) + # augment to 12 rows + aug = np.vstack([a, np.zeros(10)]) + aug = np.vstack([np.zeros(10), aug]) + # make sure the top and bottom rows are stripped + self.assertTrue(np.array_equal(a, crop_pad_2d(aug, 10))) + + # 10x12 -> 10x10 + a = np.diag(np.arange(1, 11)) + # augment to 12 columns + aug = np.column_stack([a, np.zeros(10)]) + aug = np.column_stack([np.zeros(10), aug]) + # make sure the left and right columns are stripped + self.assertTrue(np.array_equal(a, crop_pad_2d(aug, 10))) + + # 9x7 -> 7x7 + a = np.diag(np.arange(1, 8)) + # augment to 9 rows + aug = np.vstack([a, np.zeros(7)]) + aug = np.vstack([np.zeros(7), aug]) + # make sure the top and bottom rows are stripped + self.assertTrue(np.array_equal(a, crop_pad_2d(aug, 7))) + + # 7x9 -> 7x7 + a = np.diag(np.arange(1, 8)) + # augment to 9 columns + aug = np.column_stack([a, np.zeros(7)]) + aug = np.column_stack([np.zeros(7), aug]) + # make sure the left and right columns are stripped + self.assertTrue(np.array_equal(a, crop_pad_2d(aug, 7))) + + def testRectPad2D(self): + # Additional sanity checks for rectangular padding case + + # 12x10 -> 12x12 + a = np.diag(np.arange(1, 11)) + # augment to 12 rows + aug = np.vstack([a, np.zeros(10)]) + aug = np.vstack([np.zeros(10), aug]) + # expected result + padded = np.column_stack([aug, np.zeros(12)]) + padded = np.column_stack([np.zeros(12), padded]) + # make sure columns of fill value (0) are added to the + # left and right + self.assertTrue(np.array_equal(padded, crop_pad_2d(aug, 12))) + + # 10x12 -> 12x12 + a = np.diag(np.arange(1, 11)) + # augment to 12 columns + aug = np.column_stack([a, np.zeros(10)]) + aug = np.column_stack([np.zeros(10), aug]) + # expected result + padded = np.vstack([aug, np.zeros(12)]) + padded = np.vstack([np.zeros(12), padded]) + # make sure rows of fill value (0) are added to the + # top and bottom + self.assertTrue(np.array_equal(padded, crop_pad_2d(aug, 12))) + + # 9x7 -> 9x9 + a = np.diag(np.arange(1, 8)) + # augment to 9 rows + aug = np.vstack([a, np.zeros(7)]) + aug = np.vstack([np.zeros(7), aug]) + # expected result + padded = np.column_stack([aug, np.zeros(9)]) + padded = np.column_stack([np.zeros(9), padded]) + # make sure columns of fill value (0) are added to the + # left and right + self.assertTrue(np.array_equal(padded, crop_pad_2d(aug, 9))) + + # 7x9 -> 9x9 + a = np.diag(np.arange(1, 8)) + # augment to 9 columns + aug = np.column_stack([a, np.zeros(7)]) + aug = np.column_stack([np.zeros(7), aug]) + # expected result + padded = np.vstack([aug, np.zeros(9)]) + padded = np.vstack([np.zeros(9), padded]) + # make sure rows of fill value (0) are added to the + # top and bottom + self.assertTrue(np.array_equal(padded, crop_pad_2d(aug, 9))) + + def testCropPad2DError(self): + with self.assertRaises(ValueError) as e: + _ = crop_pad_2d(np.zeros((6, 10)), 8) + self.assertTrue( + "Cannot crop and pad an image at the same time.", str(e.exception) + ) + + def testCrop2DDtype(self): + # crop_pad_2d must return an array of the same dtype it was given + # in particular, because the method is used for Fourier downsampling + # methods involving cropping complex arrays + self.assertEqual( + crop_pad_2d(np.eye(10).astype("complex"), 5).dtype, np.dtype("complex128") + ) + + def testCrop2DFillValue(self): + # make sure the fill value is as expected + # we are padding from an odd to an even dimension + # so the padded column is added to the left + a = np.ones((4, 3)) + b = crop_pad_2d(a, 4, fill_value=-1) + self.assertTrue(np.array_equal(b[:, 0], np.array([-1, -1, -1, -1]))) diff --git a/tests/test_grids.py b/tests/test_grids.py index 87e6372829..c2a4b657c3 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -4,8 +4,7 @@ import numpy as np -from aspire.utils import utest_tolerance -from aspire.utils.coor_trans import grid_2d, grid_3d +from aspire.utils import grid_2d, grid_3d, utest_tolerance logger = logging.getLogger(__name__) diff --git a/tests/test_preprocess_pipeline.py b/tests/test_preprocess_pipeline.py index c1a4c6d8fa..a2efcbcf1f 100644 --- a/tests/test_preprocess_pipeline.py +++ b/tests/test_preprocess_pipeline.py @@ -7,8 +7,7 @@ from aspire.operators.filters import FunctionFilter, RadialCTFFilter, ScalarFilter from aspire.source import ArrayImageSource from aspire.source.simulation import Simulation -from aspire.utils import utest_tolerance -from aspire.utils.coor_trans import grid_2d, grid_3d +from aspire.utils import grid_2d, grid_3d, utest_tolerance from aspire.utils.matrix import anorm from aspire.volume import Volume diff --git a/tests/test_volume.py b/tests/test_volume.py index c990c22fe7..a360641f67 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -8,8 +8,7 @@ from parameterized import parameterized from pytest import raises -from aspire.utils import Rotation, powerset -from aspire.utils.coor_trans import grid_3d +from aspire.utils import Rotation, grid_3d, powerset from aspire.utils.types import utest_tolerance from aspire.volume import Volume, gaussian_blob_vols