Skip to content

Commit

Permalink
Merge 985bb5f into 5444313
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jul 1, 2021
2 parents 5444313 + 985bb5f commit 69ab1bc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
AbstractMCMC = "3.2"
AdvancedHMC = "0.2.24"
AdvancedMH = "0.6"
AdvancedPS = "0.2"
AdvancedPS = "0.2.4"
AdvancedVI = "0.1"
BangBang = "0.3"
Bijectors = "0.8, 0.9"
Expand Down
2 changes: 1 addition & 1 deletion src/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ function AbstractMCMC.step(
particles = AdvancedPS.ParticleContainer(x)

# Perform a particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler)
logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, reference)

# Pick a particle to be retained.
Ws = AdvancedPS.getweights(particles)
Expand Down
40 changes: 15 additions & 25 deletions test/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,47 +32,37 @@
end
@numerical_testset "gibbs inference" begin
Random.seed!(100)
alg = Gibbs(
CSMC(10, :s),
HMC(0.2, 4, :m))
chain = sample(gdemo(1.5, 2.0), alg, 1_500)
alg = Gibbs(CSMC(15, :s), HMC(0.2, 4, :m))
chain = sample(gdemo(1.5, 2.0), alg, 5_000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.15)

Random.seed!(100)

alg = Gibbs(
MH(:s),
HMC(0.2, 4, :m))
chain = sample(gdemo(1.5, 2.0), alg, 5000)
alg = Gibbs(MH(:s), HMC(0.2, 4, :m))
chain = sample(gdemo(1.5, 2.0), alg, 5_000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)

alg = Gibbs(
CSMC(15, :s),
ESS(:m))
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
alg = Gibbs(CSMC(15, :s), ESS(:m))
chain = sample(gdemo(1.5, 2.0), alg, 5_000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)

alg = CSMC(10)
chain = sample(gdemo(1.5, 2.0), alg, 5000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.25)
alg = CSMC(15)
chain = sample(gdemo(1.5, 2.0), alg, 5_000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)

setadsafe(true)

Random.seed!(200)
gibbs = Gibbs(
PG(10, :z1, :z2, :z3, :z4),
HMC(0.15, 3, :mu1, :mu2))
chain = sample(MoGtest_default, gibbs, 1500)
check_MoGtest_default(chain, atol=0.2)
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), HMC(0.15, 3, :mu1, :mu2))
chain = sample(MoGtest_default, gibbs, 5_000)
check_MoGtest_default(chain, atol=0.15)

setadsafe(false)

Random.seed!(200)
gibbs = Gibbs(
PG(10, :z1, :z2, :z3, :z4),
ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, gibbs, 1500)
check_MoGtest_default(chain, atol = 0.15)
gibbs = Gibbs(PG(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, gibbs, 5_000)
check_MoGtest_default(chain, atol=0.1)
end

@turing_testset "transitions" begin
Expand Down

0 comments on commit 69ab1bc

Please sign in to comment.