diff --git a/.copier-answers.yml b/.copier-answers.yml new file mode 100644 index 0000000..a467091 --- /dev/null +++ b/.copier-answers.yml @@ -0,0 +1,12 @@ +# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY +_commit: 2023.12.21 +_src_path: gh:scientific-python/cookie +backend: hatch +email: nstarman@users.noreply.github.com +full_name: Nathaniel Starkman +license: BSD +org: GalacticDynamics +project_name: jax-quantity +project_short_description: Quantities in JAX +url: https://github.com/GalacticDynamics/jax-quantity +vcs: true diff --git a/.git_archival.txt b/.git_archival.txt new file mode 100644 index 0000000..8fb235d --- /dev/null +++ b/.git_archival.txt @@ -0,0 +1,4 @@ +node: $Format:%H$ +node-date: $Format:%cI$ +describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ +ref-names: $Format:%D$ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..00a7b00 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +.git_archival.txt export-subst diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..6f3942a --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,101 @@ +See the [Scientific Python Developer Guide][spc-dev-intro] for a detailed +description of best practices for developing scientific packages. + +[spc-dev-intro]: https://learn.scientific-python.org/development/ + +# Quick development + +The fastest way to start with development is to use nox. If you don't have nox, +you can use `pipx run nox` to run it without installing, or `pipx install nox`. +If you don't have pipx (pip for applications), then you can install with +`pip install pipx` (the only case were installing an application with regular +pip is reasonable). If you use macOS, then pipx and nox are both in brew, use +`brew install pipx nox`. + +To use, run `nox`. This will lint and test using every installed version of +Python on your system, skipping ones that are not installed. You can also run +specific jobs: + +```console +$ nox -s lint # Lint only +$ nox -s tests # Python tests +$ nox -s docs -- --serve # Build and serve the docs +$ nox -s build # Make an SDist and wheel +``` + +Nox handles everything for you, including setting up an temporary virtual +environment for each run. + +# Setting up a development environment manually + +You can set up a development environment by running: + +```bash +python3 -m venv .venv +source ./.venv/bin/activate +pip install -v -e .[dev] +``` + +If you have the +[Python Launcher for Unix](https://github.com/brettcannon/python-launcher), you +can instead do: + +```bash +py -m venv .venv +py -m install -v -e .[dev] +``` + +# Post setup + +You should prepare pre-commit, which will help you by checking that commits pass +required checks: + +```bash +pip install pre-commit # or brew install pre-commit on macOS +pre-commit install # Will install a pre-commit hook into the git repo +``` + +You can also/alternatively run `pre-commit run` (changes only) or +`pre-commit run --all-files` to check even without installing the hook. + +# Testing + +Use pytest to run the unit checks: + +```bash +pytest +``` + +# Coverage + +Use pytest-cov to generate coverage reports: + +```bash +pytest --cov=jax-quantity +``` + +# Building docs + +You can build the docs using: + +```bash +nox -s docs +``` + +You can see a preview with: + +```bash +nox -s docs -- --serve +``` + +# Pre-commit + +This project uses pre-commit for all style checking. While you can run it with +nox, this is such an important tool that it deserves to be installed on its own. +Install pre-commit and run: + +```bash +pre-commit run -a +``` + +to check all files. diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..6c4b369 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + # Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + groups: + actions: + patterns: + - "*" diff --git a/.github/matchers/pylint.json b/.github/matchers/pylint.json new file mode 100644 index 0000000..e3a6bd1 --- /dev/null +++ b/.github/matchers/pylint.json @@ -0,0 +1,32 @@ +{ + "problemMatcher": [ + { + "severity": "warning", + "pattern": [ + { + "regexp": "^([^:]+):(\\d+):(\\d+): ([A-DF-Z]\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", + "file": 1, + "line": 2, + "column": 3, + "code": 4, + "message": 5 + } + ], + "owner": "pylint-warning" + }, + { + "severity": "error", + "pattern": [ + { + "regexp": "^([^:]+):(\\d+):(\\d+): (E\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", + "file": 1, + "line": 2, + "column": 3, + "code": 4, + "message": 5 + } + ], + "owner": "pylint-error" + } + ] +} diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml new file mode 100644 index 0000000..6ea4db9 --- /dev/null +++ b/.github/workflows/cd.yml @@ -0,0 +1,52 @@ +name: CD + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + release: + types: + - published + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + FORCE_COLOR: 3 + +jobs: + dist: + name: Distribution build + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: hynek/build-and-inspect-python-package@v2 + + publish: + needs: [dist] + name: Publish to PyPI + environment: pypi + permissions: + id-token: write + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + + steps: + - uses: actions/download-artifact@v4 + with: + name: Packages + path: dist + + - uses: pypa/gh-action-pypi-publish@release/v1 + if: github.event_name == 'release' && github.event.action == 'published' + with: + # Remember to tell (test-)pypi about this repo before publishing + # Remove this line to publish to PyPI + repository-url: https://test.pypi.org/legacy/ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..c21920e --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,69 @@ +name: CI + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + FORCE_COLOR: 3 + +jobs: + pre-commit: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - uses: pre-commit/action@v3.0.0 + with: + extra_args: --hook-stage manual --all-files + - name: Run PyLint + run: | + echo "::add-matcher::$GITHUB_WORKSPACE/.github/matchers/pylint.json" + pipx run nox -s pylint + + checks: + name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} + runs-on: ${{ matrix.runs-on }} + needs: [pre-commit] + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.12"] + runs-on: [ubuntu-latest, macos-latest, windows-latest] + + include: + - python-version: pypy-3.10 + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + allow-prereleases: true + + - name: Install package + run: python -m pip install .[test] + + - name: Test package + run: >- + python -m pytest -ra --cov --cov-report=xml --cov-report=term + --durations=20 + + - name: Upload coverage report + uses: codecov/codecov-action@v3.1.4 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..25cf9a4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,158 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# setuptools_scm +src/*/_version.py + + +# ruff +.ruff_cache/ + +# OS specific stuff +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Common editor files +*~ +*.swp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..dc091a6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,87 @@ +ci: + autoupdate_commit_msg: "chore: update pre-commit hooks" + autofix_commit_msg: "style: pre-commit fixes" + +repos: + - repo: https://github.com/adamchainz/blacken-docs + rev: "1.16.0" + hooks: + - id: blacken-docs + additional_dependencies: [black==23.*] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.5.0" + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + - id: name-tests-test + args: ["--pytest-test-first"] + - id: requirements-txt-fixer + - id: trailing-whitespace + + - repo: https://github.com/pre-commit/pygrep-hooks + rev: "v1.10.0" + hooks: + - id: rst-backticks + - id: rst-directive-colons + - id: rst-inline-touching-normal + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: "v3.1.0" + hooks: + - id: prettier + types_or: [yaml, markdown, html, css, scss, javascript, json] + args: [--prose-wrap=always] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.1.9" + hooks: + - id: ruff + args: ["--fix", "--show-fixes"] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: "v1.7.1" + hooks: + - id: mypy + files: src|tests + args: [] + additional_dependencies: + - pytest + - quax + + - repo: https://github.com/codespell-project/codespell + rev: "v2.2.6" + hooks: + - id: codespell + + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: "v0.9.0.6" + hooks: + - id: shellcheck + + - repo: local + hooks: + - id: disallow-caps + name: Disallow improper capitalization + language: pygrep + entry: PyBind|Numpy|Cmake|CCache|Github|PyTest + exclude: .pre-commit-config.yaml + + - repo: https://github.com/abravalheri/validate-pyproject + rev: v0.15 + hooks: + - id: validate-pyproject + + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.27.0 + hooks: + - id: check-dependabot + - id: check-github-workflows + - id: check-readthedocs diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..7e49657 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,18 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" +sphinx: + configuration: docs/conf.py + +python: + install: + - method: pip + path: . + extra_requirements: + - docs diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..48f9efb --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2023, Nathaniel Starkman. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the vector package developers nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..7644f65 --- /dev/null +++ b/README.md @@ -0,0 +1,27 @@ +# jax-quantity + +[![Actions Status][actions-badge]][actions-link] +[![Documentation Status][rtd-badge]][rtd-link] + +[![PyPI version][pypi-version]][pypi-link] +[![Conda-Forge][conda-badge]][conda-link] +[![PyPI platforms][pypi-platforms]][pypi-link] + +[![GitHub Discussion][github-discussions-badge]][github-discussions-link] + + + + +[actions-badge]: https://github.com/GalacticDynamics/jax-quantity/workflows/CI/badge.svg +[actions-link]: https://github.com/GalacticDynamics/jax-quantity/actions +[conda-badge]: https://img.shields.io/conda/vn/conda-forge/jax-quantity +[conda-link]: https://github.com/conda-forge/jax-quantity-feedstock +[github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github +[github-discussions-link]: https://github.com/GalacticDynamics/jax-quantity/discussions +[pypi-link]: https://pypi.org/project/jax-quantity/ +[pypi-platforms]: https://img.shields.io/pypi/pyversions/jax-quantity +[pypi-version]: https://img.shields.io/pypi/v/jax-quantity +[rtd-badge]: https://readthedocs.org/projects/jax-quantity/badge/?version=latest +[rtd-link]: https://jax-quantity.readthedocs.io/en/latest/?badge=latest + + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..6d92518 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,45 @@ +"""Sphinx configuration.""" + +import importlib.metadata + +project = "jax-quantity" +copyright = "2023, Nathaniel Starkman" +author = "Nathaniel Starkman" +version = release = importlib.metadata.version("jax_quantity") + +extensions = [ + "myst_parser", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx_autodoc_typehints", + "sphinx_copybutton", +] + +source_suffix = [".rst", ".md"] +exclude_patterns = [ + "_build", + "**.ipynb_checkpoints", + "Thumbs.db", + ".DS_Store", + ".env", + ".venv", +] + +html_theme = "furo" + +myst_enable_extensions = [ + "colon_fence", +] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), +} + +nitpick_ignore = [ + ("py:class", "_io.StringIO"), + ("py:class", "_io.BytesIO"), +] + +always_document_param_types = True diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..e62db15 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,17 @@ +# jax-quantity + +```{toctree} +:maxdepth: 2 +:hidden: + +``` + +```{include} ../README.md +:start-after: +``` + +## Indices and tables + +- {ref}`genindex` +- {ref}`modindex` +- {ref}`search` diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 0000000..d7fff84 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,114 @@ +"""Nox sessions.""" + +import argparse +import shutil +from pathlib import Path + +import nox + +DIR = Path(__file__).parent.resolve() + +nox.options.sessions = ["lint", "pylint", "tests"] + + +@nox.session +def lint(session: nox.Session) -> None: + """Run the linter.""" + session.install("pre-commit") + session.run( + "pre-commit", + "run", + "--all-files", + "--show-diff-on-failure", + *session.posargs, + ) + + +@nox.session +def pylint(session: nox.Session) -> None: + """Run PyLint.""" + # This needs to be installed into the package environment, and is slower + # than a pre-commit check + session.install(".", "pylint") + session.run("pylint", "jax_quantity", *session.posargs) + + +@nox.session +def tests(session: nox.Session) -> None: + """Run the unit and regular tests.""" + session.install(".[test]") + session.run("pytest", *session.posargs) + + +@nox.session(reuse_venv=True) +def docs(session: nox.Session) -> None: + """Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.""" + parser = argparse.ArgumentParser() + parser.add_argument("--serve", action="store_true", help="Serve after building") + parser.add_argument( + "-b", + dest="builder", + default="html", + help="Build target (default: html)", + ) + args, posargs = parser.parse_known_args(session.posargs) + + if args.builder != "html" and args.serve: + session.error("Must not specify non-HTML builder with --serve") + + extra_installs = ["sphinx-autobuild"] if args.serve else [] + + session.install("-e.[docs]", *extra_installs) + session.chdir("docs") + + if args.builder == "linkcheck": + session.run( + "sphinx-build", + "-b", + "linkcheck", + ".", + "_build/linkcheck", + *posargs, + ) + return + + shared_args = ( + "-n", # nitpicky mode + "-T", # full tracebacks + f"-b={args.builder}", + ".", + f"_build/{args.builder}", + *posargs, + ) + + if args.serve: + session.run("sphinx-autobuild", *shared_args) + else: + session.run("sphinx-build", "--keep-going", *shared_args) + + +@nox.session +def build_api_docs(session: nox.Session) -> None: + """Build (regenerate) API docs.""" + session.install("sphinx") + session.chdir("docs") + session.run( + "sphinx-apidoc", + "-o", + "api/", + "--module-first", + "--no-toc", + "--force", + "../src/jax_quantity", + ) + + +@nox.session +def build(session: nox.Session) -> None: + """Build an SDist and wheel.""" + build_path = DIR.joinpath("build") + if build_path.exists(): + shutil.rmtree(build_path) + + session.install("build") + session.run("python", "-m", "build") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2c7f51f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,169 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + + +[project] +name = "jax-quantity" +authors = [ + { name = "Nathaniel Starkman", email = "nstarman@users.noreply.github.com" }, +] +description = "Quantities in JAX" +readme = "README.md" +license.file = "LICENSE" +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 1 - Planning", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Typing :: Typed", +] +dynamic = ["version"] +dependencies = [ + "array_api_jax_compat @ git+https://github.com/GalacticDynamics/array-api-jax-compat.git", + "astropy", + "equinox", + "jax", + "jaxlib", + "jaxtyping", + "quax", +] + +[project.optional-dependencies] +test = [ + "pytest >=6", + "pytest-cov >=3", +] +dev = [ + "pytest >=6", + "pytest-cov >=3", +] +docs = [ + "sphinx>=7.0", + "myst_parser>=0.13", + "sphinx_copybutton", + "sphinx_autodoc_typehints", + "furo>=2023.08.17", +] + +[project.urls] +Homepage = "https://github.com/GalacticDynamics/jax-quantity" +"Bug Tracker" = "https://github.com/GalacticDynamics/jax-quantity/issues" +Discussions = "https://github.com/GalacticDynamics/jax-quantity/discussions" +Changelog = "https://github.com/GalacticDynamics/jax-quantity/releases" + + +[tool.hatch] +version.source = "vcs" +build.hooks.vcs.version-file = "src/jax_quantity/_version.py" +metadata.allow-direct-references = true + + +[tool.hatch.env.default] +features = ["test"] +scripts.test = "pytest {args}" + + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] +xfail_strict = true +filterwarnings = [ + "error", +] +log_cli_level = "INFO" +testpaths = [ + "tests", +] + + +[tool.coverage] +run.source = ["jax_quantity"] +report.exclude_also = [ + '\.\.\.', + 'if typing.TYPE_CHECKING:', +] + +[tool.mypy] +files = ["src", "tests"] +python_version = "3.10" +warn_unused_configs = true +strict = true +enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +warn_unreachable = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +warn_return_any = false + + [[tool.mypy.overrides]] + module = "jax_quantity.*" + disallow_untyped_defs = true + disallow_incomplete_defs = true + + [[tool.mypy.overrides]] + module = [ + "array_api_jax_compat._dispatch.*", # TODO: resolve + "astropy.units.*", + "equinox.*", + "jax.*", + "quax.*", + ] + ignore_missing_imports = true + + + +[tool.ruff] +src = ["src"] + +[tool.ruff.lint] +extend-select = ["ALL"] +ignore = [ + "A002", # Argument is shadowing a Python builtin + "ANN101", # Missing type annotation for `self` in method + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + "ARG001", # Unused function argument # TODO: resolve + "D103", # Missing docstring in public function # TODO: resolve + "D105", # Missing docstring in magic method + "D203", # 1 blank line required before class docstring + "D213", # Multi-line docstring summary should start at the second line + "FIX002", # Line contains TODO + "ISC001", # Conflicts with formatter + "PLR09", # Too many <...> + "PLR2004", # Magic value used in comparison + "PYI041", # Use `complex` instead of `int | complex` <- plum is more strict + "TD002", # Missing author in TODO + "TD003", # Missing issue link on the line following this TODO +] +# Uncomment if using a _compat.typing backport +# typing-modules = ["jax_quantity._compat.typing"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F403"] +"docs/conf.py" = ["A001", "INP001"] +"tests/**" = ["ANN", "S101", "T20"] +"noxfile.py" = ["T20"] + + +[tool.pylint] +py-version = "3.10" +ignore-paths = [".*/_version.py"] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "design", + "fixme", + "line-too-long", + "missing-function-docstring", # TODO: resolve + "missing-module-docstring", + "redefined-builtin", # handled by ruff + "wrong-import-position", +] diff --git a/src/jax_quantity/__init__.py b/src/jax_quantity/__init__.py new file mode 100644 index 0000000..5a6b1a4 --- /dev/null +++ b/src/jax_quantity/__init__.py @@ -0,0 +1,15 @@ +"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved. + +jax-quantity: Quantities in JAX +""" + +from . import _core +from ._core import * +from ._version import version as __version__ + +# isort: split +# Register dispatches +from . import _register_dispatches, _register_primitives # noqa: F401 + +__all__ = ["__version__"] +__all__ += _core.__all__ diff --git a/src/jax_quantity/_core.py b/src/jax_quantity/_core.py new file mode 100644 index 0000000..18cdc70 --- /dev/null +++ b/src/jax_quantity/_core.py @@ -0,0 +1,80 @@ +# pylint: disable=no-member, unsubscriptable-object +# b/c it doesn't understand dataclass fields + +__all__ = ["Quantity", "can_convert"] + +import operator +from dataclasses import replace +from typing import Any + +import equinox as eqx +import jax +import jax.core +from astropy.units import Unit, UnitConversionError +from jaxtyping import ArrayLike +from quax import ArrayValue, quaxify +from typing_extensions import Self + + +class Quantity(ArrayValue): # type: ignore[misc] + """Represents an array, with each axis bound to a name.""" + + value: jax.Array = eqx.field(converter=jax.numpy.asarray) + units: Unit = eqx.field(static=True, converter=Unit) + + # =============================================================== + # Quax + + @property + def shape(self) -> tuple[int, ...]: + """Shape of the array.""" + return self.value.shape + + def materialise(self) -> None: + msg = "Refusing to materialise `Quantity`." + raise RuntimeError(msg) + + def aval(self) -> jax.core.ShapedArray: + return jax.core.get_aval(self.value) # type: ignore[no-untyped-call] + + def enable_materialise(self, _: bool = True) -> Self: # noqa: FBT001, FBT002 + return type(self)(self.value, self.units) + + # =============================================================== + # Quantity + + def to_units(self, units: Unit) -> "Quantity": + return type(self)(self.value * self.units.to(units), units) + + def to_units_value(self, units: Unit) -> ArrayLike: + if units == self.units: + return self.value + return self.value * self.units.to(units) + + def __getitem__(self, key: Any) -> "Quantity": + return replace(self, value=self.value[key]) + + # __add__ + # __radd__ + # __sub__ + # __rsub__ + # __mul__ + # __rmul__ + # __matmul__ + # __rmatmul__ + __and__ = quaxify(operator.__and__) + __gt__ = quaxify(operator.__gt__) + __ge__ = quaxify(operator.__ge__) + __lt__ = quaxify(operator.__lt__) + __le__ = quaxify(operator.__le__) + __eq__ = quaxify(operator.__eq__) + __ne__ = quaxify(operator.__ne__) + __neg__ = quaxify(operator.__neg__) + + +def can_convert(from_: Unit, to: Unit) -> bool: + try: + from_.to(to) + except UnitConversionError: + return False + return True diff --git a/src/jax_quantity/_register_dispatches.py b/src/jax_quantity/_register_dispatches.py new file mode 100644 index 0000000..0e1b215 --- /dev/null +++ b/src/jax_quantity/_register_dispatches.py @@ -0,0 +1,46 @@ +from typing import Any, TypeVar + +import jax +import jax.core +import jax.numpy as jnp +from array_api_jax_compat._dispatch import dispatcher as dispatcher_ + +from ._core import Quantity + +T = TypeVar("T") + + +def dispatcher(f: T) -> T: # TODO: figure out mypy stub issue. + """Dispatcher that makes mypy happy.""" + return dispatcher_(f) + + +@dispatcher +def empty_like(x: Quantity, /, *, dtype: Any = None, device: Any = None) -> Quantity: + out = Quantity(jnp.empty_like(x.value, dtype=dtype), units=x.units) + return jax.device_put(out, device=device) + + +@dispatcher +def full_like( + x: Quantity, + /, + fill_value: bool | int | float | complex, + *, + dtype: Any = None, + device: Any = None, +) -> Quantity: + out = Quantity(jnp.full_like(x.value, fill_value, dtype=dtype), units=x.units) + return jax.device_put(out, device=device) + + +@dispatcher +def ones_like(x: Quantity, /, *, dtype: Any = None, device: Any = None) -> Quantity: + out = Quantity(jnp.ones_like(x.value, dtype=dtype), units=x.units) + return jax.device_put(out, device=device) + + +@dispatcher +def zeros_like(x: Quantity, /, *, dtype: Any = None, device: Any = None) -> Quantity: + out = Quantity(jnp.zeros_like(x.value, dtype=dtype), units=x.units) + return jax.device_put(out, device=device) diff --git a/src/jax_quantity/_register_primitives.py b/src/jax_quantity/_register_primitives.py new file mode 100644 index 0000000..0b5089b --- /dev/null +++ b/src/jax_quantity/_register_primitives.py @@ -0,0 +1,1509 @@ +# pylint: disable=too-many-lines + +__all__: list[str] = [] + +from collections.abc import Callable, Sequence +from dataclasses import replace +from typing import Any, TypeVar + +import jax +import jax.core +import jax.numpy as jnp +from astropy.units import ( # pylint: disable=no-name-in-module + Unit, + UnitTypeError, + radian, +) +from astropy.units import dimensionless_unscaled as dimensionless +from jax import lax +from jaxtyping import ArrayLike +from quax import ArrayValue, DenseArrayValue, Value +from quax import register as register_ + +from ._core import Quantity, can_convert + +T = TypeVar("T") + + +def register(primitive: jax.core.Primitive) -> Callable[[T], T]: + """:func`quax.register`, but makes mypy happy.""" + return register_(primitive) + + +def _to_value_rad_or_one(q: Quantity) -> ArrayLike: + return ( + q.to_units_value(radian) + if can_convert(q.units, radian) + else q.to_units_value(dimensionless) + ) + + +################################################################################ +# Registering Primitives + +# ============================================================================== + + +@register(lax.abs_p) +def _abs_p(x: Quantity) -> Quantity: + return replace(x, value=lax.abs(x.value)) + + +# ============================================================================== + + +@register(lax.acos_p) +def _acos_p(x: Quantity) -> Quantity: + v = x.to_units_value(dimensionless) + return Quantity(value=lax.acos(v), units=radian) + + +# ============================================================================== + + +@register(lax.acosh_p) +def _acosh_p(x: Quantity) -> Quantity: + v = x.to_units_value(dimensionless) + return Quantity(value=lax.acosh(v), units=radian) + + +# ============================================================================== +# Addition + + +@register(lax.add_p) +def _add_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity( + lax.add(x.to_units_value(x.units), y.to_units_value(x.units)), + units=x.units, + ) + + +@register(lax.add_p) +def _add_p_vq(x: Value, y: Quantity) -> Quantity: + # x = 0 is a special case + if jnp.array_equal(x, 0): + return y + + # otherwise we can't add a quantity to a normal value + msg = f"Cannot apply {lax.add} to quantity and non-quantity." + raise ValueError(msg) + + +@register(lax.add_p) +def _add_p_qv(x: Quantity, y: Value) -> Quantity: + # y = 0 is a special case + if jnp.array_equal(y, 0): + return x + + # otherwise we can't add a normal value to a quantity + msg = f"Cannot apply {lax.add} to quantity and non-quantity." + raise ValueError(msg) + + +# ============================================================================== + + +@register(lax.after_all_p) +def _after_all_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.all_gather_p) +def _all_gather_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.all_to_all_p) +def _all_to_all_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +# TODO: return jax.Array. But `quax` is raising an error. +@register(lax.and_p) +def _and_p(x1: Quantity, x2: Quantity, /) -> Quantity: + # IDK what to do about non-dimensionless quantities. + if x1.units != dimensionless or x2.units != dimensionless: + raise NotImplementedError + return Quantity(x1.value & x2.value, units=dimensionless) + + +# ============================================================================== + + +@register(lax.approx_top_k_p) +def _approx_top_k_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.argmax_p) +def _argmax_p(operand: Quantity, *, axes: Any, index_dtype: Any) -> Quantity: + return replace(operand, value=lax.argmax(operand.value, axes[0], index_dtype)) + + +# ============================================================================== + + +@register(lax.argmin_p) +def _argmin_p(operand: Quantity, *, axes: Any, index_dtype: Any) -> Quantity: + return replace(operand, value=lax.argmin(operand.value, axes[0], index_dtype)) + + +# ============================================================================== + + +@register(lax.asin_p) +def _asin_p(x: Quantity) -> Quantity: + return replace(x, value=lax.asin(x.to_units_value(dimensionless))) + + +# ============================================================================== + + +@register(lax.asinh_p) +def _asinh_p(x: Quantity) -> Quantity: + return replace(x, value=lax.asinh(x.to_units_value(dimensionless))) + + +# ============================================================================== + + +@register(lax.atan2_p) +def _atan2_p(x: Quantity, y: Quantity) -> Quantity: + y_ = y.to_units_value(x.units) + return Quantity(lax.atan2(x.value, y_), units=radian) + + +@register(lax.atan2_p) +def _atan2_p_vq(x: Value, y: Quantity) -> Quantity: + y_ = y.to_units_value(dimensionless) + return Quantity(lax.atan2(x, y_), units=radian) + + +@register(lax.atan2_p) +def _atan2_p_qv(x: Quantity, y: Value) -> Quantity: + x_ = x.to_units_value(dimensionless) + return Quantity(lax.atan2(x_, y), units=radian) + + +# ============================================================================== + + +@register(lax.atan_p) +def _atan_p(x: Quantity) -> Quantity: + return Quantity(lax.atan(x.to_units_value(dimensionless)), units=radian) + + +# ============================================================================== + + +@register(lax.atanh_p) +def _atanh_p(x: Quantity) -> Quantity: + return Quantity(lax.atanh(x.to_units_value(dimensionless)), units=radian) + + +# ============================================================================== + + +@register(lax.axis_index_p) +def _axis_index_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.bessel_i0e_p) +def _bessel_i0e_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.bessel_i1e_p) +def _bessel_i1e_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.bitcast_convert_type_p) +def _bitcast_convert_type_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.broadcast_in_dim_p) +def _broadcast_in_dim_p( + operand: Quantity, + *, + shape: Any, + broadcast_dimensions: Any, +) -> Quantity: + return replace( + operand, + value=lax.broadcast_in_dim(operand.value, shape, broadcast_dimensions), + ) + + +# ============================================================================== + + +@register(lax.cbrt_p) +def _cbrt_p(x: Quantity) -> Quantity: + return Quantity(lax.cbrt(x.value), units=x.units ** (1 / 3)) + + +# ============================================================================== + + +@register(lax.ceil_p) +def _ceil_p(x: Quantity) -> Quantity: + return replace(x, value=lax.ceil(x.value)) + + +# ============================================================================== + + +@register(lax.clamp_p) +def _clamp_p(min: Quantity, x: Quantity, max: Quantity) -> Quantity: + return replace( + x, + value=lax.clamp( + min.to_units_value(x.units), + x.value, + max.to_units_value(x.units), + ), + ) + + +@register(lax.clamp_p) +def _clamp_p_vqq(min: Value, x: Quantity, max: Quantity) -> Quantity: + v = x.to_units_value(dimensionless) + maxv = max.to_units_value(dimensionless) + return replace(x, value=lax.clamp(min, v, maxv)) + + +@register(lax.clamp_p) +def _clamp_p_qvq(min: Quantity, x: Value, max: Quantity) -> ArrayValue: + minv = min.to_units_value(dimensionless) + maxv = max.to_units_value(dimensionless) + return DenseArrayValue(lax.clamp(minv, x, maxv)) + + +@register(lax.clamp_p) +def _clamp_p_qqv(min: Quantity, x: Quantity, max: Value) -> Quantity: + minv = min.to_units_value(dimensionless) + v = x.to_units_value(dimensionless) + return replace(x, value=lax.clamp(minv, v, max)) + + +# ============================================================================== + + +@register(lax.clz_p) +def _clz_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.complex_p) +def _complex_p(x: Quantity, y: Quantity) -> Quantity: + y_ = y.to_units_value(x.units) + return Quantity(lax.complex(x.value, y_), units=x.units) + + +# ============================================================================== + + +@register(lax.concatenate_p) +def _concatenate_p(*operands: Quantity, dimension: Any) -> Quantity: + units = operands[0].units + return Quantity( + lax.concatenate( + [op.to_units_value(units) for op in operands], + dimension=dimension, + ), + units=units, + ) + + +# ============================================================================== + + +# @register(lax.cond_p) # TODO: implement +# def _implemen(index, consts) -> Quantity: +# raise NotImplementedError + + +# ============================================================================== + + +@register(lax.conj_p) +def _conj_p(x: Quantity, *, input_dtype: Any) -> Quantity: + del input_dtype # TODO: use this? + return replace(x, value=lax.conj(x.value)) + + +# ============================================================================== + + +@register(lax.conv_general_dilated_p) +def _conv_general_dilated_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.convert_element_type_p) +def _convert_element_type_p( + operand: Quantity, + *, + new_dtype: Any, + weak_type: Any, +) -> Quantity: + del weak_type + return replace(operand, value=lax.convert_element_type(operand.value, new_dtype)) + + +# ============================================================================== + + +@register(lax.copy_p) +def _copy_p(x: Quantity) -> Quantity: + return replace(x, value=lax.copy_p.bind(x.value)) # type: ignore[no-untyped-call] + + +# ============================================================================== + + +@register(lax.cos_p) +def _cos_p(x: Quantity) -> Quantity: + return Quantity(lax.cos(_to_value_rad_or_one(x)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.cosh_p) +def _cosh_p(x: Quantity) -> Quantity: + return Quantity(lax.cosh(_to_value_rad_or_one(x)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.create_token_p) +def _create_token_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.cumlogsumexp_p) +def _cumlogsumexp_p(operand: Quantity, *, axis: Any, reverse: Any) -> Quantity: + # TODO: double check units make sense here. + return replace( + operand, + value=lax.cumlogsumexp(operand.value, axis=axis, reverse=reverse), + ) + + +# ============================================================================== + + +@register(lax.cummax_p) +def _cummax_p(operand: Quantity, *, axis: Any, reverse: Any) -> Quantity: + return replace(operand, value=lax.cummax(operand.value, axis=axis, reverse=reverse)) + + +# ============================================================================== + + +@register(lax.cummin_p) +def _cummin_p(operand: Quantity, *, axis: Any, reverse: Any) -> Quantity: + return replace(operand, value=lax.cummin(operand.value, axis=axis, reverse=reverse)) + + +# ============================================================================== + + +@register(lax.cumprod_p) +def _cumprod_p(operand: Quantity, *, axis: Any, reverse: Any) -> Quantity: + return replace( + operand, + value=lax.cumprod(operand.value, axis=axis, reverse=reverse), + ) + + +# ============================================================================== + + +@register(lax.cumsum_p) +def _cumsum_p(operand: Quantity, *, axis: Any, reverse: Any) -> Quantity: + return replace(operand, value=lax.cumsum(operand.value, axis=axis, reverse=reverse)) + + +# ============================================================================== + + +@register(lax.device_put_p) +def _device_put_p(x: Quantity, *, device: Any, src: Any) -> Quantity: + return replace(x, value=jax.device_put(x.value, device=device, src=src)) + + +# ============================================================================== + + +@register(lax.digamma_p) +def _digamma_p(x: Quantity) -> Quantity: + if x.units != dimensionless: + msg = "TODO: implement the result units for `digamma`." + raise NotImplementedError(msg) + + return Quantity(lax.digamma(x.value), units=dimensionless) + + +# ============================================================================== +# Division + + +@register(lax.div_p) +def _div_p_qq(x: Quantity, y: Quantity) -> Quantity: + units = Unit(x.units / y.units) + return Quantity(lax.div(x.value, y.value), units=units) + + +@register(lax.div_p) +def _div_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.div(x, y.value), units=1 / y.units) + + +@register(lax.div_p) +def _div_p_qv(x: Quantity, y: Value) -> Quantity: + return Quantity(lax.div(x.value, y), units=x.units) + + +# ============================================================================== + + +@register(lax.dot_general_p) # TODO: implement +def _implemen() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.dynamic_slice_p) +def _dynamic_slice_p( + operand: Quantity, + start_indices: DenseArrayValue, + dynamic_sizes: DenseArrayValue, + *, + slice_sizes: Any, +) -> Quantity: + raise NotImplementedError # TODO: implement + + +# ============================================================================== + + +@register(lax.dynamic_update_slice_p) +def _dynamic_update_slice_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.eq_p) +def _eq_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity(lax.eq(x.value, y.to_units_value(x.units)), units=dimensionless) + + +@register(lax.eq_p) +def _eq_p_vq(x: ArrayValue, y: Quantity) -> Quantity: + return Quantity(lax.eq(x, y.to_units_value(dimensionless)), units=dimensionless) + + +@register(lax.eq_p) +def _eq_p_qv(x: Quantity, y: ArrayValue) -> Quantity: + # special-case for all-0 values + if jnp.all(y.array == 0) or jnp.all(jnp.isinf(y.array)): + return Quantity(lax.eq(x.value, y), units=dimensionless) + return Quantity(lax.eq(x.to_units_value(dimensionless), y), units=dimensionless) + + +# ============================================================================== + + +@register(lax.eq_to_p) +def _eq_to_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.erf_inv_p) +def _erf_inv_p(x: Quantity) -> Quantity: + # TODO: can this support non-dimensionless quantities? + return Quantity(lax.erf_inv(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.erf_p) +def _erf_p(x: Quantity) -> Quantity: + # TODO: can this support non-dimensionless quantities? + return Quantity(lax.erf(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.erfc_p) +def _erfc_p(x: Quantity) -> Quantity: + # TODO: can this support non-dimensionless quantities? + return Quantity(lax.erfc(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.exp2_p) +def _exp2_p(x: Quantity) -> Quantity: + return Quantity(lax.exp2(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.exp_p) +def _exp_p(x: Quantity) -> Quantity: + # TODO: more meaningful error message. + return Quantity(lax.exp(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.expm1_p) +def _expm1_p(x: Quantity) -> Quantity: + return Quantity(lax.expm1(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.fft_p) +def _fft_p(x: Quantity, *, fft_type: Any, fft_lengths: Any) -> Quantity: + # TODO: what units can this support? + return Quantity( + lax.fft(x.to_units_value(dimensionless), fft_type, fft_lengths), + units=dimensionless, + ) + + +# ============================================================================== + + +@register(lax.floor_p) +def _floor_p(x: Quantity) -> Quantity: + return replace(x, value=lax.floor(x.value)) + + +# ============================================================================== + + +@register(lax.gather_p) +def _gather_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.ge_p) +def _ge_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity(lax.ge(x.value, y.to_units_value(x.units)), units=dimensionless) + + +@register(lax.ge_p) +def _ge_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.ge(x, y.to_units_value(dimensionless)), units=dimensionless) + + +@register(lax.ge_p) +def _ge_p_qv(x: Quantity, y: Value) -> Quantity: + return Quantity(lax.ge(x.to_units_value(dimensionless), y), units=dimensionless) + + +@register(lax.ge_p) +def _ge_p_qi(x: Quantity, y: int) -> Quantity: + return Quantity(lax.ge(x.to_units_value(dimensionless), y), units=dimensionless) + + +# ============================================================================== + + +@register(lax.gt_p) +def _gt_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity(lax.gt(x.value, y.to_units_value(x.units)), units=dimensionless) + + +@register(lax.gt_p) +def _gt_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.gt(x, y.to_units_value(dimensionless)), units=dimensionless) + + +@register(lax.gt_p) +def _gt_p_qv(x: Quantity, y: Value) -> Quantity: + return Quantity(lax.gt(x.to_units_value(dimensionless), y), units=dimensionless) + + +@register(lax.gt_p) +def _gt_p_qi(x: Quantity, y: int) -> Quantity: + return Quantity(lax.gt(x.to_units_value(dimensionless), y), units=dimensionless) + + +# ============================================================================== + + +@register(lax.igamma_grad_a_p) +def _igamma_grad_a_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.igamma_p) +def _igamma_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.igammac_p) +def _igammac_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.imag_p) +def _imag_p(x: Quantity) -> Quantity: + return replace(x, value=lax.imag(x.value)) + + +# ============================================================================== + + +@register(lax.infeed_p) +def _infeed_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.integer_pow_p) +def _integer_pow_p(x: Quantity, *, y: Any) -> Quantity: + return replace(x, value=lax.integer_pow(x.value, y)) + + +# ============================================================================== + + +# @register(lax.iota_p) +# def _iota_p(dtype: Quantity) -> Quantity: +# raise NotImplementedError + + +# ============================================================================== + + +@register(lax.is_finite_p) +def _is_finite_p(x: Quantity) -> Quantity: + return Quantity(value=lax.is_finite(x.value), units=dimensionless) + + +# ============================================================================== + + +@register(lax.le_p) +def _le_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity(lax.le(x.value, y.to_units_value(x.units)), units=dimensionless) + + +@register(lax.le_p) +def _le_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.le(x, y.to_units_value(dimensionless)), units=dimensionless) + + +@register(lax.le_p) +def _le_p_qv(x: Quantity, y: Value) -> Quantity: + return Quantity(lax.le(x.to_units_value(dimensionless), y), units=dimensionless) + + +# ============================================================================== + + +@register(lax.le_to_p) +def _le_to_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.lgamma_p) +def _lgamma_p(x: Quantity) -> Quantity: + # TODO: handle non-dimensionless quantities. + return Quantity(lax.lgamma(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.linear_solve_p) +def _linear_solve_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.log1p_p) +def _log1p_p(x: Quantity) -> Quantity: + return Quantity(lax.log1p(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.log_p) +def _log_p(x: Quantity) -> Quantity: + return Quantity(lax.log(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.logistic_p) +def _logistic_p(x: Quantity) -> Quantity: + return Quantity(lax.logistic(x.to_units_value(dimensionless)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.lt_p) +def _lt_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity(lax.lt(x.value, y.to_units_value(x.units)), units=dimensionless) + + +@register(lax.lt_p) +def _lt_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.lt(x, y.to_units_value(dimensionless)), units=dimensionless) + + +@register(lax.lt_p) +def _lt_p_qv(x: Quantity, y: Value) -> Quantity: + return Quantity(lax.lt(x.to_units_value(dimensionless), y), units=dimensionless) + + +# ============================================================================== + + +@register(lax.lt_to_p) +def _lt_to_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.max_p) +def _max_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity(lax.max(x.value, y.to_units_value(x.units)), units=x.units) + + +@register(lax.max_p) +def _max_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.max(x, y.to_units_value(dimensionless)), units=dimensionless) + + +@register(lax.max_p) +def _max_p_qv(x: Quantity, y: Value) -> Quantity: + return Quantity(lax.max(x.to_units_value(dimensionless), y), units=dimensionless) + + +# ============================================================================== + + +@register(lax.min_p) +def _min_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity(lax.min(x.value, y.to_units_value(x.units)), units=x.units) + + +@register(lax.min_p) +def _min_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.min(x, y.to_units_value(dimensionless)), units=dimensionless) + + +@register(lax.min_p) +def _min_p_qv(x: Quantity, y: Value) -> Quantity: + return Quantity(lax.min(x.to_units_value(dimensionless), y), units=dimensionless) + + +# ============================================================================== +# Multiplication + + +@register(lax.mul_p) +def _mul_p_qq(x: Quantity, y: Quantity) -> Quantity: + units = Unit(x.units * y.units) + return Quantity(lax.mul(x.value, y.value), units=units) + + +@register(lax.mul_p) +def _mul_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.mul(x, y.value), units=y.units) + + +@register(lax.mul_p) +def _mul_p_qv(x: Quantity, y: Value) -> Quantity: + return Quantity(lax.mul(x.value, y), units=x.units) + + +# ============================================================================== + + +@register(lax.ne_p) +def _ne_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity(lax.ne(x.value, y.to_units_value(x.units)), units=dimensionless) + + +@register(lax.ne_p) +def _ne_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.ne(x, y.to_units_value(dimensionless)), units=dimensionless) + + +@register(lax.ne_p) +def _ne_p_qv(x: Quantity, y: ArrayValue) -> Quantity: + # special-case for scalar value=0, units=dimensionless + if y.shape == () and y.array == 0: + return Quantity(lax.ne(x.value, y), units=dimensionless) + return Quantity(lax.ne(x.to_units_value(dimensionless), y), units=dimensionless) + + +# ============================================================================== + + +@register(lax.neg_p) +def _neg_p(x: Quantity) -> Quantity: + return replace(x, value=lax.neg(x.value)) + + +# ============================================================================== + + +@register(lax.nextafter_p) +def _nextafter_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.not_p) +def _not_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.or_p) +def _or_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.outfeed_p) +def _outfeed_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.pad_p) +def _pad_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.pmax_p) +def _pmax_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.pmin_p) +def _pmin_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.polygamma_p) +def _polygamma_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.population_count_p) +def _population_count_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.pow_p) +def _pow_p_qq(x: Quantity, y: Quantity) -> Quantity: + if y.units != dimensionless: + msg = f"power must be dimensionless, got {y.units}" + raise UnitTypeError(msg) + + y0 = y.value.flatten()[0] + if not all(y.value == y0): + msg = "power must be a scalar" + raise ValueError(msg) + + return Quantity( + value=lax.pow(x.value, y0), + units=x.units**y0, + ) + + +@register(lax.pow_p) +def _pow_p_qf(x: Quantity, y: int | float) -> Quantity: + return Quantity(value=lax.pow(x.value, y), units=x.units**y) + + +# ============================================================================== + + +@register(lax.ppermute_p) +def _ppermute_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.psum_p) +def _psum_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.random_gamma_grad_p) +def _random_gamma_grad_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.real_p) +def _real_p(x: Quantity) -> Quantity: + return replace(x, value=lax.real(x.value)) + + +# ============================================================================== + + +@register(lax.reduce_and_p) +def _reduce_and_p( + operand: Quantity, + *, + axes: Sequence[int], +) -> Any: + return lax.reduce_and_p.bind(operand.value, axes=tuple(axes)) + + +# ============================================================================== + + +@register(lax.reduce_max_p) +def _reduce_max_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_min_p) +def _reduce_min_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_or_p) +def _reduce_or_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_p) +def _reduce_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_precision_p) +def _reduce_precision_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_prod_p) +def _reduce_prod_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_sum_p) +def _reduce_sum_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_window_max_p) +def _reduce_window_max_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_window_min_p) +def _reduce_window_min_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_window_p) +def _reduce_window_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_window_sum_p) +def _reduce_window_sum_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.reduce_xor_p) +def _reduce_xor_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.regularized_incomplete_beta_p) +def _regularized_incomplete_beta_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.rem_p) +def _rem_p(x: Quantity, y: Quantity) -> Quantity: + return Quantity(lax.rem(x.value, y.to_units_value(x.units)), units=x.units) + + +# ============================================================================== + + +@register(lax.reshape_p) +def _reshape_p(operand: Quantity, *, new_sizes: Any, dimensions: Any) -> Quantity: + return replace(operand, value=lax.reshape(operand.value, new_sizes, dimensions)) + + +# ============================================================================== + + +@register(lax.rev_p) +def _rev_p(operand: Quantity, *, dimensions: Any) -> Quantity: + return replace(operand, value=lax.rev(operand.value, dimensions)) + + +# ============================================================================== + + +@register(lax.rng_bit_generator_p) +def _rng_bit_generator_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.rng_uniform_p) +def _rng_uniform_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.round_p) +def _round_p(x: Quantity, *, rounding_method: Any) -> Quantity: + return replace(x, value=lax.round(x.value, rounding_method)) + + +# ============================================================================== + + +@register(lax.rsqrt_p) +def _rsqrt_p(x: Quantity) -> Quantity: + return Quantity(lax.rsqrt(x.value), units=x.units ** (-1 / 2)) + + +# ============================================================================== + + +@register(lax.scan_p) +def _scan_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.scatter_add_p) +def _scatter_add_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.scatter_max_p) +def _scatter_max_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.scatter_min_p) +def _scatter_min_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.scatter_mul_p) +def _scatter_mul_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.scatter_p) +def _scatter_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.select_and_gather_add_p) +def _select_and_gather_add_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.select_and_scatter_add_p) +def _select_and_scatter_add_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.select_and_scatter_p) +def _select_and_scatter_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.select_n_p) +def _select_n_p(which: Quantity, *cases: Quantity) -> Quantity: + # TODO: check correct dtype for `which`. + # TODO: check correct units for `cases`. + units = cases[0].units + return Quantity( + lax.select_n( + which.to_units_value(dimensionless), + *(case.to_units_value(units) for case in cases), + ), + units=units, + ) + + +# ============================================================================== + + +@register(lax.sharding_constraint_p) +def _sharding_constraint_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.shift_left_p) +def _shift_left_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.shift_right_arithmetic_p) +def _shift_right_arithmetic_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.shift_right_logical_p) +def _shift_right_logical_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.sign_p) +def _sign_p(x: Quantity) -> Quantity: + return Quantity(lax.sign(x.value), units=dimensionless) + + +# ============================================================================== + + +@register(lax.sin_p) +def _sin_p(x: Quantity) -> Quantity: + return Quantity(lax.sin(_to_value_rad_or_one(x)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.sinh_p) +def _sinh_p(x: Quantity) -> Quantity: + return Quantity(lax.sinh(_to_value_rad_or_one(x)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.slice_p) +def _slice_p( + operand: Quantity, + *, + start_indices: Any, + limit_indices: Any, + strides: Any, +) -> Quantity: + return replace( + operand, + value=lax.slice_p.bind( + operand.value, + start_indices=start_indices, + limit_indices=limit_indices, + strides=strides, + ), + ) + + +# ============================================================================== + + +@register(lax.sort_p) +def _sort_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.sqrt_p) +def _sqrt_p(x: Quantity) -> Quantity: + return Quantity(lax.sqrt(x.value), units=x.units ** (1 / 2)) + + +# ============================================================================== + + +@register(lax.squeeze_p) +def _squeeze_p(x: Quantity, *, dimensions: Any) -> Quantity: + return replace(x, value=lax.squeeze(x.value, dimensions)) + + +# ============================================================================== + + +@register(lax.stop_gradient_p) +def _stop_gradient_p(x: Quantity) -> Quantity: + return replace(x, value=lax.stop_gradient(x.value)) + + +# ============================================================================== +# Subtraction + + +@register(lax.sub_p) +def _sub_p_qq(x: Quantity, y: Quantity) -> Quantity: + return Quantity( + lax.sub(x.to_units_value(x.units), y.to_units_value(x.units)), + units=x.units, + ) + + +@register(lax.sub_p) +def _sub_p_vq(x: Value, y: Quantity) -> Quantity: + return Quantity(lax.sub(x, y.value), units=y.units) + + +@register(lax.sub_p) +def _sub_p_qv(x: Quantity, y: Value) -> Quantity: + return Quantity(lax.sub(x.value, y), units=x.units) + + +# ============================================================================== + + +@register(lax.tan_p) +def _tan_p(x: Quantity) -> Quantity: + return Quantity(lax.tan(_to_value_rad_or_one(x)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.tanh_p) +def _tanh_p(x: Quantity) -> Quantity: + return Quantity(lax.tanh(_to_value_rad_or_one(x)), units=dimensionless) + + +# ============================================================================== + + +@register(lax.top_k_p) +def _top_k_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.transpose_p) +def _transpose_p(operand: Quantity, *, permutation: Any) -> Quantity: + return replace(operand, value=lax.transpose(operand.value, permutation)) + + +# ============================================================================== + + +@register(lax.while_p) +def _while_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.xor_p) +def _xor_p() -> Quantity: + raise NotImplementedError + + +# ============================================================================== + + +@register(lax.zeta_p) +def _zeta_p() -> Quantity: + raise NotImplementedError diff --git a/src/jax_quantity/_version.pyi b/src/jax_quantity/_version.pyi new file mode 100644 index 0000000..5bb2b22 --- /dev/null +++ b/src/jax_quantity/_version.pyi @@ -0,0 +1,2 @@ +version: str +version_tuple: tuple[int, int, int] | tuple[int, int, int, str, str] diff --git a/src/jax_quantity/py.typed b/src/jax_quantity/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d420712 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests.""" diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 0000000..44b822d --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,10 @@ +"""Test the package itself.""" + +import importlib.metadata + +import jax_quantity as m + + +def test_version(): + """Test version.""" + assert importlib.metadata.version("jax_quantity") == m.__version__