diff --git a/src/aspire/basis/__init__.py b/src/aspire/basis/__init__.py index cfaea78358..a482127338 100644 --- a/src/aspire/basis/__init__.py +++ b/src/aspire/basis/__init__.py @@ -11,6 +11,7 @@ from .fb_3d import FBBasis3D from .ffb_2d import FFBBasis2D from .ffb_3d import FFBBasis3D +from .fle_2d import FLEBasis2D from .fpswf_2d import FPSWFBasis2D from .fpswf_3d import FPSWFBasis3D from .fspca import FSPCABasis diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py new file mode 100644 index 0000000000..77f8d8d286 --- /dev/null +++ b/src/aspire/basis/fle_2d.py @@ -0,0 +1,706 @@ +import logging + +import numpy as np +import scipy.sparse as sparse +from scipy.fft import dct, idct +from scipy.special import jv + +from aspire.basis import FBBasisMixin, SteerableBasis2D +from aspire.basis.basis_utils import besselj_zeros +from aspire.basis.fle_2d_utils import ( + barycentric_interp_sparse, + precomp_transform_complex_to_real, + transform_complex_to_real, +) +from aspire.nufft import anufft, nufft +from aspire.numeric import fft +from aspire.utils import complex_type, grid_2d + +logger = logging.getLogger(__name__) + + +class FLEBasis2D(SteerableBasis2D, FBBasisMixin): + """ + Define a derived class for Fast Fourier Bessel 2D expansion using interpolation + from Chebyshev nodes. + The algorithms used are described in the following publication: + N. F. Marshall, O. Mickelin, A. Singer, Fast Expansion into Harmonics on the Disk: + A Steerable Basis with Fast Radial Convolution. (submitted) + + https://arxiv.org/pdf/2207.13674.pdf + """ + + def __init__( + self, size, bandlimit=None, epsilon=1e-10, dtype=np.float32, match_fb=True + ): + """ + :param size: The size of the vectors for which to define the FLE basis. + Currently only square images are supported. + :param bandlimit: Maximum frequency band for computing basis functions. Defaults to the + resolution of the basis. + :param epsilon: Relative precision between FLE fast method and dense matrix multiplication. + :param dtype: Datatype of images and coefficients represented. + :param match_fb: + With this flag set the following will ensure that the basis functions are + identical to `FBBasis2D`: + - The initial heuristic for the number of basis functions, based on the resolution, will + be set to that of `FBBasis2D`, and the FLE frequency thresholding procedure to reduce the + number of functions will not be carried out. This means the number of basis functions for + a given image size will be identical across the two bases. + - The signs of basis functions and coefficients with `sgn == 1` will be flipped relative to + the original FLE implementation, to match FB. + - The basis functions returned will be reordered according to the FB ordering, that is, first + by `ell`s, then by `sgn`s, then by `k`s. + + """ + if isinstance(size, int): + size = (size, size) + ndim = len(size) + assert ndim == 2, "Only two-dimensional basis functions are supported." + assert len(set(size)) == 1, "Only square domains are supported" + + self.bandlimit = bandlimit + self.epsilon = epsilon + self.match_fb = match_fb + self.dtype = dtype + super().__init__(size, ell_max=None, dtype=self.dtype) + + def _build(self): + """ + Build the internal data structure for the FLEBasis2D class. + """ + + # bandlimit set to basis size by default + if not self.bandlimit: + self.bandlimit = self.nres + + # compute number of k's for each ell + self._calc_k_max() + + if self.match_fb: + # Use FB2D and FFB2D heuristic for computing max basis functions + self.max_basis_functions = self.k_max[0] + sum(2 * self.k_max[1:]) + else: + # Regular Fourier-Bessel bandlimit (equivalent to pi*R**2) + # Final self.count will be < self.max_basis_functions + # See self._threshold_basis_functions() + self.max_basis_functions = int(self.nres**2 * np.pi / 4) + + self._compute_maxitr_and_numsparse() + + self._compute_cartesian_gridpoints() + + self._precomp() + + # Steerable basis indices + self._build_indices() + + # FB compatability indices + self._generate_fb_compat_indices() + + def _build_indices(self): + self.angular_indices = np.abs(self.ells) + self.radial_indices = self.ks - 1 + self.signs_indices = np.sign(self.ells) + + def indices(self): + """ + Return the precomputed indices for each basis function. + """ + return { + "ells": self.angular_indices, + "ks": self.radial_indices, + "sgns": self.signs_indices, + } + + def _generate_fb_compat_indices(self): + """ + Generate indices to shuffle basis function ordering and flip signs in order + to match `FBBasis2D`. + """ + ind = self.indices() + # basis function ordering + self.fb_compat_indices = np.lexsort((ind["ks"], ind["sgns"], ind["ells"])) + # flip signs + self.flip_sign_indices = np.where(self.signs_indices == 1) + + def _precomp(self): + """ + Precompute the basis functions and other objects used in the evaluation of + coefficients. + """ + + # Find bessel functions zeros (the eigenvalues of the Laplacian on + # the disk) and generate the FLE Basis functions + self._lap_eig_disk() + + # Some important constants + self.smallest_lambda = np.min(self.bessel_zeros) + self.greatest_lambda = np.max(self.bessel_zeros) + self.max_ell = np.max(np.abs(self.ells)) + self.h = 1 / (self.nres / 2) + + # give each ell a positive index increasing first in |ell| + # then in sign, e.g. 0->1, -1->2, 1->3, -2->4, 2->5, etc. + self.ells_p = 2 * np.abs(self.ells) - (self.ells < 0) + self.ell_p_max = np.max(self.ells_p) + # idx_list[k] contains the indices j of ells_p where ells_p[j] = k + idx_list = [[] for i in range(self.ell_p_max + 1)] + for i in range(self.count): + ellp = self.ells_p[i] + idx_list[ellp].append(i) + self.idx_list = idx_list + + # real <-> complex + self.c2r = precomp_transform_complex_to_real(self.ells) + self.r2c = sparse.csr_matrix(self.c2r.transpose().conj()) + + # create an ordered list of the original ell values + # used in step2 (in both directions) + self.nus = np.zeros(1 + 2 * self.max_ell, dtype=int) + self.nus[0] = 0 + for i in range(1, self.max_ell + 1): + self.nus[2 * i - 1] = -i + self.nus[2 * i] = i + self.c2r_nus = precomp_transform_complex_to_real(self.nus) + self.r2c_nus = sparse.csr_matrix(self.c2r_nus.transpose().conj()) + + # radial and angular nodes for NUFFT + self._compute_nufft_points() + self.num_interp = self.num_radial_nodes + if self.numsparse > 0: + self.num_interp = 2 * self.num_radial_nodes + + self._build_interpolation_matrix() + + def _compute_maxitr_and_numsparse(self): + """ + Uses heuristics from paper to assign self.maxitr and self.numsparse. + """ + # maxitr: maximum number of iterations for numerically solving linear + # system in self.evaluate() + # numsparse: parameter used to create sparse Chebyshev interpolation matrix + # see self._build_interpolation_matrix() + + if self.epsilon >= 1e-4: + numsparse = 8 + maxitr = 1 + int(np.log2(self.nres)) // 2 + elif self.epsilon >= 1e-7: + numsparse = 16 + maxitr = 1 + int(np.log2(self.nres)) + elif self.epsilon >= 1e-10: + numsparse = 22 + maxitr = 1 + int(2 * np.log2(self.nres)) + else: + # epsilon < 1e-10 + numsparse = 32 + maxitr = 1 + int(3 * np.log2(self.nres)) + + self.maxitr = maxitr + self.numsparse = numsparse + + def _compute_cartesian_gridpoints(self): + """ + Creates meshgrids based on basis size. + """ + if self.match_fb: + # creates correct odd-resolution grid + # matching other FB classes + grid = grid_2d(self.nres, dtype=self.dtype) + self.xs = grid["x"] + self.ys = grid["y"] + self.rs = grid["r"] + else: + # original implementation + R = self.nres // 2 + x = np.arange(-R, R + self.nres % 2) + xs, ys = np.meshgrid(x, x) + # Note, the original original grids were xs/R, R=nres//2. + self.xs, self.ys = xs / (self.nres / 2), ys / (self.nres / 2) + self.rs = np.sqrt(self.xs**2 + self.ys**2) + self.radial_mask = self.rs > 1 + 1e-13 + + def _compute_nufft_points(self): + """ + Computes gridpoints for the non-uniform FFT. + """ + + # Number of radial nodes + # (Lemma 4.1) + # compute max {2.4 * self.nres , Log2 ( 1 / epsilon) } + Q = int(np.ceil(2.4 * self.nres)) + num_radial_nodes = Q + tmp = 1 / (np.sqrt(np.pi)) + for q in range(1, Q + 1): + tmp = tmp / q * (np.sqrt(np.pi) * self.nres / 4) + if tmp <= self.epsilon: + num_radial_nodes = int(max(q, np.log2(1 / self.epsilon))) + break + self.num_radial_nodes = max( + num_radial_nodes, int(np.ceil(np.log2(1 / self.epsilon))) + ) + + # Number of angular nodes + # (Lemma 4.2) + # compute max {7.08 * self.nres, Log2(1/epsilon) + Log2(self.nres**2) } + + S = int(max(7.08 * self.nres, -np.log2(self.epsilon) + 2 * np.log2(self.nres))) + num_angular_nodes = S + for s in range(int(self.greatest_lambda + self.ell_p_max) + 1, S + 1): + tmp = self.nres**2 * ((self.greatest_lambda + self.ell_p_max) / s) ** s + if tmp <= self.epsilon: + num_angular_nodes = int(max(int(s), np.log2(1 / self.epsilon))) + break + + # must be even + if num_angular_nodes % 2 == 1: + num_angular_nodes += 1 + + self.num_angular_nodes = num_angular_nodes + + # create gridpoints + nodes = 1 - (2 * np.arange(self.num_radial_nodes, dtype=self.dtype) + 1) / ( + 2 * self.num_radial_nodes + ) + nodes = (np.cos(np.pi * nodes) + 1) / 2 + nodes = ( + self.greatest_lambda - self.smallest_lambda + ) * nodes + self.smallest_lambda + nodes = nodes.reshape(self.num_radial_nodes, 1) + + radius = self.nres / 2 + h = 1 / radius + + phi = ( + 2 + * np.pi + * np.arange(self.num_angular_nodes // 2, dtype=self.dtype) + / self.num_angular_nodes + ) + x = np.cos(phi).reshape(1, self.num_angular_nodes // 2) + y = np.sin(phi).reshape(1, self.num_angular_nodes // 2) + x = x * nodes * h + y = y * nodes * h + self.grid_x = x.flatten() + self.grid_y = y.flatten() + + def _build_interpolation_matrix(self): + """ + Create the matrix used in the third step of evaluate_t() and the first step of evaluate() + for barycentric interpolation from Chebyshev nodes. + """ + A3 = [None] * (self.ell_p_max + 1) + A3_T = [None] * (self.ell_p_max + 1) + # known points from which to interpolate Beta values to desired points + known_points = np.cos( + np.pi * (1 - (2 * np.arange(self.num_interp) + 1) / (2 * self.num_interp)) + ) + for i in range(self.ell_p_max + 1): + # target points to evaluate Betas + target_points = ( + 2 + * (self.bessel_zeros[self.idx_list[i]] - self.smallest_lambda) + / (self.greatest_lambda - self.smallest_lambda) + - 1 + ) + + A3[i], A3_T[i] = barycentric_interp_sparse( + target_points, known_points, self.numsparse + ) + self.A3 = A3 + self.A3_T = A3_T + + def _lap_eig_disk(self): + """ + Compute the eigenvalues of the Laplacian operator on a disk with Dirichlet boundary conditions. + """ + # max number of Bessel function orders being considered + max_ell = int(3 * np.sqrt(self.max_basis_functions)) + # max number of zeros per Bessel function (number of frequencies per bessel) + max_k = int(2 * np.sqrt(self.max_basis_functions)) + + # preallocate containers for roots + # 0 frequency plus pos and negative frequencies for each bessel function + # num functions per frequency + num_ells = 1 + 2 * max_ell + self.ells = np.zeros((num_ells, max_k), dtype=int) + self.ks = np.zeros((num_ells, max_k), dtype=int) + self.bessel_zeros = np.ones((num_ells, max_k), dtype=np.float64) * np.Inf + + # keep track of which order Bessel function we're on + self.ells[0, :] = 0 + # bessel_roots[0, m] is the m'th zero of J_0 + self.bessel_zeros[0, :] = besselj_zeros(0, max_k) + # table of values of which zero of J_0 we are finding + self.ks[0, :] = np.arange(max_k) + 1 + + # add roots of J_ell for ell>0 twice with +k and -k (frequencies) + # iterate over Bessel function order + for ell in range(1, max_ell + 1): + self.ells[2 * ell - 1, :] = -ell + self.ks[2 * ell - 1, :] = np.arange(max_k) + 1 + + self.bessel_zeros[2 * ell - 1, :max_k] = besselj_zeros(ell, max_k) + + self.ells[2 * ell, :] = ell + self.ks[2 * ell, :] = self.ks[2 * ell - 1, :] + self.bessel_zeros[2 * ell, :] = self.bessel_zeros[2 * ell - 1, :] + + # Reshape the arrays and order by the size of the Bessel function zeros + self._flatten_and_sort_bessel_zeros() + + # Apply threshold criterion to throw out some basis functions + # Grab final number of basis functions for this Basis + self.count = self._threshold_basis_functions() + + self._create_basis_functions() + + def _flatten_and_sort_bessel_zeros(self): + """ + Reshapes arrays self.ells, self.ks, and self.bessel_zeros + """ + # flatten list of zeros, ells and ks: + self.ells = self.ells.flatten() + self.ks = self.ks.flatten() + self.bessel_zeros = self.bessel_zeros.flatten() + + idx = np.argsort(self.bessel_zeros) + self.ells = self.ells[idx] + self.ks = self.ks[idx] + self.bessel_zeros = self.bessel_zeros[idx] + + # sort complex conjugate pairs: -ell first, +ell second + idx = np.arange(self.max_basis_functions + 1) + for i in range(self.max_basis_functions + 1): + if self.ells[i] >= 0: + continue + if np.abs(self.bessel_zeros[i] - self.bessel_zeros[i + 1]) < 1e-14: + continue + idx[i - 1] = i + idx[i] = i - 1 + + self.ells = self.ells[idx] + self.ks = self.ks[idx] + self.bessel_zeros = self.bessel_zeros[idx] + + def _threshold_basis_functions(self): + """ + Implements the bandlimit threshold which caps the number of basis functions + that are actually required. + :return: The final overall number of basis functions to be used. + """ + # Maximum bandlimit + # (Section 4.1) + # Can remove frequencies above this threshold based on the fact that + # there should not be more basis functions than pixels contained in the + # unit disk inscribed on the image + _final_num_basis_functions = self.max_basis_functions + + # implement FLE thresholding unless we want to match count of other FB bases + if not self.match_fb: + for _ in range(len(self.bessel_zeros)): + if ( + self.bessel_zeros[_final_num_basis_functions] / (np.pi) + >= (self.bandlimit - 1) // 2 + ): + _final_num_basis_functions -= 1 + + # potentially subtract one to keep complex conjugate pairs + if self.ells[_final_num_basis_functions - 1] < 0: + _final_num_basis_functions -= 1 + + # discard zeros above the threshold + self.ells = self.ells[:_final_num_basis_functions] + self.ks = self.ks[:_final_num_basis_functions] + self.bessel_zeros = self.bessel_zeros[:_final_num_basis_functions] + + return _final_num_basis_functions + + def _create_basis_functions(self): + """ + Generate the actual basis functions as Python lambda operators + """ + norm_constants = np.zeros(self.count) + basis_functions = [None] * self.count + for i in range(self.count): + # parameters defining the basis function: bessel order and which bessel root + ell = self.ells[i] + bessel_zero = self.bessel_zeros[i] + + # compute normalization constant + # see Eq. 6 + c = 1 / np.sqrt(np.pi * jv(ell + 1, bessel_zero) ** 2) + # create function + # See Eq. 1 + if ell == 0: + basis_functions[i] = ( + lambda r, t, c=c, ell=ell, bessel_zero=bessel_zero: c + * jv(ell, bessel_zero * r) + * (r <= 1) + ) + else: + basis_functions[i] = ( + lambda r, t, c=c, ell=ell, bessel_zero=bessel_zero: c + * jv(ell, bessel_zero * r) + * np.exp(1j * ell * t) + * (-1) ** np.abs(ell) + * (r <= 1) + ) + + norm_constants[i] = c + + self.norm_constants = norm_constants + self.basis_functions = basis_functions + + def _evaluate(self, coeffs): + """ + Evaluates FLE coefficients and return in standard 2D Cartesian coordinates. + + :param v: A coefficient vector (or an array of coefficient vectors) to + be evaluated. The last dimension must be equal to `self.count` + :return: An Image object containing the corresponding images. + """ + if self.match_fb: + # sign of basis functions with positive indices flipped relative to FB2d + coeffs[self.flip_sign_indices] *= -1.0 + # reorder coefficients by FB2d ordering + coeffs = coeffs[self.fb_compat_indices] + + # See Remark 3.3 and Section 3.4 + betas = self._step3(coeffs) + z = self._step2(betas) + im = self._step1(z) + return im.astype(self.dtype) + + def _evaluate_t(self, imgs): + """ + Evaluate 2D Cartesian image(s) and return the corresponding FLE coefficients. + + :param imgs: An Image object containing square images of size `self.nres`. + :return: A NumPy array of size `(num_images, self.count)` containing the FLE + coefficients. + """ + # See Section 3.5 + imgs = imgs.copy() + imgs[:, self.radial_mask] = 0 + z = self._step1_t(imgs) + b = self._step2_t(z) + coeffs = self._step3_t(b) + if self.match_fb: + coeffs[:, self.flip_sign_indices] *= -1.0 + coeffs = coeffs[:, self.fb_compat_indices] + return coeffs.astype(self.coefficient_dtype) + + def _step1_t(self, im): + """ + Step 1 of the adjoint transformation (images to coefficients). + Calculates the NUFFT of the image on gridpoints `self.grid_x` and `self.grid_y`. + """ + im = im.reshape(-1, self.nres, self.nres).astype(complex_type(self.dtype)) + num_img = im.shape[0] + z = np.zeros( + (num_img, self.num_radial_nodes, self.num_angular_nodes), + dtype=complex_type(self.dtype), + ) + _z = ( + nufft(im, np.stack((self.grid_x, self.grid_y)), epsilon=self.epsilon) + * self.h**2 + ) + _z = _z.reshape(num_img, self.num_radial_nodes, self.num_angular_nodes // 2) + z[:, :, : self.num_angular_nodes // 2] = _z + z[:, :, self.num_angular_nodes // 2 :] = np.conj(_z) + return z + + def _step2_t(self, z): + """ + Step 2 of the adjoint transformation (images to coefficients). + Computes values of the analytic functions Beta_n at the Chebyshev nodes. + See Lemma 2.2. + """ + num_img = z.shape[0] + # Compute FFT along angular nodes + betas = fft.fft(z, axis=2) / self.num_angular_nodes + betas = betas[:, :, self.nus] + betas = np.conj(betas) + betas = np.swapaxes(betas, 0, 2) + betas = betas.reshape(-1, self.num_radial_nodes * num_img) + betas = self.c2r_nus @ betas + betas = betas.reshape(-1, self.num_radial_nodes, num_img) + betas = np.real(np.swapaxes(betas, 0, 2)) + return betas + + def _step3_t(self, betas): + """ + Step 3 of the adjoint transformation (images to coefficients). + Uses barycenteric interpolation to compute the values of the Betas + at the Bessel roots to arrive at the Fourier-Bessel coefficients. + """ + num_img = betas.shape[0] + if self.num_interp > self.num_radial_nodes: + betas = dct(betas, axis=1, type=2) / (2 * self.num_radial_nodes) + zeros = np.zeros(betas.shape) + betas = np.concatenate((betas, zeros), axis=1) + betas = idct(betas, axis=1, type=2) * 2 * betas.shape[1] + betas = np.moveaxis(betas, 0, -1) + + coeffs = np.zeros((self.count, num_img), dtype=np.float64) + for i in range(self.ell_p_max + 1): + coeffs[self.idx_list[i]] = self.A3[i] @ betas[:, i, :] + coeffs = coeffs.T + + return coeffs * self.norm_constants / self.h + + def _step3(self, coeffs): + """ + Adjoint of _step3_t and Step 1 of the forward transformation (coefficients + to images). + Uses barycenteric interpolation in reverse to compute values of Betas + at Chebyshev nodes, given an array of FLE coefficients. + """ + coeffs = coeffs.copy().reshape(-1, self.count) + num_img = coeffs.shape[0] + coeffs *= self.h * self.norm_constants + coeffs = coeffs.T + + out = np.zeros( + (self.num_interp, 2 * self.max_ell + 1, num_img), + dtype=np.float64, + ) + for i in range(self.ell_p_max + 1): + out[:, i, :] = self.A3_T[i] @ coeffs[self.idx_list[i]] + out = np.moveaxis(out, -1, 0) + if self.num_interp > self.num_radial_nodes: + out = dct(out, axis=1, type=2) + out = out[:, : self.num_radial_nodes, :] + out = idct(out, axis=1, type=2) + + return out + + def _step2(self, betas): + """ + Adjoint of _step2_t and Step 2 of the forward transformation (coefficients + to images). + Uses the IFFT to convert Beta values into Fourier-space images. + """ + num_img = betas.shape[0] + tmp = np.zeros( + (num_img, self.num_radial_nodes, self.num_angular_nodes), + dtype=np.complex128, + ) + + betas = np.swapaxes(betas, 0, 2) + betas = betas.reshape(-1, self.num_radial_nodes * num_img) + betas = self.r2c_nus @ betas + betas = betas.reshape(-1, self.num_radial_nodes, num_img) + betas = np.swapaxes(betas, 0, 2) + + tmp[:, :, self.nus] = np.conj(betas) + z = fft.ifft(tmp, axis=2) + + return z + + def _step1(self, z): + """ + Adjoint of _step1_t and final step of the forward transformation (coefficients + to images). + Performs the NUFFT on Fourier-space images to compute real-space images. + """ + num_img = z.shape[0] + z = z[:, :, : self.num_angular_nodes // 2].reshape(num_img, -1) + im = anufft( + z.astype(complex_type(self.dtype)), + np.stack((self.grid_x, self.grid_y)), + (self.nres, self.nres), + epsilon=self.epsilon, + ) + im = im + np.conj(im) + im = np.real(im) + im = im.reshape(num_img, self.nres, self.nres) + im[:, self.radial_mask] = 0 + + return im + + def _create_dense_matrix(self): + """ + Directly computes the transformation matrix from Cartesian coordinates to + FLE coordinates without any shortcuts. + :return: A NumPy array of size `(self.nres**2, self.count)` containing the matrix + entries. + """ + # See Eqns. 3 and 4, Section 1.2 + ts = np.arctan2(self.ys, self.xs) + + B = np.zeros((self.nres, self.nres, self.count), dtype=np.complex128) + for i in range(self.count): + B[:, :, i] = self.basis_functions[i](self.rs, ts) * self.h + B = B.reshape(self.nres**2, self.count) + B = transform_complex_to_real(np.conj(B), self.ells) + B = B.reshape(self.nres**2, self.count) + if self.match_fb: + B[:, self.flip_sign_indices] *= -1.0 + B = B[:, self.fb_compat_indices] + return B + + def lowpass(self, coeffs, bandlimit): + """ + Apply a low-pass filter to FLE coefficients `coeffs` with threshold `bandlimit`. + :param coeffs: A NumPy array of FLE coefficients, of shape (num_images, self.count) + :param bandlimit: Integer bandlimit (max frequency). + :return: Band-limited coefficient array. + """ + if len(coeffs.shape) == 1: + coeffs = coeffs.reshape((1, coeffs.shape[0])) + assert ( + len(coeffs.shape) == 2 + ), "Input a stack of coefficients of dimension (num_images, self.count)." + assert ( + coeffs.shape[1] == self.count + ), "Number of coefficients must match self.count." + + k = self.count - 1 + for _ in range(self.count): + if self.bessel_zeros[k] / (np.pi) > (bandlimit - 1) // 2: + k = k - 1 + coeffs[:, k + 1 :] = 0 + + return coeffs + + def radial_convolve(self, coeffs, radial_img): + """ + Convolve a stack of FLE coefficients with a 2D radial function. + :param coeffs: A NumPy array of FLE coefficients of size (num_images, self.count). + :param radial_img: A 2D NumPy array of size (self.nres, self.nres). + :return: Convolved FLE coefficients. + """ + num_img = coeffs.shape[0] + coeffs_conv = np.zeros(coeffs.shape) + for k in range(num_img): + _coeffs = coeffs[k, :] + z = self._step1_t(radial_img) + b = self._step2_t(z) + weights = self._radial_convolve_weights(b) + b = weights / (self.h**2) + b = b.reshape(self.count) + coeffs_conv[k, :] = np.real(self.c2r @ (b * (self.r2c @ _coeffs).flatten())) + + return coeffs_conv + + def _radial_convolve_weights(self, b): + """ + Helper function for step 3 of convolving with a radial function. + """ + b = np.squeeze(b) + b = np.array(b) + if self.num_interp > self.num_radial_nodes: + b = dct(b, axis=0, type=2) / (2 * self.num_radial_nodes) + bz = np.zeros(b.shape) + b = np.concatenate((b, bz), axis=0) + b = idct(b, axis=0, type=2) * 2 * b.shape[0] + a = np.zeros(self.count, dtype=np.float64) + y = [None] * (self.ell_p_max + 1) + for i in range(self.ell_p_max + 1): + y[i] = (self.A3[i] @ b[:, 0]).flatten() + for i in range(self.ell_p_max + 1): + a[self.idx_list[i]] = y[i] + + return a.flatten() diff --git a/src/aspire/basis/fle_2d_utils.py b/src/aspire/basis/fle_2d_utils.py new file mode 100644 index 0000000000..e97d309524 --- /dev/null +++ b/src/aspire/basis/fle_2d_utils.py @@ -0,0 +1,202 @@ +import numpy as np +import scipy.sparse as sparse + + +def transform_complex_to_real(B_conj, ells): + """ + Transforms coefficients of the matrix B (see Eq. 3) from complex + to real. B is the linear transformation that takes FLE coefficients + to images. + + :param B_conj: Complex conjugate of the matrix B. + :param ells: List of ells (Bessel function orders) in this basis. + :return: Transformed matrix. + """ + num_basis_functions = B_conj.shape[1] + X = np.zeros(B_conj.shape, dtype=np.float64) + + for i in range(num_basis_functions): + ell = ells[i] + if ell == 0: + X[:, i] = np.real(B_conj[:, i]) + # for each ell != 0, we can populate two entries of the matrix + # by taking the complex conjugate of the ell with the opposite sign + if ell < 0: + s = (-1) ** np.abs(ell) + x0 = (B_conj[:, i] + s * B_conj[:, i + 1]) / np.sqrt(2) + x1 = (-B_conj[:, i] + s * B_conj[:, i + 1]) / (1j * np.sqrt(2)) + X[:, i] = np.real(x0) + X[:, i + 1] = np.real(x1) + + return X + + +def precomp_transform_complex_to_real(ells): + """ + Returns a sparse matrix that transforms coefficients into the complex + representation of the basis to coefficients in the real + representation of the basis. See Remark 1.1 of Marshall, Mickelin, + and Singer. + + :param ells: The list of integer Bessel function orders. + :return: Sparse complex to real transformation matrix. + """ + count = len(ells) + num_nonzero = np.sum(ells == 0) + 2 * np.sum(ells != 0) + idx = np.zeros(num_nonzero, dtype=int) + jdx = np.zeros(num_nonzero, dtype=int) + vals = np.zeros(num_nonzero, dtype=np.complex128) + + k = 0 + for i in range(count): + ell = ells[i] + # ell = 0 is a special case (DC component) + if ell == 0: + vals[k] = 1 + idx[k] = i + jdx[k] = i + k = k + 1 + # Only branch the case ell < 0 and also update -ell + # via complex conjugation + if ell < 0: + s = (-1) ** np.abs(ell) + + # positive ell + vals[k] = 1 / np.sqrt(2) + idx[k] = i + jdx[k] = i + k = k + 1 + + # positive ell + vals[k] = s / np.sqrt(2) + idx[k] = i + jdx[k] = i + 1 + k = k + 1 + + # negative ell + vals[k] = -1 / (1j * np.sqrt(2)) + idx[k] = i + 1 + jdx[k] = i + k = k + 1 + + # negative ell + vals[k] = s / (1j * np.sqrt(2)) + idx[k] = i + 1 + jdx[k] = i + 1 + k = k + 1 + + A = sparse.csr_matrix((vals, (idx, jdx)), shape=(count, count), dtype=np.complex128) + + return A + + +def barycentric_interp_sparse(target_points, known_points, numsparse): + """ + Returns the sparse matrices that perform barycentric interpolation to compute values + of Betas at the points `target_points` at known points `known_points`, and the transpose + of this operation. For each target point in `target_points`, only `numsparse` centered + source points from `known_points` around the target point are used. + + Performed via the method described in + + "Barycentric Lagrange Interpolation", Jean-Paul Berrut and Lloyd Trefethen. + SIAM Review 2004 46:3, 501-517 + https://people.maths.ox.ac.uk/trefethen/barycentric.pdf + + :param target_points: The target set of points at which to evaluate the functions. + :param known_points: The points at which the values of the functions are known. + :param numsparse: Number of points used for interpolation around each target point. + :return: The interpolation matrix and its transpose as a 2-tuple. + """ + + n = len(target_points) + m = len(known_points) + + # Modify points by 2e-16 to avoid division by zero + vals, x_ind, xs_ind = np.intersect1d( + target_points, known_points, return_indices=True, assume_unique=True + ) + target_points[x_ind] = target_points[x_ind] + 2e-16 + + idx = np.zeros((n, numsparse)) + jdx = np.zeros((n, numsparse)) + vals = np.zeros((n, numsparse)) + xss = np.zeros((n, numsparse)) + denom = np.zeros((n, 1)) + temp = np.zeros((n, 1)) + ws = np.zeros((n, numsparse)) + xdiff = np.zeros(n) + + # loop over target points + for i in range(n): + # choose `numsparse` source points centered around each target point + # in order to apply a sparse barycentric interpolation to this target + # point + k = np.searchsorted(target_points[i] < known_points, True) + + idp = np.arange(k - numsparse // 2, k + (numsparse + 1) // 2) + if idp[0] < 0: + idp = np.arange(numsparse) + if idp[-1] >= m: + idp = np.arange(m - numsparse, m) + # xss stores the values from `known_points` used for interpolation on the i'th + # target point + xss[i, :] = known_points[idp] + jdx[i, :] = idp + idx[i, :] = i + + # Auxiliary vector for computing products of expressions of the form xss[:,i] - xss[:,j] + # in order not to include xss[:,i] - xss[:,i] = 0. Will be all ones except for the index i + # not to include in the running product. this index is updated in the loop + Iw = np.ones(numsparse, dtype=bool) + ew = np.zeros((n, 1)) + xtw = np.zeros((n, numsparse - 1)) + + Iw[0] = False + const = np.zeros((n, 1)) + + for _ in range(numsparse): + ew = np.sum(-np.log(np.abs(xss[:, 0].reshape(-1, 1) - xss[:, Iw])), axis=1) + # normalization constant + constw = np.exp(ew / numsparse) + constw = constw.reshape(-1, 1) + const += constw + # this normalization constant prevents numerical issues and cancels out in the end + # not included in return result + const = const / numsparse + + for j in range(numsparse): + Iw[j] = False + # compute the denominator in Eq 3.2 of Berrut and Trefethen + xtw = const * (xss[:, j].reshape(-1, 1) - xss[:, Iw]) + ws[:, j] = 1 / np.prod(xtw, axis=1) + Iw[j] = True + + xdiff = xdiff.flatten() + target_points = target_points.flatten() + temp = temp.flatten() + denom = denom.flatten() + + for j in range(numsparse): + # xdiff[i] is the i'th target point minus the j'th source point for that target pt + # see the denominator in Eq. 3.3 of Berrut and Trefethen + xdiff = target_points - xss[:, j] + temp = ws[:, j] / xdiff + # vals[:,j] = (1/const)*w_j/(x[i] - xs[j]), with the notation in Eq. 3.3 + vals[:, j] = vals[:, j] + temp + denom = denom + temp + + # Eq 4.2 + # note that const cancels in numerator and denominator + vals = vals / denom.reshape(-1, 1) + + vals = vals.flatten() + idx = idx.flatten() + jdx = jdx.flatten() + # A is the linear operator mapping the function values from the fixed source + # points to the fixed target points. + # A(i,j) = \ell(x[i] ) w_j/(x[i] - xs[j]), with the notation in Eq. 3.3 + A = sparse.csr_matrix((vals, (idx, jdx)), shape=(n, m), dtype=np.float64) + A_T = sparse.csr_matrix((vals, (jdx, idx)), shape=(m, n), dtype=np.float64) + + return A, A_T diff --git a/src/aspire/basis/steerable.py b/src/aspire/basis/steerable.py index 3363f1cdf5..955c92dc4e 100644 --- a/src/aspire/basis/steerable.py +++ b/src/aspire/basis/steerable.py @@ -164,7 +164,9 @@ def rotate(self, coef, radians, refl=None): ) # else: radians can be a constant - assert self.count == coef.shape[-1] + assert ( + self.count == coef.shape[-1] + ), "Number of coefficients must match self.count." # self.angular_indices are `ks` # For all coef in stack, diff --git a/src/aspire/nufft/__init__.py b/src/aspire/nufft/__init__.py index 65662c7213..3fbf3dea2b 100644 --- a/src/aspire/nufft/__init__.py +++ b/src/aspire/nufft/__init__.py @@ -141,7 +141,7 @@ def __new__(cls, *args, **kwargs): return super(Plan, cls).__new__(cls) -def anufft(sig_f, fourier_pts, sz, real=False): +def anufft(sig_f, fourier_pts, sz, real=False, epsilon=1e-8): """ Wrapper for 1, 2, and 3 dimensional Non Uniform FFT Adjoint. Dimension is based on the dimension of fourier_pts and checked against sig_f. @@ -173,12 +173,14 @@ def anufft(sig_f, fourier_pts, sz, real=False): if len(sig_f.shape) == 2: ntransforms = sig_f.shape[0] - plan = Plan(sz=sz, fourier_pts=fourier_pts, ntransforms=ntransforms) + plan = Plan( + sz=sz, fourier_pts=fourier_pts, ntransforms=ntransforms, epsilon=epsilon + ) adjoint = plan.adjoint(sig_f) return np.real(adjoint) if real else adjoint -def nufft(sig_f, fourier_pts, real=False): +def nufft(sig_f, fourier_pts, real=False, epsilon=1e-8): """ Wrapper for 1, 2, and 3 dimensional Non Uniform FFT Dimension is based on the dimension of fourier_pts and checked against sig_f. @@ -219,6 +221,8 @@ def nufft(sig_f, fourier_pts, real=False): if len(sig_f.shape) == dimension + 1: ntransforms = sig_f.shape[0] - plan = Plan(sz=sz, fourier_pts=fourier_pts, ntransforms=ntransforms) + plan = Plan( + sz=sz, fourier_pts=fourier_pts, ntransforms=ntransforms, epsilon=epsilon + ) transform = plan.transform(sig_f) return np.real(transform) if real else transform diff --git a/tests/saved_test_data/fle_radial_fn_32x32.npy b/tests/saved_test_data/fle_radial_fn_32x32.npy new file mode 100644 index 0000000000..350d6179fa Binary files /dev/null and b/tests/saved_test_data/fle_radial_fn_32x32.npy differ diff --git a/tests/test_FLEbasis2D.py b/tests/test_FLEbasis2D.py new file mode 100644 index 0000000000..d17f2e994c --- /dev/null +++ b/tests/test_FLEbasis2D.py @@ -0,0 +1,321 @@ +import os + +import numpy as np +import pytest + +from aspire.basis import FBBasis2D, FLEBasis2D +from aspire.image import Image +from aspire.nufft import backend_available +from aspire.numeric import fft +from aspire.source import Simulation +from aspire.utils import utest_tolerance +from aspire.volume import Volume + +from ._basis_util import UniversalBasisMixin + +DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") + + +def show_fle_params(basis): + return f"{basis.nres}-{basis.epsilon}" + + +def gpu_ci_skip(): + pytest.skip("1e-7 precision for FLEBasis2D.evaluate()") + + +fle_params = [ + (32, 1e-4), + (32, 1e-7), + (32, 1e-10), + (32, 1e-14), + (33, 1e-4), + (33, 1e-7), + (33, 1e-10), + (33, 1e-14), +] + +test_bases = [ + FLEBasis2D(L, epsilon=epsilon, dtype=np.float64, match_fb=False) + for L, epsilon in fle_params +] + +# add one case ensuring input/output dtypes for evaluate and evaluate_t +test_bases.append(FLEBasis2D(8, epsilon=1e-4, dtype=np.float32, match_fb=False)) + +test_bases_match_fb = [ + FLEBasis2D(L, epsilon=epsilon, dtype=np.float64) for L, epsilon in fle_params +] + + +def create_images(L, n): + # create sample data + v = Volume( + np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype(np.float64) + ) + v = v.downsample(L) + sim = Simulation(L=L, n=n, vols=v, dtype=v.dtype, seed=1103) + img = sim.clean_images[:] + return img + + +def relerr(base, approx): + # relative error of two arrays + base = np.array(base).flatten() + approx = np.array(approx).flatten() + return np.linalg.norm(base - approx) / np.linalg.norm(base) + + +@pytest.mark.parametrize("basis", test_bases, ids=show_fle_params) +class TestFLEBasis2D(UniversalBasisMixin): + # check closeness guarantees for fast vs dense matrix method + def testFastVDense_T(self, basis): + dense_b = basis._create_dense_matrix() + + # create sample particle + x = create_images(basis.nres, 1).asnumpy() + xvec = x.reshape((basis.nres**2, 1)) + + # explicit matrix multiplication + result_dense = dense_b.T @ xvec + # fast evaluate_t + result_fast = basis.evaluate_t(Image(x)) + + assert relerr(result_dense.T, result_fast) < basis.epsilon + + def testFastVDense(self, basis): + if backend_available("cufinufft") and basis.epsilon == 1e-7: + gpu_ci_skip() + + dense_b = basis._create_dense_matrix() + + # get sample coefficients + x = create_images(basis.nres, 1) + # hold input test data constant (would depend on epsilon parameter) + coeffs = FLEBasis2D( + basis.nres, epsilon=1e-4, dtype=np.float64, match_fb=False + ).evaluate_t(x) + + result_dense = dense_b @ coeffs.T + result_fast = basis.evaluate(coeffs).asnumpy() + + assert relerr(result_dense, result_fast) < basis.epsilon + + def testEvaluateExpand(self, basis): + if backend_available("cufinufft") and basis.epsilon == 1e-7: + gpu_ci_skip() + + # compare result of evaluate() vs more accurate expand() + # get sample coefficients + x = create_images(basis.nres, 1) + # hold input test data constant (would depend on epsilon parameter) + evaluate_t = basis.evaluate(basis.evaluate_t(x)) + expand = basis.evaluate(basis.expand(evaluate_t)) + + assert relerr(expand.asnumpy(), evaluate_t.asnumpy()) < basis.epsilon + + +@pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) +def testMatchFBEvaluate(basis): + if backend_available("cufinufft") and basis.epsilon == 1e-7: + gpu_ci_skip() + + # ensure that the basis functions are identical when in match_fb mode + fb_basis = FBBasis2D(basis.nres, dtype=np.float64) + + # in match_fb, count is the same for both bases + coeffs = np.eye(basis.count) + + fb_images = fb_basis.evaluate(coeffs) + fle_images = basis.evaluate(coeffs) + + assert np.allclose(fb_images._data, fle_images._data, atol=1e-4) + + +@pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) +def testMatchFBDenseEvaluate(basis): + # ensure that images are the same when evaluating coefficients via slow + # matrix multiplication + + fb_basis = FBBasis2D(basis.nres, dtype=np.float64) + + coeffs = np.eye(basis.count) + + fb_images = fb_basis.evaluate(coeffs).asnumpy() + fle_out = basis._create_dense_matrix() @ coeffs + fle_images = Image(fle_out.T.reshape(-1, basis.nres, basis.nres)).asnumpy() + + # Matrix column reording in match_fb mode flips signs of some of the basis functions + assert np.allclose(np.abs(fb_images), np.abs(fle_images), atol=1e-3) + + +@pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) +def testMatchFBEvaluate_t(basis): + # ensure that coefficients are the same when evaluating images + fb_basis = FBBasis2D(basis.nres, dtype=np.float64) + + # test images to evaluate + images = fb_basis.evaluate(np.eye(basis.count)) + + fb_coeffs = fb_basis.evaluate_t(images) + fle_coeffs = basis.evaluate_t(images) + + assert np.allclose(fb_coeffs, fle_coeffs, atol=1e-4) + + +@pytest.mark.parametrize("basis", test_bases_match_fb, ids=show_fle_params) +def testMatchFBDenseEvaluate_t(basis): + # ensure that coefficients are the same when evaluating images via slow + # matrix multiplication + + fb_basis = FBBasis2D(basis.nres, dtype=np.float64) + + # test images to evaluate + # gets a stack of shape (basis.count, L, L) + images = fb_basis.evaluate(np.eye(basis.count)) + # reshape to a stack of basis.count vectors of length L**2 + vec = images.asnumpy().reshape((-1, basis.nres**2)) + + fb_coeffs = fb_basis.evaluate_t(images) + fle_coeffs = basis._create_dense_matrix().T @ vec.T + + # Matrix column reording in match_fb mode flips signs of some of the basis coefficients + assert np.allclose(np.abs(fb_coeffs), np.abs(fle_coeffs), atol=1e-4) + + +def testLowPass(): + # test that low passing removes more and more high frequency + # elements as bandlimit decreases + + L = 128 + basis = FLEBasis2D(L, match_fb=False) + + # sample coefficients + ims = create_images(L, 1) + coeffs = basis.evaluate_t(ims) + + nonzero_coeffs = [] + for i in range(4): + bandlimit = L // (2**i) + coeffs_lowpassed = basis.lowpass(coeffs, bandlimit) + nonzero_coeffs.append(np.sum(coeffs_lowpassed != 0)) + + # for bandlimit == L, no frequencies should be removed + assert nonzero_coeffs[0] == basis.count + + # for lower bandlimits, there should be fewer and fewer nonzero coeffs + assert nonzero_coeffs[0] > nonzero_coeffs[1] > nonzero_coeffs[2] > nonzero_coeffs[3] + + # make sure you can pass in a 1-D array if you want + _ = basis.lowpass(coeffs[0, :], L) + + # cannot pass in the wrong number of coefficients + with pytest.raises( + AssertionError, match="Number of coefficients must match self.count." + ): + _ = basis.lowpass(coeffs[:, :1000], L) + + # cannot pass in wrong shape + with pytest.raises( + AssertionError, + match="Input a stack of coefficients of dimension", + ): + _ = basis.lowpass(np.zeros((3, 3, 3)), L) + + +def testRotate(): + # test ability to accurately rotate images via + # FLE coefficients + + L = 128 + basis = FLEBasis2D(L, match_fb=False) + + # sample image + ims = create_images(L, 1) + # rotate 90 degrees in cartesian coordinates + ims_90 = Image(np.rot90(ims.asnumpy(), axes=(1, 2))) + + # get FLE coefficients + coeffs = basis.evaluate_t(ims) + coeffs_cart_rot = basis.evaluate_t(ims_90) + + # rotate original image in FLE space using Steerable rotate method + coeffs_fle_rot = basis.rotate(coeffs, np.pi / 2) + + # back to cartesian + ims_cart_rot = basis.evaluate(coeffs_cart_rot) + ims_fle_rot = basis.evaluate(coeffs_fle_rot) + + # test rot90 close + assert np.allclose(ims_cart_rot[0], ims_fle_rot[0], atol=1e-4) + + # 2Pi identity in FLE space (rotate by 2Pi) + coeffs_fle_2pi = basis.rotate(coeffs, 2 * np.pi) + ims_fle_2pi = basis.evaluate(coeffs_fle_2pi) + + # test 2Pi identity + assert np.allclose(ims[0], ims_fle_2pi[0], atol=utest_tolerance(basis.dtype)) + + # Reflect in FLE space (rotate by Pi) + coeffs_fle_pi = basis.rotate(coeffs, np.pi) + ims_fle_pi = basis.evaluate(coeffs_fle_pi) + + # test reflection + assert np.allclose(np.flipud(ims[0]), ims_fle_pi[0], atol=1e-4) + + # make sure you can pass in a 1-D array if you want + _ = basis.lowpass(np.zeros((basis.count,)), np.pi) + + # cannot pass in the wrong number of coefficients + with pytest.raises( + AssertionError, match="Number of coefficients must match self.count." + ): + _ = basis.rotate(np.zeros((1, 10)), np.pi) + + # cannot pass in wrong shape + with pytest.raises( + AssertionError, + match="Input a stack of coefficients of dimension", + ): + _ = basis.lowpass(np.zeros((3, 3, 3)), np.pi) + + +def testRadialConvolution(): + # test ability to accurately convolve with a radial + # (e.g. CTF) function via FLE coefficients + + L = 32 + basis = FLEBasis2D(L, match_fb=False) + # load test radial function + x = np.load(os.path.join(DATA_DIR, "fle_radial_fn_32x32.npy")).reshape(1, 32, 32) + x = x / np.max(np.abs(x.flatten())) + + # get sample images + ims = create_images(L, 10) + # convolve using coefficients + coeffs = basis.evaluate_t(ims) + coeffs_convolved = basis.radial_convolve(coeffs, x) + imgs_convolved_fle = basis.evaluate(coeffs_convolved).asnumpy() + + # convolve using FFT + x = basis.evaluate(basis.evaluate_t(x)).asnumpy() + ims = basis.evaluate(coeffs).asnumpy() + + imgs_convolved_slow = np.zeros((10, L, L)) + for i in range(10): + x_pad = np.zeros((2 * L, 2 * L)) + ims_pad = np.zeros((2 * L, 2 * L)) + x_pad[L // 2 : L // 2 + L, L // 2 : L // 2 + L] = x[0, :, :] + ims_pad[L // 2 : L // 2 + L, L // 2 : L // 2 + L] = ims[i, :, :] + + x_shift = fft.fftshift(x_pad.reshape(2 * L, 2 * L)) + ims_shift = fft.fftshift(ims_pad.reshape(2 * L, 2 * L)) + + convolution_fft_pad = fft.fftshift( + fft.ifft2(np.fft.fft2(x_shift) * np.fft.fft2(ims_shift)) + ) + imgs_convolved_slow[i, :, :] = np.real( + convolution_fft_pad[L // 2 : L // 2 + L, L // 2 : L // 2 + L] + ) + + assert np.allclose(imgs_convolved_fle, imgs_convolved_slow, atol=1e-5)