Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions notebooks/custom_gradients.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
"os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
"os.environ['JAX_CHECK_TRACER_LEAKS'] = 'True'\n",
"\n",
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"# Check we're running on GPU\n",
"from jax.lib import xla_bridge\n",
"print(xla_bridge.get_backend().platform)\n",
"\n",
"import jax\n",
"from jax import jit, grad \n",
"import jax.numpy as jnp \n",
"from jax.test_util import check_grads\n",
Expand Down Expand Up @@ -98,7 +97,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
10 changes: 5 additions & 5 deletions notebooks/spherical_harmonic_transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
"metadata": {},
"outputs": [],
"source": [
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"import numpy as np\n",
"import s2fft \n",
Expand Down Expand Up @@ -199,7 +199,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.16 64-bit ('s2fft')",
"display_name": "Python 3.10.4 ('s2fft')",
"language": "python",
"name": "python3"
},
Expand All @@ -213,12 +213,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "d6019e21eb0d27eebd69283f1089b8b605b46cb058a452b887458f3af7017e46"
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
}
}
},
Expand Down
4 changes: 2 additions & 2 deletions notebooks/spherical_rotation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
"metadata": {},
"outputs": [],
"source": [
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"import numpy as np\n",
"import s2fft \n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/wigner_transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
"metadata": {},
"outputs": [],
"source": [
"from jax.config import config\n",
"config.update(\"jax_enable_x64\", True) \n",
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"import numpy as np\n",
"import s2fft \n",
Expand Down
4 changes: 2 additions & 2 deletions s2fft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from .utils.rotation import rotate_flms, generate_rotate_dls

import logging
from jax.config import config
import jax

if config.read("jax_enable_x64") is False:
if jax.config.read("jax_enable_x64") is False:
logger = logging.getLogger("s2fft")
logger.warning(
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L."
Expand Down
4 changes: 2 additions & 2 deletions s2fft/precompute_transforms/construct.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jax import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

import numpy as np
import jax.numpy as jnp
Expand Down
6 changes: 3 additions & 3 deletions s2fft/recursions/risbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
r"""Compute Wigner-d at argument :math:`\beta` for full plane using
Risbo recursion.

The Wigner-d plane is computed by recursion over :math:`\ell` (`el`).
The Wigner-d plane is computed by recursion over :math:`\ell`.
Thus, for :math:`\ell > 0` the plane must be computed already for
:math:`\ell - 1`. At present, for :math:`\ell = 0` the recusion is initialised.

Expand All @@ -19,7 +19,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
el (int): Spherical harmonic degree :math:`\ell`.

Returns:
np.ndarray: Plane of Wigner-d for `el` and `beta`, with full plane computed.
np.ndarray: Plane of Wigner-d for :math:`\ell` and :math:`\beta`, with full plane computed.
"""

_arg_checks(dl, beta, L, el)
Expand Down Expand Up @@ -103,7 +103,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:


def _arg_checks(dl: np.ndarray, beta: float, L: int, el: int):
"""Check arguments of Risbo functions.
r"""Check arguments of Risbo functions.

Args:
dl (np.ndarray): Wigner-d plane of which to check shape.
Expand Down
4 changes: 2 additions & 2 deletions s2fft/transforms/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def f_bwd(res, gtm):
ftm = ftm.at[:, m_start_ind + m_offset :].multiply(phase_shifts)

# Perform longitundal Fast Fourier Transforms
ftm *= (-1) ** spin
ftm *= (-1) ** jnp.abs(spin)
if reality:
ftm = ftm.at[:, m_offset : L - 1 + m_offset].set(
jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1)
Expand Down Expand Up @@ -657,4 +657,4 @@ def f_bwd(res, glm):
indices = jnp.repeat(jnp.expand_dims(jnp.arange(L), -1), 2 * L - 1, axis=-1)
flm = jnp.where(indices < abs(spin), jnp.zeros_like(flm), flm[..., :])

return flm * (-1) ** spin
return flm * (-1) ** jnp.abs(spin)
2 changes: 1 addition & 1 deletion s2fft/transforms/wigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def inverse_jax(
def spherical_loop(n, args):
fban, flmn, lrenorm, vsign, spins = args
fban = fban.at[n].add(
(-1) ** spins[n]
(-1) ** jnp.abs(spins[n])
* s2fft.inverse_jax(
flmn[n],
L,
Expand Down
2 changes: 1 addition & 1 deletion s2fft/utils/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def rotate_flms(
dl = (
dl_array
if dl_array != None
else jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.complex128)
else jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.float64)
)

# Perform rotation
Expand Down
7 changes: 3 additions & 4 deletions tests/test_healpix_ffts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import healpy as hp
import pytest
from jax import config
import jax

jax.config.update("jax_enable_x64", True)
from s2fft.sampling import s2_samples as samples
from s2fft.utils.healpix_ffts import (
healpix_fft_jax,
Expand All @@ -11,9 +13,6 @@
)


config.update("jax_enable_x64", True)


nside_to_test = [4, 5]
reality_to_test = [False, True]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_spherical_custom_grads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jax import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import pytest
import jax.numpy as jnp
from jax.test_util import check_grads
Expand Down
4 changes: 2 additions & 2 deletions tests/test_spherical_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jax import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import pytest
import pyssht as ssht
import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jax.config import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import pytest
import pyssht as ssht
import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions tests/test_wigner_custom_grads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jax import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import pytest
import jax.numpy as jnp
from jax.test_util import check_grads
Expand Down
1 change: 0 additions & 1 deletion tests/test_wigner_precompute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import numpy as np
import s2fft
from s2fft.precompute_transforms.wigner import inverse, forward
from s2fft.precompute_transforms.construct import wigner_kernel
from s2fft.base_transforms import wigner as base
Expand Down
4 changes: 2 additions & 2 deletions tests/test_wigner_recursions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jax.config import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import pytest
import numpy as np
import jax.numpy as jnp
Expand Down
4 changes: 2 additions & 2 deletions tests/test_wigner_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from jax import config
import jax

config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)
import pytest
import numpy as np

Expand Down