Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computing the evidence from NUTS chains #229

Open
stefanocovino opened this issue Jun 4, 2023 · 18 comments
Open

Computing the evidence from NUTS chains #229

stefanocovino opened this issue Jun 4, 2023 · 18 comments

Comments

@stefanocovino
Copy link

Dear friends,

I am trying to apply the harmonic algorithm using chains produced by the NUTS sampler under numpyro. However, so far, with little luck. Do you have any examples to post to show how you manipulate the NUTS chains to be compatible with harmonic?

Thanks,
Stefano

@stefanocovino stefanocovino changed the title Computing the evidence from NUTS chain Computing the evidence from NUTS chains Jun 4, 2023
@jasonmcewen
Copy link
Contributor

Hi @stefanocovino , great to see you're interested in this. I don't think we've applied to NUTS samples from numpyro but it's definitely on the list of things to do. If you're interested in this we'd be very happy to help to try to get things working.

Do you have a minimal working problem so we can try to help?

Basically you should just need to get posterior samples out and then harmonic can be applied to those. I would recommend starting with a low-dimensional problem first.

Pinging @alicjapolanska, @CosmoMatt, @alessiospuriomancini, @dpiras, who make also be interested in this and able to help.

@stefanocovino
Copy link
Author

stefanocovino commented Jun 5, 2023 via email

@dpiras
Copy link

dpiras commented Jun 5, 2023

Hi @stefanocovino! I have not used harmonic yet, but I have used numpyro NUTS to get posterior chains. My understanding is that the log_probabilities should be minus the potential_energy, and the numpyro documentation seems to support this. In short, if you only pass a model to the sampler, it will compute the negative log-probability as the potential_energy.

I'm interested to know if it works! Can we compare the evidence values against some ground truth in this simple example?

@jasonmcewen
Copy link
Contributor

Thanks @stefanocovino. Given @dpiras comment, it should indeed just be a matter of setting up the chains with the logprob values. If you have a script or a notebook with a minimal version we can help to get it running? Feel free to set up a WIP PR and we can work together to get things going.

@stefanocovino
Copy link
Author

stefanocovino commented Jun 8, 2023 via email

@jasonmcewen
Copy link
Contributor

Thanks @stefanocovino but I'm not sure the attachment made it's way to github?

@stefanocovino
Copy link
Author

stefanocovino commented Jun 13, 2023 via email

@dpiras
Copy link

dpiras commented Jun 13, 2023

@stefanocovino could you perhaps try to click on this link (#229), and post it as a comment here below?

@stefanocovino
Copy link
Author

Actually, the system does not allow me to attach anything. I did not know that. So I just list the code below!
Else, this is the link to colab: https://colab.research.google.com/drive/1hlmnjIftdsO9SyeHzDaXmHtDvszTeh35?usp=sharing

Play with the notebook as you like.

Stefano


-- coding: utf-8 --

"""Harmonic-Numpyro-Test.ipynb
"""

!pip install numpyro
!pip install harmonic
!pip install jaxns
#!pip install tensorflow

"""# Simulated data"""

Commented out IPython magic to ensure Python compatibility.

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

data = np.array([[ 0.42, 0.72, 0. , 0.3 , 0.15,
0.09, 0.19, 0.35, 0.4 , 0.54,
0.42, 0.69, 0.2 , 0.88, 0.03,
0.67, 0.42, 0.56, 0.14, 0.2 ],
[ 0.33, 0.41, -0.22, 0.01, -0.05,
-0.05, -0.12, 0.26, 0.29, 0.39,
0.31, 0.42, -0.01, 0.58, -0.2 ,
0.52, 0.15, 0.32, -0.13, -0.09 ],
[ 0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,
0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,
0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,
0.1 , 0.1 , 0.1 , 0.1 , 0.1 ]])
x, y, sigma_y = data

plt.errorbar(x, y, yerr=sigma_y, fmt='o')
plt.xlabel('x')
plt.ylabel('y');

"""# Probabilistic model"""

import numpyro
import numpyro.distributions as dist
from numpyro import infer
from numpyro.infer import MCMC, NUTS
import jax
import jax.numpy as jnp

def y_model (x,m,q):
return q + m * x

def numpyro_model(x, ey, y=None):
theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
q_perp = numpyro.sample("q_perp", dist.Normal(0, 5))
#
m = numpyro.deterministic("m", jnp.tan(theta))
q = numpyro.deterministic("q", q_perp / jnp.cos(theta))
#
ymd = y_model(x,m,q)
#
with numpyro.plate("data", len(x)):
numpyro.sample("y", dist.Normal(ymd, ey), obs=y)

"""## NUTS sampling"""

nuts_kernel = NUTS(numpyro_model, dense_mass=True, target_accept_prob=0.9)
mcmc = MCMC(
nuts_kernel,
num_warmup=300,
num_samples=300,
num_chains=4,
)
rng_key = jax.random.PRNGKey(34923)

mcmc.run(rng_key, x, sigma_y, y=y, extra_fields=('potential_energy',))
samples = mcmc.get_samples()

pred = infer.Predictive(numpyro_model, samples)(jax.random.PRNGKey(1), x, sigma_y)
pred_y = pred["y"]

for n in np.random.default_rng(0).integers(len(pred_y), size=100):
plt.plot(x, pred['m'][n]*x + pred['q'][n], "-", color="C0", alpha=0.1, label='')

plt.errorbar(x, y, yerr=sigma_y, fmt=".k", capsize=0)
plt.xlabel("x")
plt.ylabel("y");

"""# Evidence computation by Harmonic"""

import harmonic as hm
import numpy as np

"""### Reformatting sample chains"""

inpsmpl = [samples[i].reshape(-1,1) for i in samples.keys() if i in ('theta','q_perp')]
cdata = np.float64(np.hstack(inpsmpl))

"""### Reformatting logprob (note the minus sign)"""

nuprob = np.float64(-mcmc.get_extra_fields()['potential_energy'].flatten())

chains = hm.Chains(2)
chains.add_chains_2d(cdata, nuprob, 4)

chains_train, chains_infer = hm.utils.split_data(chains, training_proportion=0.5)

domains = [np.array([1E-1,1E1])] # hyper-sphere bounding domain

Select model

model = hm.model.HyperSphere(2, domains)

Train model

fit_success = model.fit(chains_train.samples, chains_train.ln_posterior)

Instantiate harmonic's evidence class

ev = hm.Evidence(chains_infer.nchains, model)

Pass the evidence class the inference chains and compute the log of the evidence!

ev.add_chains(chains_infer)
evidence, evidence_std = ev.compute_evidence()

print(np.log(evidence), evidence_std/evidence)

"""## Nested sampling to check the evidence"""

from numpyro.contrib.nested_sampling import NestedSampler
from jax import random

ns = NestedSampler(numpyro_model)
ns.run(random.PRNGKey(0), x, sigma_y, y=y)

ns.print_summary()
nsamples = ns.get_samples(random.PRNGKey(3), num_samples=10000)

ns.diagnostics()

@dpiras
Copy link

dpiras commented Jun 13, 2023

Thank you @stefanocovino! I was able to run the Colab notebook.

It seems that the log(evidence) values agree between harmonic applied to the NUTS samples and jaxns (as implemented in NumPyro), right? I got:

log(Z) = 13.1 ± 0.1 (NUTS+harmonic)
log(Z) = 12.8 ± 0.4 (jaxns)

Perhaps there is also an explanation for the different values of the std deviations?

@stefanocovino
Copy link
Author

stefanocovino commented Jun 13, 2023 via email

@jasonmcewen
Copy link
Contributor

Ok, fantastic. So it seems this is working. Is it ok to close this issue then?

@stefanocovino
Copy link
Author

stefanocovino commented Jun 14, 2023 via email

@dpiras
Copy link

dpiras commented Aug 8, 2023

@stefanocovino I just realised that in the above I referred to the negative potential energy as log_probabilities, but that is actually the log_likelihood. However, harmonic requires the log_posterior, so one needs to add the log_prior too.

I don't think this is currently being done in the notebook you shared, but let me know if I missed something. I will be shortly trying to run your notebook with the log_prior too, and check if the results change significantly.

The negative potential energy returned by NUTS should actually be the log_posterior, so everything should be in order 👍

@stefanocovino
Copy link
Author

stefanocovino commented Aug 8, 2023 via email

@dpiras
Copy link

dpiras commented Aug 8, 2023

Hi Stefano, please bear with us as we check the above. Sorry about it.

@stefanocovino
Copy link
Author

stefanocovino commented Aug 8, 2023 via email

@dpiras
Copy link

dpiras commented Aug 8, 2023

After some more checking, it seems that the potential energy returned by NUTS should actually be the negative log_posterior, so everything should be correct - we are further testing this, we'll let you know if we find anything!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants