New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quadrature Refactoring #1505
Quadrature Refactoring #1505
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gustavocmv, thanks for your amazing PR and apologies for the long-overdue review - let's try to get this PR merged this week. I left you a couple of minor comments. Would you mind following Ti's suggestion and moving the old quadrature code to the tests. We will also need to update the code base to use the new quadrature code.
Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com>
* Update pull request template (#1510) Clarify template to make it easier for contributors to fill in relevant information. * Temporary workaround for tensorflow_probability dependency issue (#1522) * pin cloudpickle==1.3.0 as temporary workaround for tensorflow/probability#991 to unblock our build (to be reverted once fixed upstream) * Update readme with new project using GPflow (#1530) * fix bug in varying_noise notebook (#1526) * Fix formatting in docs (intro.md) and restore link removed by #1498 (#1520) * pin tensorflow<2.3 tensorflow-probability<0.11 (#1537) * Quadrature Refactoring (#1505) * WIP: quadrature refactoring * Removing old ndiagquad code * deleted test code * formatting and type-hint * merge modules * black formatting * formatting * solving failing tests * fixing failing tests * fixes * adapting tests for new syntax, keeping numerical behavior * black formatting * remove printf * changed code for compiled tf compatibility * black * restored to original version * undoing changes * renaming * renaming * renaming * reshape kwargs * quadrature along axis=-2, simplified broadcasting * black * docs * docs * helper function * docstrings and typing * added new and old quadrature equivalence tests * black * Removing comments Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com> * Typo Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com> * notation Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com> * reshape_Z_dZ return docstring fix * FIX: quad_old computed with the ndiagquad_old Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com> * more readable implementation Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com> * tf.ensure_shape added * removed ndiagquad * removed ndiagquad * Revert "removed ndiagquad" This reverts commit 7bb0e9f. * FIX: shape checking of dZ * Revert "removed ndiagquad" This reverts commit 8e23524. Co-authored-by: Gustavo Carvalho <gustavo.carvalho@delfosim.com> Co-authored-by: ST John <st@prowler.io> Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com> * Add base_conditional_with_lm function (#1528) * Added base_conditional_with_lm function, which accepts Lm instead of Kmm Co-authored-by: Neil Ferguson <neil@prowler.io> Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com> Co-authored-by: st-- <st--@users.noreply.github.com> * Fixed separate_independent_conditional to correctly handle q_sqrt=None. (#1533) * Fixed separate_independent_conditional to correctly handle q_sqrt=None. Co-authored-by: Aidan Scannell <scannell.aidan@gmail.com> Co-authored-by: st-- <st--@users.noreply.github.com> * Bump version numbers to 2.1.0. (#1544) * Re-introduce pytest-xdist (#1541) Enables pytest-xdist for locally running tests (`make test`) on multiple cores in parallel. * check dependency versions are valid on CI (#1536) * Update to not use custom image (#1545) * Update to not use custom image * Add test requirements * Update parameter to be savable (#1518) * Fix for quadrature failure mode when autograph was set to False (#1548) * Fix and test * Change shape of quadrature tensors for better broadcasting (#1542) * using the first dimension to hold the quadrature summation * adapting ndiagquad wrapper * Changed bf for bX in docstrings Co-authored-by: Gustavo Carvalho <gustavo.carvalho@delfosim.com> Co-authored-by: st-- <st--@users.noreply.github.com> Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com> * Update min TFP supported version to 0.10 (#1551) * Broadcasting constant and zero mean function (#1550) * Broadcasting constant and zero mean function * Use rank instead of ndim Co-authored-by: st-- <st--@users.noreply.github.com> Co-authored-by: joelberkeley-pio <joel.berkeley@prowler.io> Co-authored-by: gustavocmv <47801305+gustavocmv@users.noreply.github.com> Co-authored-by: Gustavo Carvalho <gustavo.carvalho@delfosim.com> Co-authored-by: ST John <st@prowler.io> Co-authored-by: Neil Ferguson <nfergu@users.noreply.github.com> Co-authored-by: Neil Ferguson <neil@prowler.io> Co-authored-by: Aidan Scannell <as12528@my.bristol.ac.uk> Co-authored-by: Aidan Scannell <scannell.aidan@gmail.com> Co-authored-by: Sandeep Tailor <s.tailor@insysion.net> Co-authored-by: Artem Artemev <art.art.v@gmail.com>
Goals
This PR refactors the
ndiagquad
function responsible for performing the quadrature needed to compute expectations related to non-conjugate likelihoods.The new implementation uses classes to implement quadrature methods. This enables the class to store the points and weights used for quadrature. It also opens the possibility of componentizing the quadrature class inside the likelihood using it, allowing for operations such as training with some quadrature algorithm (e.g. Gauss-Hermite) and predicting with another (e.g. Monte Carlo).
It also clears the quadrature function signature. Instead of expecting lists Fmu and Fvar representing multiple latent GP's, it now takes the tensors as arguments, with the latent GPs stacked in the last dimension.
Design
The new quadrature class design can be summarized as follows:
GaussianQuadrature
abstract base class, which provides the methods for computing expectations (or log-expections-exp) given points X and weights W, which must be provided by the inheriting class nNDiagGHQuadrature
class, which inherits fromGaussianQuadrature
, and implements the construction of the points X and weights W with the methodThis separation aims to modularize these concepts and provide better extension capacity.
Backward Compatibility
Due to the migration to a class architecture, the old
gpflow/quadrature.py
file was converted to a submodule in the foldergpflow/quadrature
, with it's contents moved togpflow/quadrature/deprecated.py
.For backward compatibility, The
ndiagquad
function was refactored to wrap the newNDiagGHQuadrature
class.The old code was not deleted -- it was moved to
gpflow/quadrature/old.py
and can be accessed withgpflow.quadrature.ndiagquad_old
. @st-- suggested moving it to thetests
and build tests to compare both implementations.