In [1]:
using Distributions
using Rocket
using ReactiveMP
using BenchmarkTools

import Base: show

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1278
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP


In [2]:
function filtering_model()
    model = Model()
    
    x_prior_mean = datavar(model, :x_prior_mean, Dirac{Float64})
    x_prior_var  = datavar(model, :x_prior_var, Dirac{Float64})
    
    _, x_prior = make_node(model, NormalMeanVariance, AutoVar(:x_prior), x_prior_mean, x_prior_var)
    
    c = constvar(model, :c, 1.0)
    
    _, x = make_node(model, +, AutoVar(:x), x_prior, c)
    
    noise_mean = constvar(model, :noise_mean, 0.0)
    noise_var  = constvar(model, :noise_var, 200.0)
    
    _, noise = make_node(model, NormalMeanVariance, AutoVar(:noise), noise_mean, noise_var)
    
    y = datavar(model, :y, Dirac{Float64})
    
    _ = make_node(model, +, y, x, noise)
    
    activate!(model)
    
    return x_prior_mean, x_prior_var, x, y
end

function filtering(data)
    N = length(data)
    
    x_prior_mean, x_prior_var, x, y = filtering_model()
    
    marginals = Vector{NormalMeanVariance{Float64}}()
    
    subscription = subscribe!(getmarginal(x), (t) -> begin
        update!(x_prior_mean, mean(t))
        update!(x_prior_var, var(t))

        push!(marginals, getdata(t))
    end)
    
    update!(x_prior_mean, 0.0)
    update!(x_prior_var, 100000.0)
    
    for d in data
        update!(y, d)
    end
    
    unsubscribe!(subscription)
    
    return marginals
end

filtering (generic function with 1 method)

In [3]:
N = 600
data = collect(1:N) + sqrt(200.0) * randn(N);

In [4]:
@time filtering(data); # Initial compilation

  3.644662 seconds (9.11 M allocations: 487.569 MiB, 4.21% gc time)


In [5]:
@time filtering(data); # Subsequent runs

  0.001839 seconds (32.30 k allocations: 1.612 MiB)


In [None]:
@btime filtering($data); # Performance benchmark

In [None]:
using Plots

In [None]:
real_data = collect(1:100)
obs_data  = real_data .+ (sqrt(200.0) * randn(100));
estimated = filtering(obs_data)

graph = plot(mean.(estimated), ribbon = std.(estimated), label = :estimated)
graph = plot!(graph, real_data, label = :real)
graph = scatter!(graph, obs_data, ms = 3, label = :observed)

plot(graph, size = (1000, 500))

In [None]:
filtering(data)