Skip to content

Commit

Permalink
NUTS implementation (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Apr 24, 2017
1 parent b663714 commit c724040
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/samplers/nuts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ function step(model, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)
else
# Set parameters
δ = spl.alg.delta
λ = spl.alg.lambda
ϵ = spl.info[]

dprintln(2, "current ϵ: ")
Expand All @@ -54,6 +53,7 @@ function step(model, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)
u = rand() * exp(-find_H(p, model, vi, spl))

θm, θp, rm, rp, j, vi_new, n, s = deepcopy(vi), deepcopy(vi), deepcopy(p), deepcopy(p), 0, deepcopy(vi), 1, 1
local α, n_α
while s == 1
v_j = rand([-1, 1]) # Note: this variable actually does not depend on j;
# it is set as `v_j` just to be consistent to the paper
Expand All @@ -77,8 +77,6 @@ function step(model, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)

cleandual!(vi)

α = min(1, exp(-ΔH)) # MH accept rate

# Use Dual Averaging to adapt ϵ
m = spl.info[:m] += 1
if m <= spl.alg.n_adapt
Expand All @@ -95,7 +93,7 @@ function step(model, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)
end
end

function build_tree(θ, r, u, v, j, ϵ, θ0, r0)
function build_tree(θ, r, u, v, j, ϵ, θ0, r0, model, spl)
doc"""
- θ : model parameter
- r : momentum variable
Expand All @@ -114,12 +112,12 @@ function build_tree(θ, r, u, v, j, ϵ, θ0, r0)
return θ′, r′, θ′, r′, θ′, n′, s′, min(1, exp(-find_H(r′, model, θ′, spl) - (-find_H(r0, model, θ0, spl)))), 1
else
# Recursion - build the left and right subtrees.
θm, rm, θp, rp, θ′, n′, s′, α′, n′_α = build_tree(θ, r, u, v, j - 1, ϵ, θ0, r0)
θm, rm, θp, rp, θ′, n′, s′, α′, n′_α = build_tree(θ, r, u, v, j - 1, ϵ, θ0, r0, model, spl)
if s′ == 1
if v == -1
θm, rm, _, _, θ′′, n′′, s′′, α′′, n′′_α = build_tree(θm, rm, u, v, j - 1, ϵ, θ0, r0)
θm, rm, _, _, θ′′, n′′, s′′, α′′, n′′_α = build_tree(θm, rm, u, v, j - 1, ϵ, θ0, r0, model, spl)
else
_, _, θp, rp, θ′′, n′′, s′′, α′′, n′′_α = build_tree(θp, rp, u, v, j - 1, ϵ, θ0, r0)
_, _, θp, rp, θ′′, n′′, s′′, α′′, n′′_α = build_tree(θp, rp, u, v, j - 1, ϵ, θ0, r0, model, spl)
end
if rand() < n′′ / (n′ + n′′)
θ′ = θ′′
Expand Down
23 changes: 23 additions & 0 deletions test/nuts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Distributions, Turing

@model gdemo(x) = begin
# s ~ InverseGamma(2,3)
s = 1
m ~ Normal(0,sqrt(s))
x[1] ~ Normal(m, sqrt(s))
x[2] ~ Normal(m, sqrt(s))
return s, m
end

alg = NUTS(1000, 200, 0.65)
res = sample(gdemo([1.5, 2.0]), alg)
# print(mean(res[:m]))

ans1 = abs(mean(res[:m]) - 7/6) <= 0.15
print("E[m] ≈ $(7/6) ? ")
if ans1
print_with_color(:green, "\n")
else
print_with_color(:red, " X\n")
print_with_color(:red, " m = $(mean(res[:m])), diff = $(abs(mean(res[:m]) - 7/6))\n")
end

0 comments on commit c724040

Please sign in to comment.