In [1]:
using Distributions
using Rocket
using ReactiveMP
using BenchmarkTools

import Base: show

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1260
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP


In [2]:
function kalman_filter_graph()
    x_prev_add   = AdditionNode("x_prev_add");
    add_1        = ConstantVariable("add_1", 1.0, x_prev_add.in2);

    x_prev_prior = GaussianMeanVarianceNode("x_prev_prior");
    x_prev_m     = PriorVariable("x_prev_m", x_prev_prior.mean);
    x_prev_v     = PriorVariable("x_prev_v", x_prev_prior.variance);

    x_prev = RandomVariable("x_prev", x_prev_prior.value, x_prev_add.in1);

    noise_node     = GaussianMeanVarianceNode("noise_node");
    noise_mean     = ConstantVariable("noise_mean", 0.0, noise_node.mean);
    noise_variance = ConstantVariable("noise_variance", 200.0, noise_node.variance);

    add_x_and_noise = AdditionNode("add_x_and_noise");

    x = RandomVariable("x", x_prev_add.out, add_x_and_noise.in1);
    n = RandomVariable("n", noise_node.value, add_x_and_noise.in2);
    y = ObservedVariable("y", add_x_and_noise.out);
    
    return x_prev_m, x_prev_v, x, y
end

function kalman(data)
    N = length(data)
    
    x_prev_m, x_prev_v, x, y = kalman_filter_graph()
    
    messages = Vector{AbstractMessage}(undef, N)
    updates  = Channel{Tuple{Float64, Float64, Float64}}(Inf) do ch
        while true
            update_data = take!(ch)
            update!(x_prev_m, update_data[1])
            update!(x_prev_v, update_data[2])
            update!(y, update_data[3])
        end
    end
    
    actor = sync(lambda(Tuple{AbstractMessage, Int},
        on_next = (d) -> begin
            # println(d)    
                
            message = d[1]
            index   = d[2]
                    
            if index < N
                push!(updates, (mean(message.distribution), var(message.distribution), data[index + 1]))
            else
                complete!(x_prev_m.values)
                complete!(x_prev_v.values)
                complete!(y.values)
                close(updates)
            end
                
            messages[index] = message
        end,
        on_error = (e) -> println(e)
    ))
    
    @async begin
        subscribe!(inference(x) |> safe() |> enumerate(), actor)
        push!(updates, (0.0, 1000.0, data[1]))
    end
    
    wait(actor)
    
    return messages
end

kalman (generic function with 1 method)

In [3]:
N = 128
data = collect(1:N) + sqrt(200.0) * randn(N);

In [6]:
@btime kalman($data);

  345.549 μs (4262 allocations: 154.25 KiB)


In [7]:
@time kalman(data)

  0.000710 seconds (4.26 k allocations: 154.234 KiB)


128-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-2.0769999460614073, σ=12.909944487358057))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=0.8912630631289872, σ=9.534625892455924))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=5.33119085457633, σ=7.905694150420949))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=5.742951350408047, σ=6.900655593423543))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=10.319694743802762, σ=6.201736729460423))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=12.392982588110826, σ=5.679618342470648))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=16.429146955706095, σ=5.2704627669473))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=17.33010833714951, σ=4.938647983247948))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=18.26004947000655, σ=4.66252404120157))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=17.245917230473182, σ=4.428074427700477))
 Stochas