diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index f2d5472e52..c41b221ec4 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -147,7 +147,10 @@ jobs: echo "Stash the WORK_DIR to GitHub env so we can clean it up later." echo "WORK_DIR=${WORK_DIR}" >> $GITHUB_ENV echo -e "ray:\n temp_dir: ${WORK_DIR}\n" > ${WORK_DIR}/config.yaml - echo -e "common:\n cache_dir: ${CI_CACHE_DIR}\n" >> ${WORK_DIR}/config.yaml + echo -e "common:" >> ${WORK_DIR}/config.yaml + echo -e " cache_dir: ${CI_CACHE_DIR}" >> ${WORK_DIR}/config.yaml + echo -e " numeric: cupy" >> ${WORK_DIR}/config.yaml + echo -e " fft: cupy\n" >> ${WORK_DIR}/config.yaml echo "Log the config: ${WORK_DIR}/config.yaml" cat ${WORK_DIR}/config.yaml - name: Run diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 4a48e3a505..5fa608ecdf 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -129,10 +129,10 @@ an M1 laptop: Installing GPU Extensions ************************* -ASPIRE does support GPUs, depending on several external packages. The -collection of GPU extensions can be installed using ``pip``. -Extensions are grouped based on CUDA versions. To find the CUDA -driver version, run ``nvidia-smi`` on the intended system. +ASPIRE does support using a GPU, depending on several external +packages. The collection of GPU extensions can be installed using +``pip``. Extensions are grouped based on CUDA versions. To find the +CUDA driver version, run ``nvidia-smi`` on the intended system. .. list-table:: CUDA GPU Extension Versions :widths: 25 25 @@ -140,14 +140,6 @@ driver version, run ``nvidia-smi`` on the intended system. * - CUDA Version - ASPIRE Extension - * - 10.2 - - gpu-102 - * - 11.0 - - gpu-110 - * - 11.1 - - gpu-111 - * - >=11.2 - - gpu-11x * - >=12 - gpu-12x @@ -164,12 +156,15 @@ the command below would install GPU packages required for ASPIRE. By default if the required GPU extensions are correctly installed, -ASPIRE should automatically begin using the GPU for select components -(such as those using ``nufft``). - -Because GPU extensions depend on several third party packages and -libraries, we can only offer limited support if one of the packages -has a problem on your system. +ASPIRE should automatically begin using the GPU calls to our ``nufft`` module. + +Using GPU in other areas of the code is still an experimental feature +and requires a minor configuration setting to enable ``cupy``. See the +:ref:`sphx_glr_auto_tutorials_configuration.py` for details. Because +GPU extensions depend on several third party softwares and machines +vary wildly, we can only offer limited support if one of the packages +has a problem on your system. We are currently expanding GPU code +coverage. Generating Documentation ************************ diff --git a/gallery/tutorials/configuration.py b/gallery/tutorials/configuration.py index 819ff9b675..372d97df06 100644 --- a/gallery/tutorials/configuration.py +++ b/gallery/tutorials/configuration.py @@ -102,6 +102,36 @@ time.sleep(1) print("Done Loop 2\n") +# %% +# Enabling GPU Acceleration +# ------------------------- +# Enabling GPU acceleration requires installing supporting software +# packages and small config changes. Installing the supporting +# software is most easily accomplished by installing ASPIRE with one +# of the published GPU extensions, for example ``pip install +# "aspire[dev,gpu_12x]"``. Once the packages are installed users +# should find that the NUFFT calls are automatically running on the +# GPU. Additional acceleration is achieved by enabling `cupy` for +# `numeric` and `fft` components. +# +# .. code-block:: yaml +# +# common: +# # numeric module to use - one of numpy/cupy +# numeric: cupy +# # fft backend to use - one of pyfftw/scipy/cupy/mkl +# fft: cupy +# +# Alternatively, like other config options, this can be changed +# dynamically with code. +# +# .. code-block:: python +# +# from aspire import config +# +# config["common"]["numeric"] = "cupy" +# config["common"]["fft"] = "cupy" +# # %% # Resolution diff --git a/pyproject.toml b/pyproject.toml index 3cd57981ef..c9c25a9976 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,11 +61,7 @@ dependencies = [ "Source" = "https://github.com/ComputationalCryoEM/ASPIRE-Python" [project.optional-dependencies] -gpu-102 = ["pycuda", "cupy-cuda102", "cufinufft==1.3"] -gpu-110 = ["pycuda", "cupy-cuda110", "cufinufft==1.3"] -gpu-111 = ["pycuda", "cupy-cuda111", "cufinufft==1.3"] -gpu-11x = ["pycuda", "cupy-cuda11x", "cufinufft==1.3"] -gpu-12x = ["pycuda", "cupy-cuda12x", "cufinufft==2.2.0"] +gpu-12x = ["cupy-cuda12x", "cufinufft==2.2.0"] dev = [ "black", "bumpversion", diff --git a/src/aspire/basis/ffb_2d.py b/src/aspire/basis/ffb_2d.py index 5a5c7c3f27..8d46e8419c 100644 --- a/src/aspire/basis/ffb_2d.py +++ b/src/aspire/basis/ffb_2d.py @@ -58,6 +58,16 @@ def _build(self): # precompute the basis functions in 2D grids self._precomp = self._precomp() + # include the normalization factor of angular part into radial part + self.radial_norm = xp.asarray(self._precomp["radial"]) / xp.asarray( + np.expand_dims(self.angular_norms, 1) + ) + + # precompute weighted nodes + self.gl_weighted_nodes = xp.asarray(self._precomp["gl_weights"]) * xp.asarray( + self._precomp["gl_nodes"] + ) + def _precomp(self): """ Precomute the basis functions on a polar Fourier grid @@ -105,6 +115,7 @@ def _evaluate(self, v): coordinate basis. This is Image instance with resolution of `self.sz` and the first dimension correspond to remaining dimension of `v`. """ + v = xp.asarray(v) sz_roll = v.shape[:-1] v = v.reshape(-1, self.count) @@ -112,25 +123,23 @@ def _evaluate(self, v): n_data = v.shape[0] # get information on polar grids from precomputed data - n_theta = np.size(self._precomp["freqs"], 2) - n_r = np.size(self._precomp["freqs"], 1) + n_theta = self._precomp["freqs"].shape[2] + n_r = self._precomp["freqs"].shape[1] # go through each basis function and find corresponding coefficient - pf = np.zeros((n_data, 2 * n_theta, n_r), dtype=complex_type(self.dtype)) + pf = xp.zeros((n_data, 2 * n_theta, n_r), dtype=complex_type(self.dtype)) ind = 0 idx = ind + np.arange(self.k_max[0], dtype=int) - # include the normalization factor of angular part into radial part - radial_norm = self._precomp["radial"] / np.expand_dims(self.angular_norms, 1) - pf[:, 0, :] = v[:, self._zero_angular_inds] @ radial_norm[idx] - ind = ind + np.size(idx) + pf[:, 0, :] = v[:, self._zero_angular_inds] @ self.radial_norm[idx] + ind = ind + idx.size ind_pos = ind for ell in range(1, self.ell_max + 1): - idx = ind + np.arange(self.k_max[ell], dtype=int) + idx = ind + xp.arange(self.k_max[ell], dtype=int) idx_pos = ind_pos + np.arange(self.k_max[ell], dtype=int) idx_neg = idx_pos + self.k_max[ell] @@ -139,7 +148,7 @@ def _evaluate(self, v): if np.mod(ell, 2) == 1: v_ell = 1j * v_ell - pf_ell = v_ell @ radial_norm[idx] + pf_ell = v_ell @ self.radial_norm[idx] pf[:, ell, :] = pf_ell if np.mod(ell, 2) == 0: @@ -147,22 +156,17 @@ def _evaluate(self, v): else: pf[:, 2 * n_theta - ell, :] = -pf_ell.conjugate() - ind = ind + np.size(idx) + ind = ind + idx.size ind_pos = ind_pos + 2 * self.k_max[ell] # 1D inverse FFT in the degree of polar angle - pf = 2 * pi * xp.asnumpy(fft.ifft(xp.asarray(pf), axis=1)) + pf = 2 * xp.pi * fft.ifft(pf, axis=1) # Only need "positive" frequencies. - hsize = int(np.size(pf, 1) / 2) + hsize = int(pf.shape[1] / 2) pf = pf[:, 0:hsize, :] - - for i_r in range(0, n_r): - pf[..., i_r] = pf[..., i_r] * ( - self._precomp["gl_weights"][i_r] * self._precomp["gl_nodes"][i_r] - ) - - pf = np.reshape(pf, (n_data, n_r * n_theta)) + pf *= self.gl_weighted_nodes[None, None, :] + pf = pf.reshape(n_data, n_r * n_theta) # perform inverse non-uniformly FFT transform back to 2D coordinate basis freqs = m_reshape(self._precomp["freqs"], (2, n_r * n_theta)) @@ -172,7 +176,7 @@ def _evaluate(self, v): # Return X as Image instance with the last two dimensions as *self.sz x = x.reshape((*sz_roll, *self.sz)) - return x + return xp.asnumpy(x) def _evaluate_t(self, x): """ @@ -193,56 +197,51 @@ def _evaluate_t(self, x): n_images = x.shape[0] # resamping x in a polar Fourier gird using nonuniform discrete Fourier transform - pf = nufft(x, 2 * pi * freqs) - pf = np.reshape(pf, (n_images, n_r, n_theta)) + pf = nufft(xp.asarray(x), 2 * pi * freqs) + pf = pf.reshape(n_images, n_r, n_theta) # Recover "negative" frequencies from "positive" half plane. - pf = np.concatenate((pf, pf.conjugate()), axis=2) + pf = xp.concatenate((pf, pf.conjugate()), axis=2) # evaluate radial integral using the Gauss-Legendre quadrature rule - for i_r in range(0, n_r): - pf[:, i_r, :] = pf[:, i_r, :] * ( - self._precomp["gl_weights"][i_r] * self._precomp["gl_nodes"][i_r] - ) + pf = pf * self.gl_weighted_nodes[None, :, None] # 1D FFT on the angular dimension for each concentric circle - pf = 2 * pi / (2 * n_theta) * xp.asnumpy(fft.fft(xp.asarray(pf))) + pf = 2 * xp.pi / (2 * n_theta) * fft.fft(pf) # This only makes it easier to slice the array later. - v = np.zeros((n_images, self.count), dtype=x.dtype) + v = xp.zeros((n_images, self.count), dtype=x.dtype) # go through each basis function and find the corresponding coefficient ind = 0 - idx = ind + np.arange(self.k_max[0]) + idx = ind + xp.arange(self.k_max[0]) - # include the normalization factor of angular part into radial part - radial_norm = self._precomp["radial"] / np.expand_dims(self.angular_norms, 1) - v[:, self._zero_angular_inds] = pf[:, :, 0].real @ radial_norm[idx].T - ind = ind + np.size(idx) + v[:, self._zero_angular_inds] = pf[:, :, 0].real @ self.radial_norm[idx].T + ind = ind + idx.size ind_pos = ind for ell in range(1, self.ell_max + 1): - idx = ind + np.arange(self.k_max[ell]) - idx_pos = ind_pos + np.arange(self.k_max[ell]) + idx = ind + xp.arange(self.k_max[ell]) + idx_pos = ind_pos + xp.arange(self.k_max[ell]) idx_neg = idx_pos + self.k_max[ell] - v_ell = pf[:, :, ell] @ radial_norm[idx].T + v_ell = pf[:, :, ell] @ self.radial_norm[idx].T if np.mod(ell, 2) == 0: - v_pos = np.real(v_ell) - v_neg = -np.imag(v_ell) + v_pos = v_ell.real + v_neg = -v_ell.imag else: - v_pos = np.imag(v_ell) - v_neg = np.real(v_ell) + v_pos = v_ell.imag + v_neg = v_ell.real v[:, idx_pos] = v_pos v[:, idx_neg] = v_neg - ind = ind + np.size(idx) + ind = ind + idx.size ind_pos = ind_pos + 2 * self.k_max[ell] - return v + return xp.asnumpy(v) def filter_to_basis_mat(self, f, **kwargs): """ diff --git a/src/aspire/basis/ffb_3d.py b/src/aspire/basis/ffb_3d.py index 6362a9a703..7f0821b99a 100644 --- a/src/aspire/basis/ffb_3d.py +++ b/src/aspire/basis/ffb_3d.py @@ -1,11 +1,11 @@ import logging import numpy as np -from numpy import pi from aspire.basis import FBBasis3D from aspire.basis.basis_utils import lgwt, norm_assoc_legendre, sph_bessel from aspire.nufft import anufft, nufft +from aspire.numeric import xp from aspire.utils.matlab_compat import m_flatten, m_reshape logger = logging.getLogger(__name__) @@ -60,26 +60,29 @@ def _precomp(self): r, wt_r = lgwt(n_r, 0.0, self.kcut, dtype=self.dtype) z, wt_z = lgwt(n_phi, -1, 1, dtype=self.dtype) - r = m_reshape(r, (n_r, 1)) - wt_r = m_reshape(wt_r, (n_r, 1)) - z = m_reshape(z, (n_phi, 1)) - wt_z = m_reshape(wt_z, (n_phi, 1)) - phi = np.arccos(z) + r = m_reshape(xp.asarray(r), (n_r, 1)) + rh = xp.asnumpy(r) + wt_r = m_reshape(xp.asarray(wt_r), (n_r, 1)) + z = m_reshape(xp.asarray(z), (n_phi, 1)) + wt_z = m_reshape(xp.asarray(wt_z), (n_phi, 1)) + phi = xp.arccos(z) wt_phi = wt_z - theta = 2 * pi * np.arange(n_theta, dtype=self.dtype).T / (2 * n_theta) + theta = 2 * xp.pi * xp.arange(n_theta, dtype=self.dtype).T / (2 * n_theta) theta = m_reshape(theta, (n_theta, 1)) # evaluate basis function in the radial dimension - radial_wtd = np.zeros( + radial_wtd = xp.zeros( shape=(n_r, np.max(self.k_max), self.ell_max + 1), dtype=self.dtype ) for ell in range(0, self.ell_max + 1): k_max_ell = self.k_max[ell] - rmat = r * self.r0[ell][0:k_max_ell].T / self.kcut - radial_ell = np.zeros_like(rmat) + rmat = rh * self.r0[ell][0:k_max_ell].T / self.kcut # host + radial_ell = xp.zeros_like(rmat) for ik in range(0, k_max_ell): - radial_ell[:, ik] = sph_bessel(ell, rmat[:, ik]) - nrm = np.abs(sph_bessel(ell + 1, self.r0[ell][0:k_max_ell].T) / 4) + radial_ell[:, ik] = xp.asarray(sph_bessel(ell, rmat[:, ik])) + nrm = xp.abs( + xp.asarray(sph_bessel(ell + 1, self.r0[ell][0:k_max_ell].T)) / 4 + ) radial_ell = radial_ell / nrm radial_ell_wtd = r**2 * wt_r * radial_ell radial_wtd[:, 0:k_max_ell, ell] = radial_ell_wtd @@ -94,14 +97,14 @@ def _precomp(self): - np.mod(self.ell_max, 2) * np.mod(m, 2) ) n_odd_ell = int(self.ell_max - m + 1 - n_even_ell) - phi_wtd_m_even = np.zeros((n_phi, n_even_ell), dtype=phi.dtype) - phi_wtd_m_odd = np.zeros((n_phi, n_odd_ell), dtype=phi.dtype) + phi_wtd_m_even = xp.zeros((n_phi, n_even_ell), dtype=phi.dtype) + phi_wtd_m_odd = xp.zeros((n_phi, n_odd_ell), dtype=phi.dtype) ind_even = 0 ind_odd = 0 for ell in range(m, self.ell_max + 1): - phi_m_ell = norm_assoc_legendre(ell, m, z) - nrm_inv = np.sqrt(0.5 / pi) + phi_m_ell = xp.asarray(norm_assoc_legendre(ell, m, z)) + nrm_inv = np.sqrt(0.5 / np.pi) phi_m_ell = nrm_inv * phi_m_ell phi_wtd_m_ell = wt_phi * phi_m_ell if np.mod(ell, 2) == 0: @@ -115,32 +118,32 @@ def _precomp(self): ang_phi_wtd_odd.append(phi_wtd_m_odd) # evaluate basis function in the theta dimension - ang_theta = np.zeros((n_theta, 2 * self.ell_max + 1), dtype=theta.dtype) + ang_theta = xp.zeros((n_theta, 2 * self.ell_max + 1), dtype=theta.dtype) - ang_theta[:, 0 : self.ell_max] = np.sqrt(2) * np.sin( - theta @ m_reshape(np.arange(self.ell_max, 0, -1), (1, self.ell_max)) + ang_theta[:, 0 : self.ell_max] = np.sqrt(2) * xp.sin( + theta @ m_reshape(xp.arange(self.ell_max, 0, -1), (1, self.ell_max)) ) - ang_theta[:, self.ell_max] = np.ones(n_theta, dtype=theta.dtype) - ang_theta[:, self.ell_max + 1 : 2 * self.ell_max + 1] = np.sqrt(2) * np.cos( - theta @ m_reshape(np.arange(1, self.ell_max + 1), (1, self.ell_max)) + ang_theta[:, self.ell_max] = xp.ones(n_theta, dtype=theta.dtype) + ang_theta[:, self.ell_max + 1 : 2 * self.ell_max + 1] = np.sqrt(2) * xp.cos( + theta @ m_reshape(xp.arange(1, self.ell_max + 1), (1, self.ell_max)) ) - ang_theta_wtd = (2 * pi / n_theta) * ang_theta + ang_theta_wtd = (2 * np.pi / n_theta) * ang_theta - theta_grid, phi_grid, r_grid = np.meshgrid( - theta, phi, r, sparse=False, indexing="ij" + theta_grid, phi_grid, r_grid = xp.meshgrid( + theta.flatten(), phi.flatten(), r.flatten(), sparse=False, indexing="ij" ) - fourier_x = m_flatten(r_grid * np.cos(theta_grid) * np.sin(phi_grid)) - fourier_y = m_flatten(r_grid * np.sin(theta_grid) * np.sin(phi_grid)) - fourier_z = m_flatten(r_grid * np.cos(phi_grid)) + fourier_x = m_flatten(r_grid * xp.cos(theta_grid) * xp.sin(phi_grid)) + fourier_y = m_flatten(r_grid * xp.sin(theta_grid) * xp.sin(phi_grid)) + fourier_z = m_flatten(r_grid * xp.cos(phi_grid)) fourier_pts = ( 2 - * pi - * np.vstack( + * xp.pi + * xp.vstack( ( - fourier_z[np.newaxis, ...], - fourier_y[np.newaxis, ...], - fourier_x[np.newaxis, ...], + fourier_z[None, ...], + fourier_y[None, ...], + fourier_x[None, ...], ) ) ) @@ -163,6 +166,7 @@ def _evaluate(self, v): coordinate basis. This is an array whose last three dimensions equal `self.sz` and the remaining dimensions correspond to `v`. """ + v = xp.asarray(v) # roll dimensions of v sz_roll = v.shape[:-1] v = v.reshape((-1, self.count)) @@ -175,7 +179,7 @@ def _evaluate(self, v): # number of 3D image samples n_data = v.shape[0] - u_even = np.zeros( + u_even = xp.zeros( ( n_r, int(2 * self.ell_max + 1), @@ -184,7 +188,7 @@ def _evaluate(self, v): ), dtype=v.dtype, ) - u_odd = np.zeros( + u_odd = xp.zeros( (n_r, int(2 * self.ell_max + 1), n_data, int(np.ceil(self.ell_max / 2))), dtype=v.dtype, ) @@ -216,10 +220,10 @@ def _evaluate(self, v): int((ell - 1) / 2), ] = v_ell - u_even = np.transpose(u_even, (3, 0, 1, 2)) - u_odd = np.transpose(u_odd, (3, 0, 1, 2)) - w_even = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) - w_odd = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) + u_even = u_even.transpose((3, 0, 1, 2)) + u_odd = u_odd.transpose((3, 0, 1, 2)) + w_even = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) + w_odd = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) # evaluate the phi parts for m in range(0, self.ell_max + 1): @@ -252,8 +256,8 @@ def _evaluate(self, v): w_even[:, :, :, self.ell_max + sgn * m] = w_m_even w_odd[:, :, :, self.ell_max + sgn * m] = w_m_odd - w_even = np.transpose(w_even, (3, 0, 1, 2)) - w_odd = np.transpose(w_odd, (3, 0, 1, 2)) + w_even = w_even.transpose((3, 0, 1, 2)) + w_odd = w_odd.transpose((3, 0, 1, 2)) u_even = w_even u_odd = w_odd @@ -266,7 +270,7 @@ def _evaluate(self, v): pf = w_even + 1j * w_odd pf = m_reshape(pf, (n_theta * n_phi * n_r, n_data)) - pf = np.moveaxis(pf, 0, -1) + pf = xp.moveaxis(pf, 0, -1) # perform inverse non-uniformly FFT transformation back to 3D rectangular coordinates freqs = m_reshape(self._precomp["fourier_pts"], (3, n_r * n_theta * n_phi)) @@ -275,7 +279,7 @@ def _evaluate(self, v): # Roll, return the x with the last three dimensions as self.sz # Higher dimensions should be like v. x = x.reshape((*sz_roll, *self.sz)) - return x + return xp.asnumpy(x) def _evaluate_t(self, x): """ @@ -288,6 +292,7 @@ def _evaluate_t(self, x): `self.count` and whose remaining dimensions correspond to higher dimensions of `x`. """ + x = xp.asarray(x) # roll dimensions sz_roll = x.shape[:-3] x = x.reshape((-1, *self.sz)) @@ -303,20 +308,21 @@ def _evaluate_t(self, x): pf = m_reshape(pf.T, (n_theta, n_phi * n_r * n_data)) # evaluate the theta parts - u_even = self._precomp["ang_theta_wtd"].T @ np.real(pf) - u_odd = self._precomp["ang_theta_wtd"].T @ np.imag(pf) + ang_theta_wtd_trans = self._precomp["ang_theta_wtd"].T + u_even = ang_theta_wtd_trans @ pf.real + u_odd = ang_theta_wtd_trans @ pf.imag u_even = m_reshape(u_even, (2 * self.ell_max + 1, n_phi, n_r, n_data)) u_odd = m_reshape(u_odd, (2 * self.ell_max + 1, n_phi, n_r, n_data)) - u_even = np.transpose(u_even, (1, 2, 3, 0)) - u_odd = np.transpose(u_odd, (1, 2, 3, 0)) + u_even = u_even.transpose((1, 2, 3, 0)) + u_odd = u_odd.transpose((1, 2, 3, 0)) - w_even = np.zeros( + w_even = xp.zeros( (int(np.floor(self.ell_max / 2) + 1), n_r, 2 * self.ell_max + 1, n_data), dtype=x.dtype, ) - w_odd = np.zeros( + w_odd = xp.zeros( (int(np.ceil(self.ell_max / 2)), n_r, 2 * self.ell_max + 1, n_data), dtype=x.dtype, ) @@ -351,11 +357,11 @@ def _evaluate_t(self, x): end = np.size(w_odd, 0) w_odd[end - n_odd_ell : end, :, self.ell_max + sgn * m, :] = w_m_odd - w_even = np.transpose(w_even, (1, 2, 3, 0)) - w_odd = np.transpose(w_odd, (1, 2, 3, 0)) + w_even = w_even.transpose((1, 2, 3, 0)) + w_odd = w_odd.transpose((1, 2, 3, 0)) # evaluate the radial parts - v = np.zeros((n_data, self.count), dtype=x.dtype) + v = xp.zeros((n_data, self.count), dtype=x.dtype) for ell in range(0, self.ell_max + 1): k_max_ell = self.k_max[ell] radial_wtd = self._precomp["radial_wtd"][:, 0:k_max_ell, ell] @@ -388,4 +394,4 @@ def _evaluate_t(self, x): # Roll dimensions, last dimension should be self.count, # Higher dimensions like x. v = v.reshape((*sz_roll, self.count)) - return v + return xp.asnumpy(v) diff --git a/src/aspire/basis/fle_2d.py b/src/aspire/basis/fle_2d.py index 423d37c093..76330e6fba 100644 --- a/src/aspire/basis/fle_2d.py +++ b/src/aspire/basis/fle_2d.py @@ -1,8 +1,6 @@ import logging import numpy as np -import scipy.sparse as sparse -from scipy.fft import dct, idct from scipy.special import jv from aspire.basis import Coef, FBBasisMixin, SteerableBasis2D @@ -13,13 +11,28 @@ transform_complex_to_real, ) from aspire.nufft import anufft, nufft -from aspire.numeric import fft +from aspire.numeric import fft, sparse, xp from aspire.operators import DiagMatrix from aspire.utils import complex_type, grid_2d logger = logging.getLogger(__name__) +def _cleanup(): + """ + Utility for informing cupy to cleanup memory held by old vars. + + This method is designed to be safely called even when `CuPy` is + not installed, in which case it is a no-op. + """ + try: + import cupy + + cupy.get_default_memory_pool().free_all_blocks() + except ModuleNotFoundError: + pass + + class FLEBasis2D(SteerableBasis2D, FBBasisMixin): """ Define a derived class for Fast Fourier Bessel 2D expansion using interpolation @@ -278,10 +291,10 @@ def _compute_nufft_points(self): self.num_angular_nodes = num_angular_nodes # create gridpoints - nodes = 1 - (2 * np.arange(self.num_radial_nodes, dtype=self.dtype) + 1) / ( + nodes = 1 - (2 * xp.arange(self.num_radial_nodes, dtype=self.dtype) + 1) / ( 2 * self.num_radial_nodes ) - nodes = (np.cos(np.pi * nodes) + 1) / 2 + nodes = (xp.cos(np.pi * nodes) + 1) / 2 nodes = ( self.greatest_lambda - self.smallest_lambda ) * nodes + self.smallest_lambda @@ -292,16 +305,17 @@ def _compute_nufft_points(self): phi = ( 2 - * np.pi - * np.arange(self.num_angular_nodes // 2, dtype=self.dtype) + * xp.pi + * xp.arange(self.num_angular_nodes // 2, dtype=self.dtype) / self.num_angular_nodes ) - x = np.cos(phi).reshape(1, self.num_angular_nodes // 2) - y = np.sin(phi).reshape(1, self.num_angular_nodes // 2) - x = x * nodes * h - y = y * nodes * h - self.grid_x = x.flatten() - self.grid_y = y.flatten() + grid_xy = xp.empty( + (2, self.num_radial_nodes, self.num_angular_nodes // 2), dtype=self.dtype + ) + grid_xy[0] = xp.cos(phi) # x + grid_xy[1] = xp.sin(phi) # y + grid_xy = grid_xy * nodes * h + self.grid_xy = grid_xy.reshape(2, -1) def _build_interpolation_matrix(self): """ @@ -469,7 +483,7 @@ def _create_basis_functions(self): norm_constants[i] = c - self.norm_constants = norm_constants + self.norm_constants = xp.asarray(norm_constants) self.basis_functions = basis_functions def _evaluate(self, coefs): @@ -498,34 +512,39 @@ def _evaluate_t(self, imgs): coefficients. """ # See Section 3.5 - imgs = imgs.copy() + imgs = xp.array(imgs) # Intentionally copying here, mutating. imgs[:, self.radial_mask] = 0 z = self._step1_t(imgs) + del imgs # inform python we're done with imgs + _cleanup() + b = self._step2_t(z) + del z # inform python we're done with z + _cleanup() + coefs = self._step3_t(b) + del b # inform python we're done with b + _cleanup() # return in FB order coefs = coefs[..., self._fle_to_fb_indices] - return coefs.astype(self.coefficient_dtype, copy=False) + return xp.asnumpy(coefs.astype(self.coefficient_dtype)) def _step1_t(self, im): """ Step 1 of the adjoint transformation (images to coefficients). - Calculates the NUFFT of the image on gridpoints `self.grid_x` and `self.grid_y`. + Calculates the NUFFT of the image on gridpoints `grid_xy`. """ im = im.reshape(-1, self.nres, self.nres).astype(complex_type(self.dtype)) num_img = im.shape[0] - z = np.zeros( + z = xp.zeros( (num_img, self.num_radial_nodes, self.num_angular_nodes), dtype=complex_type(self.dtype), ) - _z = ( - nufft(im, np.stack((self.grid_x, self.grid_y)), epsilon=self.epsilon) - * self.h**2 - ) + _z = nufft(im, self.grid_xy, epsilon=self.epsilon) * self.h**2 _z = _z.reshape(num_img, self.num_radial_nodes, self.num_angular_nodes // 2) z[:, :, : self.num_angular_nodes // 2] = _z - z[:, :, self.num_angular_nodes // 2 :] = np.conj(_z) + z[:, :, self.num_angular_nodes // 2 :] = _z.conj() return z def _step2_t(self, z): @@ -538,12 +557,12 @@ def _step2_t(self, z): # Compute FFT along angular nodes betas = fft.fft(z, axis=2) / self.num_angular_nodes betas = betas[:, :, self.nus] - betas = np.conj(betas) - betas = np.swapaxes(betas, 0, 2) + betas = betas.conj() + betas = betas.swapaxes(0, 2) betas = betas.reshape(-1, self.num_radial_nodes * num_img) betas = self.c2r_nus @ betas betas = betas.reshape(-1, self.num_radial_nodes, num_img) - betas = np.real(np.swapaxes(betas, 0, 2)) + betas = betas.swapaxes(0, 2).real return betas def _step3_t(self, betas): @@ -554,13 +573,12 @@ def _step3_t(self, betas): """ num_img = betas.shape[0] if self.num_interp > self.num_radial_nodes: - betas = dct(betas, axis=1, type=2) / (2 * self.num_radial_nodes) - zeros = np.zeros(betas.shape) - betas = np.concatenate((betas, zeros), axis=1) - betas = idct(betas, axis=1, type=2) * 2 * betas.shape[1] - betas = np.moveaxis(betas, 0, -1) + betas = fft.dct(betas, axis=1, type=2) / (2 * self.num_radial_nodes) + betas = xp.concatenate((betas, xp.zeros(betas.shape)), axis=1) + betas = fft.idct(betas, axis=1, type=2) * 2 * betas.shape[1] + betas = xp.moveaxis(betas, 0, -1) - coefs = np.zeros((self.count, num_img), dtype=np.float64) + coefs = xp.zeros((self.count, num_img), dtype=np.float64) for i in range(self.ell_p_max + 1): coefs[self.idx_list[i]] = self.A3[i] @ betas[:, i, :] coefs = coefs.T @@ -574,22 +592,22 @@ def _step3(self, coefs): Uses barycenteric interpolation in reverse to compute values of Betas at Chebyshev nodes, given an array of FLE coefficients. """ - coefs = coefs.copy().reshape(-1, self.count) + coefs = xp.asarray(coefs.reshape(-1, self.count)) num_img = coefs.shape[0] coefs *= self.h * self.norm_constants coefs = coefs.T - out = np.zeros( + out = xp.zeros( (self.num_interp, 2 * self.max_ell + 1, num_img), dtype=np.float64, ) for i in range(self.ell_p_max + 1): out[:, i, :] = self.A3_T[i] @ coefs[self.idx_list[i]] - out = np.moveaxis(out, -1, 0) + out = xp.moveaxis(out, -1, 0) if self.num_interp > self.num_radial_nodes: - out = dct(out, axis=1, type=2) + out = fft.dct(out, axis=1, type=2) out = out[:, : self.num_radial_nodes, :] - out = idct(out, axis=1, type=2) + out = fft.idct(out, axis=1, type=2) return out @@ -600,18 +618,18 @@ def _step2(self, betas): Uses the IFFT to convert Beta values into Fourier-space images. """ num_img = betas.shape[0] - tmp = np.zeros( + tmp = xp.zeros( (num_img, self.num_radial_nodes, self.num_angular_nodes), dtype=np.complex128, ) - betas = np.swapaxes(betas, 0, 2) + betas = betas.swapaxes(0, 2) betas = betas.reshape(-1, self.num_radial_nodes * num_img) betas = self.r2c_nus @ betas betas = betas.reshape(-1, self.num_radial_nodes, num_img) - betas = np.swapaxes(betas, 0, 2) + betas = betas.swapaxes(0, 2) - tmp[:, :, self.nus] = np.conj(betas) + tmp[:, :, self.nus] = betas.conj() z = fft.ifft(tmp, axis=2) return z @@ -625,17 +643,17 @@ def _step1(self, z): num_img = z.shape[0] z = z[:, :, : self.num_angular_nodes // 2].reshape(num_img, -1) im = anufft( - z.astype(complex_type(self.dtype)), - np.stack((self.grid_x, self.grid_y)), + z.astype(complex_type(self.dtype), copy=False), + self.grid_xy, (self.nres, self.nres), epsilon=self.epsilon, ) - im = im + np.conj(im) - im = np.real(im) + im = im + im.conj() + im = im.real im = im.reshape(num_img, self.nres, self.nres) im[:, self.radial_mask] = 0 - return im + return xp.asnumpy(im) def _create_dense_matrix(self): """ @@ -702,10 +720,12 @@ def radial_convolve(self, coefs, radial_img): "`radial_convolve` currently only implemented for 1D stacks." ) - coefs = coefs.asnumpy() + # Potentially migrate to GPU + coefs = xp.asarray(coefs.asnumpy()) + radial_img = xp.asarray(radial_img) num_img = coefs.shape[0] - coefs_conv = np.zeros(coefs.shape) + coefs_conv = xp.zeros(coefs.shape) # Convert to internal FLE indices ordering coefs = coefs[..., self._fb_to_fle_indices] @@ -717,25 +737,26 @@ def radial_convolve(self, coefs, radial_img): weights = self._radial_convolve_weights(b) b = weights / (self.h**2) b = b.reshape(self.count) - coefs_conv[k, :] = np.real(self.c2r @ (b * (self.r2c @ _coefs).flatten())) + coefs_conv[k, :] = (self.c2r @ (b * (self.r2c @ _coefs).flatten())).real # Convert from internal FLE ordering to FB convention coefs_conv = coefs_conv[..., self._fle_to_fb_indices] - return Coef(self, coefs_conv) + # Return as Coef on host + return Coef(self, xp.asnumpy(coefs_conv)) def _radial_convolve_weights(self, b): """ Helper function for step 3 of convolving with a radial function. """ - b = np.squeeze(b) - b = np.array(b) + b = xp.squeeze(b) + b = xp.array(b) # implies copy if self.num_interp > self.num_radial_nodes: - b = dct(b, axis=0, type=2) / (2 * self.num_radial_nodes) - bz = np.zeros(b.shape) - b = np.concatenate((b, bz), axis=0) - b = idct(b, axis=0, type=2) * 2 * b.shape[0] - a = np.zeros(self.count, dtype=np.float64) + b = fft.dct(b, axis=0, type=2) / (2 * self.num_radial_nodes) + bz = xp.zeros(b.shape) + b = xp.concatenate((b, bz), axis=0) + b = fft.idct(b, axis=0, type=2) * 2 * b.shape[0] + a = xp.zeros(self.count, dtype=np.float64) y = [None] * (self.ell_p_max + 1) for i in range(self.ell_p_max + 1): y[i] = (self.A3[i] @ b[:, 0]).flatten() @@ -764,20 +785,26 @@ def filter_to_basis_mat(self, f, **kwargs): # get 2D grid in polar coordinate k_vals, wts = lgwt(n_k, 0, 0.5, dtype=self.dtype) - k, theta = np.meshgrid( - k_vals, np.arange(n_theta) * 2 * np.pi / (2 * n_theta), indexing="ij" + k, theta = xp.meshgrid( + xp.asarray(k_vals), + xp.arange(n_theta) * 2 * np.pi / (2 * n_theta), + indexing="ij", ) # Get function values in polar 2D grid and average out angle contribution # NOTE: should probably just let the ctf objects handle this... - omegax = k * np.cos(theta) - omegay = k * np.sin(theta) - omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C"))) - - h_vals2d = h_fun(omega).reshape(n_k, n_theta).astype(self.dtype) - h_vals = np.sum(h_vals2d, axis=1) / n_theta + omegax = k * xp.cos(theta) + omegay = k * xp.sin(theta) + omega = 2 * xp.pi * xp.vstack((omegax.flatten("C"), omegay.flatten("C"))) + + h_vals2d = ( + xp.asarray(h_fun(omega)) + .reshape(n_k, n_theta) + .astype(self.dtype, copy=False) + ) + h_vals = xp.sum(h_vals2d, axis=1) / n_theta - h_basis = np.zeros(self.count, dtype=self.dtype) + h_basis = xp.zeros(self.count, dtype=self.dtype) # For now we just need to handle 1D (stack of one ctf) for j in range(self.ell_p_max + 1): h_basis[self.idx_list[j]] = self.A3[j] @ h_vals @@ -785,4 +812,4 @@ def filter_to_basis_mat(self, f, **kwargs): # Convert from internal FLE ordering to FB convention h_basis = h_basis[self._fle_to_fb_indices] - return DiagMatrix(h_basis) + return DiagMatrix(xp.asnumpy(h_basis)) diff --git a/src/aspire/basis/fle_2d_utils.py b/src/aspire/basis/fle_2d_utils.py index cde0cd11bf..ea459988b0 100644 --- a/src/aspire/basis/fle_2d_utils.py +++ b/src/aspire/basis/fle_2d_utils.py @@ -1,5 +1,6 @@ import numpy as np -import scipy.sparse as sparse + +from aspire.numeric import sparse, xp def transform_complex_to_real(B, ells): @@ -85,7 +86,11 @@ def precomp_transform_complex_to_real(ells): jdx[k] = i + 1 k = k + 1 - A = sparse.csr_matrix((vals, (idx, jdx)), shape=(count, count), dtype=np.complex128) + A = sparse.csr_matrix( + (xp.asarray(vals), (xp.asarray(idx), xp.asarray(jdx))), + shape=(count, count), + dtype=np.complex128, + ) return A.conjugate() @@ -190,9 +195,9 @@ def barycentric_interp_sparse(target_points, known_points, numsparse): # note that const cancels in numerator and denominator vals = vals / denom.reshape(-1, 1) - vals = vals.flatten() - idx = idx.flatten() - jdx = jdx.flatten() + vals = xp.array(vals.flatten()) + idx = xp.array(idx.flatten()) + jdx = xp.array(jdx.flatten()) # A is the linear operator mapping the function values from the fixed source # points to the fixed target points. # A(i,j) = \ell(x[i] ) w_j/(x[i] - xs[j]), with the notation in Eq. 3.3 diff --git a/src/aspire/image/image.py b/src/aspire/image/image.py index 20d998afe6..f03372b087 100644 --- a/src/aspire/image/image.py +++ b/src/aspire/image/image.py @@ -392,12 +392,17 @@ def downsample(self, ds_res): original_stack_shape = self.stack_shape im = self.stack_reshape(-1) + # Note image data is intentionally migrated via `xp.asarray` + # because all of the subsequent calls until `asnumpy` are GPU + # when xp and fft in `cupy` mode. + # compute FT with centered 0-frequency - fx = fft.centered_fft2(im._data) + fx = fft.centered_fft2(xp.asarray(im._data)) # crop 2D Fourier transform for each image - crop_fx = np.array([crop_pad_2d(fx[i], ds_res) for i in range(self.n_images)]) + crop_fx = crop_pad_2d(fx, ds_res) # take back to real space, discard complex part, and scale - out = np.real(fft.centered_ifft2(crop_fx)) * (ds_res**2 / self.resolution**2) + out = fft.centered_ifft2(crop_fx).real * (ds_res**2 / self.resolution**2) + out = xp.asnumpy(out) return self.__class__(out).stack_reshape(original_stack_shape) @@ -412,17 +417,21 @@ def filter(self, filter): im = self.stack_reshape(-1) - filter_values = filter.evaluate_grid(self.resolution) + # Note image and filter data is intentionally migrated via + # `xp.asarray` because all of the subsequent calls until + # `asnumpy` are GPU when xp and fft in `cupy` mode. + # + # Second note, filter dtype may not match image dtype. + filter_values = xp.asarray( + filter.evaluate_grid(self.resolution), dtype=self.dtype + ) - im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(im._data))) + # Convolve + im_f = fft.centered_fft2(xp.asarray(im._data)) + im_f = filter_values * im_f + im = fft.centered_ifft2(im_f) - # TODO: why are these different? Doesn't the broadcast work? - if im_f.ndim > filter_values.ndim: - im_f *= filter_values - else: - im_f = filter_values * im_f - im = xp.asnumpy(fft.centered_ifft2(xp.asarray(im_f))) - im = np.real(im) + im = xp.asnumpy(im.real) return self.__class__(im).stack_reshape(original_stack_shape) @@ -493,15 +502,15 @@ def _im_translate(self, shifts): n_shifts == 1 or n_shifts == self.n_images ), "number of shifts must be 1 or match the number of images" # Cast shifts to this instance's internal dtype - shifts = shifts.astype(self.dtype) + shifts = xp.asarray(shifts, dtype=self.dtype) L = self.resolution - im_f = xp.asnumpy(fft.fft2(xp.asarray(im))) + im_f = fft.fft2(xp.asarray(im)) grid_shifted = fft.ifftshift( - xp.asarray(np.ceil(np.arange(-L / 2, L / 2, dtype=self.dtype))) + xp.ceil(xp.arange(-L / 2, L / 2, dtype=self.dtype)) ) - grid_1d = xp.asnumpy(grid_shifted) * 2 * np.pi / L - om_x, om_y = np.meshgrid(grid_1d, grid_1d, indexing="ij") + grid_1d = grid_shifted * 2 * xp.pi / L + om_x, om_y = xp.meshgrid(grid_1d, grid_1d, indexing="ij") phase_shifts_x = -shifts[:, 0].reshape((n_shifts, 1, 1)) phase_shifts_y = -shifts[:, 1].reshape((n_shifts, 1, 1)) @@ -510,10 +519,10 @@ def _im_translate(self, shifts): om_x[np.newaxis, :, :] * phase_shifts_x + om_y[np.newaxis, :, :] * phase_shifts_y ) - mult_f = np.exp(-1j * phase_shifts) + mult_f = xp.exp(-1j * phase_shifts) im_translated_f = im_f * mult_f - im_translated = xp.asnumpy(fft.ifft2(xp.asarray(im_translated_f))) - im_translated = np.real(im_translated) + im_translated = fft.ifft2(im_translated_f) + im_translated = xp.asnumpy(im_translated.real) # Reshape to stack shape return self.__class__(im_translated).stack_reshape(stack_shape) diff --git a/src/aspire/nufft/__init__.py b/src/aspire/nufft/__init__.py index aa7c3a4adf..07d92c736c 100644 --- a/src/aspire/nufft/__init__.py +++ b/src/aspire/nufft/__init__.py @@ -5,6 +5,13 @@ from aspire import config from aspire.utils import LogFilterByCount, complex_type, real_type +cp = None +try: + import cupy as cp +except ModuleNotFoundError: + pass + + logger = logging.getLogger(__name__) # Cached Plan Class objects, indexed by backend string identifier, and ordered by preference (highest first) @@ -152,6 +159,9 @@ def anufft(sig_f, fourier_pts, sz, real=False, epsilon=1e-8): Selects best available package from `nfft` `backends` configuration list. + When sig_f is provided as a CuPy GPU array with a cufinufft + backend, result is maintained on GPU. + :param sig_f: Array representing the signal(s) in Fourier space to be transformed. \ sig_f either matches length of fourier_pts or sig_f.shape is stack of (`ntransforms`, ...). :param fourier_pts: The points in Fourier space where the Fourier transform is to be calculated, @@ -162,6 +172,10 @@ def anufft(sig_f, fourier_pts, sz, real=False, epsilon=1e-8): """ + _keep_on_gpu = False + if cp and isinstance(sig_f, cp.ndarray): + _keep_on_gpu = True + if fourier_pts.dtype != real_type(sig_f.dtype): raise RuntimeError( "anufft passed inconsistent dtypes." @@ -181,7 +195,13 @@ def anufft(sig_f, fourier_pts, sz, real=False, epsilon=1e-8): sz=sz, fourier_pts=fourier_pts, ntransforms=ntransforms, epsilon=epsilon ) adjoint = plan.adjoint(sig_f) - return np.real(adjoint) if real else adjoint + + adjoint = adjoint.real if real else adjoint + + if cp and not _keep_on_gpu: + adjoint = adjoint.get() + + return adjoint def nufft(sig_f, fourier_pts, real=False, epsilon=1e-8): @@ -191,6 +211,9 @@ def nufft(sig_f, fourier_pts, real=False, epsilon=1e-8): Selects best available package from `nfft` `backends` configuration list. + When sig_f is provided as a CuPy GPU array with a cufinufft + backend, result is maintained on GPU. + :param sig_f: Array representing the signal(s) in real space to be transformed. \ sig_f either matches `sz` or sig_f.shape is stack of (..., `ntransforms`). :param fourier_pts: The points in Fourier space where the Fourier transform is to be calculated, @@ -200,6 +223,10 @@ def nufft(sig_f, fourier_pts, real=False, epsilon=1e-8): """ + _keep_on_gpu = False + if cp and isinstance(sig_f, cp.ndarray): + _keep_on_gpu = True + if fourier_pts.dtype != real_type(sig_f.dtype): raise RuntimeError( "nufft passed inconsistent dtypes." @@ -229,4 +256,10 @@ def nufft(sig_f, fourier_pts, real=False, epsilon=1e-8): sz=sz, fourier_pts=fourier_pts, ntransforms=ntransforms, epsilon=epsilon ) transform = plan.transform(sig_f) - return np.real(transform) if real else transform + + transform = transform.real if real else transform + + if cp and not _keep_on_gpu: + transform = transform.get() + + return transform diff --git a/src/aspire/nufft/cufinufft.py b/src/aspire/nufft/cufinufft.py index 465c0b23f9..218fbd5fb7 100644 --- a/src/aspire/nufft/cufinufft.py +++ b/src/aspire/nufft/cufinufft.py @@ -1,9 +1,7 @@ import logging +import cupy as cp import numpy as np -import pycuda.autoinit # noqa: F401 -import pycuda.driver as cuda # noqa: F401 -import pycuda.gpuarray as gpuarray # noqa: F401 from cufinufft import Plan as cufPlan from aspire.nufft import Plan @@ -53,11 +51,11 @@ def __init__(self, sz, fourier_pts, epsilon=1e-8, ntransforms=1, **kwargs): "cufinufft has caught a non C_CONTIGUOUS array," " `fourier_pts` will be copied to C_CONTIGUOUS." ) - self.fourier_pts = np.ascontiguousarray( - np.mod(fourier_pts + np.pi, 2 * np.pi) - np.pi, dtype=self.dtype + self.fourier_pts = cp.ascontiguousarray( + cp.mod(cp.asarray(fourier_pts, dtype=self.dtype) + cp.pi, 2 * cp.pi) - cp.pi ) - self.num_pts = fourier_pts.shape[1] + self.num_pts = self.fourier_pts.shape[1] self.epsilon = max(epsilon, np.finfo(self.dtype).eps) self._transform_plan = cufPlan( @@ -83,12 +81,8 @@ def __init__(self, sz, fourier_pts, epsilon=1e-8, ntransforms=1, **kwargs): **self.adjoint_opts, ) - # Note, I store self.fourier_pts_gpu so the GPUArrray life - # is tied to instance, instead of this method. - self.fourier_pts_gpu = gpuarray.to_gpu(self.fourier_pts) - - self._transform_plan.setpts(*self.fourier_pts_gpu) - self._adjoint_plan.setpts(*self.fourier_pts_gpu) + self._transform_plan.setpts(*self.fourier_pts) + self._adjoint_plan.setpts(*self.fourier_pts) def transform(self, signal): """ @@ -99,7 +93,7 @@ def transform(self, signal): For a batch, signal should have shape `(*sz, ntransforms)`. :returns: Transformed signal of shape `num_pts` or - `(ntransforms, num_pts)`. + `(ntransforms, num_pts)` as CuPy array. """ # Check we're not forcing a dtype workaround for ASPIRE-Python/703, @@ -113,6 +107,8 @@ def transform(self, signal): " In the future this will be an error." ) + signal = cp.asarray(signal, dtype=self.complex_dtype) + sig_shape = signal.shape res_shape = self.num_pts # Note, there is a corner case for ntransforms == 1. @@ -134,17 +130,16 @@ def transform(self, signal): sig_shape == self.sz ), f"Signal frame to be transformed must have shape {self.sz}" - signal_gpu = gpuarray.to_gpu( - np.ascontiguousarray(signal, dtype=self.complex_dtype) - ) + result = cp.empty(res_shape, dtype=self.complex_dtype) - result_gpu = gpuarray.GPUArray(res_shape, dtype=self.complex_dtype) + if signal.dtype != self.complex_dtype: + signal = signal.astype(self.complex_dtype) - self._transform_plan.execute(signal_gpu, out=result_gpu) + self._transform_plan.execute(signal, out=result) - result = result_gpu.get() # ASPIRE-Python/703 - result = result.astype(complex_type(self._original_dtype), copy=False) + if result.dtype != complex_type(self._original_dtype): + result = result.astype(complex_type(self._original_dtype)) return result @@ -156,7 +151,7 @@ def adjoint(self, signal): this should be a a 1D array of len `num_pts`. For a batch, signal should have shape `(ntransforms, num_pts)`. - :returns: Transformed signal `(sz)` or `(sz, ntransforms)`. + :returns: Transformed signal `(sz)` or `(sz, ntransforms)` as CuPy array. """ # Check we're not forcing a dtype workaround for ASPIRE-Python/703, @@ -170,6 +165,9 @@ def adjoint(self, signal): " In the future this will be an error." ) + # Note, if not C order, cuFINUFFT will copy-cast anyway. + signal = cp.asarray(signal, order="C", dtype=self.complex_dtype) + res_shape = self.sz # Note, there is a corner case for ntransforms == 1. if self.ntransforms > 1 or (self.ntransforms == 1 and len(signal.shape) == 2): @@ -181,16 +179,15 @@ def adjoint(self, signal): ), "For multiple transforms, signal stack length should match ntransforms {self.ntransforms}." res_shape = (self.ntransforms, *self.sz) - signal_gpu = gpuarray.to_gpu( - np.ascontiguousarray(signal, dtype=self.complex_dtype) - ) + result = cp.empty(res_shape, dtype=self.complex_dtype) - result_gpu = gpuarray.GPUArray(res_shape, dtype=self.complex_dtype) + if signal.dtype != self.complex_dtype: + signal = signal.astype(self.complex_dtype) - self._adjoint_plan.execute(signal_gpu, out=result_gpu) + self._adjoint_plan.execute(signal, out=result) - result = result_gpu.get() # ASPIRE-Python/703 - result = result.astype(complex_type(self._original_dtype), copy=False) + if result.dtype != complex_type(self._original_dtype): + result = result.astype(complex_type(self._original_dtype)) return result diff --git a/src/aspire/numeric/__init__.py b/src/aspire/numeric/__init__.py index d298f131e4..be88775498 100644 --- a/src/aspire/numeric/__init__.py +++ b/src/aspire/numeric/__init__.py @@ -35,3 +35,37 @@ def fft_object(which): fft = fft_object(config["common"]["fft"].as_str()) + +# Sanity check. +if (config["common"]["numeric"].as_str() == "cupy") and ( + config["common"]["fft"].as_str() != "cupy" +): + raise RuntimeError( + "Using `cupy` numeric backend without `cupy` fft is unsupported." + ) + +if (config["common"]["fft"].as_str() == "cupy") and ( + config["common"]["numeric"].as_str() != "cupy" +): + raise RuntimeError( + "Using `cupy` fft without `cupy` numeric backend is unsupported." + ) + + +# Configure `sparse` in tandem with `numeric` as the arrays generally will need to interoperate. +def sparse_object(which): + if which == "cupy": + from cupyx.scipy import sparse as SparseClass + + # CuPy imports don't work the same as scipy + from cupyx.scipy.sparse.linalg import eigsh + + SparseClass.linalg.eigsh = eigsh + elif which == "numpy": + from scipy import sparse as SparseClass + else: + raise RuntimeError(f"Invalid selection for sparse module: {which}") + return SparseClass + + +sparse = sparse_object(config["common"]["numeric"].as_str()) diff --git a/src/aspire/numeric/cupy_fft.py b/src/aspire/numeric/cupy_fft.py index 4f45f92117..6ad6a4e9da 100644 --- a/src/aspire/numeric/cupy_fft.py +++ b/src/aspire/numeric/cupy_fft.py @@ -1,8 +1,59 @@ +import functools + import cupy as cp +import cupyx.scipy.fft as cufft +import numpy as np from aspire.numeric.base_fft import FFT +# This improves the flexibility of our FFT wrappers by allowing for +# incremental code changes and testing. +def _preserve_host(func): + """ + Method decorator that returns a numpy/cupy array result when + passed a numpy/cupy array input respectively. + + At the time of writing this wrapper will also upcast cupy FFT + operations to doubles as the precision in singles can cause + accuracy issues. + """ + + @functools.wraps(func) # Pass metadata (eg name and doctrings) from `func` + def wrapper(self, x, *args, **kwargs): + + # CuPy's single precision FFT appears to be too inaccurate for + # many of our unit tests, so the signal is upcast and recast + # on return. + _singles = False + if x.dtype == np.float32: + _singles = True + x = x.astype(np.float64) + elif x.dtype == np.complex64: + _singles = True + x = x.astype(np.complex128) + + _host = False + if not isinstance(x, cp.ndarray): + _host = True + x = cp.asarray(x) + + res = func(self, x, *args, **kwargs) + + if _host: + res = res.get() + + # Recast if needed. + if _singles and res.dtype == np.float64: + res = res.astype(np.float32) + elif _singles and res.dtype == np.complex128: + res = res.astype(np.complex64) + + return res + + return wrapper + + class CupyFFT(FFT): """ Define a unified wrapper class for Cupy FFT functions @@ -10,26 +61,42 @@ class CupyFFT(FFT): To be consistent with Scipy and Pyfftw, not all arguments are included. """ + @_preserve_host def fft(self, x, axis=-1, workers=-1): return cp.fft.fft(x, axis=axis) + @_preserve_host def ifft(self, x, axis=-1, workers=-1): return cp.fft.ifft(x, axis=axis) + @_preserve_host def fft2(self, x, axes=(-2, -1), workers=-1): return cp.fft.fft2(x, axes=axes) + @_preserve_host def ifft2(self, x, axes=(-2, -1), workers=-1): return cp.fft.ifft2(x, axes=axes) + @_preserve_host def fftn(self, x, axes=None, workers=-1): return cp.fft.fftn(x, axes=axes) + @_preserve_host def ifftn(self, x, axes=None, workers=-1): return cp.fft.ifftn(x, axes=axes) + @_preserve_host def fftshift(self, x, axes=None): return cp.fft.fftshift(x, axes=axes) + @_preserve_host def ifftshift(self, x, axes=None): return cp.fft.ifftshift(x, axes=axes) + + @_preserve_host + def dct(self, x, **kwargs): + return cufft.dct(x, **kwargs) + + @_preserve_host + def idct(self, x, **kwargs): + return cufft.idct(x, **kwargs) diff --git a/src/aspire/numeric/numpy.py b/src/aspire/numeric/numpy.py index 3237c2c3ad..ddc8355816 100644 --- a/src/aspire/numeric/numpy.py +++ b/src/aspire/numeric/numpy.py @@ -1,8 +1,22 @@ import numpy as np +cp = None +try: + import cupy as cp +except ModuleNotFoundError: + pass + class Numpy: - asnumpy = staticmethod(lambda x: x) + # This can be required when mixing nufft/fft/numpy backend combinations. + @staticmethod + def asnumpy(x): + """ + Ensure `asnumpy` is always available and returns a numpy array. + """ + if cp and isinstance(x, cp.ndarray): + x = x.get() + return x def __getattr__(self, item): """ diff --git a/src/aspire/numeric/pyfftw_fft.py b/src/aspire/numeric/pyfftw_fft.py index 9cfdd45210..95a8ea80f7 100644 --- a/src/aspire/numeric/pyfftw_fft.py +++ b/src/aspire/numeric/pyfftw_fft.py @@ -159,3 +159,9 @@ def fftshift(self, a, axes=None): def ifftshift(self, a, axes=None): return scipy_fft.ifftshift(a, axes=axes) + + def dct(self, x, **kwargs): + return scipy_fft.dct(x, **kwargs) + + def idct(self, x, **kwargs): + return scipy_fft.idct(x, **kwargs) diff --git a/src/aspire/numeric/scipy_fft.py b/src/aspire/numeric/scipy_fft.py index c5a392f96b..3891d45671 100644 --- a/src/aspire/numeric/scipy_fft.py +++ b/src/aspire/numeric/scipy_fft.py @@ -33,3 +33,9 @@ def fftshift(self, x, axes=None): def ifftshift(self, x, axes=None): return sp.fft.ifftshift(x, axes=axes) + + def dct(self, x, **kwargs): + return sp.fft.dct(x, **kwargs) + + def idct(self, x, **kwargs): + return sp.fft.idct(x, **kwargs) diff --git a/src/aspire/utils/coor_trans.py b/src/aspire/utils/coor_trans.py index e909e2f394..cad8fb0295 100644 --- a/src/aspire/utils/coor_trans.py +++ b/src/aspire/utils/coor_trans.py @@ -8,6 +8,7 @@ from numpy.linalg import norm from scipy.linalg import svd +from aspire.numeric import xp from aspire.utils.random import Random from aspire.utils.rotation import Rotation @@ -368,54 +369,105 @@ def rots_to_clmatrix(rots, n_theta): def crop_pad_2d(im, size, fill_value=0): """ - :param im: A 2-dimensional numpy array + Crop/pads `im` according to `size`. + + Padding will use `fill_value`. + Return's host/GPU array based on `im`. + + :param im: A >=2-dimensional numpy array :param size: Integer size of cropped/padded output - :return: A numpy array of shape (size, size) + :return: Array of shape (..., size, size) """ - im_y, im_x = im.shape + im_y, im_x = im.shape[-2:] # shift terms start_x = math.floor(im_x / 2) - math.floor(size / 2) start_y = math.floor(im_y / 2) - math.floor(size / 2) # cropping if size <= min(im_y, im_x): - return im[start_y : start_y + size, start_x : start_x + size] + return im[..., start_y : start_y + size, start_x : start_x + size] # padding elif size >= max(im_y, im_x): - # ensure that we return in the same dtype as the input - to_return = fill_value * np.ones((size, size), dtype=im.dtype) + # Determine shape + shape = list(im.shape[:-2]) + shape.extend([size, size]) + + # Ensure that we return the same dtype as the input + _full = np.full # Default to numpy array + if isinstance(im, xp.ndarray): + # Use cupy when `im` _and_ xp are cupy ndarray + # Avoids having to handle when cupy is not installed + _full = xp.full + + to_return = _full(shape, fill_value, dtype=im.dtype) + # when padding, start_x and start_y are negative since size is larger # than im_x and im_y; the below line calculates where the original image # is placed in relation to the (now-larger) box size - to_return[-start_y : im_y - start_y, -start_x : im_x - start_x] = im + to_return[..., -start_y : im_y - start_y, -start_x : im_x - start_x] = im return to_return else: # target size is between mat_x and mat_y - raise ValueError("Cannot crop and pad an image at the same time.") + raise ValueError( + "Cannot crop and pad Image at the same time." + "If this is really what you intended," + " make two seperate calls for cropping and padding." + ) + + +def crop_pad_3d(vol, size, fill_value=0): + """ + Crop/pads `vol` according to `size`. + Padding will use `fill_value`. + Return's host/GPU array based on `vol`. -def crop_pad_3d(im, size, fill_value=0): - im_y, im_x, im_z = im.shape + :param vol: A >=3-dimensional numpy array + :param size: Integer size of cropped/padded output + :return: Array of shape (..., size, size, size) + """ + + vol_z, vol_y, vol_x = vol.shape[-3:] # shift terms - start_x = math.floor(im_x / 2) - math.floor(size / 2) - start_y = math.floor(im_y / 2) - math.floor(size / 2) - start_z = math.floor(im_z / 2) - math.floor(size / 2) + start_z = math.floor(vol_z / 2) - math.floor(size / 2) + start_y = math.floor(vol_y / 2) - math.floor(size / 2) + start_x = math.floor(vol_x / 2) - math.floor(size / 2) # cropping - if size <= min(im_y, im_x, im_z): - return im[ - start_y : start_y + size, start_x : start_x + size, start_z : start_z + size + if size <= min(vol_z, vol_y, vol_x): + return vol[ + ..., + start_z : start_z + size, + start_y : start_y + size, + start_x : start_x + size, ] # padding - elif size >= max(im_y, im_x, im_z): - to_return = fill_value * np.ones((size, size, size), dtype=im.dtype) + elif size >= max(vol_z, vol_y, vol_x): + # Determine shape + shape = list(vol.shape[:-3]) + shape.extend([size, size, size]) + + # Ensure that we return the same dtype as the input + _full = np.full # Default to numpy array + if isinstance(vol, xp.ndarray): + # Use cupy when `vol` _and_ xp are cupy ndarray + # Avoids having to handle when cupy is not installed + _full = xp.full + + to_return = _full(shape, fill_value, dtype=vol.dtype) + to_return[ - -start_y : im_y - start_y, - -start_x : im_x - start_x, - -start_z : im_z - start_z, - ] = im + ..., + -start_z : vol_z - start_z, + -start_y : vol_y - start_y, + -start_x : vol_x - start_x, + ] = vol return to_return else: - # target size is between min and max of (im_y, im_x, im_z) - raise ValueError("Cannot crop and pad a volume at the same time.") + # target size is between min and max of (vol_x, vol_y, vol_z) + raise ValueError( + "Cannot crop and pad Volume at the same time." + "If this is really what you intended," + " make two seperate calls for cropping and padding." + ) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index b6c100db36..0f01ef5e61 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -342,13 +342,13 @@ def project(self, rot_matrices): if rot_matrices.ndim == 2: rot_matrices = np.expand_dims(rot_matrices, axis=0) - data = self._data + data = xp.asarray(self._data) n_rots = rot_matrices.shape[0] pts_rot = rotated_grids(self.resolution, rot_matrices) if n_rots == self.n_vols: # Apply rotations to Volumes element-wise. - im_f = np.empty( + im_f = xp.empty( (self.n_vols, self.resolution**2), dtype=complex_type(self.dtype) ) pts_rot = pts_rot.reshape((3, n_rots, self.resolution**2)) @@ -370,9 +370,9 @@ def project(self, rot_matrices): im_f[:, 0, :] = 0 im_f[:, :, 0] = 0 - im_f = xp.asnumpy(fft.centered_ifft2(xp.asarray(im_f))) + im_f = fft.centered_ifft2(im_f) - return aspire.image.Image(np.real(im_f)) + return aspire.image.Image(xp.asnumpy(im_f.real)) def to_vec(self): """Returns an N x resolution ** 3 array.""" @@ -468,26 +468,27 @@ def downsample(self, ds_res, mask=None): :param ds_res: Desired resolution. :param mask: Optional NumPy array mask to multiply in Fourier space. """ - if mask is None: - mask = 1.0 original_stack_shape = self.stack_shape v = self.stack_reshape(-1) # take 3D Fourier transform of each volume in the stack - fx = fft.fftshift(fft.fftn(v._data, axes=(1, 2, 3))) + fx = fft.fftshift(fft.fftn(xp.asarray(v._data), axes=(1, 2, 3))) + # crop each volume to the desired resolution in frequency space - crop_fx = ( - np.array([crop_pad_3d(fx[i, :, :, :], ds_res) for i in range(self.n_vols)]) - * mask - ) + fx = crop_pad_3d(fx, ds_res) + + # Optionally apply mask + if mask is not None: + fx = fx * xp.asarray(mask) + # inverse Fourier transform of each volume - out = fft.ifftn(fft.ifftshift(crop_fx), axes=(1, 2, 3)) * ( - ds_res**3 / self.resolution**3 - ) + out = fft.ifftn(fft.ifftshift(fx), axes=(1, 2, 3)).real + out = out.real * (ds_res**3 / self.resolution**3) + # returns a new Volume object return self.__class__( - np.real(out), symmetry_group=self.symmetry_group + xp.asnumpy(out), symmetry_group=self.symmetry_group ).stack_reshape(original_stack_shape) def shift(self): diff --git a/tests/test_numeric_sparse.py b/tests/test_numeric_sparse.py new file mode 100644 index 0000000000..e58aa02e6a --- /dev/null +++ b/tests/test_numeric_sparse.py @@ -0,0 +1,58 @@ +""" +Tests basic numpy/cupy functionality of sparse numeric wrappers. +""" + +import numpy as np +import pytest + +from aspire.numeric import numeric_object, sparse_object + +# If cupy is not available, skip this entire test module +pytest.importorskip("cupy") + +NUMERICS = ["numpy", "cupy"] + + +@pytest.fixture(params=NUMERICS, ids=lambda x: f"{x}", scope="module") +def backends(request): + xp = numeric_object(request.param) + sparse = sparse_object(request.param) + return xp, sparse + + +def test_csr_matrix(backends): + """ + Create csr_matrix and multiply with an `xp` array. + """ + xp, sparse = backends + + m, n = 10, 10 + jdx = xp.arange(m) + idx = xp.arange(n) + vals = xp.random.random(10) + + # Compute dense matmul + _A = np.diag(xp.asnumpy(vals)) + _B = np.random.random((n, 20)) + _C = _A @ _B + + # Compute matmul using sparse csr + A = sparse.csr_matrix((vals, (jdx, idx)), shape=(m, n), dtype=np.float64) + B = xp.array(_B) + C = A @ B + + # Compare + np.testing.assert_allclose(_C, xp.asnumpy(C)) + + +def test_eigsh(backends): + """ + Invoke sparse eigsh call with `xp` arrays. + """ + xp, sparse = backends + + n = 123 + A = xp.diag(xp.arange(1, n + 1, dtype=np.float64)) + + lamb, _ = sparse.linalg.eigsh(A, k=1) + np.testing.assert_allclose(xp.asnumpy(lamb), n) diff --git a/tests/test_orient_sdp.py b/tests/test_orient_sdp.py index a161d2fdd7..22658ee06a 100644 --- a/tests/test_orient_sdp.py +++ b/tests/test_orient_sdp.py @@ -77,7 +77,7 @@ def test_estimate_rotations(src_orient_est_fixture): src, orient_est = src_orient_est_fixture if backend_available("cufinufft") and src.dtype == np.float32: - pytest.skip("CI on gpu fails for singles.") + pytest.skip("CI on GPU fails for singles.") orient_est.estimate_rotations()