diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index f9d9346f..b2f8c93e 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -29,7 +29,7 @@ jobs:
pip install jaxlib
pip install -r requirements/requirements-core.txt
pip install -r requirements/requirements-docs.txt
- pip install .\[torch\]
+ pip install .
- name: Build Documentation
run: |
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index e9d4acbf..f2c2ebc7 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -30,7 +30,7 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements/requirements-tests.txt
pip install -r requirements/requirements-core.txt
- pip install .\[torch\]
+ pip install .
- name: Run tests
run: |
diff --git a/.pip_readme.rst b/.pip_readme.rst
index d973c4dd..22458fa9 100644
--- a/.pip_readme.rst
+++ b/.pip_readme.rst
@@ -11,7 +11,7 @@
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black
.. image:: https://colab.research.google.com/assets/colab-badge.svg
- :target: https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing
+ :target: https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooksspherical_harmonic_transform.ipynb
Differentiable and accelerated spherical transforms
=================================================================================================================
@@ -31,6 +31,12 @@ As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.
+As of version 1.1.0 `S2FFT` also provides JAX support for existing C/C++ packages,
+specifically `HEALPix` and `SSHT`. This works by wrapping python bindings with custom
+JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
+limited to CPU, however for many applications this is desirable due to memory
+constraints.
+
Documentation
=============
Read the full documentation `here `_.
diff --git a/README.md b/README.md
index a1fed65f..8cc86547 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
[](https://badge.fury.io/py/s2fft)
[](https://arxiv.org/abs/2311.14670)
[](#contributors-)
-[](https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing)
+[](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/spherical_harmonic_transform.ipynb)
@@ -22,10 +22,20 @@ for adjoint transformations where needed, and comes with different
optimisations (precompute or not) that one may select depending on
available resources and desired angular resolution $L$.
+> [!IMPORTANT]
+> HEALPix long JIT compile time fixed for CPU! Fix for GPU coming soon.
+
+> [!TIP]
As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
precompute transforms. In future releases this support will be extended to our
on-the-fly algorithms.
+> [!TIP]
+As of version 1.1.0 `S2FFT` also provides JAX support for existing C/C++ packages,
+specifically `HEALPix` and `SSHT`. This works by wrapping python bindings with custom
+JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
+limited to CPU.
+
## Algorithms :zap:
`S2FFT` leverages new algorithmic structures that can he highly
@@ -53,7 +63,7 @@ diagram below illustrates the separable spherical harmonic transform
## Sampling :earth_africa:
The structure of the algorithms implemented in `S2FFT` can support any
-isolattitude sampling scheme. A number of sampling schemes are currently
+isolatitude sampling scheme. A number of sampling schemes are currently
supported.
The equiangular sampling schemes of [McEwen & Wiaux
@@ -73,10 +83,10 @@ so the corresponding harmonic transforms do not achieve machine
precision but exhibit some error. However, the HEALPix sampling provides
pixels of equal areas, which has many practical advantages.
-

+
> [!NOTE]
-> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and will work to improve this in subsequent versions.
+> For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example [notebook](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_HEALPix_backend.html)). Fix for GPU execution is coming soon.
## Installation :computer:
@@ -87,12 +97,7 @@ into the active python environment by [pip](https://pypi.org) when running
``` bash
pip install s2fft
```
-This will install all core functionality which includes JAX support. To install `S2FFT`
-with PyTorch support run
-
-``` bash
-pip install s2fft[torch]
-```
+This will install all core functionality which includes JAX support (including PyTorch support).
Alternatively, the `S2FFT` package may be installed directly from GitHub by cloning this
repository and then running
@@ -101,16 +106,22 @@ repository and then running
pip install .
```
-from the root directory of the repository. To enable PyTorch support you will need to run
+from the root directory of the repository.
+
+Unit tests can then be executed to ensure the installation was successful by first installing the test requirements and then running pytest
``` bash
-pip install .[torch]
+pip install -r requirements/requirements-tests.txt
+pytest tests/
```
-Unit tests can then be executed to ensure the installation was successful by running
+Documentation for the released version is available [here](https://astro-informatics.github.io/s2fft/). To build the documentation locally run
``` bash
-pytest tests/
+pip install -r requirements/requirements-docs.txt
+cd docs
+make html
+open _build/html/index.html
```
> [!NOTE]
@@ -143,7 +154,29 @@ For further details on usage see the [documentation](https://astro-informatics.g
> [!NOTE]
> We also provide PyTorch support for the precompute version of our transforms. These are called through forward/inverse_torch(). Full PyTorch support will be provided in future releases.
-## Benchmarking :hourglass_flowing_sand:
+## C/C++ JAX Frontends for SSHT/HEALPix :bulb:
+
+`S2FFT` also provides JAX support for existing C/C++ packages, specifically [`HEALPix`](https://healpix.jpl.nasa.gov) and [`SSHT`](https://github.com/astro-informatics/ssht). This works
+by wrapping python bindings with custom JAX frontends. Note that this C/C++ to JAX interoperability is currently limited to CPU.
+
+For example, one may call these alternate backends for the spherical harmonic transform by:
+
+``` python
+# Forward SSHT spherical harmonic transform
+flm = s2fft.forward(f, L, sampling=["mw"], method="jax_ssht")
+
+# Forward HEALPix spherical harmonic transform
+flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
+```
+
+All of these JAX frontends supports out of the box reverse mode automatic differentiation,
+and under the hood is simply linking to the C/C++ packages you are familiar with. In this
+way `S2fft` enhances existing packages with gradient functionality for modern scientific computing or machine learning
+applications!
+
+For further details on usage see the associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/JAX_SSHT_backend.html).
+
+
## Contributors ✨
diff --git a/docs/api/transforms/c_backend_spherical.rst b/docs/api/transforms/c_backend_spherical.rst
new file mode 100644
index 00000000..a0b66f28
--- /dev/null
+++ b/docs/api/transforms/c_backend_spherical.rst
@@ -0,0 +1,7 @@
+:html_theme.sidebar_secondary.remove:
+
+**************************
+C/C++ custom JAX support
+**************************
+.. automodule:: s2fft.transforms.c_backend_spherical
+ :members:
\ No newline at end of file
diff --git a/docs/api/transforms/index.rst b/docs/api/transforms/index.rst
index 28e63c8b..08b6165f 100644
--- a/docs/api/transforms/index.rst
+++ b/docs/api/transforms/index.rst
@@ -41,6 +41,25 @@ Transforms
* - :func:`~s2fft.transforms.wigner.forward_jax`
- Forward Wigner transform (JAX)
+.. list-table:: C/C++ backend gradient support
+ :widths: 25 25
+ :header-rows: 1
+
+ * - Function Name
+ - Description
+ * - :func:`~s2fft.transforms.c_backend_spherical.ssht_inverse`
+ - Custom JAX frontend for inverse SSHT C spherical harmonic library.
+ * - :func:`~s2fft.transforms.c_backend_spherical.ssht_forward`
+ - Custom JAX frontend for forward SSHT C spherical harmonic library.
+ * - :func:`~s2fft.transforms.c_backend_spherical.healpy_inverse`
+ - Custom JAX frontend for inverse HEALPix C++ spherical harmonic library.
+ * - :func:`~s2fft.transforms.c_backend_spherical.healpy_forward`
+ - Custom JAX frontend for forwardHEALPix C++ spherical harmonic library.
+ * - :func:`~s2fft.transforms.wigner.inverse_jax_ssht`
+ - Custom JAX frontend for hybrid inverse SSHT C Wigner transforms.
+ * - :func:`~s2fft.transforms.wigner.forward_jax_ssht`
+ - Custom JAX frontend for hybrid forward SSHT C Wigner transforms.
+
.. list-table:: On-the-fly Price-McEwen recursions.
:widths: 25 25
:header-rows: 1
@@ -64,4 +83,5 @@ Transforms
on_the_fly_recursions
spin_spherical_transform
wigner
+ .. c_backend_spherical
diff --git a/docs/assets/figures/spherical_sampling.png b/docs/assets/figures/spherical_sampling.png
index ded84d74..2c986954 100644
Binary files a/docs/assets/figures/spherical_sampling.png and b/docs/assets/figures/spherical_sampling.png differ
diff --git a/docs/conf.py b/docs/conf.py
index a3dedda5..c63110d4 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -25,9 +25,9 @@
author = "Matthew Price, Jason McEwen, Matthew Graham, Sofia Miñano, Devaraj Gopinathan"
# The short X.Y version
-version = "1.0.2"
+version = "1.1.0"
# The full version, including alpha/beta/rc tags
-release = "1.0.2"
+release = "1.1.0"
# -- General configuration ---------------------------------------------------
@@ -106,12 +106,12 @@
"icon": "_static/arxiv-logomark-small.png",
"type": "local",
},
- # {
- # "name": "YouTube",
- # "url": "https://www.youtube.com/channel/UCrCOQsyQOJhOUaIYzmbkKQQ",
- # "icon": "fa-brands fa-youtube fa-2x",
- # "type": "fontawesome",
- # },
+ {
+ "name": "Medium",
+ "url": "https://towardsdatascience.com/differentiable-and-accelerated-spherical-harmonic-transforms-c269393d08f1",
+ "icon": "fa-brands fa-medium",
+ "type": "fontawesome",
+ },
{
"name": "PyPi",
"url": "https://pypi.org/project/s2fft/",
diff --git a/docs/index.rst b/docs/index.rst
index c7b072fd..d679f9cc 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -11,9 +11,19 @@ transforms (for both real and complex signals), with support for adjoint transfo
where needed, and comes with different optimisations (precompute or not) that one
may select depending on available resources and desired angular resolution :math:`L`.
-As of version 1.0.2 ``S2FFT`` also provides PyTorch implementations of underlying
-precompute transforms. In future releases this support will be extended to our
-on-the-fly algorithms.
+.. important::
+ HEALPix long JIT compile time fixed for CPU! Fix for GPU coming soon.
+
+.. tip::
+ As of version 1.0.2 ``S2FFT`` also provides PyTorch implementations of underlying
+ precompute transforms. In future releases this support will be extended to our
+ on-the-fly algorithms.
+
+.. tip::
+ As of version 1.1.0 ``S2FFT`` also provides JAX support for existing C/C++ packages,
+ specifically ``HEALPix`` and ``SSHT``. This works by wrapping python bindings with custom
+ JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
+ limited to CPU.
Algorithms |:zap:|
-------------------
@@ -40,7 +50,7 @@ diagram below illustrates the separable spherical harmonic transform.
.. image:: ./assets/figures/sax_schematic_github_docs.png
.. note::
- For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and will work to improve this in subsequent versions.
+ For algorithmic reasons JIT compilation of HEALPix transforms can become slow at high bandlimits, due to XLA unfolding of loops which currently cannot be avoided. After compiling HEALPix transforms should execute with the efficiency outlined in the associated paper, therefore this additional time overhead need only be incurred once. We are aware of this issue and are working to fix it. A fix for CPU execution has now been implemented (see example `notebook `_). Fix for GPU execution is coming soon.
Sampling |:earth_africa:|
-----------------------------------
@@ -53,7 +63,7 @@ The equiangular sampling schemes of `McEwen & Wiaux (2012) `_ sampling scheme (`Gorski et al. 2005 `_) is also supported. The HEALPix sampling does not exhibit a sampling theorem and so the corresponding harmonic transforms do not achieve machine precision but exhibit some error. However, the HEALPix sampling provides pixels of equal areas, which has many practical advantages.
.. image:: ./assets/figures/spherical_sampling.png
- :width: 700
+ :width: 900
:align: center
Contributors ✨
diff --git a/docs/tutorials/JAX_HEALPix/JAX_HEALPix_frontend.nblink b/docs/tutorials/JAX_HEALPix/JAX_HEALPix_frontend.nblink
new file mode 100644
index 00000000..2a665f1e
--- /dev/null
+++ b/docs/tutorials/JAX_HEALPix/JAX_HEALPix_frontend.nblink
@@ -0,0 +1,3 @@
+{
+ "path": "../../../notebooks/JAX_HEALPix_frontend.ipynb"
+}
\ No newline at end of file
diff --git a/docs/tutorials/JAX_SSHT/JAX_SSHT_frontend.nblink b/docs/tutorials/JAX_SSHT/JAX_SSHT_frontend.nblink
new file mode 100644
index 00000000..0bbd2c1c
--- /dev/null
+++ b/docs/tutorials/JAX_SSHT/JAX_SSHT_frontend.nblink
@@ -0,0 +1,3 @@
+{
+ "path": "../../../notebooks/JAX_SSHT_frontend.ipynb"
+}
\ No newline at end of file
diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst
index 12ce7955..20702c99 100644
--- a/docs/tutorials/index.rst
+++ b/docs/tutorials/index.rst
@@ -9,7 +9,7 @@ in the time being feel free to contact contributors for advice! At a high-level
``S2FFT`` package is structured such that the 2 primary transforms, the Wigner and
spherical harmonic transforms, can easily be accessed.
-Usage |:rocket:|
+Core usage |:rocket:|
-----------------
To import and use ``S2FFT`` is as simple follows:
@@ -25,39 +25,27 @@ To import and use ``S2FFT`` is as simple follows:
| f = s2fft.inverse_jax(flm, L) | f = s2fft.wigner.inverse_jax(flmn, L, N) |
+-------------------------------------------------------+------------------------------------------------------------+
+C/C++ backend usage |:bulb:|
+-----------------
+``S2FFT`` also provides JAX support for existing C/C++ packages, specifically `HEALPix `_
+and `SSHT `_. This works
+by wrapping python bindings with custom JAX frontends. Note that currently this C/C++ to JAX interoperability is currently
+limited to CPU, however for many applications this is desirable due to memory constraints.
+
+For example, one may call these alternate backends for the spherical harmonic transform by:
+
+.. code-block:: python
-Benchmarking |:hourglass_flowing_sand:|
--------------------------------------
-We benchmarked the spherical harmonic and Wigner transforms implemented in ``S2FFT``
-against the C implementations in the `SSHT `_
-pacakge.
+ # Forward SSHT spherical harmonic transform
+ flm = s2fft.forward(f, L, sampling=["mw"], method="jax_ssht")
-A brief summary is shown in the table below for the recursion (left) and precompute
-(right) algorithms, with ``S2FFT`` running on GPUs (for further details see Price &
-McEwen, in prep.). Note that our compute time is agnostic to spin number (which is not
-the case for many other methods that scale linearly with spin).
+ # Forward HEALPix spherical harmonic transform
+ flm = s2fft.forward(f, L, nside=nside, sampling="healpix", method="jax_healpy")
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| | Recursive Algorithm | Precompute Algorithm |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| L | Wall-Time | Speed-up | Error | Wall-Time | Speed-up | Error | Memory |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| 64 | 3.6 ms | 0.88 | 1.81E-15 | 52.4 μs | 60.5 | 1.67E-15 | 4.2 MB |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| 128 | 7.26 ms | 1.80 | 3.32E-15 | 162 μs | 80.5 | 3.64E-15 | 33 MB |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| 256 | 17.3 ms | 6.32 | 6.66E-15 | 669 μs | 163 | 6.74E-15 | 268 MB |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| 512 | 58.3 ms | 11.4 | 1.43E-14 | 3.6 ms | 184 | 1.37E-14 | 2.14 GB |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| 1024 | 194 ms | 32.9 | 2.69E-14 | 32.6 ms | 195 | 2.47E-14 | 17.1 GB |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| 2048 | 1.44 s | 49.7 | 5.17E-14 | N/A | N/A | N/A | N/A |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| 4096 | 8.48 s | 133.9 | 1.06E-13 | N/A | N/A | N/A | N/A |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
-| 8192 | 82 s | 110.8 | 2.14E-13 | N/A | N/A | N/A | N/A |
-+------+-----------+-----------+----------+-----------+----------+----------+---------+
+All of these JAX frontends supports out of the box reverse mode automatic differentiation,
+and under the hood is simply linking to the C/C++ packages you are familiar with. In this
+way ``S2FFT`` enhances existing packages with gradient functionality for modern signal processing
+applications!
.. toctree::
@@ -69,3 +57,5 @@ the case for many other methods that scale linearly with spin).
wigner/wigner_transform.nblink
rotation/rotation.nblink
torch_frontend/torch_frontend.nblink
+ JAX_SSHT/JAX_SSHT_frontend.nblink
+ JAX_HEALPix/JAX_HEALPix_frontend.nblink
diff --git a/docs/user_guide/install.rst b/docs/user_guide/install.rst
index 10d15695..7dad3ef5 100644
--- a/docs/user_guide/install.rst
+++ b/docs/user_guide/install.rst
@@ -19,11 +19,7 @@ from PyPi by running
pip install s2fft
after which ``S2FFT`` may be imported and run as outlined in the associated notebooks and collab tutorials.
-To install the PyTorch functionality you will need to install the subpackage by running
-
-.. code-block:: bash
-
- pip install s2fft[torch]
+This will include PyTorch functionality.
Install from source (GitHub)
----------------------------
@@ -44,26 +40,25 @@ and pip installing locally
cd s2fft
pip install .
-from the root directory of the repository. To install the Pytorch support you will need to
-install the subpackage by running
+from the root directory of the repository. This will include PyTorch functionality.
-.. code-block:: bash
-
- pip install .[torch]
+Unit tests can then be executed to ensure the installation was successful by first installing the test requirements and then running pytest
-which, depending on operating system, can sometimes be
+.. code-block:: bash
-.. code-block:: bash
+ pip install -r requirements/requirements-tests.txt
+ pytest tests/
- pip install .\[torch\]
-
-Unit tests can then be executed to ensure the installation was successful by running
+Documentation for the released version is available `here `_. To build the documentation locally run
.. code-block:: bash
- pytest tests/
+ pip install -r requirements/requirements-docs.txt
+ cd docs
+ make html
+ open _build/html/index.html
+
-In the very near future one will be able to install ``S2FFT`` directly from `PyPi` by ``pip install s2fft`` but this is not yet supported.
Installing JAX for NVIDIA GPUs
------------------------------
diff --git a/notebooks/JAX_HEALPix_frontend.ipynb b/notebooks/JAX_HEALPix_frontend.ipynb
new file mode 100644
index 00000000..ea4b6d66
--- /dev/null
+++ b/notebooks/JAX_HEALPix_frontend.ipynb
@@ -0,0 +1,197 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# __JAX HEALPix frontend__\n",
+ "---\n",
+ "\n",
+ "[](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/JAX_HEALPix_frontend.ipyn)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Install s2fft\n",
+ "!pip install s2fft &> /dev/null"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This short tutorial demonstrates how to use the custom JAX frontend support `S2FFT` provides for the [`HEALPix`](https://healpix.jpl.nasa.gov) C++ library. This solves the long JIT compile time for HEALPix when running on CPU.\n",
+ "\n",
+ "As with the other introductions, let's import some packages and define an arbitrary bandlimited signal to work with."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "jax.config.update(\"jax_enable_x64\", True)\n",
+ "\n",
+ "import numpy as np\n",
+ "import s2fft \n",
+ "\n",
+ "L = 1024\n",
+ "nside = 512\n",
+ "method = \"jax_healpy\"\n",
+ "sampling = \"healpix\"\n",
+ "flm = np.random.randn(L, 2*L-1) + 1j*np.random.randn(L, 2*L-1)\n",
+ "f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Calling forward HEALPix C++ function from JAX.\n",
+ "\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "flm = s2fft.forward(f, L, nside=nside, sampling=sampling, method=method)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Calling inverse HEALPix C++ function from JAX.\n",
+ "\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "f_recov = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Computing the roundtrip error\n",
+ "\n",
+ "---\n",
+ "\n",
+ "Let's check the associated error, which should be around 1e-5 for healpix, which is not an exact sampling of the sphere. Note that increasing `iters` will reduce the numerical error here slightly, at the cost of linearly increased compute."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mean absolute error = 2.5921182352491347e-06\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Differentiating through HEALPix C++ functions.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "So far all this is doing is providing an interface between `JAX` and `HEALPix`, the real novelty comes when we differentiate through the C++ library."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define an arbitrary JAX function\n",
+ "def differentiable_test(flm) -> int:\n",
+ " f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)\n",
+ " return jax.numpy.nanmean(jax.numpy.abs(f)**2)\n",
+ "\n",
+ "# Create the JAX reverse mode gradient function\n",
+ "gradient_func = jax.grad(differentiable_test)\n",
+ "\n",
+ "# Compute the gradient automatically\n",
+ "gradient = gradient_func(flm)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Validating these gradients\n",
+ "\n",
+ "---\n",
+ "This is all well and good, but how do we know these gradients are correct? Thankfully `JAX` prvoides a simple function to check this..."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from jax.test_util import check_grads\n",
+ "check_grads(differentiable_test, (flm,), order=1, modes=(\"rev\"))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.10.4 ('s2fft')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.8"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/JAX_SSHT_frontend.ipynb b/notebooks/JAX_SSHT_frontend.ipynb
new file mode 100644
index 00000000..ec46c0b4
--- /dev/null
+++ b/notebooks/JAX_SSHT_frontend.ipynb
@@ -0,0 +1,195 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# __JAX SSHT frontend__\n",
+ "---\n",
+ "\n",
+ "[](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/JAX_SSHT_frontend.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Install s2fft\n",
+ "!pip install s2fft &> /dev/null"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This short tutorial demonstrates how to use the custom JAX frontend support `S2FFT` provides for the [`SSHT`](https://github.com/astro-informatics/ssht) C library.\n",
+ "\n",
+ "As with the other introductions, let's import some packages and define an arbitrary bandlimited signal to work with."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "jax.config.update(\"jax_enable_x64\", True)\n",
+ "\n",
+ "import numpy as np\n",
+ "import s2fft \n",
+ "\n",
+ "L = 1024\n",
+ "method = \"jax_ssht\"\n",
+ "flm = np.random.randn(L, 2*L-1) + 1j*np.random.randn(L, 2*L-1)\n",
+ "f = s2fft.inverse(flm, L, method=method)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Calling forward SSHT C function from JAX.\n",
+ "\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "flm = s2fft.forward(f, L, method=method)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Calling inverse SSHT C function from JAX.\n",
+ "\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "f_recov = s2fft.inverse(flm, L, method=method)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Computing the roundtrip error\n",
+ "\n",
+ "---\n",
+ "\n",
+ "Let's check the associated error, which should be close to machine precision for the sampling scheme used."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Mean absolute error = 4.909423754134027e-11\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Differentiating through SSHT C functions.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "So far all this is doing is providing an interface between `JAX` and `SSHT`, the real novelty comes when we differentiate through the C library."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define an arbitrary JAX function\n",
+ "def differentiable_test(flm) -> int:\n",
+ " f = s2fft.inverse(flm, L, method=method)\n",
+ " return jax.numpy.nanmean(jax.numpy.abs(f)**2)\n",
+ "\n",
+ "# Create the JAX reverse mode gradient function\n",
+ "gradient_func = jax.grad(differentiable_test)\n",
+ "\n",
+ "# Compute the gradient automatically\n",
+ "gradient = gradient_func(flm)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Validating these gradients\n",
+ "\n",
+ "---\n",
+ "This is all well and good, but how do we know these gradients are correct? Thankfully `JAX` prvoides a simple function to check this..."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from jax.test_util import check_grads\n",
+ "check_grads(differentiable_test, (flm,), order=1, modes=(\"rev\"))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.10.4 ('s2fft')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.8"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/custom_gradients.ipynb b/notebooks/custom_gradients.ipynb
deleted file mode 100644
index 41ec883a..00000000
--- a/notebooks/custom_gradients.ipynb
+++ /dev/null
@@ -1,111 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "cpu\n"
- ]
- }
- ],
- "source": [
- "# Specify CUDA device\n",
- "import os\n",
- "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
- "os.environ['JAX_CHECK_TRACER_LEAKS'] = 'True'\n",
- "\n",
- "import jax\n",
- "jax.config.update(\"jax_enable_x64\", True)\n",
- "\n",
- "# Check we're running on GPU\n",
- "from jax.lib import xla_bridge\n",
- "print(xla_bridge.get_backend().platform)\n",
- "\n",
- "from jax import jit, grad \n",
- "import jax.numpy as jnp \n",
- "from jax.test_util import check_grads\n",
- "import numpy as np \n",
- "\n",
- "import s2fft "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "L = 16\n",
- "sampling = \"mw\"\n",
- "np.random.seed(1911851)\n",
- "f_target = np.random.randn(2*L, 2*L-1)+1j*np.random.randn(2*L, 2*L-1)\n",
- "flm_target = s2fft.forward_jax(f_target, L, sampling=sampling)\n",
- "f_target = s2fft.inverse_jax(flm_target, L, sampling=sampling)\n",
- "precomps = s2fft.generate_precomputes_jax(L, forward=True, sampling=sampling)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "np.random.seed(130672510)\n",
- "f = np.random.randn(2*L, 2*L-1) + 1j*np.random.randn(2*L, 2*L-1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "def func(f):\n",
- " flm = s2fft.forward_jax(f, L, reality=False, precomps=precomps,sampling=sampling)\n",
- " return jnp.sum(jnp.abs(flm-flm_target)**2)\n",
- "grad_func = grad(func)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "check_grads(func, (f,), order=1, modes=('rev'))"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.0 ('s2fft')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.4"
- },
- "orig_nbformat": 4,
- "vscode": {
- "interpreter": {
- "hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/notebooks/spherical_harmonic_transform.ipynb b/notebooks/spherical_harmonic_transform.ipynb
index d885addf..73d898e2 100644
--- a/notebooks/spherical_harmonic_transform.ipynb
+++ b/notebooks/spherical_harmonic_transform.ipynb
@@ -4,10 +4,20 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# __Spherical harmonic transform__\n",
- "\n",
+ "# __Spherical harmonic transform__ \n",
"---\n",
- "\n"
+ "\n",
+ "[](https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/spherical_harmonic_transform.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Install s2fft\n",
+ "!pip install s2fft &> /dev/null"
]
},
{
@@ -15,11 +25,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "This tutorial demonstrates how to use `S2FFT` to compute spherical harmonic transforms.\n",
- "\n",
- "Specifically, we will adopt the sampling scheme of [McEwen & Wiaux (2012)](https://arxiv.org/abs/1110.6298). \n",
+ "This tutorial demonstrates how to use `S2FFT` to compute spherical harmonic transforms. Specifically, we will adopt the sampling scheme of [McEwen & Wiaux (2012)](https://arxiv.org/abs/1110.6298). \n",
"\n",
- "First let's load an input signal that is sampled on the sphere with this sampling scheme. We'll consider the Galactic plane map captured by ESA's [Gaia satellite](https://sci.esa.int/web/gaia)!"
+ "First let's load an input signal that is sampled on the sphere with this sampling scheme."
]
},
{
@@ -33,43 +41,11 @@
"\n",
"import numpy as np\n",
"import s2fft \n",
- "import plotting_functions\n",
"\n",
- "L = 1000\n",
+ "L = 256\n",
"sampling = \"mw\"\n",
- "f = np.load('data/Gaia_EDR3_flux.npy')"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Now, lets take a look at the data on the sphere using [PyVista](https://docs.pyvista.org/index.html). "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "e24324bbdc364d63ae9f34f68d19aba4",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Widget(value=\"