Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions src/aspire/basis/fb_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from aspire.basis import FBBasisMixin, SteerableBasis2D
from aspire.basis.basis_utils import unique_coords_nd
from aspire.utils import roll_dim, unroll_dim
from aspire.utils.matlab_compat import m_flatten, m_reshape

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -185,11 +183,11 @@ def _evaluate(self, v):
dimensions correspond to first dimensions of `v`.
"""
# Transpose here once, instead of several times below #RCOPT
v = v.reshape(-1, self.count).T
v = v.T

r_idx = self.basis_coords["r_idx"]
ang_idx = self.basis_coords["ang_idx"]
mask = m_flatten(self.basis_coords["mask"])
mask = self.basis_coords["mask"].flatten()

ind = 0
ind_radial = 0
Expand Down Expand Up @@ -220,32 +218,28 @@ def _evaluate(self, v):

return x

def _evaluate_t(self, v):
def _evaluate_t(self, x):
"""
Evaluate coefficient in FB basis from those in standard 2D coordinate basis

:param v: The coefficient array to be evaluated. The last dimensions
:param x: The coefficient array to be evaluated. The last dimensions
must equal `self.sz`.
:return: The evaluation of the coefficient array `v` in the dual basis
of `basis`. This is an array of vectors whose last dimension equals
`self.count` and whose first dimensions correspond to
first dimensions of `v`.
"""
v = v.T
x, sz_roll = unroll_dim(v, self.ndim + 1)
x = m_reshape(
x, new_shape=tuple([np.prod(self.sz)] + list(x.shape[self.ndim :]))
)
x = x.reshape(x.shape[0], -1)

r_idx = self.basis_coords["r_idx"]
ang_idx = self.basis_coords["ang_idx"]
mask = m_flatten(self.basis_coords["mask"])
mask = self.basis_coords["mask"].flatten()

ind = 0
ind_radial = 0
ind_ang = 0

v = np.zeros(shape=tuple([self.count] + list(x.shape[1:])), dtype=v.dtype)
v = np.zeros((x.shape[0], self.count), dtype=x.dtype)
for ell in range(0, self.ell_max + 1):
k_max = self.k_max[ell]
idx_radial = ind_radial + np.arange(0, k_max)
Expand All @@ -259,14 +253,13 @@ def _evaluate_t(self, v):
ang = self._precomp["ang"][:, ind_ang]
ang_radial = np.expand_dims(ang[ang_idx], axis=1) * radial[r_idx]
idx = ind + np.arange(0, k_max)
v[idx] = ang_radial.T @ x[mask]
v[:, idx] = x[:, mask] @ ang_radial
ind += len(idx)
ind_ang += 1

ind_radial += len(idx_radial)

v = roll_dim(v, sz_roll)
return v.T # RCOPT
return v

def calculate_bispectrum(
self, coef, flatten=False, filter_nonzero_freqs=False, freq_cutoff=None
Expand Down
42 changes: 14 additions & 28 deletions src/aspire/basis/fb_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from aspire.basis import Basis, FBBasisMixin
from aspire.basis.basis_utils import real_sph_harmonic, sph_bessel, unique_coords_nd
from aspire.utils import roll_dim, unroll_dim
from aspire.utils.matlab_compat import m_flatten, m_reshape

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -150,21 +148,17 @@ def _evaluate(self, v):
This is an array whose first dimensions equal `self.z` and the
remaining dimensions correspond to dimensions two and higher of `v`.
"""

v = v.T
v, sz_roll = unroll_dim(v, 2)

stack_shape = v.shape[:-1]
v = v.reshape(-1, v.shape[-1])
r_idx = self.basis_coords["r_idx"]
ang_idx = self.basis_coords["ang_idx"]
mask = m_flatten(self.basis_coords["mask"])
mask = self.basis_coords["mask"].flatten()

ind = 0
ind_radial = 0
ind_ang = 0

x = np.zeros(
shape=tuple([np.prod(self.sz)] + list(v.shape[1:])), dtype=self.dtype
)
x = np.zeros((v.shape[0], np.prod(self.sz)), dtype=self.dtype)
for ell in range(0, self.ell_max + 1):
k_max = self.k_max[ell]
idx_radial = ind_radial + np.arange(0, k_max)
Expand All @@ -176,43 +170,36 @@ def _evaluate(self, v):
ang = self._precomp["ang"][:, ind_ang]
ang_radial = np.expand_dims(ang[ang_idx], axis=1) * radial[r_idx]
idx = ind + np.arange(0, len(idx_radial))
x[mask] += ang_radial @ v[idx]
x[:, mask] += v[:, idx] @ ang_radial.T
ind += len(idx)
ind_ang += 1

ind_radial += len(idx_radial)

x = m_reshape(x, self.sz + x.shape[1:])
x = roll_dim(x, sz_roll)
return x.reshape(*stack_shape, *self.sz)

return x.T

def _evaluate_t(self, v):
def _evaluate_t(self, x):
"""
Evaluate coefficient in FB basis from those in standard 3D coordinate basis

:param v: The coefficient array to be evaluated. The first dimensions
:param x: The coefficient array to be evaluated. The first dimensions
must equal `self.sz`.
:return: The evaluation of the coefficient array `v` in the dual
basis of `basis`. This is an array of vectors whose first dimension
equals `self.count` and whose remaining dimensions correspond
to higher dimensions of `v`.
"""
v = v.T
x, sz_roll = unroll_dim(v, self.ndim + 1)
x = m_reshape(
x, new_shape=tuple([np.prod(self.sz)] + list(x.shape[self.ndim :]))
)

stack_shape = x.shape[: -self.ndim]
x = x.reshape(-1, np.prod(self.sz))
r_idx = self.basis_coords["r_idx"]
ang_idx = self.basis_coords["ang_idx"]
mask = m_flatten(self.basis_coords["mask"])
mask = self.basis_coords["mask"].flatten()

ind = 0
ind_radial = 0
ind_ang = 0

v = np.zeros(shape=tuple([self.count] + list(x.shape[1:])), dtype=self.dtype)
v = np.zeros((x.shape[0], self.count), dtype=self.dtype)
for ell in range(0, self.ell_max + 1):
k_max = self.k_max[ell]
idx_radial = ind_radial + np.arange(0, k_max)
Expand All @@ -224,11 +211,10 @@ def _evaluate_t(self, v):
ang = self._precomp["ang"][:, ind_ang]
ang_radial = np.expand_dims(ang[ang_idx], axis=1) * radial[r_idx]
idx = ind + np.arange(0, len(idx_radial))
v[idx] = np.real(ang_radial.T @ x[mask])
v[:, idx] = x[:, mask] @ ang_radial
ind += len(idx)
ind_ang += 1

ind_radial += len(idx_radial)

v = roll_dim(v, sz_roll)
return v.T
return v.reshape(*stack_shape, self.count)
31 changes: 13 additions & 18 deletions src/aspire/basis/ffb_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from aspire.numeric import fft, xp
from aspire.operators import BlkDiagMatrix
from aspire.utils import complex_type
from aspire.utils.matlab_compat import m_reshape

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -93,14 +92,9 @@ def _precomp(self):
ind_radial += 1

# Only calculate "positive" frequencies in one half-plane.
freqs_x = m_reshape(r, (n_r, 1)) @ m_reshape(
np.cos(np.arange(n_theta, dtype=self.dtype) * 2 * pi / (2 * n_theta)),
(1, n_theta),
)
freqs_y = m_reshape(r, (n_r, 1)) @ m_reshape(
np.sin(np.arange(n_theta, dtype=self.dtype) * 2 * pi / (2 * n_theta)),
(1, n_theta),
)
theta_grid = np.arange(n_theta, dtype=self.dtype) * 2 * pi / (2 * n_theta)
freqs_x = r[:, None] @ np.cos(theta_grid)[None, :]
freqs_y = r[:, None] @ np.sin(theta_grid)[None, :]
freqs = np.vstack((freqs_y[np.newaxis, ...], freqs_x[np.newaxis, ...]))

return {"gl_nodes": r, "gl_weights": w, "radial": radial, "freqs": freqs}
Expand All @@ -127,13 +121,13 @@ def _evaluate(self, v):
n_r = self._precomp["freqs"].shape[1]

# go through each basis function and find corresponding coefficient
pf = xp.zeros((n_data, 2 * n_theta, n_r), dtype=complex_type(self.dtype))
pf = xp.zeros((2 * n_theta, n_data, n_r), dtype=complex_type(self.dtype))

ind = 0

idx = ind + np.arange(self.k_max[0], dtype=int)

pf[:, 0, :] = v[:, self._zero_angular_inds] @ self.radial_norm[idx]
pf[0] = v[:, self._zero_angular_inds] @ self.radial_norm[idx]
ind = ind + idx.size

ind_pos = ind
Expand All @@ -149,27 +143,28 @@ def _evaluate(self, v):
v_ell = 1j * v_ell

pf_ell = v_ell @ self.radial_norm[idx]
pf[:, ell, :] = pf_ell
pf[ell] = pf_ell

if np.mod(ell, 2) == 0:
pf[:, 2 * n_theta - ell, :] = pf_ell.conjugate()
pf[2 * n_theta - ell] = pf_ell.conjugate()
else:
pf[:, 2 * n_theta - ell, :] = -pf_ell.conjugate()
pf[2 * n_theta - ell] = -pf_ell.conjugate()

ind = ind + idx.size
ind_pos = ind_pos + 2 * self.k_max[ell]

# 1D inverse FFT in the degree of polar angle
pf = 2 * xp.pi * fft.ifft(pf, axis=1)
pf = 2 * xp.pi * fft.ifft(pf, axis=0)

# Only need "positive" frequencies.
hsize = int(pf.shape[1] / 2)
pf = pf[:, 0:hsize, :]
hsize = int(pf.shape[0] / 2)
pf = pf[0:hsize]
pf *= self.gl_weighted_nodes[None, None, :]
pf = pf.transpose(1, 2, 0)
pf = pf.reshape(n_data, n_r * n_theta)

# perform inverse non-uniformly FFT transform back to 2D coordinate basis
freqs = m_reshape(self._precomp["freqs"], (2, n_r * n_theta))
freqs = self._precomp["freqs"].reshape(2, n_r * n_theta)

x = 2 * anufft(pf, 2 * pi * freqs, self.sz, real=True)

Expand Down
Loading
Loading