Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions src/aspire/basis/ffb_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aspire.basis import FBBasis3D
from aspire.basis.basis_utils import lgwt, norm_assoc_legendre, sph_bessel
from aspire.nufft import anufft, nufft
from aspire.numeric import xp
from aspire.utils.matlab_compat import m_flatten, m_reshape

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -146,10 +147,10 @@ def _precomp(self):
)

return {
"radial_wtd": radial_wtd,
"ang_phi_wtd_even": ang_phi_wtd_even,
"ang_phi_wtd_odd": ang_phi_wtd_odd,
"ang_theta_wtd": ang_theta_wtd,
"radial_wtd": xp.asarray(radial_wtd),
"ang_phi_wtd_even": [xp.asarray(x) for x in ang_phi_wtd_even],
"ang_phi_wtd_odd": [xp.asarray(x) for x in ang_phi_wtd_odd],
"ang_theta_wtd": xp.asarray(ang_theta_wtd),
"fourier_pts": fourier_pts,
}

Expand All @@ -163,6 +164,7 @@ def _evaluate(self, v):
coordinate basis. This is an array whose last three dimensions equal
`self.sz` and the remaining dimensions correspond to `v`.
"""
v = xp.asarray(v)
# roll dimensions of v
sz_roll = v.shape[:-1]
v = v.reshape((-1, self.count))
Expand All @@ -175,7 +177,7 @@ def _evaluate(self, v):
# number of 3D image samples
n_data = v.shape[0]

u_even = np.zeros(
u_even = xp.zeros(
(
n_r,
int(2 * self.ell_max + 1),
Expand All @@ -184,7 +186,7 @@ def _evaluate(self, v):
),
dtype=v.dtype,
)
u_odd = np.zeros(
u_odd = xp.zeros(
(n_r, int(2 * self.ell_max + 1), n_data, int(np.ceil(self.ell_max / 2))),
dtype=v.dtype,
)
Expand Down Expand Up @@ -216,10 +218,10 @@ def _evaluate(self, v):
int((ell - 1) / 2),
] = v_ell

u_even = np.transpose(u_even, (3, 0, 1, 2))
u_odd = np.transpose(u_odd, (3, 0, 1, 2))
w_even = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype)
w_odd = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype)
u_even = u_even.transpose((3, 0, 1, 2))
u_odd = u_odd.transpose((3, 0, 1, 2))
w_even = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype)
w_odd = xp.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1), dtype=v.dtype)

# evaluate the phi parts
for m in range(0, self.ell_max + 1):
Expand Down Expand Up @@ -252,8 +254,8 @@ def _evaluate(self, v):
w_even[:, :, :, self.ell_max + sgn * m] = w_m_even
w_odd[:, :, :, self.ell_max + sgn * m] = w_m_odd

w_even = np.transpose(w_even, (3, 0, 1, 2))
w_odd = np.transpose(w_odd, (3, 0, 1, 2))
w_even = w_even.transpose((3, 0, 1, 2))
w_odd = w_odd.transpose((3, 0, 1, 2))
u_even = w_even
u_odd = w_odd

Expand All @@ -266,7 +268,7 @@ def _evaluate(self, v):

pf = w_even + 1j * w_odd
pf = m_reshape(pf, (n_theta * n_phi * n_r, n_data))
pf = np.moveaxis(pf, 0, -1)
pf = xp.moveaxis(pf, 0, -1)

# perform inverse non-uniformly FFT transformation back to 3D rectangular coordinates
freqs = m_reshape(self._precomp["fourier_pts"], (3, n_r * n_theta * n_phi))
Expand All @@ -275,7 +277,7 @@ def _evaluate(self, v):
# Roll, return the x with the last three dimensions as self.sz
# Higher dimensions should be like v.
x = x.reshape((*sz_roll, *self.sz))
return x
return xp.asnumpy(x)

def _evaluate_t(self, x):
"""
Expand All @@ -288,6 +290,7 @@ def _evaluate_t(self, x):
`self.count` and whose remaining dimensions correspond to higher
dimensions of `x`.
"""
x = xp.asarray(x)
# roll dimensions
sz_roll = x.shape[:-3]
x = x.reshape((-1, *self.sz))
Expand All @@ -303,20 +306,21 @@ def _evaluate_t(self, x):
pf = m_reshape(pf.T, (n_theta, n_phi * n_r * n_data))

# evaluate the theta parts
u_even = self._precomp["ang_theta_wtd"].T @ np.real(pf)
u_odd = self._precomp["ang_theta_wtd"].T @ np.imag(pf)
tmp = self._precomp["ang_theta_wtd"].T
u_even = tmp @ pf.real
u_odd = tmp @ pf.imag

u_even = m_reshape(u_even, (2 * self.ell_max + 1, n_phi, n_r, n_data))
u_odd = m_reshape(u_odd, (2 * self.ell_max + 1, n_phi, n_r, n_data))

u_even = np.transpose(u_even, (1, 2, 3, 0))
u_odd = np.transpose(u_odd, (1, 2, 3, 0))
u_even = u_even.transpose((1, 2, 3, 0))
u_odd = u_odd.transpose((1, 2, 3, 0))

w_even = np.zeros(
w_even = xp.zeros(
(int(np.floor(self.ell_max / 2) + 1), n_r, 2 * self.ell_max + 1, n_data),
dtype=x.dtype,
)
w_odd = np.zeros(
w_odd = xp.zeros(
(int(np.ceil(self.ell_max / 2)), n_r, 2 * self.ell_max + 1, n_data),
dtype=x.dtype,
)
Expand Down Expand Up @@ -351,11 +355,11 @@ def _evaluate_t(self, x):
end = np.size(w_odd, 0)
w_odd[end - n_odd_ell : end, :, self.ell_max + sgn * m, :] = w_m_odd

w_even = np.transpose(w_even, (1, 2, 3, 0))
w_odd = np.transpose(w_odd, (1, 2, 3, 0))
w_even = w_even.transpose((1, 2, 3, 0))
w_odd = w_odd.transpose((1, 2, 3, 0))

# evaluate the radial parts
v = np.zeros((n_data, self.count), dtype=x.dtype)
v = xp.zeros((n_data, self.count), dtype=x.dtype)
for ell in range(0, self.ell_max + 1):
k_max_ell = self.k_max[ell]
radial_wtd = self._precomp["radial_wtd"][:, 0:k_max_ell, ell]
Expand Down Expand Up @@ -388,4 +392,4 @@ def _evaluate_t(self, x):
# Roll dimensions, last dimension should be self.count,
# Higher dimensions like x.
v = v.reshape((*sz_roll, self.count))
return v
return xp.asnumpy(v)
7 changes: 4 additions & 3 deletions src/aspire/volume/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,15 +475,16 @@ def downsample(self, ds_res, mask=None):
v = self.stack_reshape(-1)

# take 3D Fourier transform of each volume in the stack
fx = fft.fftshift(fft.fftn(v._data, axes=(1, 2, 3)))
fx = xp.asnumpy(fft.fftshift(fft.fftn(xp.asarray(v._data), axes=(1, 2, 3))))
# crop each volume to the desired resolution in frequency space
crop_fx = (
np.array([crop_pad_3d(fx[i, :, :, :], ds_res) for i in range(self.n_vols)])
* mask
)
# inverse Fourier transform of each volume
out = fft.ifftn(fft.ifftshift(crop_fx), axes=(1, 2, 3)) * (
ds_res**3 / self.resolution**3
out = xp.asnumpy(
fft.ifftn(fft.ifftshift(xp.asarray(crop_fx)), axes=(1, 2, 3))
* (ds_res**3 / self.resolution**3)
)
# returns a new Volume object
return self.__class__(
Expand Down