In [1]:
using Rocket
using ReactiveMP
using GraphPPL
using BenchmarkTools
using Distributions
using MacroTools

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1278
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP
┌ Info: Precompiling GraphPPL [b3f8163a-e979-4e85-b43e-1f63d8c8b42c]
└ @ Base loading.jl:1278
│ - If you have GraphPPL checked out for development and have
│   added ReactiveMP as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with GraphPPL


In [2]:
@model function smoothing(n, x0, c, P)
    
    x_prior ~ NormalMeanVariance(mean(x0), cov(x0)) 

    x = randomvar(n)
    y = datavar(Float64, n)

    x_prev = x_prior

    for i in 1:n
        x[i] ~ x_prev + c
        y[i] ~ NormalMeanVariance(x[i], P)
        
        x_prev = x[i]
    end

    return x, y
end

smoothing (generic function with 1 method)

In [96]:
using Random

In [97]:
P = 1.0

Random.seed!(123)

n = 250
data = -5.0 .+ collect(1:n) + rand(Normal(0.0, sqrt(P)), n);

In [4]:
function inference(data, x0, c, P)
    n = length(data)
    
    model, (x, y) = smoothing(n, x0, c, P);

    ms_buffer = Vector{Marginal}(undef, n)
    fe_buffer = nothing
    
    marginals = collectLatest(getmarginals(x))
    
    fe_subscription = subscribe!(score(BetheFreeEnergy(), model, AsapScheduler()), (fe) -> fe_buffer = fe)
    ms_subscription = subscribe!(marginals, (ms) -> copyto!(ms_buffer, ms))
    
    update!(y, data)
    
    unsubscribe!(ms_subscription)
    unsubscribe!(fe_subscription)
    
    return ms_buffer, fe_buffer
end

inference (generic function with 1 method)

In [94]:
# c[1] is C
# c[2] is μ0
function f(c)
    x0_prior = NormalMeanVariance(c[2], 100.0)
    ms, fe = inference(data, x0_prior, c[1], P)
    return fe
end

f (generic function with 1 method)

In [98]:
using Optim

In [99]:
res = optimize(f, ones(2), GradientDescent(), Optim.Options(g_tol = 1e-3, iterations = 100, store_trace = true, show_trace = true))

Iter     Function value   Gradient norm 
     0     3.655789e+02     8.149754e+02
 * time: 0.0001761913299560547
     1     3.653239e+02     5.997070e-02
 * time: 0.23605704307556152
     2     3.652277e+02     3.413628e+02
 * time: 1.9465692043304443
     3     3.651829e+02     2.788324e-02
 * time: 2.1853091716766357
     4     3.651634e+02     1.591045e+02
 * time: 3.8991951942443848
     5     3.651536e+02     1.385709e-02
 * time: 4.15316915512085
     6     3.651486e+02     7.898214e+01
 * time: 5.894620180130005
     7     3.651462e+02     6.599552e-03
 * time: 6.1291420459747314
     8     3.651447e+02     3.462404e+01
 * time: 7.8428730964660645
     9     3.651442e+02     2.002090e-03
 * time: 8.084136009216309
    10     3.651441e+02     1.142017e+01
 * time: 9.761056184768677
    11     3.651441e+02     9.742167e-04
 * time: 10.023941993713379


 * Status: success

 * Candidate solution
    Final objective value:     3.651441e+02

 * Found with
    Algorithm:     Gradient Descent

 * Convergence measures
    |x - x'|               = 8.77e-06 ≰ 0.0e+00
    |x - x'|/|x'|          = 1.79e-06 ≰ 0.0e+00
    |f(x) - f(x')|         = 5.01e-05 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 1.37e-07 ≰ 0.0e+00
    |g(x)|                 = 9.74e-04 ≤ 1.0e-03

 * Work counters
    Seconds run:   10  (vs limit Inf)
    Iterations:    11
    f(x) calls:    83
    ∇f(x) calls:   83


In [100]:
res.minimizer # Real values are indeed (c = 1.0 and μ0 = -5.0)

2-element Array{Float64,1}:
  1.0006315218517985
 -4.90059862195555