In [1]:
using Rocket
using ReactiveMP
using GraphPPL
using Distributions

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1278
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP
┌ Info: Precompiling GraphPPL [b3f8163a-e979-4e85-b43e-1f63d8c8b42c]
└ @ Base loading.jl:1278
│ - If you have GraphPPL checked out for development and have
│   added ReactiveMP as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with GraphPPL


In [54]:
@model function coin_model()
    a = datavar(Float64)
    b = datavar(Float64)
    y = datavar(Float64)
    
    θ ~ Beta(a, b)
    y ~ Bernoulli(θ)
    
    return y, a, b, θ
end

coin_model (generic function with 1 method)

In [62]:
N = 100000     # number of coin tosses
p = 0.5        # p parameter of the Bernoulli distribution
sbernoulli(n, p) = [(rand() < p) ? 1 : 0 for _ = 1:n] # define Bernoulli sampler
dataset = float.(sbernoulli(N, p)); # run N Bernoulli trials

In [63]:
function inference(data)
    model, (y, a, b, θ) = coin_model()
    
    prior_a = 3.0
    prior_b = 3.0
    
    fe = Vector{Float64}()
    θs = Vector{Marginal}()
    
    fe_sub = subscribe!(score(BetheFreeEnergy(), model, AsapScheduler()), (f) -> push!(fe, f))
    θ_sub = subscribe!(getmarginal(θ), (m) -> push!(θs, m))
    
    for d in data
        update!(y, d)
        update!(a, prior_a)
        update!(b, prior_b)

        prior_a, prior_b = params(getdata(θs[end]))
    end
    
    unsubscribe!(θ_sub)
    unsubscribe!(fe_sub)
    
    return θs, fe
end

inference (generic function with 1 method)

In [70]:
using BenchmarkTools



In [72]:
@btime est, fe = inference($dataset);

  840.480 ms (11500812 allocations: 506.06 MiB)


In [73]:
mean(est[end])

0.5006599604023758

In [75]:
fe[end]

0.6918381032936693

In [76]:
-log(1/2)

0.6931471805599453

└ @ Revise /Users/bvdmitri/.julia/packages/Revise/fwStr/src/packagedef.jl:551
