In [1]:
using Rocket
using ReactiveMP
using GraphPPL
using BenchmarkTools
using Distributions
using MacroTools

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1278
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP
┌ Info: Precompiling GraphPPL [b3f8163a-e979-4e85-b43e-1f63d8c8b42c]
└ @ Base loading.jl:1278
│ - If you have GraphPPL checked out for development and have
│   added ReactiveMP as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with GraphPPL


In [10]:
@model [ outbound_message_portal = EmptyPortal() ] function smoothing(n_observations, noise_mean, noise_var, splits)
    
    x_prior_mean = datavar(Float64)
    x_prior_var  = datavar(Float64)
    
    x_prior ~ NormalMeanVariance(x_prior_mean, x_prior_var) 

    x = randomvar(n_observations)
    y = datavar(Float64, n_observations)

    x_prev = x_prior

    for i in 1:n_observations
        
        x[i] ~ (x_prev + 1.0) where { portal = rem(i, splits) === 0 ? AsyncPortal() : EmptyPortal() }
        
        y[i] ~ x[i] + NormalMeanVariance(noise_mean, noise_var)
        
        x_prev = x[i]
    end

    return x_prior_mean, x_prior_var, x, y
end

smoothing (generic function with 1 method)

In [15]:
noise_real_mean = -10.0
noise_real_var  = 100.0

n = 100_000
splits = 715
data = collect(1:n) + rand(Normal(noise_real_mean, sqrt(noise_real_var)), n);

In [16]:
import ProgressMeter

function inference(; data, x_prior, splits)
    n = length(data)
    
    _, (x_prior_mean, x_prior_var, x, y) = smoothing(n, noise_real_mean, noise_real_var, splits);

    marginals = Vector{Marginal}(undef, n)
    msub = subscribe!(collectLatest(Marginal, map(getmarginal, x), Vector{Marginal}, identity), (m) -> copyto!(marginals, m))
    
    yield()
    
    update!(x_prior_mean, mean(x_prior))
    update!(x_prior_var, var(x_prior))
    update!(y, data)
    
    for i in 1:(div(n, splits) + 1)
        yield()
    end
    
    unsubscribe!(msub)
    
    return marginals
end

inference (generic function with 1 method)

In [17]:
@time res = inference(
    data = data,
    x_prior = NormalMeanVariance(0.0, 10000.0),
    splits = splits
)

 17.851256 seconds (92.92 M allocations: 5.559 GiB, 52.01% gc time)


100000-element Array{Marginal,1}:
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=0.9600137439904993, v=0.000999999900000007))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=1.9600137439904992, v=0.000999999900000007))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=2.9600137439904985, v=0.000999999900000007))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=3.9600137439904985, v=0.000999999900000007))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=4.9600137439904985, v=0.000999999900000007))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=5.9600137439904985, v=0.0009999999000000071))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=6.9600137439905, v=0.0009999999000000071))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVariance{Float64}(μ=7.960013743990501, v=0.0009999999000000071))
 Marginal{NormalMeanVariance{Float64}}(NormalMeanVarian