Skip to content

Commit

Permalink
ENH update to latest JAX version and slight refactor in test suite (#94)
Browse files Browse the repository at this point in the history
* change `_wraps` to `implements` for compatibility with new jax version

* update `jax.config` usage due to deprecation

* black formatting

* pin tensorflow-probability

* REf make sure to use full import path

* Update setup.py

* Update bessel.py

* TST fix test for new error string

* TST fix more tests

* TST fix more tests

* TST fix tests for api

* TST simplify a bit

* ENH update submodule

* TST run specific test

* TST update submodule again

* TST now run only test_fft

* TST update to latest tests

* TST run all tests

* TST make GHA config cleaner

* TST make tests more robust

* TST add colors and workflow dipsatch to running tests

---------

Co-authored-by: Matthew R. Becker <beckermr@users.noreply.github.com>
Co-authored-by: beckermr <becker.mr@gmail.com>
  • Loading branch information
3 people committed Jun 5, 2024
1 parent 4b12d6b commit 3292342
Show file tree
Hide file tree
Showing 37 changed files with 323 additions and 304 deletions.
31 changes: 16 additions & 15 deletions .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
@@ -1,45 +1,46 @@
name: Python package

env:
PY_COLORS: "1"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

on:
push:
branches:
- main
pull_request:
workflow_dispatch: null

jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install isort flake8 pytest black==23.3.0 flake8-pyproject
python -m pip install pytest pre-commit
python -m pip install .
- name: Ensure black formatting
run: |
black --check jax_galsim/ tests/ --exclude "tests/GalSim/|tests/Coord/|tests/jax/galsim/"
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 jax_galsim/ --count --exit-zero --statistics
flake8 tests/jax/ --count --exit-zero --statistics
- name: Ensure isort
- name: Run pre-commit
run: |
isort --check jax_galsim
pre-commit run --all-files --show-diff-on-failure
- name: Test with pytest
run: |
git submodule update --init --recursive
pytest --durations=0
pytest -v --durations=0
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
language: python
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
rev: 7.0.0
hooks:
- id: flake8
entry: pflake8
Expand Down
8 changes: 4 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ And that's all you need to do from now on.

JAX-GalSim follows the NumPy/SciPy format: <https://numpydoc.readthedocs.io/en/latest/format.html>

However, most JAX-GalSim function will directly inherit the documentation from the reference GalSim project. We recommend avoid copy/pasting documentation, and instead using the `_wraps` utility to automatically reuse GalSim documentation:
However, most JAX-GalSim function will directly inherit the documentation from the reference GalSim project. We recommend avoid copy/pasting documentation, and instead using the `implements` utility to automatically reuse GalSim documentation:

```python
import galsim as _galsim
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

@_wraps(_galsim.Add,
@implements(_galsim.Add,
lax_description="Does not support `ChromaticObject` at this point.")
def Add(*args, **kwargs):
return Sum(*args, **kwargs)
Expand All @@ -160,4 +160,4 @@ Note that this tool has the option of providing a `lax_description` which will b
In order to be able to use JAX transformations, we need to be able to flatten and unflatten objects. This happens within the `tree_flatten` and `tree_unflatten` methods.
The unflattening can fail to work as expected when type checks are performed in the `__init__` method of a given object. To avoid this issue, the following strategy can used:

https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
20 changes: 10 additions & 10 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
# SOFTWARE.
import galsim as _galsim
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_to_float, ensure_hashable


@_wraps(_galsim.AngleUnit)
@implements(_galsim.AngleUnit)
@register_pytree_node_class
class AngleUnit(object):
valid_names = ["rad", "deg", "hr", "hour", "arcmin", "arcsec"]
Expand Down Expand Up @@ -61,7 +61,7 @@ def __div__(self, unit):
__truediv__ = __div__

@staticmethod
@_wraps(_galsim.AngleUnit.from_name)
@implements(_galsim.AngleUnit.from_name)
def from_name(unit):
unit = unit.strip().lower()
if unit.startswith("rad"):
Expand Down Expand Up @@ -127,7 +127,7 @@ def tree_unflatten(cls, aux_data, children):
arcsec = AngleUnit(jnp.pi / 648000.0)


@_wraps(_galsim.Angle)
@implements(_galsim.Angle)
@register_pytree_node_class
class Angle(object):
def __init__(self, theta, unit=None):
Expand Down Expand Up @@ -198,7 +198,7 @@ def __div__(self, other):

__truediv__ = __div__

@_wraps(_galsim.Angle.wrap)
@implements(_galsim.Angle.wrap)
def wrap(self, center=None):
if center is None:
center = _Angle(0.0)
Expand Down Expand Up @@ -329,15 +329,15 @@ def _make_dms_string(decimal, sep, prec, pad, plus_sign):
string = string + sep3
return string

@_wraps(_galsim.Angle.hms)
@implements(_galsim.Angle.hms)
def hms(self, sep=":", prec=None, pad=True, plus_sign=False):
if not len(sep) <= 3:
raise ValueError("sep must be a string or tuple of length <= 3")
if prec is not None and not prec >= 0:
raise ValueError("prec must be >= 0")
return self._make_dms_string(self / hours, sep, prec, pad, plus_sign)

@_wraps(_galsim.Angle.dms)
@implements(_galsim.Angle.dms)
def dms(self, sep=":", prec=None, pad=True, plus_sign=False):
if not len(sep) <= 3:
raise ValueError("sep must be a string or tuple of length <= 3")
Expand All @@ -346,12 +346,12 @@ def dms(self, sep=":", prec=None, pad=True, plus_sign=False):
return self._make_dms_string(self / degrees, sep, prec, pad, plus_sign)

@staticmethod
@_wraps(_galsim.Angle.from_hms)
@implements(_galsim.Angle.from_hms)
def from_hms(str):
return Angle._parse_dms(str) * hours

@staticmethod
@_wraps(_galsim.Angle.from_dms)
@implements(_galsim.Angle.from_dms)
def from_dms(str):
return Angle._parse_dms(str) * degrees

Expand Down Expand Up @@ -400,7 +400,7 @@ def tree_unflatten(cls, aux_data, children):
return ret


@_wraps(_galsim._Angle)
@implements(_galsim._Angle)
def _Angle(theta):
ret = Angle.__new__(Angle)
ret._rad = theta
Expand Down
8 changes: 4 additions & 4 deletions jax_galsim/bessel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
import tensorflow_probability as tfp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from tensorflow_probability.substrates.jax.math import bessel_kve as _tfp_bessel_kve


# the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp
Expand Down Expand Up @@ -91,7 +91,7 @@ def _si_small_pade(x, x2):
# fmt: on


@_wraps(_galsim.bessel.si)
@implements(_galsim.bessel.si)
@jax.jit
def si(x):
x2 = x * x
Expand All @@ -109,4 +109,4 @@ def kv(nu, x):
"""Modified Bessel 2nd kind"""
nu = 1.0 * nu
x = 1.0 * x
return tfp.substrates.jax.math.bessel_kve(nu, x) / jnp.exp(jnp.abs(x))
return _tfp_bessel_kve(nu, x) / jnp.exp(jnp.abs(x))
12 changes: 6 additions & 6 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable
Expand All @@ -16,7 +16,7 @@


# The reason for avoid these tests is that they are not easy to do for jitted code.
@_wraps(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR)
@implements(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class Bounds(_galsim.Bounds):
def _parse_args(self, *args, **kwargs):
Expand Down Expand Up @@ -104,7 +104,7 @@ def true_center(self):
)
return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0)

@_wraps(_galsim.Bounds.includes)
@implements(_galsim.Bounds.includes)
def includes(self, *args):
if len(args) == 1:
if isinstance(args[0], Bounds):
Expand Down Expand Up @@ -138,7 +138,7 @@ def includes(self, *args):
else:
raise TypeError("include takes at most 2 arguments (%d given)" % len(args))

@_wraps(_galsim.Bounds.expand)
@implements(_galsim.Bounds.expand)
def expand(self, factor_x, factor_y=None):
if factor_y is None:
factor_y = factor_x
Expand Down Expand Up @@ -266,7 +266,7 @@ def from_galsim(cls, galsim_bounds):
return _cls()


@_wraps(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR)
@implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class BoundsD(Bounds):
_pos_class = PositionD
Expand Down Expand Up @@ -300,7 +300,7 @@ def _center(self):
return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0)


@_wraps(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR)
@implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class BoundsI(Bounds):
_pos_class = PositionI
Expand Down
12 changes: 6 additions & 6 deletions jax_galsim/box.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import galsim as _galsim
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
Expand All @@ -9,7 +9,7 @@
from jax_galsim.random import UniformDeviate


@_wraps(_galsim.Box)
@implements(_galsim.Box)
@register_pytree_node_class
class Box(GSObject):
_has_hard_edges = True
Expand Down Expand Up @@ -100,7 +100,7 @@ def _drawKImage(self, image, jac=None):
_jac = jnp.eye(2) if jac is None else jac
return draw_by_kValue(self, image, _jac)

@_wraps(_galsim.Box.withFlux)
@implements(_galsim.Box.withFlux)
def withFlux(self, flux):
return Box(
width=self.width, height=self.height, flux=flux, gsparams=self.gsparams
Expand All @@ -116,7 +116,7 @@ def tree_unflatten(cls, aux_data, children):
**aux_data
)

@_wraps(_galsim.Box._shoot)
@implements(_galsim.Box._shoot)
def _shoot(self, photons, rng):
ud = UniformDeviate(rng)

Expand All @@ -126,7 +126,7 @@ def _shoot(self, photons, rng):
photons.flux = self.flux / photons.size()


@_wraps(_galsim.Pixel)
@implements(_galsim.Pixel)
@register_pytree_node_class
class Pixel(Box):
def __init__(self, scale, flux=1.0, gsparams=None):
Expand All @@ -153,7 +153,7 @@ def __str__(self):
s += ")"
return s

@_wraps(_galsim.Pixel.withFlux)
@implements(_galsim.Pixel.withFlux)
def withFlux(self, flux):
return Pixel(scale=self.scale, flux=flux, gsparams=self.gsparams)

Expand Down
Loading

0 comments on commit 3292342

Please sign in to comment.