In [1]:
using Rocket
using ReactiveMP
using GraphPPL
using BenchmarkTools
using Distributions
using MacroTools

In [2]:
@model function smoothing(n, x0, P::ConstVariable)
    
    x_prior ~ NormalMeanVariance(mean(x0), cov(x0)) 

    x = randomvar(n)
    y = datavar(Float64, n)
    c = constvar(1.0)

    x_prev = x_prior

    for i in 1:n
        x[i] ~ x_prev + c
        y[i] ~ NormalMeanVariance(x[i], P)
        
        x_prev = x[i]
    end

    return x, y
end

smoothing (generic function with 1 method)

In [3]:
@rule NormalMeanVariance(:μ, Marginalisation) (m_out::Any, m_v::Missing) = missing
@rule NormalMeanVariance(:μ, Marginalisation) (m_out::Missing, m_v::Any) = missing

@rule typeof(+)(:in1, Marginalisation) (m_out::Missing, m_in2::Any) = missing
@rule typeof(+)(:in1, Marginalisation) (m_out::Any, m_in2::Missing) = missing

In [7]:
P = 1.0

n = 500
data = convert(Vector{Union{Float64, Missing}}, collect(1:n) + rand(Normal(0.0, sqrt(P)), n));

for index in map(d -> rem(abs(d), n), rand(Int, Int(n / 2)))
    data[index] = missing
end

In [8]:
function inference(data, x0, P)
    n = length(data)
    
    _, (x, y) = smoothing(n, x0, P);

    buffer    = Vector{Marginal}(undef, n)
    marginals = getmarginals(x)
    
    subscription = subscribe!(marginals, (ms) -> copyto!(buffer, ms))
    
    update!(y, data)
    
    unsubscribe!(subscription)
    
    return buffer
end

inference (generic function with 1 method)

In [10]:
x0_prior = NormalMeanVariance(0.0, 1000.0)
res = inference(data, x0_prior, P)

500-element Array{Marginal,1}:
 Marginal(NormalMeanVariance{Float64}(μ=0.9701310471387815, v=0.0032573183800704265))
 Marginal(NormalMeanVariance{Float64}(μ=1.9701310471387816, v=0.0032573183800704265))
 Marginal(NormalMeanVariance{Float64}(μ=2.970131047138782, v=0.0032573183800704273))
 Marginal(NormalMeanVariance{Float64}(μ=3.9701310471387825, v=0.0032573183800704273))
 Marginal(NormalMeanVariance{Float64}(μ=4.970131047138782, v=0.003257318380070427))
 Marginal(NormalMeanVariance{Float64}(μ=5.970131047138782, v=0.003257318380070427))
 Marginal(NormalMeanVariance{Float64}(μ=6.97013104713878, v=0.003257318380070426))
 Marginal(NormalMeanVariance{Float64}(μ=7.97013104713878, v=0.003257318380070426))
 Marginal(NormalMeanVariance{Float64}(μ=8.97013104713878, v=0.003257318380070426))
 Marginal(NormalMeanVariance{Float64}(μ=9.97013104713878, v=0.0032573183800704256))
 Marginal(NormalMeanVariance{Float64}(μ=10.97013104713878, v=0.003257318380070426))
 Marginal(NormalMeanVariance{Float64}(μ=1