In [30]:
using Distributions
using Rocket
using ReactiveMP
using BenchmarkTools

import Base: show

In [37]:
function kalman_filter_graph()
    x_prev_add   = AdditionNode("x_prev_add");
    add_1        = ConstantVariable("add_1", 1.0, x_prev_add.in2);

    x_prev_prior = GaussianMeanVarianceNode("x_prev_prior");
    x_prev_m     = PriorVariable("x_prev_m", x_prev_prior.mean);
    x_prev_v     = PriorVariable("x_prev_v", x_prev_prior.variance);

    x_prev = RandomVariable("x_prev", x_prev_prior.value, x_prev_add.in1);

    noise_node     = GaussianMeanVarianceNode("noise_node");
    noise_mean     = ConstantVariable("noise_mean", 0.0, noise_node.mean);
    noise_variance = ConstantVariable("noise_variance", 200.0, noise_node.variance);

    add_x_and_noise = AdditionNode("add_x_and_noise");

    x = RandomVariable("x", x_prev_add.out, add_x_and_noise.in1);
    n = RandomVariable("n", noise_node.value, add_x_and_noise.in2);
    y = ObservedVariable("y", add_x_and_noise.out);
    
    return x_prev_m, x_prev_v, x, y
end

function kalman(data)
    N = length(data)
    
    x_prev_m, x_prev_v, x, y = kalman_filter_graph()
    
    messages = Vector{AbstractMessage}(undef, N)
    updates  = Channel{Tuple{Float64, Float64, Float64}}(Inf) do ch
        while true
            update_data = take!(ch)
            update!(x_prev_m, update_data[1])
            update!(x_prev_v, update_data[2])
            update!(y, update_data[3])
        end
    end
    
    actor = sync(lambda(Tuple{Int, AbstractMessage},
        on_next = (d) -> begin

            index   = d[1]    
            message = d[2]

                    
            if index < N
                push!(updates, (mean(message.distribution), var(message.distribution), data[index + 1]))
            else
                complete!(x_prev_m.values)
                complete!(x_prev_v.values)
                complete!(y.values)
                close(updates)
            end
                
            messages[index] = message
        end,
        on_error = (e) -> println(e)
    ), timeout = 10000)
    
    @async begin
        subscribe!(inference(x) |> safe() |> enumerate(), actor)
        push!(updates, (0.0, 1000.0, data[1]))
    end
    
    wait(actor)
    
    return messages
end

kalman (generic function with 1 method)

In [38]:
N = 600
data = collect(1:N) + sqrt(200.0) * randn(N);

In [43]:
@btime kalman($data)

  1.009 ms (14331 allocations: 562.70 KiB)


600-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-15.063328211119412, σ=12.909944487358057))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-2.0944705427605106, σ=9.534625892455924))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-0.6053103347121965, σ=7.905694150420949))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-1.086678923847172, σ=6.900655593423543))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-0.7397687660612686, σ=6.201736729460423))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-1.9460486303666849, σ=5.679618342470648))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=0.09812697734232478, σ=5.2704627669473))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=3.874076042837787, σ=4.938647983247948))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=6.770683972930398, σ=4.66252404120157))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.143948799383181, σ=4.428074427700477))
 