Skip to content

Commit

Permalink
fix return logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed Feb 28, 2019
1 parent dcf6efc commit 5b78405
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
14 changes: 9 additions & 5 deletions thejoker/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _unpack_full_samples(self, result, prior_units, return_logprobs,
samples['v'+str(i)] = samples_arr[:, k + 2 + i] * _unit

if return_logprobs:
return samples, ln_prior
return samples, ln_prior, ln_like

else:
return samples
Expand Down Expand Up @@ -333,10 +333,12 @@ def rejection_sample(self, data, n_prior_samples=None,
prior_cache_file = f.name

# first do prior sampling, cache to temporary file
prior_samples = self.sample_prior(size=n_prior_samples)
prior_samples, lnp = self.sample_prior(size=n_prior_samples,
return_logprobs=True)
prior_units = save_prior_samples(prior_cache_file,
prior_samples,
data.rv.unit)
data.rv.unit,
ln_prior_probs=lnp)

result = self._rejection_sample_from_cache(
data, n_prior_samples, prior_cache_file, start_idx,
Expand Down Expand Up @@ -421,9 +423,11 @@ def iterative_rejection_sample(self, data, n_requested_samples,
"and saving them to: {0}".format(prior_cache_file))

# first do prior sampling, cache to temporary file
prior_samples = self.sample_prior(size=n_prior_samples)
prior_samples, lnp = self.sample_prior(size=n_prior_samples,
return_logprobs=True)
prior_units = save_prior_samples(f.name, prior_samples,
data.rv.unit)
data.rv.unit,
ln_prior_probs=lnp)

maxiter = 128
for i in range(maxiter): # we just need to iterate for a long time
Expand Down
12 changes: 12 additions & 0 deletions thejoker/sampler/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def test_rejection_sample(self):
full_samples = joker.rejection_sample(data, n_prior_samples=128)
assert quantity_allclose(full_samples['jitter'], jitter)

samples, lnp, lnl = joker.rejection_sample(data,
n_prior_samples=128,
return_logprobs=True)
assert len(lnp) == len(samples)
assert len(lnl) == len(samples)

def test_iterative_rejection_sample(self):

# First, try just running rejection_sample()
Expand All @@ -103,6 +109,12 @@ def test_iterative_rejection_sample(self):

assert quantity_allclose(samples['jitter'], jitter)

samples, lnp, lnl = joker.iterative_rejection_sample(
data, n_prior_samples=100000, n_requested_samples=2,
return_logprobs=True)
assert len(lnp) == len(samples)
assert len(lnl) == len(samples)

def test_mcmc_continue(self):
rnd = np.random.RandomState(42)

Expand Down

0 comments on commit 5b78405

Please sign in to comment.