-
-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make fit_pathfinder more similar to fit_laplace and pm.sample #447
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import collections | ||
import logging | ||
import time | ||
|
@@ -67,6 +68,7 @@ | |
# TODO: change to typing.Self after Python versions greater than 3.10 | ||
from typing_extensions import Self | ||
|
||
from pymc_extras.inference.laplace import add_data_to_inferencedata | ||
from pymc_extras.inference.pathfinder.importance_sampling import ( | ||
importance_sampling as _importance_sampling, | ||
) | ||
|
@@ -1630,6 +1632,7 @@ def fit_pathfinder( | |
inference_backend: Literal["pymc", "blackjax"] = "pymc", | ||
pathfinder_kwargs: dict = {}, | ||
compile_kwargs: dict = {}, | ||
initvals: dict | None = None, | ||
) -> az.InferenceData: | ||
""" | ||
Fit the Pathfinder Variational Inference algorithm. | ||
|
@@ -1665,12 +1668,12 @@ def fit_pathfinder( | |
importance_sampling : str, None, optional | ||
Method to apply sampling based on log importance weights (logP - logQ). | ||
Options are: | ||
"psis" : Pareto Smoothed Importance Sampling (default) | ||
Recommended for more stable results. | ||
"psir" : Pareto Smoothed Importance Resampling | ||
Less stable than PSIS. | ||
"identity" : Applies log importance weights directly without resampling. | ||
None : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). | ||
|
||
- **"psis"** : Pareto Smoothed Importance Sampling (default). Usually most stable. | ||
- **"psir"** : Pareto Smoothed Importance Resampling. Less stable than PSIS. | ||
- **"identity"** : Applies log importance weights directly without resampling. | ||
- **None** : No importance sampling weights. Returns raw samples of size (num_paths, num_draws_per_path, N) where N is number of model parameters. Other methods return samples of size (num_draws, N). | ||
|
||
progressbar : bool, optional | ||
Whether to display a progress bar (default is True). Setting this to False will likely reduce the computation time. | ||
random_seed : RandomSeed, optional | ||
|
@@ -1685,17 +1688,24 @@ def fit_pathfinder( | |
Additional keyword arguments for the Pathfinder algorithm. | ||
compile_kwargs | ||
Additional keyword arguments for the PyTensor compiler. If not provided, the default linker is "cvm_nogc". | ||
initvals: dict | None = None | ||
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. | ||
If None, the model's default initial values are used. | ||
|
||
Returns | ||
------- | ||
arviz.InferenceData | ||
:class:`~arviz.InferenceData` | ||
The inference data containing the results of the Pathfinder algorithm. | ||
|
||
References | ||
---------- | ||
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49. | ||
""" | ||
|
||
if initvals is not None: | ||
for rv_name, ivals in initvals.items(): | ||
model.set_initval(model.named_vars[rv_name], ivals) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might need to ensure that ivals is a support point for the RV. For example, x ~ Uniform(-1, 1) would have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While in the ideal world, I would agree, in practice There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, seems fair enough. Thanks for this submission @velochy |
||
|
||
model = modelcontext(model) | ||
|
||
valid_importance_sampling = {"psis", "psir", "identity", None} | ||
|
@@ -1775,4 +1785,7 @@ def fit_pathfinder( | |
model=model, | ||
importance_sampling=importance_sampling, | ||
) | ||
|
||
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs) | ||
|
||
return idata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't mutate the model, make a copy perhaps if there's no better way to just forward the initvals
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would normally agree. However, I tried it, but model.copy() does not produce a working model sometimes - most notably when any transformations are used.
Should I use some other copy function?