Skip to content

Commit

Permalink
fix max_n_samples buisness
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed Oct 20, 2019
1 parent 23e5511 commit 0d29d35
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
5 changes: 2 additions & 3 deletions thejoker/sampler/fast_likelihood.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,7 @@ cpdef batch_marginal_ln_likelihood(double[:,::1] chunk,


cpdef batch_get_posterior_samples(double[:,::1] chunk,
data, joker_params, int max_n_samples,
rnd, return_logprobs):
data, joker_params, rnd, return_logprobs):
"""TODO:
Parameters
Expand Down Expand Up @@ -395,7 +394,7 @@ cpdef batch_get_posterior_samples(double[:,::1] chunk,
joker_params.num_params + int(return_logprobs)))
double[::1] linear_pars

for n in range(max_n_samples):
for n in range(n_samples):
pars[n, 0] = chunk[n, 0] # P
pars[n, 1] = chunk[n, 1] # M0
pars[n, 2] = chunk[n, 2] # e
Expand Down
12 changes: 6 additions & 6 deletions thejoker/sampler/multiproc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def _sample_vector_worker(task):
is not supposed to be in the public API.
"""

(idx, chunk_index, prior_cache_file, data, joker_params, max_n_samples,
global_seed, return_logprobs) = task
(idx, chunk_index, prior_cache_file, data, joker_params, global_seed,
return_logprobs) = task

if global_seed is not None:
seed = global_seed + chunk_index
Expand All @@ -214,7 +214,7 @@ def _sample_vector_worker(task):

chunk = chunk.astype(np.float64)

pars = batch_get_posterior_samples(chunk, data, joker_params, max_n_samples,
pars = batch_get_posterior_samples(chunk, data, joker_params,
rnd, return_logprobs)
if return_logprobs:
pars = np.hstack((pars[:, :-1], ln_prior[:, None], pars[:, -1:]))
Expand Down Expand Up @@ -258,9 +258,9 @@ def sample_indices_to_full_samples(good_samples_idx, prior_cache_file, data,
"""

n_samples = min(len(good_samples_idx), max_n_samples)
args = [prior_cache_file, data, joker_params, n_samples,
global_seed, return_logprobs]
good_samples_idx = good_samples_idx[:max_n_samples]
n_samples = len(good_samples_idx)
args = [prior_cache_file, data, joker_params, global_seed, return_logprobs]
if n_batches is None:
n_batches = pool.size
tasks = chunk_tasks(n_samples, n_batches=n_batches, arr=good_samples_idx,
Expand Down

0 comments on commit 0d29d35

Please sign in to comment.