In [1]:
using Rocket
using ReactiveMP
using GraphPPL
using BenchmarkTools
using Distributions
using MacroTools

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1278
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP
┌ Info: Precompiling GraphPPL [b3f8163a-e979-4e85-b43e-1f63d8c8b42c]
└ @ Base loading.jl:1278
│ - If you have GraphPPL checked out for development and have
│   added ReactiveMP as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with GraphPPL


In [60]:
@model function smoothing(n_observations, noise_mean, noise_var)
    
    x_prior_mean = datavar(Float64)
    x_prior_var  = datavar(Float64)
    
    x_prior ~ NormalMeanVariance(x_prior_mean, x_prior_var)

    x = randomvar(n_observations)
    y = datavar(Float64, n_observations)

    x_prev = x_prior

    for i in 1:n_observations
        x[i] ~ x_prev + 1.0
        y[i] ~ x[i] + NormalMeanVariance(noise_mean, noise_var)
        
        x_prev = x[i]
    end

    return x_prior_mean, x_prior_var, x, y
end

smoothing (generic function with 1 method)

In [61]:
noise_real_mean = -10.0
noise_real_var  = 100.0

n = 500
data = collect(1:n) + rand(Normal(noise_real_mean, sqrt(noise_real_var)), n);

In [62]:
function inference(; data, x_prior)
    n = length(data)
    
    _, (x_prior_mean, x_prior_var, x, y) = smoothing(n, noise_real_mean, noise_real_var);

    marginals = Vector{Marginal}(undef, n)
    
    msub = subscribe!(collectLatest(Marginal, map(getmarginal, x)), (result) -> copyto!(marginals, result))
    
    update!(x_prior_mean, mean(x_prior))
    update!(x_prior_var, var(x_prior))
    update!(y, data)
    
    unsubscribe!(msub)
    
    return marginals
end

inference (generic function with 1 method)

In [63]:
inference(
    data = data,
    x_prior = NormalMeanVariance(0.0, 10000.0)
)

500-element Array{Marginal,1}:
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=0.631192850547531, v=0.19999600007999815))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=1.631192850547531, v=0.1999960000799982))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=2.631192850547531, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=3.6311928505475306, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=4.6311928505475315, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=5.631192850547531, v=0.19999600007999815))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=6.6311928505475315, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=7.6311928505475315, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=8

In [64]:
@btime inference(
    data = $data,
    x_prior = NormalMeanVariance(0.0, 10000.0)
)

  36.474 ms (473716 allocations: 29.84 MiB)


500-element Array{Marginal,1}:
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=0.631192850547531, v=0.19999600007999815))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=1.631192850547531, v=0.1999960000799982))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=2.631192850547531, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=3.6311928505475306, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=4.6311928505475315, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=5.631192850547531, v=0.19999600007999815))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=6.6311928505475315, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=7.6311928505475315, v=0.19999600007999818))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=8