Skip to content

Commit cb39a8a

Browse files
Simplify Fourier-Bessel basis list of zeros (#776)
* no m_reshape * implement ragged list of 1d np arrays * missed one r0 usage
1 parent 63255ce commit cb39a8a

File tree

8 files changed

+13
-18
lines changed

8 files changed

+13
-18
lines changed

src/aspire/basis/fb.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44

55
from aspire.basis.basis_utils import all_besselj_zeros
6-
from aspire.utils.matlab_compat import m_reshape
76

87
logger = logging.getLogger(__name__)
98

@@ -47,10 +46,6 @@ def _calc_k_max(self):
4746
# set the maximum of k for each ell
4847
self.k_max = np.array(n, dtype=int)
4948

50-
max_num_zeros = max(len(z) for z in zeros)
51-
for i, z in enumerate(zeros):
52-
zeros[i] = np.hstack(
53-
(z, np.zeros(max_num_zeros - len(z), dtype=self.dtype))
54-
)
55-
56-
self.r0 = m_reshape(np.hstack(zeros), (-1, self.ell_max + 1)).astype(self.dtype)
49+
# set the zeros for each ell
50+
# this is a ragged list of 1d ndarrays, where the i'th element is of size self.k_max[i]
51+
self.r0 = zeros

src/aspire/basis/fb_2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _precomp(self):
143143
for k in range(1, self.k_max[ell] + 1):
144144
# Only normalized by the radial part of basis function
145145
radial[:, ind_radial] = (
146-
jv(ell, self.r0[k - 1, ell] * r_unique)
146+
jv(ell, self.r0[ell][k - 1] * r_unique)
147147
/ self.radial_norms[ind_radial]
148148
)
149149
ind_radial += 1
@@ -177,7 +177,7 @@ def basis_norm_2d(self, ell, k):
177177
Calculate the normalized factors from radial and angular parts of a specified basis function
178178
"""
179179
rad_norm = (
180-
np.abs(jv(ell + 1, self.r0[k - 1, ell]))
180+
np.abs(jv(ell + 1, self.r0[ell][k - 1]))
181181
* np.sqrt(1 / 2.0)
182182
* self.nres
183183
/ 2.0

src/aspire/basis/fb_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _precomp(self):
105105

106106
for ell in range(0, self.ell_max + 1):
107107
for k in range(1, self.k_max[ell] + 1):
108-
radial[:, ind_radial] = sph_bessel(ell, self.r0[k - 1, ell] * r_unique)
108+
radial[:, ind_radial] = sph_bessel(ell, self.r0[ell][k - 1] * r_unique)
109109
ind_radial += 1
110110

111111
for m in range(-ell, ell + 1):
@@ -136,7 +136,7 @@ def basis_norm_3d(self, ell, k):
136136
Calculate the normalized factor of a specified basis function.
137137
"""
138138
return (
139-
np.abs(sph_bessel(ell + 1, self.r0[k - 1, ell]))
139+
np.abs(sph_bessel(ell + 1, self.r0[ell][k - 1]))
140140
/ np.sqrt(2)
141141
* np.sqrt((self.nres / 2) ** 3)
142142
)

src/aspire/basis/ffb_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _precomp(self):
7474
ind_radial = 0
7575
for ell in range(0, self.ell_max + 1):
7676
for k in range(1, self.k_max[ell] + 1):
77-
radial[ind_radial] = jv(ell, self.r0[k - 1, ell] * r / self.kcut)
77+
radial[ind_radial] = jv(ell, self.r0[ell][k - 1] * r / self.kcut)
7878
# NOTE: We need to remove the factor due to the discretization here
7979
# since it is already included in our quadrature weights
8080
# Only normalized by the radial part of basis function

src/aspire/basis/ffb_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ def _precomp(self):
7575
)
7676
for ell in range(0, self.ell_max + 1):
7777
k_max_ell = self.k_max[ell]
78-
rmat = r * self.r0[0:k_max_ell, ell].T / self.kcut
78+
rmat = r * self.r0[ell][0:k_max_ell].T / self.kcut
7979
radial_ell = np.zeros_like(rmat)
8080
for ik in range(0, k_max_ell):
8181
radial_ell[:, ik] = sph_bessel(ell, rmat[:, ik])
82-
nrm = np.abs(sph_bessel(ell + 1, self.r0[0:k_max_ell, ell].T) / 4)
82+
nrm = np.abs(sph_bessel(ell + 1, self.r0[ell][0:k_max_ell].T) / 4)
8383
radial_ell = radial_ell / nrm
8484
radial_ell_wtd = r**2 * wt_r * radial_ell
8585
radial_wtd[:, 0:k_max_ell, ell] = radial_ell_wtd

src/aspire/utils/filter_to_fb_mat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def filter_to_fb_mat(h_fun, fbasis):
4343
ind_ell = 0
4444
for ell in range(0, fbasis.ell_max + 1):
4545
k_max = fbasis.k_max[ell]
46-
rmat = 2 * k_vals.reshape(n_k, 1) * fbasis.r0[0:k_max, ell].T
46+
rmat = 2 * k_vals.reshape(n_k, 1) * fbasis.r0[ell][0:k_max].T
4747
fb_vals = np.zeros_like(rmat)
4848
ind_radial = np.sum(fbasis.k_max[0:ell])
4949
fb_vals[:, 0:k_max] = radial[ind_radial : ind_radial + k_max].T

tests/test_FBbasis2D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _testElement(self, ell, k, sgn):
5151
g2d = grid_2d(self.L, dtype=self.dtype)
5252
mask = g2d["r"] < 1
5353

54-
r0 = self.basis.r0[k, ell]
54+
r0 = self.basis.r0[ell][k]
5555

5656
im = np.zeros((self.L, self.L), dtype=self.dtype)
5757
im[mask] = jv(ell, g2d["r"][mask] * r0)

tests/test_FFBbasis2D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _testElement(self, ell, k, sgn):
4949
g2d = grid_2d(self.L, dtype=self.dtype)
5050
mask = g2d["r"] < 1
5151

52-
r0 = self.basis.r0[k, ell]
52+
r0 = self.basis.r0[ell][k]
5353

5454
# TODO: Figure out where these factors of 1 / 2 are coming from.
5555
# Intuitively, the grid should go from -L / 2 to L / 2, not -L / 2 to

0 commit comments

Comments
 (0)