Skip to content

Commit

Permalink
Merge 8808521 into d7c549f
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Sep 13, 2023
2 parents d7c549f + 8808521 commit d0278d0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.4.2"
version = "4.4.3"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
14 changes: 10 additions & 4 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -466,10 +469,10 @@ function mcmcsample(
# Return the new chain.
return chain
end
chains = if init_params === nothing
chains = if _init_params === nothing
Distributed.pmap(sample_chain, pool, seeds)
else
Distributed.pmap(sample_chain, pool, seeds, init_params)
Distributed.pmap(sample_chain, pool, seeds, _init_params)
end
finally
# Stop updating the progress bar.
Expand Down Expand Up @@ -499,6 +502,9 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand All @@ -519,10 +525,10 @@ function mcmcsample(
)
end

chains = if init_params === nothing
chains = if _init_params === nothing
map(sample_chain, 1:nchains, seeds)
else
map(sample_chain, 1:nchains, seeds, init_params)
map(sample_chain, 1:nchains, seeds, _init_params)
end

# Concatenate the chains together.
Expand Down

0 comments on commit d0278d0

Please sign in to comment.