Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
778d777
10081 pipeline doc and sym updates
garrettwrong Jul 2, 2024
eaa9437
CL sync c3c4 eps change, and numerical issue
garrettwrong Jul 2, 2024
92b2e93
fix doc " typo
garrettwrong Jul 10, 2024
01058cc
Log a diagnostic whether we are actually boosting anything
garrettwrong Jul 10, 2024
db84487
symmetry_group pass-through for ClassAvgSource.
j-c-c Jul 8, 2024
2511f43
Add symmetry_group to test_indexed_source.
j-c-c Jul 10, 2024
2b63e39
missing f in log message
garrettwrong Jul 12, 2024
e92971e
use utest_tolerance for single precision run to run variability
garrettwrong Jul 15, 2024
6e45cfc
added 2D projection stub
Jun 13, 2024
5ac3cbc
initial test file add
Jun 17, 2024
fa5f37b
Stashing initial project with test placeholder
Jun 18, 2024
f6834b2
Style Updates
Jun 18, 2024
aef4619
Pytest fixtures
Jun 21, 2024
df8953d
Cleanup
Jun 21, 2024
9970e81
changed nufft call
Jun 21, 2024
eafe89e
added stub for image stack line project
garrettwrong Jun 21, 2024
c078f84
Dimensional Test Fix
Jun 24, 2024
49ecd4b
Multidim FFT
Jun 24, 2024
399be2f
Integrated stack reshape to project
Jun 25, 2024
7994ecd
Fleshed out Image Project Single and Multidim Tests
Jun 26, 2024
59e4b24
Fixed the grid issues yay
Jun 27, 2024
5532c3b
Angle slow moving axis
Jun 28, 2024
6620a2e
Replaced FFT with rfft
Jun 28, 2024
223065f
Cleaned up other unit tests
Jun 28, 2024
cff9378
Added Doc Test and Cleaned up Code
Jul 6, 2024
5192d54
fixup sinogram tests and simpler multi test
garrettwrong Jul 10, 2024
b321364
fix irfft and shift
garrettwrong Jul 11, 2024
3d9d123
added angles but need to change multidim
Jul 11, 2024
d486dd2
Added Changes from PR: parameterized angles, adjusted tests according…
Jul 11, 2024
614f08d
Added extra comments + Integrated Changes from lineproject_dbg2 branch
Jul 18, 2024
2b2e019
Changed angle fixture description + Id Name
Jul 19, 2024
fa8bd54
Docstring len cleanup
garrettwrong Jul 19, 2024
2d5a459
add simple basis benchmark and plotting script
garrettwrong May 9, 2024
f4f41df
convert cufinufft towards cupy, keeping result on dvice
garrettwrong Jun 4, 2024
a60a3e0
convert anufft and nufft towards detecting whether to keep array on gpu
garrettwrong Jun 4, 2024
1a14068
whitespace
garrettwrong Jun 5, 2024
092cda0
add sparse cupy gpu wrapper and tests for methods in use
garrettwrong Jun 5, 2024
415c941
fixup mn
garrettwrong Jun 5, 2024
62d4620
first pass migrating FLE to cupy via xp
garrettwrong Jun 5, 2024
51f60db
add dct/idct to pyfftw, scipy, cupy wrappers
garrettwrong Jun 5, 2024
61e7db1
phase 2, fle internals
garrettwrong Jun 5, 2024
a84fb5a
mem cleanup workaround
garrettwrong Jun 5, 2024
5843309
cupy eigvals needs large problem or nans...
garrettwrong Jun 5, 2024
a780c51
crop pad updates
garrettwrong Jun 5, 2024
f4c8bf7
tox cleanup
garrettwrong Jun 5, 2024
440175c
evaluate_t on gpu.
Jun 4, 2024
a64b872
Optimize ffb2d for gpu.
Jun 5, 2024
8a6b4c4
downsample return
j-c-c Jun 11, 2024
81ba7af
remove unnecessary xp.array
j-c-c Jun 11, 2024
325129f
convert pf to complex
j-c-c Jun 11, 2024
af6d519
precompute radial_norm and gl_weighted_nodes in build.
j-c-c Jun 13, 2024
a98e00b
remove comment
j-c-c Jun 13, 2024
92c61f2
use asarray
j-c-c Jun 13, 2024
030062c
Remove cupy.fill culprit. un-cupy indices.
j-c-c Jun 13, 2024
58838e3
cupy.fill culprit in fle_2d. sparse indices.
j-c-c Jun 13, 2024
03697ff
bare min vol hack
garrettwrong Jun 13, 2024
eceaf25
bare min ffb3d hacks
garrettwrong Jun 13, 2024
1aae072
better style
garrettwrong Jun 13, 2024
d63b1dc
last cupy fill
j-c-c Jun 13, 2024
b16fa01
revert config to numpy/scipy
garrettwrong Jun 18, 2024
41e3208
fft host array preservation
garrettwrong Jun 18, 2024
ca657a7
interop crop_pad_2d
garrettwrong Jun 20, 2024
dbe66e5
interop fle radial convolve
garrettwrong Jun 20, 2024
8e2f200
cleanup
garrettwrong Jun 20, 2024
57a3679
remove bbenchmark code, hackathon over
garrettwrong Jun 20, 2024
07daa17
fix interop cp check
garrettwrong Jun 20, 2024
1947c7d
use cupy modes on ampere_gpu jobs
garrettwrong Jun 20, 2024
8e60d46
ws cleanup in gha config gen
garrettwrong Jun 20, 2024
c7eb9dd
remove older GPU environments.
garrettwrong Jun 20, 2024
8accd1f
fle basis to mat xp conversion
garrettwrong Jun 21, 2024
7b3b080
better eigsh sanity check
garrettwrong Jun 21, 2024
ac63b7c
cupy fft accuracy casting work around
garrettwrong Jun 24, 2024
47ee759
some numpy/cupy interop tweaks
garrettwrong Jun 24, 2024
e58e47a
more image interop tweaks
garrettwrong Jun 24, 2024
0b877b5
misc xp/numeric wrapper cleanup
garrettwrong Jun 24, 2024
5324e7f
precache fle x y grids on gpu
garrettwrong Jun 26, 2024
afbc468
Rm unneeded gc call
garrettwrong Jun 27, 2024
8526aa7
Add cupy GPU options to config tutorial
garrettwrong Jun 28, 2024
2ef3cd3
update GPU install docs
garrettwrong Jun 28, 2024
3579787
improve crop 3d xp interop
garrettwrong Jun 28, 2024
f35ad52
ffb2d self review cleanup
garrettwrong Jul 1, 2024
e448402
ffb3d move more grid precomp to gpu
garrettwrong Jul 1, 2024
286301e
Move more FLE2D grid precomp to GPU
garrettwrong Jul 1, 2024
db0e7d6
image self review cleanup
garrettwrong Jul 1, 2024
235979c
var name improvement
garrettwrong Jul 1, 2024
4f6ca0a
minor crop pad string improvements
garrettwrong Jul 1, 2024
6fa1ec6
Update volume downsample with crop_pad_3d improvements
garrettwrong Jul 1, 2024
2431495
add docstring
garrettwrong Jul 1, 2024
8db1221
enforce filter dtype
garrettwrong Jul 1, 2024
329de8f
explicitly force C order before cufinufft call
garrettwrong Jul 1, 2024
da18c56
Add dtype note and utest tolerance for singles
garrettwrong Jul 2, 2024
7e2bdb3
configuration doc wording (strings)
garrettwrong Jul 15, 2024
a7fa3f3
keep a few more vars as cupy
garrettwrong Jul 15, 2024
9ac4be9
replace xp.newaxis with None
garrettwrong Jul 15, 2024
5ab1c7b
put cache dir on new line
garrettwrong Jul 23, 2024
b86c2d5
rename tmp to ang_theta_wtd_trans
garrettwrong Jul 23, 2024
b9f263b
gpu to GPU and rm dev comment
garrettwrong Jul 23, 2024
0c20e2e
Add epsilon arg to PowerFilter.
j-c-c Jun 6, 2024
0ee7a52
Add threshold to whiten function with matlab default.
j-c-c Jun 6, 2024
d177574
Recast ArrayFilter result after scipy workaround upcast occurs.
j-c-c Jun 10, 2024
8648698
Revert to original threshold of eps(dtype) on PSD.
j-c-c Jun 10, 2024
582077f
remove comment
j-c-c Jun 11, 2024
644de3c
set default epsilon inside whiten function.
j-c-c Jun 11, 2024
1c09a39
test PowerFilter argument
j-c-c Jun 18, 2024
61d79fb
smoke test for whiten epsilon param.
j-c-c Jun 18, 2024
75a232b
bump scipy version. remove upcast.
j-c-c Jun 18, 2024
53b06e3
test that whiten safeguard is actually working.
j-c-c Jun 18, 2024
cf7a778
use np.testing in suite that failed on arm.
j-c-c Jun 20, 2024
f58976c
use np.testing in test_FLEbasis2D.py
j-c-c Jun 20, 2024
9984cac
revert to upcasting.
j-c-c Jun 20, 2024
6cba858
update comments
j-c-c Jun 20, 2024
6d18cec
remove scipy bump
j-c-c Jun 20, 2024
5aeca45
Revert scipy workaround.
j-c-c Jul 10, 2024
598d440
xfail ArrayFilter test for singles.
j-c-c Jul 10, 2024
440cda2
make xfail strict
j-c-c Jul 12, 2024
4be9032
remove strict param
j-c-c Jul 12, 2024
4ecf092
Use pytest fixtures.
j-c-c Jul 23, 2024
2368196
remove unused import
j-c-c Jul 23, 2024
3ae167f
minimal patch to support cupy install and disabled cufinufft
garrettwrong Jul 24, 2024
57d34e0
skip enormous FFB2D test on GPU
garrettwrong Jul 24, 2024
325b1de
simpler solution
garrettwrong Jul 25, 2024
ff06daa
make the long workflow not so long
garrettwrong Jul 25, 2024
7ae2f27
run long workflow on pull requests
garrettwrong Jul 25, 2024
a1c0727
Added backproject script and stub in the image folder
Jul 22, 2024
4ddbe8c
Added One-Dimension Test for Backproject
Jul 23, 2024
40d594c
Stashing
Jul 23, 2024
17b09dc
Fixed Scaling Issue with BackProject and Integrated NRMSE to One Stac…
Jul 25, 2024
d1c7fad
fixed single back_project test
Jul 26, 2024
9d6c2f8
reorg Line to avoid circ import. Interop Image/Line classes
garrettwrong Jul 26, 2024
5d5f9b7
adjust tests towards Line/Image interop
garrettwrong Jul 26, 2024
9165226
passing the 20/20 test cases and added Attributes + Methods to the Li…
Jul 26, 2024
1bca0a9
finished multidim test
Jul 30, 2024
f3a0925
removed unused statements
Jul 30, 2024
a1eba5b
initial fft changes
Jul 30, 2024
de5388f
stashing gpu fixes
Jul 30, 2024
01d5995
forward gpu
Jul 30, 2024
b9bee17
changed backproject to run on gpu (cupy)
Aug 2, 2024
0007b52
revert config
Aug 2, 2024
8952f53
fixed gpu issues
Aug 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .github/workflows/long_workflow.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
name: ASPIRE Python Long Running Test Suite

on:
push:
branches:
- 'main'
- 'develop'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]

jobs:
expensive_tests:
runs-on: self-hosted
# Only run on review ready pull_requests
if: ${{ github.event_name == 'pull_request' && github.event.pull_request.draft == false }}
timeout-minutes: 360
steps:
- uses: actions/checkout@v4
Expand All @@ -33,8 +33,9 @@ jobs:
cat ${WORK_DIR}/config.yaml
- name: Run
run: |
export OMP_NUM_THREADS=1
ASPIREDIR=${{ env.WORK_DIR }} python -c \
"import aspire; print(aspire.config['ray']['temp_dir'])"
ASPIREDIR=${{ env.WORK_DIR }} python -m pytest -m "expensive" --durations=0
ASPIREDIR=${{ env.WORK_DIR }} python -m pytest -n8 -m "expensive" --durations=0
- name: Cleanup
run: rm -rf ${{ env.WORK_DIR }}
5 changes: 4 additions & 1 deletion .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ jobs:
echo "Stash the WORK_DIR to GitHub env so we can clean it up later."
echo "WORK_DIR=${WORK_DIR}" >> $GITHUB_ENV
echo -e "ray:\n temp_dir: ${WORK_DIR}\n" > ${WORK_DIR}/config.yaml
echo -e "common:\n cache_dir: ${CI_CACHE_DIR}\n" >> ${WORK_DIR}/config.yaml
echo -e "common:" >> ${WORK_DIR}/config.yaml
echo -e " cache_dir: ${CI_CACHE_DIR}" >> ${WORK_DIR}/config.yaml
echo -e " numeric: cupy" >> ${WORK_DIR}/config.yaml
echo -e " fft: cupy\n" >> ${WORK_DIR}/config.yaml
echo "Log the config: ${WORK_DIR}/config.yaml"
cat ${WORK_DIR}/config.yaml
- name: Run
Expand Down
31 changes: 13 additions & 18 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,17 @@ an M1 laptop:
Installing GPU Extensions
*************************

ASPIRE does support GPUs, depending on several external packages. The
collection of GPU extensions can be installed using ``pip``.
Extensions are grouped based on CUDA versions. To find the CUDA
driver version, run ``nvidia-smi`` on the intended system.
ASPIRE does support using a GPU, depending on several external
packages. The collection of GPU extensions can be installed using
``pip``. Extensions are grouped based on CUDA versions. To find the
CUDA driver version, run ``nvidia-smi`` on the intended system.

.. list-table:: CUDA GPU Extension Versions
:widths: 25 25
:header-rows: 1

* - CUDA Version
- ASPIRE Extension
* - 10.2
- gpu-102
* - 11.0
- gpu-110
* - 11.1
- gpu-111
* - >=11.2
- gpu-11x
* - >=12
- gpu-12x

Expand All @@ -164,12 +156,15 @@ the command below would install GPU packages required for ASPIRE.


By default if the required GPU extensions are correctly installed,
ASPIRE should automatically begin using the GPU for select components
(such as those using ``nufft``).

Because GPU extensions depend on several third party packages and
libraries, we can only offer limited support if one of the packages
has a problem on your system.
ASPIRE should automatically begin using the GPU calls to our ``nufft`` module.

Using GPU in other areas of the code is still an experimental feature
and requires a minor configuration setting to enable ``cupy``. See the
:ref:`sphx_glr_auto_tutorials_configuration.py` for details. Because
GPU extensions depend on several third party softwares and machines
vary wildly, we can only offer limited support if one of the packages
has a problem on your system. We are currently expanding GPU code
coverage.

Generating Documentation
************************
Expand Down
19 changes: 12 additions & 7 deletions gallery/experiments/experimental_abinitio_pipeline_10081.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@

# Create a source object for the experimental images
src = RelionSource(
starfile_in, pixel_size=pixel_size, max_rows=n_imgs, data_folder=data_folder
starfile_in,
pixel_size=pixel_size,
max_rows=n_imgs,
data_folder=data_folder,
symmetry_group="C4",
)

# Downsample the images
Expand Down Expand Up @@ -115,12 +119,13 @@
# Volume Reconstruction
# ----------------------
#
# Using the oriented source, attempt to reconstruct a volume.
# Since this is a Cn symmetric molecule, as indicated by
# ``symmetry="C4"`` above, the ``avgs`` images set will be repeated
# for each of the 3 additional rotations during the back-projection
# step. This boosts the effective number of images used in the
# reconstruction from ``n_classes`` to ``4*n_classes``.
# Using the oriented source, attempt to reconstruct a volume. Since
# this is a Cn symmetric molecule, as specified by ``RelionSource(...,
# symmetry_group="C4", ...)``, the ``symmetry_group`` source attribute
# will flow through the pipeline to ``avgs``. Then each image will be
# repeated for each of the 3 additional rotations during
# back-projection. This boosts the effective number of images used in
# the reconstruction from ``n_classes`` to ``4*n_classes``.

logger.info("Begin Volume reconstruction")

Expand Down
30 changes: 30 additions & 0 deletions gallery/tutorials/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,36 @@
time.sleep(1)
print("Done Loop 2\n")

# %%
# Enabling GPU Acceleration
# -------------------------
# Enabling GPU acceleration requires installing supporting software
# packages and small config changes. Installing the supporting
# software is most easily accomplished by installing ASPIRE with one
# of the published GPU extensions, for example ``pip install
# "aspire[dev,gpu_12x]"``. Once the packages are installed users
# should find that the NUFFT calls are automatically running on the
# GPU. Additional acceleration is achieved by enabling `cupy` for
# `numeric` and `fft` components.
#
# .. code-block:: yaml
#
# common:
# # numeric module to use - one of numpy/cupy
# numeric: cupy
# # fft backend to use - one of pyfftw/scipy/cupy/mkl
# fft: cupy
#
# Alternatively, like other config options, this can be changed
# dynamically with code.
#
# .. code-block:: python
#
# from aspire import config
#
# config["common"]["numeric"] = "cupy"
# config["common"]["fft"] = "cupy"
#

# %%
# Resolution
Expand Down
6 changes: 1 addition & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ dependencies = [
"Source" = "https://github.com/ComputationalCryoEM/ASPIRE-Python"

[project.optional-dependencies]
gpu-102 = ["pycuda", "cupy-cuda102", "cufinufft==1.3"]
gpu-110 = ["pycuda", "cupy-cuda110", "cufinufft==1.3"]
gpu-111 = ["pycuda", "cupy-cuda111", "cufinufft==1.3"]
gpu-11x = ["pycuda", "cupy-cuda11x", "cufinufft==1.3"]
gpu-12x = ["pycuda", "cupy-cuda12x", "cufinufft==2.2.0"]
gpu-12x = ["cupy-cuda12x", "cufinufft==2.2.0"]
dev = [
"black",
"bumpversion",
Expand Down
5 changes: 3 additions & 2 deletions src/aspire/abinitio/commonline_c3_c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
n_theta=None,
max_shift=0.15,
shift_step=1,
epsilon=1e-3,
epsilon=1e-2,
max_iters=1000,
degree_res=1,
seed=None,
Expand Down Expand Up @@ -691,7 +691,8 @@ def _J_sync_power_method(self, vijs):
)
while itr < max_iters and residual > epsilon:
itr += 1
vec_new = self._signs_times_v(vijs, vec)
# Note, this appears to need double precision for accuracy in the following division.
vec_new = self._signs_times_v(vijs, vec).astype(np.float64, copy=False)
vec_new = vec_new / norm(vec_new)
residual = norm(vec_new - vec)
vec = vec_new
Expand Down
85 changes: 42 additions & 43 deletions src/aspire/basis/ffb_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def _build(self):
# precompute the basis functions in 2D grids
self._precomp = self._precomp()

# include the normalization factor of angular part into radial part
self.radial_norm = xp.asarray(self._precomp["radial"]) / xp.asarray(
np.expand_dims(self.angular_norms, 1)
)

# precompute weighted nodes
self.gl_weighted_nodes = xp.asarray(self._precomp["gl_weights"]) * xp.asarray(
self._precomp["gl_nodes"]
)

def _precomp(self):
"""
Precomute the basis functions on a polar Fourier grid
Expand Down Expand Up @@ -105,32 +115,31 @@ def _evaluate(self, v):
coordinate basis. This is Image instance with resolution of `self.sz`
and the first dimension correspond to remaining dimension of `v`.
"""
v = xp.asarray(v)
sz_roll = v.shape[:-1]
v = v.reshape(-1, self.count)

# number of 2D image samples
n_data = v.shape[0]

# get information on polar grids from precomputed data
n_theta = np.size(self._precomp["freqs"], 2)
n_r = np.size(self._precomp["freqs"], 1)
n_theta = self._precomp["freqs"].shape[2]
n_r = self._precomp["freqs"].shape[1]

# go through each basis function and find corresponding coefficient
pf = np.zeros((n_data, 2 * n_theta, n_r), dtype=complex_type(self.dtype))
pf = xp.zeros((n_data, 2 * n_theta, n_r), dtype=complex_type(self.dtype))

ind = 0

idx = ind + np.arange(self.k_max[0], dtype=int)

# include the normalization factor of angular part into radial part
radial_norm = self._precomp["radial"] / np.expand_dims(self.angular_norms, 1)
pf[:, 0, :] = v[:, self._zero_angular_inds] @ radial_norm[idx]
ind = ind + np.size(idx)
pf[:, 0, :] = v[:, self._zero_angular_inds] @ self.radial_norm[idx]
ind = ind + idx.size

ind_pos = ind

for ell in range(1, self.ell_max + 1):
idx = ind + np.arange(self.k_max[ell], dtype=int)
idx = ind + xp.arange(self.k_max[ell], dtype=int)
idx_pos = ind_pos + np.arange(self.k_max[ell], dtype=int)
idx_neg = idx_pos + self.k_max[ell]

Expand All @@ -139,30 +148,25 @@ def _evaluate(self, v):
if np.mod(ell, 2) == 1:
v_ell = 1j * v_ell

pf_ell = v_ell @ radial_norm[idx]
pf_ell = v_ell @ self.radial_norm[idx]
pf[:, ell, :] = pf_ell

if np.mod(ell, 2) == 0:
pf[:, 2 * n_theta - ell, :] = pf_ell.conjugate()
else:
pf[:, 2 * n_theta - ell, :] = -pf_ell.conjugate()

ind = ind + np.size(idx)
ind = ind + idx.size
ind_pos = ind_pos + 2 * self.k_max[ell]

# 1D inverse FFT in the degree of polar angle
pf = 2 * pi * xp.asnumpy(fft.ifft(xp.asarray(pf), axis=1))
pf = 2 * xp.pi * fft.ifft(pf, axis=1)

# Only need "positive" frequencies.
hsize = int(np.size(pf, 1) / 2)
hsize = int(pf.shape[1] / 2)
pf = pf[:, 0:hsize, :]

for i_r in range(0, n_r):
pf[..., i_r] = pf[..., i_r] * (
self._precomp["gl_weights"][i_r] * self._precomp["gl_nodes"][i_r]
)

pf = np.reshape(pf, (n_data, n_r * n_theta))
pf *= self.gl_weighted_nodes[None, None, :]
pf = pf.reshape(n_data, n_r * n_theta)

# perform inverse non-uniformly FFT transform back to 2D coordinate basis
freqs = m_reshape(self._precomp["freqs"], (2, n_r * n_theta))
Expand All @@ -172,7 +176,7 @@ def _evaluate(self, v):
# Return X as Image instance with the last two dimensions as *self.sz
x = x.reshape((*sz_roll, *self.sz))

return x
return xp.asnumpy(x)

def _evaluate_t(self, x):
"""
Expand All @@ -193,56 +197,51 @@ def _evaluate_t(self, x):
n_images = x.shape[0]

# resamping x in a polar Fourier gird using nonuniform discrete Fourier transform
pf = nufft(x, 2 * pi * freqs)
pf = np.reshape(pf, (n_images, n_r, n_theta))
pf = nufft(xp.asarray(x), 2 * pi * freqs)
pf = pf.reshape(n_images, n_r, n_theta)

# Recover "negative" frequencies from "positive" half plane.
pf = np.concatenate((pf, pf.conjugate()), axis=2)
pf = xp.concatenate((pf, pf.conjugate()), axis=2)

# evaluate radial integral using the Gauss-Legendre quadrature rule
for i_r in range(0, n_r):
pf[:, i_r, :] = pf[:, i_r, :] * (
self._precomp["gl_weights"][i_r] * self._precomp["gl_nodes"][i_r]
)
pf = pf * self.gl_weighted_nodes[None, :, None]

# 1D FFT on the angular dimension for each concentric circle
pf = 2 * pi / (2 * n_theta) * xp.asnumpy(fft.fft(xp.asarray(pf)))
pf = 2 * xp.pi / (2 * n_theta) * fft.fft(pf)

# This only makes it easier to slice the array later.
v = np.zeros((n_images, self.count), dtype=x.dtype)
v = xp.zeros((n_images, self.count), dtype=x.dtype)

# go through each basis function and find the corresponding coefficient
ind = 0
idx = ind + np.arange(self.k_max[0])
idx = ind + xp.arange(self.k_max[0])

# include the normalization factor of angular part into radial part
radial_norm = self._precomp["radial"] / np.expand_dims(self.angular_norms, 1)
v[:, self._zero_angular_inds] = pf[:, :, 0].real @ radial_norm[idx].T
ind = ind + np.size(idx)
v[:, self._zero_angular_inds] = pf[:, :, 0].real @ self.radial_norm[idx].T
ind = ind + idx.size

ind_pos = ind
for ell in range(1, self.ell_max + 1):
idx = ind + np.arange(self.k_max[ell])
idx_pos = ind_pos + np.arange(self.k_max[ell])
idx = ind + xp.arange(self.k_max[ell])
idx_pos = ind_pos + xp.arange(self.k_max[ell])
idx_neg = idx_pos + self.k_max[ell]

v_ell = pf[:, :, ell] @ radial_norm[idx].T
v_ell = pf[:, :, ell] @ self.radial_norm[idx].T

if np.mod(ell, 2) == 0:
v_pos = np.real(v_ell)
v_neg = -np.imag(v_ell)
v_pos = v_ell.real
v_neg = -v_ell.imag
else:
v_pos = np.imag(v_ell)
v_neg = np.real(v_ell)
v_pos = v_ell.imag
v_neg = v_ell.real

v[:, idx_pos] = v_pos
v[:, idx_neg] = v_neg

ind = ind + np.size(idx)
ind = ind + idx.size

ind_pos = ind_pos + 2 * self.k_max[ell]

return v
return xp.asnumpy(v)

def filter_to_basis_mat(self, f, **kwargs):
"""
Expand Down
Loading