Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 20 additions & 26 deletions src/aspire/basis/ffb_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def _build(self):
# precompute the basis functions in 2D grids
self._precomp = self._precomp()

# include the normalization factor of angular part into radial part
self.radial_norm = xp.asarray(self._precomp["radial"]) / xp.asarray(
np.expand_dims(self.angular_norms, 1)
)

# precompute weighted nodes
self.gl_weighted_nodes = xp.asarray(self._precomp["gl_weights"]) * xp.asarray(
self._precomp["gl_nodes"]
)

def _precomp(self):
"""
Precomute the basis functions on a polar Fourier grid
Expand Down Expand Up @@ -105,7 +115,7 @@ def _evaluate(self, v):
coordinate basis. This is Image instance with resolution of `self.sz`
and the first dimension correspond to remaining dimension of `v`.
"""
v = xp.array(v)
v = xp.asarray(v)
sz_roll = v.shape[:-1]
v = v.reshape(-1, self.count)

Expand All @@ -123,26 +133,22 @@ def _evaluate(self, v):

idx = ind + xp.arange(self.k_max[0], dtype=int)

# include the normalization factor of angular part into radial part
radial_norm = xp.array(self._precomp["radial"]) / xp.array(
np.expand_dims(self.angular_norms, 1)
)
pf[:, 0, :] = v[:, xp.array(self._zero_angular_inds)] @ radial_norm[idx]
pf[:, 0, :] = v[:, xp.asarray(self._zero_angular_inds)] @ self.radial_norm[idx]
ind = ind + idx.size

ind_pos = ind

for ell in range(1, self.ell_max + 1):
idx = ind + xp.arange(self.k_max[ell], dtype=int)
idx_pos = ind_pos + xp.arange(self.k_max[ell], dtype=int)
idx_neg = idx_pos + xp.array(self.k_max[ell])
idx_neg = idx_pos + self.k_max[ell]

v_ell = (v[:, idx_pos] - 1j * v[:, idx_neg]) / 2.0

if np.mod(ell, 2) == 1:
v_ell = 1j * v_ell

pf_ell = v_ell @ radial_norm[idx]
pf_ell = v_ell @ self.radial_norm[idx]
pf[:, ell, :] = pf_ell

if np.mod(ell, 2) == 0:
Expand All @@ -151,17 +157,15 @@ def _evaluate(self, v):
pf[:, 2 * n_theta - ell, :] = -pf_ell.conjugate()

ind = ind + idx.size
ind_pos = ind_pos + 2 * xp.array(self.k_max[ell])
ind_pos = ind_pos + 2 * self.k_max[ell]

# 1D inverse FFT in the degree of polar angle
pf = 2 * xp.pi * fft.ifft(pf, axis=1)

# Only need "positive" frequencies.
hsize = int(pf.shape[1] / 2)
pf = pf[:, 0:hsize, :]
pf *= (
xp.array(self._precomp["gl_weights"]) * xp.array(self._precomp["gl_nodes"])
)[None, None, :]
pf *= self.gl_weighted_nodes[None, None, :]
pf = pf.reshape(n_data, n_r * n_theta)

# perform inverse non-uniformly FFT transform back to 2D coordinate basis
Expand Down Expand Up @@ -195,20 +199,14 @@ def _evaluate_t(self, x):
n_images = x.shape[0]

# resamping x in a polar Fourier gird using nonuniform discrete Fourier transform
pf = nufft(xp.array(x), 2 * pi * freqs)
pf = nufft(xp.asarray(x), 2 * pi * freqs)
pf = pf.reshape(n_images, n_r, n_theta)

# Recover "negative" frequencies from "positive" half plane.
pf = xp.concatenate((pf, pf.conjugate()), axis=2)

# evaluate radial integral using the Gauss-Legendre quadrature rule
pf = (
pf
* (
xp.array(self._precomp["gl_weights"])
* xp.array(self._precomp["gl_nodes"])
)[None, :, None]
)
pf *= self.gl_weighted_nodes[None, :, None]

# 1D FFT on the angular dimension for each concentric circle
pf = 2 * xp.pi / (2 * n_theta) * fft.fft(pf)
Expand All @@ -220,11 +218,7 @@ def _evaluate_t(self, x):
ind = 0
idx = ind + xp.arange(self.k_max[0])

# include the normalization factor of angular part into radial part
radial_norm = xp.array(
self._precomp["radial"] / np.expand_dims(self.angular_norms, 1)
)
v[:, self._zero_angular_inds] = pf[:, :, 0].real @ radial_norm[idx].T
v[:, self._zero_angular_inds] = pf[:, :, 0].real @ self.radial_norm[idx].T
ind = ind + idx.size

ind_pos = ind
Expand All @@ -233,7 +227,7 @@ def _evaluate_t(self, x):
idx_pos = ind_pos + xp.arange(self.k_max[ell])
idx_neg = idx_pos + self.k_max[ell]

v_ell = pf[:, :, ell] @ radial_norm[idx].T
v_ell = pf[:, :, ell] @ self.radial_norm[idx].T

if np.mod(ell, 2) == 0:
v_pos = v_ell.real
Expand Down
12 changes: 8 additions & 4 deletions src/aspire/basis/fle_2d_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def precomp_transform_complex_to_real(ells):
"""
count = len(ells)
num_nonzero = np.sum(ells == 0) + 2 * np.sum(ells != 0)
idx = xp.zeros(num_nonzero, dtype=int)
jdx = xp.zeros(num_nonzero, dtype=int)
vals = xp.zeros(num_nonzero, dtype=np.complex128)
idx = np.zeros(num_nonzero, dtype=int)
jdx = np.zeros(num_nonzero, dtype=int)
vals = np.zeros(num_nonzero, dtype=np.complex128)

k = 0
for i in range(count):
Expand Down Expand Up @@ -86,7 +86,11 @@ def precomp_transform_complex_to_real(ells):
jdx[k] = i + 1
k = k + 1

A = sparse.csr_matrix((vals, (idx, jdx)), shape=(count, count), dtype=np.complex128)
A = sparse.csr_matrix(
(xp.asarray(vals), (xp.asarray(idx), xp.asarray(jdx))),
shape=(count, count),
dtype=np.complex128,
)

return A.conjugate()

Expand Down