In [1]:
using Rocket
using ReactiveMP
using GraphPPL
using BenchmarkTools
using Distributions
using MacroTools

┌ Info: Precompiling GraphPPL [b3f8163a-e979-4e85-b43e-1f63d8c8b42c]
└ @ Base loading.jl:1278
│ - If you have GraphPPL checked out for development and have
│   added ReactiveMP as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with GraphPPL


In [2]:
@model function smoothing(n, k, x0, P::ConstVariable)
    
    x_prior ~ NormalMeanVariance(mean(x0), cov(x0)) 

    x = randomvar(n)
    y = datavar(Float64, n)
    c = constvar(1.0)

    x_prev = x_prior
    
    sync = PostponeScheduler()

    for i in 1:n
        x[i] ~ (x_prev + c) where { portal = rem(i, k) === 0 ? ScheduleOnPortal(sync) : EmptyPortal() }
        y[i] ~ NormalMeanVariance(x[i], P)
        
        x_prev = x[i]
    end

    return sync, x, y
end

smoothing (generic function with 1 method)

In [3]:
P = 1.0

n = 10_000
k = 500
data = collect(1:n) + rand(Normal(0.0, sqrt(P)), n);

In [6]:
function inference(; data, k, x0, P)
    n = length(data)
    
    _, (sync, x, y) = smoothing(n, k, x0, P);

    buffer    = Vector{Marginal}(undef, n)
    marginals = getmarginals(x)
    
    subscription = subscribe!(marginals, (ms) -> copyto!(buffer, ms))
    
    wait(sync)
    
    update!(y, data)
    
    wait(sync)
    
    unsubscribe!(subscription)
    
    return buffer
end

inference (generic function with 1 method)

In [8]:
@benchmark res = inference(
    data = $data,
    k = $k,
    x0 = NormalMeanVariance(0.0, 10000.0),
    P = $P
)

BenchmarkTools.Trial: 
  memory estimate:  413.73 MiB
  allocs estimate:  6610062
  --------------
  minimum time:     751.389 ms (31.61% GC)
  median time:      923.108 ms (40.31% GC)
  mean time:        942.329 ms (42.82% GC)
  maximum time:     1.192 s (53.44% GC)
  --------------
  samples:          6
  evals/sample:     1