-
Notifications
You must be signed in to change notification settings - Fork 231
Closed
Description
The following are the posterior plots for the coin flip turing example

As we can see the posterior from multi-threaded sampling with rdcache is not at all close to others
Code to reproduce this
# Load the modules
using Turing, MCMCChains, Distributions, StatsPlots, Random
using ReverseDiff, Memoization, Zygote
println("loaded modules")
# Set the true probability of heads in a coin.
p_true = 0.5
# Iterate from having seen 0 observations to 100 observations.
Ns = 0:100
Random.seed!(12)
data = rand(Bernoulli(p_true), last(Ns))
# define the model
@model coinflip(y) = begin
p ~ Beta(1, 1)
N = length(y)
for n in 1:N
y[n] ~ Bernoulli(p)
end
end
# parameters
num_chains = 10
iterations = 100
model_coin = coinflip(data)
# sampling
chain_serial = mapreduce(c -> sample(model_coin, NUTS(0.65), iterations), chainscat, 1:num_chains)
chain_multithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)
# now enable Zygote backend
Turing.setadbackend(:zygote)
chain_zymultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)
# now enable ReverseDiff backend without rdcache
Turing.setadbackend(:reversediff)
Turing.setrdcache(false)
chain_rdmultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)
# with rdcache
Turing.setrdcache(true)
chain_rdcachemultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)
# plot the combined density of all those chains
density(chain_serial[:p][:], label="serial", legend=:topleft)
density!(chain_multithread[:p][:], label="multi-thread-normal", legend=:topleft)
density!(chain_zymultithread[:p][:], label="multi-thread-zygote", legend=:topleft)
density!(chain_rdmultithread[:p][:], label="multi-thread-reversediff", legend=:topleft)
density!(chain_rdcachemultithread[:p][:], label="multi-thread-rdcache", legend=:topleft)
savefig("mwe.png")
My configuration
JULIA_NUM_THREADS=4
[31c24e10] Distributions v0.23.12
[ced4e74d] DistributionsAD v0.6.9
[c7f686f2] MCMCChains v4.2.1
[6fafb56a] Memoization v0.1.4
[37e2e3b7] ReverseDiff v1.4.3
[f3b207a7] StatsPlots v0.14.13
[fce5fe82] Turing v0.14.3
[e88e6eb3] Zygote v0.5.7
Also posted in the slack channel
Metadata
Metadata
Assignees
Labels
No labels