From d622e6cdc376e6995c542440d1f485daf8984be0 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 13 Jun 2024 10:58:04 -0400 Subject: [PATCH 1/3] bare min vol hack --- src/aspire/volume/volume.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/aspire/volume/volume.py b/src/aspire/volume/volume.py index b6c100db36..7883f59190 100644 --- a/src/aspire/volume/volume.py +++ b/src/aspire/volume/volume.py @@ -475,15 +475,16 @@ def downsample(self, ds_res, mask=None): v = self.stack_reshape(-1) # take 3D Fourier transform of each volume in the stack - fx = fft.fftshift(fft.fftn(v._data, axes=(1, 2, 3))) + fx = xp.asnumpy(fft.fftshift(fft.fftn(xp.asarray(v._data), axes=(1, 2, 3)))) # crop each volume to the desired resolution in frequency space crop_fx = ( np.array([crop_pad_3d(fx[i, :, :, :], ds_res) for i in range(self.n_vols)]) * mask ) # inverse Fourier transform of each volume - out = fft.ifftn(fft.ifftshift(crop_fx), axes=(1, 2, 3)) * ( - ds_res**3 / self.resolution**3 + out = xp.asnumpy( + fft.ifftn(fft.ifftshift(xp.asarray(crop_fx)), axes=(1, 2, 3)) + * (ds_res**3 / self.resolution**3) ) # returns a new Volume object return self.__class__( From 2e5809e38a2c754209931165c39c56a0d5312a67 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 13 Jun 2024 11:33:04 -0400 Subject: [PATCH 2/3] bare min ffb3d hacks [skip ci] --- src/aspire/basis/ffb_3d.py | 52 ++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/src/aspire/basis/ffb_3d.py b/src/aspire/basis/ffb_3d.py index 6362a9a703..1ac5fd62ff 100644 --- a/src/aspire/basis/ffb_3d.py +++ b/src/aspire/basis/ffb_3d.py @@ -6,6 +6,7 @@ from aspire.basis import FBBasis3D from aspire.basis.basis_utils import lgwt, norm_assoc_legendre, sph_bessel from aspire.nufft import anufft, nufft +from aspire.numeric import xp from aspire.utils.matlab_compat import m_flatten, m_reshape logger = logging.getLogger(__name__) @@ -146,10 +147,10 @@ def _precomp(self): ) return { - "radial_wtd": radial_wtd, - "ang_phi_wtd_even": ang_phi_wtd_even, - "ang_phi_wtd_odd": ang_phi_wtd_odd, - "ang_theta_wtd": ang_theta_wtd, + "radial_wtd": xp.asarray(radial_wtd), + "ang_phi_wtd_even": [xp.asarray(x) for x in ang_phi_wtd_even], + "ang_phi_wtd_odd": [xp.asarray(x) for x in ang_phi_wtd_odd], + "ang_theta_wtd": xp.asarray(ang_theta_wtd), "fourier_pts": fourier_pts, } @@ -163,6 +164,7 @@ def _evaluate(self, v): coordinate basis. This is an array whose last three dimensions equal `self.sz` and the remaining dimensions correspond to `v`. """ + v = xp.asarray(v) # roll dimensions of v sz_roll = v.shape[:-1] v = v.reshape((-1, self.count)) @@ -175,7 +177,7 @@ def _evaluate(self, v): # number of 3D image samples n_data = v.shape[0] - u_even = np.zeros( + u_even = xp.zeros( ( n_r, int(2 * self.ell_max + 1), @@ -184,7 +186,7 @@ def _evaluate(self, v): ), dtype=v.dtype, ) - u_odd = np.zeros( + u_odd = xp.zeros( (n_r, int(2 * self.ell_max + 1), n_data, int(np.ceil(self.ell_max / 2))), dtype=v.dtype, ) @@ -216,10 +218,10 @@ def _evaluate(self, v): int((ell - 1) / 2), ] = v_ell - u_even = np.transpose(u_even, (3, 0, 1, 2)) - u_odd = np.transpose(u_odd, (3, 0, 1, 2)) - w_even = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) - w_odd = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) + u_even = xp.transpose(u_even, (3, 0, 1, 2)) + u_odd = xp.transpose(u_odd, (3, 0, 1, 2)) + w_even = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) + w_odd = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) # evaluate the phi parts for m in range(0, self.ell_max + 1): @@ -252,8 +254,8 @@ def _evaluate(self, v): w_even[:, :, :, self.ell_max + sgn * m] = w_m_even w_odd[:, :, :, self.ell_max + sgn * m] = w_m_odd - w_even = np.transpose(w_even, (3, 0, 1, 2)) - w_odd = np.transpose(w_odd, (3, 0, 1, 2)) + w_even = xp.transpose(w_even, (3, 0, 1, 2)) + w_odd = xp.transpose(w_odd, (3, 0, 1, 2)) u_even = w_even u_odd = w_odd @@ -266,7 +268,7 @@ def _evaluate(self, v): pf = w_even + 1j * w_odd pf = m_reshape(pf, (n_theta * n_phi * n_r, n_data)) - pf = np.moveaxis(pf, 0, -1) + pf = xp.moveaxis(pf, 0, -1) # perform inverse non-uniformly FFT transformation back to 3D rectangular coordinates freqs = m_reshape(self._precomp["fourier_pts"], (3, n_r * n_theta * n_phi)) @@ -275,7 +277,7 @@ def _evaluate(self, v): # Roll, return the x with the last three dimensions as self.sz # Higher dimensions should be like v. x = x.reshape((*sz_roll, *self.sz)) - return x + return xp.asnumpy(x) def _evaluate_t(self, x): """ @@ -288,6 +290,7 @@ def _evaluate_t(self, x): `self.count` and whose remaining dimensions correspond to higher dimensions of `x`. """ + x = xp.asarray(x) # roll dimensions sz_roll = x.shape[:-3] x = x.reshape((-1, *self.sz)) @@ -303,20 +306,21 @@ def _evaluate_t(self, x): pf = m_reshape(pf.T, (n_theta, n_phi * n_r * n_data)) # evaluate the theta parts - u_even = self._precomp["ang_theta_wtd"].T @ np.real(pf) - u_odd = self._precomp["ang_theta_wtd"].T @ np.imag(pf) + tmp = self._precomp["ang_theta_wtd"].T + u_even = tmp @ xp.real(pf) + u_odd = tmp @ xp.imag(pf) u_even = m_reshape(u_even, (2 * self.ell_max + 1, n_phi, n_r, n_data)) u_odd = m_reshape(u_odd, (2 * self.ell_max + 1, n_phi, n_r, n_data)) - u_even = np.transpose(u_even, (1, 2, 3, 0)) - u_odd = np.transpose(u_odd, (1, 2, 3, 0)) + u_even = xp.transpose(u_even, (1, 2, 3, 0)) + u_odd = xp.transpose(u_odd, (1, 2, 3, 0)) - w_even = np.zeros( + w_even = xp.zeros( (int(np.floor(self.ell_max / 2) + 1), n_r, 2 * self.ell_max + 1, n_data), dtype=x.dtype, ) - w_odd = np.zeros( + w_odd = xp.zeros( (int(np.ceil(self.ell_max / 2)), n_r, 2 * self.ell_max + 1, n_data), dtype=x.dtype, ) @@ -351,11 +355,11 @@ def _evaluate_t(self, x): end = np.size(w_odd, 0) w_odd[end - n_odd_ell : end, :, self.ell_max + sgn * m, :] = w_m_odd - w_even = np.transpose(w_even, (1, 2, 3, 0)) - w_odd = np.transpose(w_odd, (1, 2, 3, 0)) + w_even = xp.transpose(w_even, (1, 2, 3, 0)) + w_odd = xp.transpose(w_odd, (1, 2, 3, 0)) # evaluate the radial parts - v = np.zeros((n_data, self.count), dtype=x.dtype) + v = xp.zeros((n_data, self.count), dtype=x.dtype) for ell in range(0, self.ell_max + 1): k_max_ell = self.k_max[ell] radial_wtd = self._precomp["radial_wtd"][:, 0:k_max_ell, ell] @@ -388,4 +392,4 @@ def _evaluate_t(self, x): # Roll dimensions, last dimension should be self.count, # Higher dimensions like x. v = v.reshape((*sz_roll, self.count)) - return v + return xp.asnumpy(v) From 50bfcb3eb8c40ada8b5a5f9f319fb322f2072660 Mon Sep 17 00:00:00 2001 From: Garrett Wright Date: Thu, 13 Jun 2024 14:58:00 -0400 Subject: [PATCH 3/3] better style --- src/aspire/basis/ffb_3d.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/aspire/basis/ffb_3d.py b/src/aspire/basis/ffb_3d.py index 1ac5fd62ff..5740900a34 100644 --- a/src/aspire/basis/ffb_3d.py +++ b/src/aspire/basis/ffb_3d.py @@ -218,8 +218,8 @@ def _evaluate(self, v): int((ell - 1) / 2), ] = v_ell - u_even = xp.transpose(u_even, (3, 0, 1, 2)) - u_odd = xp.transpose(u_odd, (3, 0, 1, 2)) + u_even = u_even.transpose((3, 0, 1, 2)) + u_odd = u_odd.transpose((3, 0, 1, 2)) w_even = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) w_odd = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype) @@ -254,8 +254,8 @@ def _evaluate(self, v): w_even[:, :, :, self.ell_max + sgn * m] = w_m_even w_odd[:, :, :, self.ell_max + sgn * m] = w_m_odd - w_even = xp.transpose(w_even, (3, 0, 1, 2)) - w_odd = xp.transpose(w_odd, (3, 0, 1, 2)) + w_even = w_even.transpose((3, 0, 1, 2)) + w_odd = w_odd.transpose((3, 0, 1, 2)) u_even = w_even u_odd = w_odd @@ -307,14 +307,14 @@ def _evaluate_t(self, x): # evaluate the theta parts tmp = self._precomp["ang_theta_wtd"].T - u_even = tmp @ xp.real(pf) - u_odd = tmp @ xp.imag(pf) + u_even = tmp @ pf.real + u_odd = tmp @ pf.imag u_even = m_reshape(u_even, (2 * self.ell_max + 1, n_phi, n_r, n_data)) u_odd = m_reshape(u_odd, (2 * self.ell_max + 1, n_phi, n_r, n_data)) - u_even = xp.transpose(u_even, (1, 2, 3, 0)) - u_odd = xp.transpose(u_odd, (1, 2, 3, 0)) + u_even = u_even.transpose((1, 2, 3, 0)) + u_odd = u_odd.transpose((1, 2, 3, 0)) w_even = xp.zeros( (int(np.floor(self.ell_max / 2) + 1), n_r, 2 * self.ell_max + 1, n_data), @@ -355,8 +355,8 @@ def _evaluate_t(self, x): end = np.size(w_odd, 0) w_odd[end - n_odd_ell : end, :, self.ell_max + sgn * m, :] = w_m_odd - w_even = xp.transpose(w_even, (1, 2, 3, 0)) - w_odd = xp.transpose(w_odd, (1, 2, 3, 0)) + w_even = w_even.transpose((1, 2, 3, 0)) + w_odd = w_odd.transpose((1, 2, 3, 0)) # evaluate the radial parts v = xp.zeros((n_data, self.count), dtype=x.dtype)