In [6]:
using Rocket
using ReactiveMP
using GraphPPL
using BenchmarkTools
using Distributions

In [7]:
@model function smoothing(n_observations, noise_mean, noise_var)
    
   x_prior_mean = datavar(Float64)
   x_prior_var  = datavar(Float64)
   
   x_prior ~ NormalMeanVariance(x_prior_mean, x_prior_var)
   
   x = randomvar(n_observations)
   y = datavar(Float64, n_observations)
   
   x_prev = x_prior
   
   for i in 1:n_observations
       x[i] ~ x_prev + 1.0
       y[i] ~ x[i] + NormalMeanVariance(noise_mean, noise_var)
       x_prev = x[i]
   end
   
   return x_prior_mean, x_prior_var, x, y
end

smoothing (generic function with 1 method)

In [8]:
noise_real_mean = 0.0
noise_real_var  = 10.0

n = 100
data = map(d -> d + rand(Normal(noise_real_mean, sqrt(noise_real_var))), collect(1:n));

In [13]:
function inference(; data, x_prior)
    n = length(data)
    
    x_prior_mean, x_prior_var, x, y = smoothing(n, noise_real_mean, noise_real_var);

    subscriptions = Vector{Teardown}(undef, n)
    marginals = Vector{Marginal}(undef, n)
    
    for i in 1:n
         @inbounds subscriptions[i] = subscribe!(getmarginal(x[i]), (m) -> marginals[i] = m)
    end
    
    update!(x_prior_mean, mean(x_prior))
    update!(x_prior_var, var(x_prior))
    
    for i in 1:n
        @inbounds update!(y[i], data[i])
    end
    
    foreach(unsubscribe!, subscriptions)
    
    return marginals
end

inference (generic function with 1 method)

In [14]:
inference(
    data = data,
    x_prior = NormalMeanVariance(0.0, 10000.0)
)

100-element Array{Marginal,1}:
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=1.052071277527567, v=0.09999900001000005))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=2.052071277527567, v=0.09999900001000007))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=3.0520712775275665, v=0.09999900001000006))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=4.052071277527566, v=0.09999900001000003))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=5.052071277527565, v=0.09999900001000003))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=6.052071277527566, v=0.09999900001000003))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=7.052071277527566, v=0.09999900001000002))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=8.052071277527569, v=0.09999900001000005))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=9.0