Skip to content

Commit

Permalink
REF use our own implements decorator (#95)
Browse files Browse the repository at this point in the history
* REF use our own implements decorator

* ENH add pre-commit

* TST remove pre-commit from tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* DOC add pre-commit badge

* BUG old syntax

* BUG remove typing eww

* BUG add this back

* DOC typos

Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 7, 2024
1 parent 3292342 commit 4cef308
Show file tree
Hide file tree
Showing 35 changed files with 320 additions and 59 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest pre-commit
python -m pip install pytest
python -m pip install .
- name: Run pre-commit
run: |
pre-commit run --all-files --show-diff-on-failure
- name: Test with pytest
run: |
git submodule update --init --recursive
Expand Down
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
ci:
autofix_commit_msg: |
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
autofix_prs: false
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: monthly
skip: []
submodules: false

repos:
- repo: https://github.com/psf/black
rev: 23.9.1
Expand Down
4 changes: 2 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ However, most JAX-GalSim function will directly inherit the documentation from t

```python
import galsim as _galsim
from jax._src.numpy.util import implements
from jax_galsim.core.utils import implements
from jax.tree_util import register_pytree_node_class

@implements(_galsim.Add,
Expand All @@ -157,7 +157,7 @@ Note that this tool has the option of providing a `lax_description` which will b

### Flattening and Unflattening of objects

In order to be able to use JAX transformations, we need to be able to flatten and unflatten objects. This happens within the `tree_flatten` and `tree_unflatten` methods.
In order to be able to use JAX transformations, we need to be able to flatten and unflatten objects. This happens within the `tree_flatten` and `tree_unflatten` methods.
The unflattening can fail to work as expected when type checks are performed in the `__init__` method of a given object. To avoid this issue, the following strategy can used:

https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.**

[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](code_of_conduct.md) [![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](code_of_conduct.md) [![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg)](https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main)

**Disclaimer**: This project is still in an early development phase, **please use the [reference GalSim implementation](https://github.com/GalSim-developers/GalSim) for any scientific applications.**

Expand Down
3 changes: 1 addition & 2 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
# SOFTWARE.
import galsim as _galsim
import jax.numpy as jnp
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_to_float, ensure_hashable
from jax_galsim.core.utils import cast_to_float, ensure_hashable, implements


@implements(_galsim.AngleUnit)
Expand Down
3 changes: 2 additions & 1 deletion jax_galsim/bessel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax._src.numpy.util import implements
from tensorflow_probability.substrates.jax.math import bessel_kve as _tfp_bessel_kve

from jax_galsim.core.utils import implements


# the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp
@jax.jit
Expand Down
8 changes: 6 additions & 2 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable
from jax_galsim.core.utils import (
cast_to_float,
cast_to_int,
ensure_hashable,
implements,
)
from jax_galsim.position import Position, PositionD, PositionI

BOUNDS_LAX_DESCR = """\
Expand Down
3 changes: 1 addition & 2 deletions jax_galsim/box.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import galsim as _galsim
import jax.numpy as jnp
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import ensure_hashable
from jax_galsim.core.utils import ensure_hashable, implements
from jax_galsim.gsobject import GSObject
from jax_galsim.random import UniformDeviate

Expand Down
3 changes: 1 addition & 2 deletions jax_galsim/celestial.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.angle import Angle, _Angle, arcsec, degrees, radians
from jax_galsim.core.utils import ensure_hashable
from jax_galsim.core.utils import ensure_hashable, implements


# we have to copy this one since JAX sends in `t` as a traced array
Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/convolve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import galsim as _galsim
import jax.numpy as jnp
from galsim.errors import galsim_warn
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import implements
from jax_galsim.gsobject import GSObject
from jax_galsim.gsparams import GSParams
from jax_galsim.photon_array import PhotonArray
Expand Down
198 changes: 198 additions & 0 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import re
import textwrap
from functools import partial
from typing import NamedTuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -252,3 +255,198 @@ def _func(i, args):
fhigh = func(high)
args = (func, low, flow, high, fhigh)
return jax.lax.fori_loop(0, niter, _func, args)[-2]


# start of code from https://github.com/google/jax/blob/main/jax/_src/numpy/util.py #
# used with modifications for galsim under the following license:
# fmt: off
#
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# fmt: on

_galsim_signature_re = re.compile(r"^([\w., ]+=)?\s*[\w\.]+\([\w\W]*?\)$", re.MULTILINE)
_docreference = re.compile(r":doc:`(.*?)\s*<.*?>`")


class ParsedDoc(NamedTuple):
"""
docstr: full docstring
signature: signature from docstring.
summary: summary from docstring.
front_matter: front matter before sections.
sections: dictionary of section titles to section content.
"""

docstr: str = ""
signature: str = ""
summary: str = ""
front_matter: str = ""
sections: dict[str, str] = {}


def _break_off_body_section_by_newline(body):
first_lines = []
body_lines = []
found_first_break = False
for line in body.split("\n"):
if not first_lines:
first_lines.append(line)
continue

if not line.strip() and not found_first_break:
found_first_break = True
continue

if found_first_break:
body_lines.append(line)
else:
first_lines.append(line)

firstline = "\n".join(first_lines)
body = "\n".join(body_lines)
body = textwrap.dedent(body.lstrip("\n"))

return firstline, body


def _parse_galsimdoc(docstr):
"""Parse a standard galsim-style docstring.
Args:
docstr: the raw docstring from a function
Returns:
ParsedDoc: parsed version of the docstring
"""
if docstr is None or not docstr.strip():
return ParsedDoc(docstr)

# Remove any :doc: directives in the docstring to avoid sphinx errors
docstr = _docreference.sub(lambda match: f"{match.groups()[0]}", docstr)

signature, body = "", docstr
match = _galsim_signature_re.match(body)
if match:
signature = match.group()
body = docstr[match.end() :]

firstline, body = _break_off_body_section_by_newline(body)

match = _galsim_signature_re.match(body)
if match:
signature = match.group()
body = body[match.end() :]

summary = firstline
if not summary:
summary, body = _break_off_body_section_by_newline(body)

front_matter_lines = []
body_lines = []
found_params = False
for line in body.split("\n"):
if not found_params and line.lstrip().startswith("Parameters:"):
found_params = True

if found_params:
body_lines.append(line)
else:
front_matter_lines.append(line)
front_matter = "\n".join(front_matter_lines)
body = "\n".join(body_lines)

# we add back the body for now, but keep code above if we parse params in the future
front_matter = front_matter + "\n" + body

return ParsedDoc(
docstr=docstr,
signature=signature,
summary=summary,
front_matter=front_matter,
sections={},
)


def implements(
original_fun,
lax_description="",
module=None,
):
"""Decorator for JAX functions which implement a specified GalSim function.
This mainly contains logic to copy and modify the docstring of the original
function. In particular, if `update_doc` is True, parameters listed in the
original function that are not supported by the decorated function will
be removed from the docstring. For this reason, it is important that parameter
names match those in the original GalSim function.
Parameters:
original_fun: The original function being implemented
lax_description: A string description that will be added to the beginning of
the docstring.
module: An optional string specifying the module from which the original function
is imported. This is useful for objects, where the module cannot
be determined from the original function itself.
"""

def decorator(wrapped_fun):
wrapped_fun.__galsim_wrapped__ = original_fun

# Allows this pattern: @implements(getattr(np, 'new_function', None))
if original_fun is None:
if lax_description:
wrapped_fun.__doc__ = lax_description
return wrapped_fun

docstr = getattr(original_fun, "__doc__", None)
name = getattr(
original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun))
)
try:
mod = module or original_fun.__module__
except AttributeError:
pass
else:
name = f"{mod}.{name}"

if docstr:
try:
parsed = _parse_galsimdoc(docstr)

docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
if lax_description:
docstr += "\n" + lax_description.strip() + "\n"
docstr += "\n*Original docstring below.*\n"

if parsed.front_matter:
docstr += "\n" + parsed.front_matter.strip() + "\n"
except Exception:
docstr = original_fun.__doc__

wrapped_fun.__doc__ = docstr
for attr in ["__name__", "__qualname__"]:
try:
value = getattr(original_fun, attr)
except AttributeError:
pass
else:
setattr(wrapped_fun, attr, value)
return wrapped_fun

return decorator


# end of code from https://github.com/google/jax/blob/main/jax/_src/numpy/util.py #
3 changes: 1 addition & 2 deletions jax_galsim/deltafunction.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import ensure_hashable
from jax_galsim.core.utils import ensure_hashable, implements
from jax_galsim.gsobject import GSObject


Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/deprecated.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import warnings

import galsim as _galsim
from jax._src.numpy.util import implements

from jax_galsim.core.utils import implements
from jax_galsim.errors import GalSimDeprecationWarning


Expand Down
3 changes: 1 addition & 2 deletions jax_galsim/exponential.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import galsim as _galsim
import jax.numpy as jnp
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import ensure_hashable
from jax_galsim.core.utils import ensure_hashable, implements
from jax_galsim.gsobject import GSObject
from jax_galsim.random import UniformDeviate
from jax_galsim.utilities import lazy_property
Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
from galsim.fits import FitsHeader, closeHDUList, readFile, writeFile # noqa: F401
from galsim.utilities import galsim_warn
from jax._src.numpy.util import implements

from jax_galsim.core.utils import implements
from jax_galsim.image import Image

# We wrap the galsim FITS read functions to return jax_galsim Image objects.
Expand Down
Loading

0 comments on commit 4cef308

Please sign in to comment.