Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V0.6 #212

Merged
merged 40 commits into from
May 11, 2023
Merged

V0.6 #212

merged 40 commits into from
May 11, 2023

Conversation

thomaspinder
Copy link
Collaborator

@thomaspinder thomaspinder commented Apr 9, 2023

This is the master PR for v0.6 release.

The major change invoked by this PR is the backend transition to PyTrees. Further changes include the addition of decoupled sampling, abstract objective functions, and a general imrpovement to the code's cleanliness.

Pull request type

Please check the type of change your PR introduces:

  • Bugfix
  • Feature
  • Code style update (formatting, renaming)
  • Refactoring (no functional changes, no api changes)
  • Build related changes
  • Documentation content changes
  • Other (please describe):

Issue Number: #227 #218 #216 #215 #214 #213 #203 #200 #144

henrymoss and others added 28 commits March 21, 2023 14:23
* Add module.

* Add stuff.

* Fix mytree links

* Tests fixed

* Refactor tests

* Reformat

* Module methods' return type is Self (the subclass)

* refactor matern12

* stationary kernels refactoring

* tests draft

* spectral density as property (RBF)

* add jitter in gram test

* fix default engine for white kernel

* fix jaxtyping hints

* Fix bugs on the base.

* Refactored variational families.

* Update likelihoods and refactor collapsed variational family

* Update likelihoods.

* Add fit.py and test.

* Remove types and add dataset.

* Commit.

* Use tfb bijectors, update base.

* Minimal passing tests except for eigen and basis work

Note:

- The tests need rewriting, with additional checks for the param_field leaves and for static_fields.

- Docstrings need doing, etc.

- The code is rough, and needs to be cleaned up, and the structure needs to be revised in places.

* Improve dataset tests.

* Update fit testing.

* Refactor docs

* Classification nb

* Collapsed VI

* Sampling fixed

* Sampling fixed

* Graph kernel

* RFF refactored

* Graph kernel refactored

* Fix imports and switch to FillTriangular for now, to avoid dtype error.

* Docs complete

* Docs outline

* Push fix.

* DKL fixed

* Docs up-to-date

* Add flax to reqs

* Drop beartype refs

* Fix link fn. tests

* Add flax deps

* Documentation text updates

---------

Co-authored-by: Daniel Dodd <daniel_dodd@icloud.com>
Co-authored-by: frazane <zanetta.francesco@gmail.com>
* Add module.

* Add stuff.

* Fix mytree links

* Tests fixed

* Refactor tests

* Reformat

* Add poetry setup

* Drop more files

* Drop gitattributes and move contributing

* Create static

* Update citation

* Readd pyspelling

* Update workflows

* Upload to codecov

* Add dependabot

* Add labels.yml

* Add release drafter

* Update mds

* Drop circleci

* Simplify directories

* Add doc deps

* Add PR welcome

* Drop tilde req

* Add TFP dep

* Drop distrax refs

* Drop JaxUtils refs

* Run docs workflow

* Implement sample method

* Fix graph kernel sampler

* Add version

* Dynamic versioning

* Fix jit

* Add poetry build

* Add poetry build

---------

Co-authored-by: Daniel Dodd <daniel_dodd@icloud.com>
* Initial update.

* Delete .mailmap

---------

Co-authored-by: Daniel Dodd <daniel_dodd@icloud.com>
…e). (#222)

Signed-off-by: Daniel Dodd <daniel_dodd@icloud.com>
* add likelihood tests

* add nonstationary kernels tests

* add stationary kernels tests

* remove distrax import

* Fix typing

---------

Co-authored-by: Thomas Pinder <tompinder@live.co.uk>
* update dependencies

* add data

* wip example

* spatial example wip

* update dependencies

* spatial example wip (almost there)

* add reference (wilson 2020)

* fix markdown cell

* minor fixes

* link in docs index

* enforce capital G in bib

* tom's review revision
* add beartype dependency

* from typing import -> from beartype.typing import

* jaxtyping import_hook for @jaxtyped @beartype everywhere

* fix Type[] of class-as-argument

* fix KeyArray type hint (should probably move into jaxutils though)

* fix return value of slice_input when active_dims is None

* fix return value of squared_distance

* fix return type of recursive_bijectors

* fix slice_input type annotations

* new KernelCallable type to fix kernel_fn annotations

* fix kernel __call__ annotation

* fix KeyArray type hint

* beartype does not like forward references; replaced with string types

* linops other type hint fixes

* fix KeyArray

* abstractions.py some type fixes

* fix GaussianDistribution.log_prob return type

* fix depreciations & warnings

* fix scalar array types

* introduce ScalarBool, ScalarInt for jitted calls in abstractions

* relax LinearOperator's solve() types (can be both matrix or vector), not ideal :S

* remove _stop_grad type hints, not sure what they should be

* found some more

* float -> ScalarFloat fixes

* linops log_det type fixes

* some more linops type fixes

* actually commit KeyArray and Scalar* types

* add beartype to pyproject

* from beartype.typing import ...

* try to fix Self in gpjax/base/module

* fix _check_shape

* gpjax.objectives: always import from gps and variational_families

* Revert "gpjax.objectives: always import from gps and variational_families"

This reverts commit 9359b05.

* fix gpjax.objectives imported types

* <...> | None not supported by beartype; replaced by Optional[<...>]

* gpjax.datasets: cannot specify strict array shape AND rely on _check_shape

* our tfd.Distribution subclassing requires the fix introduced in jaxtyping 0.2.15

* need to import base first!

* bugfix

* AbstractKernel: string for forward references

* remove from __future__ import annotations

* fix type annotations to make up for changes in 0c3ae8a

* pytree map functions may take a non-Module argument

* ScalarFloat

* VecNOrMatNM

* remove unnecessary / buggy methods

* more ScalarFloat

* ScalarFloat

* type fixes

* fix shape type

* fix one KeyArray

* more ScalarFloat corrections in kernels

* fix test_stationary accordingly for ScalarFloat params

* fix return type

* ScalarInt for Polynomial kernel and fix test for Scalar* params

* fix mock in test_abstract_variational_family

* fix link_function and variational_expectations shape annotations

* minor test fix

* fix exception test for beartype

* fix Constant mean function

* base_kernel as kwarg in test_approximations

* rename func to test_ so it actually gets collected

* mark test_graph_kernel as broken

* fix LinearOperator DTypeT

* Revert "fix Constant mean function"

This reverts commit 6fc3a55.

* fix test_mean_functions instead

* fix one more bug in RFF test

* Self

* relax fit objective type

* rename gpjax.utils -> gpjax.typing

* Kernel = Any -> string forward reference

* relax Gaussian.predict type annotation to include GaussianDistribution

* our own `Array` type that accepts both JAX and Numpy arrays

* some Float -> Num relaxations for graph kernel...

* ScalarFloat for GraphKernel hyperparams

* fix type hints to what happens (even if it seems wrong)

* type relaxation for deep_kernels.pct.py

* some more minor consistency fixes

* bugfix

* Update examples/graph_kernels.pct.py

Co-authored-by: st-- <st--@users.noreply.github.com>
Signed-off-by: Thomas Pinder <tompinder@live.co.uk>

* Update gpjax/dataset.py

Allow for integer responses

Co-authored-by: st-- <st--@users.noreply.github.com>
Signed-off-by: Thomas Pinder <tompinder@live.co.uk>

* jaxtyping import hook for notebooks

* conftest.py to apply jaxtyping import hook before loading tests

* remove import hook from gpjax/__init__

* Update gpjax/dataset.py

Signed-off-by: st-- <st--@users.noreply.github.com>

* Update gpjax/dataset.py

Co-authored-by: st-- <st--@users.noreply.github.com>
Signed-off-by: Thomas Pinder <tompinder@live.co.uk>

* fix tests of shape checks now that we have beartype

---------

Signed-off-by: Thomas Pinder <tompinder@live.co.uk>
Signed-off-by: st-- <st--@users.noreply.github.com>
Co-authored-by: Thomas Pinder <tompinder@live.co.uk>
* add .pre-commit-config.yaml

* poetry add --dev pre-commit

* poetry run pre-commit run --all-files

* revert examples/ to v0.6, only run black

* run trailing-whitespace, end-of-file-fixer on examples/

* manually curated ruff edits to examples/

* exclude examples/ in isort and ruff pre-commit

* incorporate pyproject.toml edits (isort, ruff config) from docs_update branch

* poetry run pre-commit -a with new settings

* update poetry.lock
* add poisson likelihood

* add tests

* import poisson

* poisson example wip

* test posterior construction with poisson

* example wip

* cleanup, begin with elliptical slice sampler

* Update poisson.pct.py

* revert to nuts mcmc

* fix merge

* revision

---------

Co-authored-by: Daniel Dodd <daniel_dodd@icloud.com>
* WIP

* first go

* nice test

---------

Signed-off-by: Thomas Pinder <tompinder@live.co.uk>
Co-authored-by: Thomas Pinder <tompinder@live.co.uk>
…tibility (#246)

* Add static_field to gpjax.base

* Update imports.

* Add static_field to base

* Add fix for python 3.11

* Update tests to run on 3.11

* Update test_base.py
* Revamp docs

* Revamp docs

* Update workflows

* Ruff passing

* Resolve ruff issues

* Resolve ruff issues

* Drop usage of mu

* Spell checker

* Sort imports correctly

* Revert to py files

* Update pytrees.py

* Add RBF dataclass text

* Fix bold symbols

* Show pre-commit

* Fix clf example

* Update pytrees.py

* Update pytrees.py

* Resolve conflicts

* fix spaces in shape strings

* fix pytrees.py example

* Revert "fix spaces in shape strings"

Spaces actually needed for linting to pass
(https://github.com/google/jaxtyping/blob/main/FAQ.md#flake8-or-ruff-are-throwing-an-error)

This reverts commit f129775.

* remaining manual changes (see 1e2201e)

* ruff fixes

* fix some types

* fix some more types

* fix some more more types

* Types fixed

* Update lockfile

* Fix bib

* Respond to Ti comments

* Clean up docstrings

* Simplify watermark

* Update pytrees.py

* Update pytrees.py

* Update index.md

* Update index.md

* Format doc

* Update pytrees.py

* Update pytrees.py

* Update pytrees.py

* Update pytrees.py

* Update PyTree doc

* Test code cells

* PyTrees update

* katex backend

* All docstrings updated

* Add katex install

* Add target PR

* Remove ipynb

* Update typing

* Update workflows

* Update markdown checks

* Update docs/examples/README.md

Co-authored-by: st-- <st--@users.noreply.github.com>
Signed-off-by: Thomas Pinder <tompinder@live.co.uk>

* Resolve 3.11 issue

* Drop dupe doc

* Update README.md

Co-authored-by: st-- <st--@users.noreply.github.com>
Signed-off-by: Thomas Pinder <tompinder@live.co.uk>

* Fix factory bug

* Make API TOC entries lead with a capital letter

* Improve TOC rendering

* Update formatting

---------

Signed-off-by: Thomas Pinder <tompinder@live.co.uk>
Co-authored-by: Daniel Dodd <daniel_dodd@icloud.com>
Co-authored-by: ST John <st--@users.noreply.github.com>
@codecov-commenter
Copy link

codecov-commenter commented May 11, 2023

Codecov Report

Merging #212 (6d3fffb) into main (aa3f5d2) will increase coverage by 36.42%.
The diff coverage is 89.61%.

❗ Current head 6d3fffb differs from pull request most recent head ca2ac13. Consider uploading reports for the commit ca2ac13 to get more accurate results

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

@@             Coverage Diff             @@
##             main     #212       +/-   ##
===========================================
+ Coverage   51.87%   88.30%   +36.42%     
===========================================
  Files          66       53       -13     
  Lines        3302     2026     -1276     
  Branches        0      232      +232     
===========================================
+ Hits         1713     1789       +76     
+ Misses       1589      123     -1466     
- Partials        0      114      +114     
Flag Coverage Δ
unittests 88.30% <89.61%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
gpjax/progress_bar.py 0.00% <ø> (ø)
gpjax/quadrature.py 100.00% <ø> (ø)
gpjax/scan.py 98.00% <ø> (ø)
gpjax/typing.py 100.00% <ø> (ø)
gpjax/variational_families.py 94.80% <ø> (-4.83%) ⬇️
gpjax/mean_functions.py 67.34% <64.44%> (-29.80%) ⬇️
gpjax/base/module.py 76.38% <76.38%> (ø)
gpjax/fit.py 80.55% <80.55%> (ø)
gpjax/kernels/computations/base.py 88.23% <81.81%> (+1.27%) ⬆️
gpjax/kernels/base.py 85.24% <83.01%> (+22.74%) ⬆️
... and 43 more

... and 16 files with indirect coverage changes

@thomaspinder thomaspinder merged commit d50be70 into main May 11, 2023
11 checks passed
@thomaspinder thomaspinder deleted the v0.6 branch May 11, 2023 21:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants