Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fit_pathfinder more similar to fit_laplace and pm.sample #447

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
@@ -23,7 +23,10 @@ Inference
.. autosummary::
:toctree: generated/

find_MAP
fit
fit_laplace
fit_pathfinder


Distributions
4 changes: 1 addition & 3 deletions pymc_extras/__init__.py
Original file line number Diff line number Diff line change
@@ -15,9 +15,7 @@

from pymc_extras import gp, statespace, utils
from pymc_extras.distributions import *
from pymc_extras.inference.find_map import find_MAP
from pymc_extras.inference.fit import fit
from pymc_extras.inference.laplace import fit_laplace
from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
from pymc_extras.model.marginal.marginal_model import (
MarginalModel,
marginalize,
6 changes: 4 additions & 2 deletions pymc_extras/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from pymc_extras.inference.find_map import find_MAP
from pymc_extras.inference.fit import fit
from pymc_extras.inference.laplace import fit_laplace
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["fit"]
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
7 changes: 4 additions & 3 deletions pymc_extras/inference/fit.py
Original file line number Diff line number Diff line change
@@ -15,19 +15,20 @@

def fit(method, **kwargs):
"""
Fit a model with an inference algorithm
Fit a model with an inference algorithm.
See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.

Parameters
----------
method : str
Which inference method to run.
Supported: pathfinder or laplace

kwargs are passed on.
kwargs: keyword arguments are passed on to the inference method.

Returns
-------
arviz.InferenceData
:class:`~arviz.InferenceData`
"""
if method == "pathfinder":
from pymc_extras.inference.pathfinder import fit_pathfinder
2 changes: 1 addition & 1 deletion pymc_extras/inference/laplace.py
Original file line number Diff line number Diff line change
@@ -509,7 +509,7 @@ def fit_laplace(

Returns
-------
idata: az.InferenceData
:class:`~arviz.InferenceData`
An InferenceData object containing the approximated posterior samples.

Examples
27 changes: 20 additions & 7 deletions pymc_extras/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import collections
import logging
import time
@@ -67,6 +68,7 @@
# TODO: change to typing.Self after Python versions greater than 3.10
from typing_extensions import Self

from pymc_extras.inference.laplace import add_data_to_inferencedata
from pymc_extras.inference.pathfinder.importance_sampling import (
importance_sampling as _importance_sampling,
)
@@ -1630,6 +1632,7 @@ def fit_pathfinder(
inference_backend: Literal["pymc", "blackjax"] = "pymc",
pathfinder_kwargs: dict = {},
compile_kwargs: dict = {},
initvals: dict | None = None,
) -> az.InferenceData:
"""
Fit the Pathfinder Variational Inference algorithm.
@@ -1665,12 +1668,12 @@ def fit_pathfinder(
importance_sampling : str, None, optional
Method to apply sampling based on log importance weights (logP - logQ).
Options are:
"psis" : Pareto Smoothed Importance Sampling (default)
Recommended for more stable results.
"psir" : Pareto Smoothed Importance Resampling
Less stable than PSIS.
"identity" : Applies log importance weights directly without resampling.
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).

- **"psis"** : Pareto Smoothed Importance Sampling (default). Usually most stable.
- **"psir"** : Pareto Smoothed Importance Resampling. Less stable than PSIS.
- **"identity"** : Applies log importance weights directly without resampling.
- **None** : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N).

progressbar : bool, optional
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time.
random_seed : RandomSeed, optional
@@ -1685,17 +1688,24 @@ def fit_pathfinder(
Additional keyword arguments for the Pathfinder algorithm.
compile_kwargs
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc".
initvals: dict | None = None
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
If None, the model's default initial values are used.

Returns
-------
arviz.InferenceData
:class:`~arviz.InferenceData`
The inference data containing the results of the Pathfinder algorithm.

References
----------
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
"""

if initvals is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't mutate the model, make a copy perhaps if there's no better way to just forward the initvals

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would normally agree. However, I tried it, but model.copy() does not produce a working model sometimes - most notably when any transformations are used.

Should I use some other copy function?

for rv_name, ivals in initvals.items():
model.set_initval(model.named_vars[rv_name], ivals)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might need to ensure that ivals is a support point for the RV. For example, x ~ Uniform(-1, 1) would have nan initial values with model.set_initval(model.named_vars["x"], 2)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While in the ideal world, I would agree, in practice
a) It is very nontrivial to do as I understand, as the limits are not specified anywhere where they are easy to take
b) pm.sample does no such checks, and the goal of this PR is to be compatible with that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, seems fair enough. Thanks for this submission @velochy


model = modelcontext(model)

valid_importance_sampling = {"psis", "psir", "identity", None}
@@ -1775,4 +1785,7 @@ def fit_pathfinder(
model=model,
importance_sampling=importance_sampling,
)

idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)

return idata
Loading
Oops, something went wrong.