Skip to content

Multi-threaded sampling with reversediff backend and rdcache gives bad samples #1412

@BlackWingedKing

Description

@BlackWingedKing

The following are the posterior plots for the coin flip turing example
image
As we can see the posterior from multi-threaded sampling with rdcache is not at all close to others
Code to reproduce this

# Load the modules
using Turing, MCMCChains, Distributions, StatsPlots, Random
using ReverseDiff, Memoization, Zygote

println("loaded modules")
# Set the true probability of heads in a coin.
p_true = 0.5

# Iterate from having seen 0 observations to 100 observations.
Ns = 0:100
Random.seed!(12)
data = rand(Bernoulli(p_true), last(Ns))
# define the model
@model coinflip(y) = begin
    p ~ Beta(1, 1)
    N = length(y)
    for n in 1:N
        y[n] ~ Bernoulli(p)
    end
end
# parameters
num_chains = 10
iterations = 100
model_coin = coinflip(data)

# sampling
chain_serial = mapreduce(c -> sample(model_coin, NUTS(0.65), iterations), chainscat, 1:num_chains)
chain_multithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# now enable Zygote backend
Turing.setadbackend(:zygote)
chain_zymultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# now enable ReverseDiff backend without rdcache
Turing.setadbackend(:reversediff)
Turing.setrdcache(false)
chain_rdmultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# with rdcache
Turing.setrdcache(true)
chain_rdcachemultithread = sample(model_coin, NUTS(0.65), MCMCThreads(), iterations, num_chains)

# plot the combined density of all those chains
density(chain_serial[:p][:], label="serial", legend=:topleft)
density!(chain_multithread[:p][:], label="multi-thread-normal", legend=:topleft)
density!(chain_zymultithread[:p][:], label="multi-thread-zygote", legend=:topleft)
density!(chain_rdmultithread[:p][:], label="multi-thread-reversediff", legend=:topleft)
density!(chain_rdcachemultithread[:p][:], label="multi-thread-rdcache", legend=:topleft)

savefig("mwe.png")

My configuration

  JULIA_NUM_THREADS=4
  [31c24e10] Distributions v0.23.12
  [ced4e74d] DistributionsAD v0.6.9
  [c7f686f2] MCMCChains v4.2.1
  [6fafb56a] Memoization v0.1.4
  [37e2e3b7] ReverseDiff v1.4.3
  [f3b207a7] StatsPlots v0.14.13
  [fce5fe82] Turing v0.14.3
  [e88e6eb3] Zygote v0.5.7

Also posted in the slack channel

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions