Skip to content

Conversation

CosmoMatt
Copy link
Collaborator

@CosmoMatt CosmoMatt commented Apr 8, 2024

This PR adds frontend JAX support for existing python wrapped C/C++ spherical harmonic transform libraries; specifically ssht and healpy, though in principle any spherical harmonic transform may be straightforwardly included.

The JAX frontend we provide provides automatic reverse mode differentiation, so existing packages may be integrated as transforms within differentiable programming pipelines.

This PR includes:

  • JAX frontend for SSHT library (for McEwen Wiaux, Driscoll Healy, Gauss Legendre sampling schemes)
  • JAX frontend for HEALPix library (for HEALPix sampling scheme)
  • Notebooks demonstrating the above
  • Update core requirements
  • Unit testing of reverse mode gradients against finite differences
  • Update docstrings across the board.

Important

Our JAX wrappers are limited by the hardware on which these C/C++ libraries are designed to run. So currently these transforms run only on CPU, however they are highly optimised and consequently very fast. For applications with a large memory overhead (e.g. sampling methods on the sphere), computing everything on CPU may be more useful to avoid throttling due to I/O.

Tip

Users should notice that this functionality for HEALPix sampling avoids the compile time issues we are still investigating in the core GPU transforms. So this may be a useful stop-gap until this issue is resolve.

@CosmoMatt CosmoMatt added documentation Improvements or additions to documentation enhancement New feature or request labels Apr 8, 2024
@CosmoMatt CosmoMatt requested a review from jasonmcewen April 8, 2024 10:58
@CosmoMatt CosmoMatt self-assigned this Apr 8, 2024
Copy link

codecov bot commented Apr 8, 2024

Codecov Report

Attention: Patch coverage is 94.44444% with 8 lines in your changes are missing coverage. Please review.

Project coverage is 92.63%. Comparing base (b3d033c) to head (64d5586).

❗ Current head 64d5586 differs from pull request most recent head baa412b. Consider uploading reports for the commit baa412b to get more accurate results

Files Patch % Lines
s2fft/transforms/spherical.py 80.95% 4 Missing ⚠️
s2fft/_version.py 0.00% 2 Missing ⚠️
s2fft/transforms/wigner.py 94.59% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #195      +/-   ##
==========================================
+ Coverage   92.49%   92.63%   +0.14%     
==========================================
  Files          27       28       +1     
  Lines        2971     3109     +138     
==========================================
+ Hits         2748     2880     +132     
- Misses        223      229       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

CosmoMatt and others added 24 commits April 8, 2024 12:37
@CosmoMatt CosmoMatt merged commit 76fa862 into main Apr 9, 2024
@CosmoMatt CosmoMatt deleted the feature/JAX_frontend_for_C++_codes branch April 9, 2024 07:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants