In [1]:
using Rocket
using ReactiveMP
using GraphPPL
using BenchmarkTools
using Distributions
using Random

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1342


In [2]:
@model function smoothing(n, x0, c::ConstVariable, P::ConstVariable)
    
    x_prior ~ NormalMeanVariance(mean(x0), cov(x0)) 

    x = randomvar(n)
    y = datavar(Float64, n)
    
    x_prev = x_prior
    
    for i in 1:n
        x[i] ~ x_prev + c
        y[i] ~ NormalMeanVariance(x[i], P)
        
        x_prev = x[i]
    end

    return x, y
end

smoothing (generic function with 1 method)

In [3]:
seed = 123

rng = MersenneTwister(seed)

P = 1.0
n = 500

data = collect(1:n) + rand(rng, Normal(0.0, sqrt(P)), n);

In [4]:
function inference(data, x0, P)
    n = length(data)
    
    _, (x, y) = smoothing(n, x0, 1.0, P);

    x_buffer  = buffer(Marginal, n)
    marginals = getmarginals(x)
    
    subscription = subscribe!(marginals, x_buffer)
    
    update!(y, data)
    
    unsubscribe!(subscription)
    
    return getvalues(x_buffer)
end

inference (generic function with 1 method)

In [5]:
x0_prior = NormalMeanVariance(0.0, 10000.0)

NormalMeanVariance{Float64}(μ=0.0, v=10000.0)

In [6]:
@benchmark res = inference($data, $x0_prior, $P)

BenchmarkTools.Trial: 244 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m16.704 ms[22m[39m … [35m31.415 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 35.97%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m18.564 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m20.499 ms[22m[39m ± [32m 3.807 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m10.70% ± 13.60%

  [39m [39m [39m▂[39m▄[39m▂[39m▅[39m█[39m▅[39m▄[34m▃[39m[39m▁[39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m█[39m█[39m█[39m█[39m█

In [7]:
inference(data, x0_prior, P)

500-element Vector{Marginal}:
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=520.3550543909051, w=500.0000999999997))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=1020.3551543909047, w=500.0000999999997))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=1520.3552543909045, w=500.0000999999997))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=2020.3553543909045, w=500.0000999999997))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=2520.355454390904, w=500.00009999999963))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=3020.3555543909038, w=500.00009999999963))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=3520.3556543909035, w=500.0000999999997))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=4020.355754390903, w=500.0000999999997))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=4520.355854390902, w=500.0000999999997))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=5020.355954390903, w=500.00009999999975))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=55