Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1404575
add jax frontend support for c/c++ sht libraries
CosmoMatt Apr 8, 2024
da1b009
remove comment on healpix spherical custom grad tests
CosmoMatt Apr 8, 2024
c23d620
Remove partially implemented PyTorch optional dependence and update r…
jasonmcewen Apr 8, 2024
097cca0
update sampling figure to include gl sampling
CosmoMatt Apr 8, 2024
7bfb783
default ci to standard install over optional torch
CosmoMatt Apr 8, 2024
0d944f7
reorder sampling schemes in sampling diagram
CosmoMatt Apr 8, 2024
934d0d1
Update HEALPix comments in readme and docs and remove benchmarking
jasonmcewen Apr 8, 2024
cc5b63f
Merge
jasonmcewen Apr 8, 2024
54c48bc
Update HEALPix doc comment to match readme and change sampling plot s…
jasonmcewen Apr 8, 2024
68189c1
Minor cosmetic changes to comments in example
jasonmcewen Apr 8, 2024
45cd9e9
Update documentation
jasonmcewen Apr 8, 2024
363a308
remove commented Wigner tests
CosmoMatt Apr 8, 2024
bc64ff5
Merge branch 'feature/JAX_frontend_for_C++_codes' of https://github.c…
CosmoMatt Apr 8, 2024
42ef10c
clean error statement for c backend
CosmoMatt Apr 8, 2024
7320fe8
add google colab support for notebooks
CosmoMatt Apr 8, 2024
d0693fe
update collab link to automatically update in readme
CosmoMatt Apr 8, 2024
4c54ede
Correct colab links in notebooks to work in documentation
jasonmcewen Apr 8, 2024
5ea9c8d
Correct some docstrings; only warn for high spin for certain methods
jasonmcewen Apr 8, 2024
1ae80ba
Update main colab badge in readme to point to basic spherical harmoni…
jasonmcewen Apr 8, 2024
939ae74
Correct version number for upcoming release
jasonmcewen Apr 8, 2024
e184f53
Add comment about precomputes trading off memory for speed
jasonmcewen Apr 8, 2024
824e2c7
Remove highly optimised comments
jasonmcewen Apr 8, 2024
706bbb2
Remove highly optimised comments
jasonmcewen Apr 8, 2024
d6769cb
Correct upcoming version number (replace 1.0.3 by 1.1.0)
jasonmcewen Apr 8, 2024
64d5586
Correct upcoming version number (replace 1.0.3 by 1.1.0)
jasonmcewen Apr 8, 2024
baa412b
add error catch tests to CI
CosmoMatt Apr 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
pip install jaxlib
pip install -r requirements/requirements-core.txt
pip install -r requirements/requirements-docs.txt
pip install .\[torch\]
pip install .

- name: Build Documentation
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements/requirements-tests.txt
pip install -r requirements/requirements-core.txt
pip install .\[torch\]
pip install .

- name: Run tests
run: |
Expand Down
8 changes: 7 additions & 1 deletion .pip_readme.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black
.. image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing
:target: https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooksspherical_harmonic_transform.ipynb

Differentiable and accelerated spherical transforms
=================================================================================================================
Expand All @@ -31,6 +31,12 @@ As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.

As of version 1.1.0 `S2FFT` also provides JAX support for existing C/C++ packages,
specifically `HEALPix` and `SSHT`. This works by wrapping python bindings with custom
JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
limited to CPU, however for many applications this is desirable due to memory
constraints.

Documentation
=============
Read the full documentation `here <https://astro-informatics.github.io/s2fft/>`_.
Expand Down
65 changes: 49 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![image](https://badge.fury.io/py/s2fft.svg)](https://badge.fury.io/py/s2fft)
[![image](http://img.shields.io/badge/arXiv-2311.14670-orange.svg?style=flat)](https://arxiv.org/abs/2311.14670)<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-9-orange.svg?style=flat-square)](#contributors-)<!-- ALL-CONTRIBUTORS-BADGE:END -->
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/spherical_harmonic_transform.ipynb)
<!-- [![image](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -->

<img align="left" height="85" width="98" src="./docs/assets/sax_logo.png">
Expand All @@ -22,10 +22,20 @@ for adjoint transformations where needed, and comes with different
optimisations (precompute or not) that one may select depending on
available resources and desired angular resolution $L$.

> [!IMPORTANT]
> HEALPix long JIT compile time fixed for CPU! Fix for GPU coming soon.

> [!TIP]
As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.

> [!TIP]
As of version 1.1.0 `S2FFT` also provides JAX support for existing C/C++ packages,
specifically `HEALPix` and `SSHT`. This works by wrapping python bindings with custom
JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
limited to CPU.

## Algorithms :zap:

`S2FFT` leverages new algorithmic structures that can he highly
Expand Down Expand Up @@ -53,7 +63,7 @@ diagram below illustrates the separable spherical harmonic transform
## Sampling :earth_africa:

The structure of the algorithms implemented in `S2FFT` can support any
isolattitude sampling scheme. A number of sampling schemes are currently
isolatitude sampling scheme. A number of sampling schemes are currently
supported.

The equiangular sampling schemes of [McEwen & Wiaux
Expand All @@ -73,10 +83,10 @@ so the corresponding harmonic transforms do not achieve machine
precision but exhibit some error. However, the HEALPix sampling provides
pixels of equal areas, which has many practical advantages.

<p align="center"><img src="./docs/assets/figures/spherical_sampling.png" width="500"></p>
<p align="center"><img src="./docs/assets/figures/spherical_sampling.png" width="700"></p>

> [!NOTE]
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and will work to improve this in subsequent versions.
> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example [notebook](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html)). Fix for GPU execution is coming soon.

## Installation :computer:

Expand All @@ -87,12 +97,7 @@ into the active python environment by [pip](https://pypi.org) when running
``` bash
pip install s2fft
```
This will install all core functionality which includes JAX support. To install `S2FFT`
with PyTorch support run

``` bash
pip install s2fft[torch]
```
This will install all core functionality which includes JAX support (including PyTorch support).

Alternatively, the `S2FFT` package may be installed directly from GitHub by cloning this
repository and then running
Expand All @@ -101,16 +106,22 @@ repository and then running
pip install .
```

from the root directory of the repository. To enable PyTorch support you will need to run
from the root directory of the repository.

Unit tests can then be executed to ensure the installation was successful by first installing the test requirements and then running pytest

``` bash
pip install .[torch]
pip install -r requirements/requirements-tests.txt
pytest tests/
```

Unit tests can then be executed to ensure the installation was successful by running
Documentation for the released version is available [here](https://astro-informatics.github.io/s2fft/). To build the documentation locally run

``` bash
pytest tests/
pip install -r requirements/requirements-docs.txt
cd docs
make html
open _build/html/index.html
```

> [!NOTE]
Expand Down Expand Up @@ -143,7 +154,29 @@ For further details on usage see the [documentation](https://astro-informatics.g
> [!NOTE]
> We also provide PyTorch support for the precompute version of our transforms. These are called through forward/inverse_torch(). Full PyTorch support will be provided in future releases.

## Benchmarking :hourglass_flowing_sand:
## C/C++ JAX Frontends for SSHT/HEALPix :bulb:

`S2FFT` also provides JAX support for existing C/C++ packages, specifically [`HEALPix`](https://healpix.jpl.nasa.gov) and [`SSHT`](https://github.com/astro-informatics/ssht). This works
by wrapping python bindings with custom JAX frontends. Note that this C/C++ to JAX interoperability is currently limited to CPU.

For example, one may call these alternate backends for the spherical harmonic transform by:

``` python
# Forward SSHT spherical harmonic transform
flm = s2fft.forward(f, L, sampling=["mw"], method="jax_ssht")

# Forward HEALPix spherical harmonic transform
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
```

All of these JAX frontends supports out of the box reverse mode automatic differentiation,
and under the hood is simply linking to the C/C++ packages you are familiar with. In this
way `S2fft` enhances existing packages with gradient functionality for modern scientific computing or machine learning
applications!

For further details on usage see the associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_SSHT_backend.html).

<!-- ## Benchmarking :hourglass_flowing_sand:

We benchmarked the spherical harmonic and Wigner transforms implemented
in `S2FFT` against the C implementations in the
Expand All @@ -167,7 +200,7 @@ that scale linearly with spin).
| 8192 | 82 s | 110.8 | 2.14E-13 | N/A | N/A | N/A | N/A |

where the left hand results are for the recursive based algorithm and the right hand side are
our precompute implementation.
our precompute implementation. -->

## Contributors ✨

Expand Down
7 changes: 7 additions & 0 deletions docs/api/transforms/c_backend_spherical.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:html_theme.sidebar_secondary.remove:

**************************
C/C++ custom JAX support
**************************
.. automodule:: s2fft.transforms.c_backend_spherical
:members:
20 changes: 20 additions & 0 deletions docs/api/transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ Transforms
* - :func:`~s2fft.transforms.wigner.forward_jax`
- Forward Wigner transform (JAX)

.. list-table:: C/C++ backend gradient support
:widths: 25 25
:header-rows: 1

* - Function Name
- Description
* - :func:`~s2fft.transforms.c_backend_spherical.ssht_inverse`
- Custom JAX frontend for inverse SSHT C spherical harmonic library.
* - :func:`~s2fft.transforms.c_backend_spherical.ssht_forward`
- Custom JAX frontend for forward SSHT C spherical harmonic library.
* - :func:`~s2fft.transforms.c_backend_spherical.healpy_inverse`
- Custom JAX frontend for inverse HEALPix C++ spherical harmonic library.
* - :func:`~s2fft.transforms.c_backend_spherical.healpy_forward`
- Custom JAX frontend for forwardHEALPix C++ spherical harmonic library.
* - :func:`~s2fft.transforms.wigner.inverse_jax_ssht`
- Custom JAX frontend for hybrid inverse SSHT C Wigner transforms.
* - :func:`~s2fft.transforms.wigner.forward_jax_ssht`
- Custom JAX frontend for hybrid forward SSHT C Wigner transforms.

.. list-table:: On-the-fly Price-McEwen recursions.
:widths: 25 25
:header-rows: 1
Expand All @@ -64,4 +83,5 @@ Transforms
on_the_fly_recursions
spin_spherical_transform
wigner
.. c_backend_spherical

Binary file modified docs/assets/figures/spherical_sampling.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
author = "Matthew Price, Jason McEwen, Matthew Graham, Sofia Miñano, Devaraj Gopinathan"

# The short X.Y version
version = "1.0.2"
version = "1.1.0"
# The full version, including alpha/beta/rc tags
release = "1.0.2"
release = "1.1.0"


# -- General configuration ---------------------------------------------------
Expand Down Expand Up @@ -106,12 +106,12 @@
"icon": "_static/arxiv-logomark-small.png",
"type": "local",
},
# {
# "name": "YouTube",
# "url": "https://www.youtube.com/channel/UCrCOQsyQOJhOUaIYzmbkKQQ",
# "icon": "fa-brands fa-youtube fa-2x",
# "type": "fontawesome",
# },
{
"name": "Medium",
"url": "https://towardsdatascience.com/differentiable-and-accelerated-spherical-harmonic-transforms-c269393d08f1",
"icon": "fa-brands fa-medium",
"type": "fontawesome",
},
{
"name": "PyPi",
"url": "https://pypi.org/project/s2fft/",
Expand Down
20 changes: 15 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,19 @@ transforms (for both real and complex signals), with support for adjoint transfo
where needed, and comes with different optimisations (precompute or not) that one
may select depending on available resources and desired angular resolution :math:`L`.

As of version 1.0.2 ``S2FFT`` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.
.. important::
HEALPix long JIT compile time fixed for CPU! Fix for GPU coming soon.

.. tip::
As of version 1.0.2 ``S2FFT`` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.

.. tip::
As of version 1.1.0 ``S2FFT`` also provides JAX support for existing C/C++ packages,
specifically ``HEALPix`` and ``SSHT``. This works by wrapping python bindings with custom
JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
limited to CPU.

Algorithms |:zap:|
-------------------
Expand All @@ -40,7 +50,7 @@ diagram below illustrates the separable spherical harmonic transform.
.. image:: ./assets/figures/sax_schematic_github_docs.png

.. note::
For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and will work to improve this in subsequent versions.
For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example `notebook <https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html>`_). Fix for GPU execution is coming soon.

Sampling |:earth_africa:|
-----------------------------------
Expand All @@ -53,7 +63,7 @@ The equiangular sampling schemes of `McEwen & Wiaux (2012) <https://arxiv.org/ab
The popular `HEALPix <https://healpix.jpl.nasa.gov>`_ sampling scheme (`Gorski et al. 2005 <https://arxiv.org/abs/astro-ph/0409513>`_) is also supported. The HEALPix sampling does not exhibit a sampling theorem and so the corresponding harmonic transforms do not achieve machine precision but exhibit some error. However, the HEALPix sampling provides pixels of equal areas, which has many practical advantages.

.. image:: ./assets/figures/spherical_sampling.png
:width: 700
:width: 900
:align: center

Contributors ✨
Expand Down
3 changes: 3 additions & 0 deletions docs/tutorials/JAX_HEALPix/JAX_HEALPix_frontend.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../../../notebooks/JAX_HEALPix_frontend.ipynb"
}
3 changes: 3 additions & 0 deletions docs/tutorials/JAX_SSHT/JAX_SSHT_frontend.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../../../notebooks/JAX_SSHT_frontend.ipynb"
}
52 changes: 21 additions & 31 deletions docs/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ in the time being feel free to contact contributors for advice! At a high-level
``S2FFT`` package is structured such that the 2 primary transforms, the Wigner and
spherical harmonic transforms, can easily be accessed.

Usage |:rocket:|
Core usage |:rocket:|
-----------------
To import and use ``S2FFT`` is as simple follows:

Expand All @@ -25,39 +25,27 @@ To import and use ``S2FFT`` is as simple follows:
| f = s2fft.inverse_jax(flm, L) | f = s2fft.wigner.inverse_jax(flmn, L, N) |
+-------------------------------------------------------+------------------------------------------------------------+

C/C++ backend usage |:bulb:|
-----------------
``S2FFT`` also provides JAX support for existing C/C++ packages, specifically `HEALPix <https://healpix.jpl.nasa.gov>`_
and `SSHT <https://github.com/astro-informatics/ssht>`_. This works
by wrapping python bindings with custom JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
limited to CPU, however for many applications this is desirable due to memory constraints.

For example, one may call these alternate backends for the spherical harmonic transform by:

.. code-block:: python

Benchmarking |:hourglass_flowing_sand:|
-------------------------------------
We benchmarked the spherical harmonic and Wigner transforms implemented in ``S2FFT``
against the C implementations in the `SSHT <https://github.com/astro-informatics/ssht>`_
pacakge.
# Forward SSHT spherical harmonic transform
flm = s2fft.forward(f, L, sampling=["mw"], method="jax_ssht")

A brief summary is shown in the table below for the recursion (left) and precompute
(right) algorithms, with ``S2FFT`` running on GPUs (for further details see Price &
McEwen, in prep.). Note that our compute time is agnostic to spin number (which is not
the case for many other methods that scale linearly with spin).
# Forward HEALPix spherical harmonic transform
flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")

+------+-----------+-----------+----------+-----------+----------+----------+---------+
| | Recursive Algorithm | Precompute Algorithm |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| L | Wall-Time | Speed-up | Error | Wall-Time | Speed-up | Error | Memory |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 64 | 3.6 ms | 0.88 | 1.81E-15 | 52.4 μs | 60.5 | 1.67E-15 | 4.2 MB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 128 | 7.26 ms | 1.80 | 3.32E-15 | 162 μs | 80.5 | 3.64E-15 | 33 MB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 256 | 17.3 ms | 6.32 | 6.66E-15 | 669 μs | 163 | 6.74E-15 | 268 MB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 512 | 58.3 ms | 11.4 | 1.43E-14 | 3.6 ms | 184 | 1.37E-14 | 2.14 GB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 1024 | 194 ms | 32.9 | 2.69E-14 | 32.6 ms | 195 | 2.47E-14 | 17.1 GB |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 2048 | 1.44 s | 49.7 | 5.17E-14 | N/A | N/A | N/A | N/A |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 4096 | 8.48 s | 133.9 | 1.06E-13 | N/A | N/A | N/A | N/A |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
| 8192 | 82 s | 110.8 | 2.14E-13 | N/A | N/A | N/A | N/A |
+------+-----------+-----------+----------+-----------+----------+----------+---------+
All of these JAX frontends supports out of the box reverse mode automatic differentiation,
and under the hood is simply linking to the C/C++ packages you are familiar with. In this
way ``S2FFT`` enhances existing packages with gradient functionality for modern signal processing
applications!


.. toctree::
Expand All @@ -69,3 +57,5 @@ the case for many other methods that scale linearly with spin).
wigner/wigner_transform.nblink
rotation/rotation.nblink
torch_frontend/torch_frontend.nblink
JAX_SSHT/JAX_SSHT_frontend.nblink
JAX_HEALPix/JAX_HEALPix_frontend.nblink
Loading