In [2]:
using Distributions
using ReactiveMP
using Rocket
using BenchmarkTools

import Base: show

In [16]:
function kalman_filter_graph()
    x_prev_add   = AdditionNode("x_prev_add");
    add_1        = ConstantVariable("add_1", 1.0, x_prev_add.in2);

    x_prev_prior = GaussianMeanVarianceNode("x_prev_prior");
    x_prev_m     = PriorVariable("x_prev_m", x_prev_prior.mean);
    x_prev_v     = PriorVariable("x_prev_v", x_prev_prior.variance);

    x_prev = RandomVariable("x_prev", x_prev_prior.value, x_prev_add.in1);

    noise_node     = GaussianMeanVarianceNode("noise_node");
    noise_mean     = ConstantVariable("noise_mean", 0.0, noise_node.mean);
    noise_variance = ConstantVariable("noise_variance", 200.0, noise_node.variance);

    add_x_and_noise = AdditionNode("add_x_and_noise");

    x = RandomVariable("x", x_prev_add.out, add_x_and_noise.in1);
    n = RandomVariable("n", noise_node.value, add_x_and_noise.in2);
    y = ObservedVariable("y", add_x_and_noise.out);
    
    return x_prev_m, x_prev_v, x, y
end

function kalman(data)
    N = length(data)
    
    x_prev_m, x_prev_v, x, y = kalman_filter_graph()
    
    messages = Vector{AbstractMessage}(undef, N)
    updates  = Channel{Tuple{Float64, Float64, Float64}}(Inf) do ch
        while true
            update_data = take!(ch)
            update!(x_prev_m, update_data[1])
            update!(x_prev_v, update_data[2])
            update!(y, update_data[3])
        end
    end
    
    actor = sync(lambda(Tuple{AbstractMessage, Int},
        on_next = (d) -> begin
            # println(d)    
                
            message = d[1]
            index   = d[2]
                    
            if index < N
                push!(updates, (mean(message.distribution), var(message.distribution), data[index + 1]))
            else
                complete!(x_prev_m.values)
                complete!(x_prev_v.values)
                complete!(y.values)
                close(updates)
            end
                
            messages[index] = message
        end,
        on_error = (e) -> println(e)
    ))
    
    @async begin
        subscribe!(inference(x) |> safe() |> enumerate(), actor)
        push!(updates, (0.0, 1000.0, data[1]))
    end
    
    wait(actor)
    
    return messages
end

kalman (generic function with 2 methods)

In [17]:
N = 5000
data = collect(1:N) + sqrt(200.0) * randn(N);

In [21]:
@btime s = kalman($data)

  8.528 ms (155492 allocations: 5.40 MiB)


5000-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-5.8090765603313, σ=12.909944487358057))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-11.78623001892783, σ=9.534625892455924))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-10.829235558908996, σ=7.905694150420949))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-7.190578666903145, σ=6.900655593423543))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-4.205553014335523, σ=6.201736729460423))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-1.697258853940365, σ=5.679618342470648))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-2.448883333212301, σ=5.2704627669473))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-1.557294814603282, σ=4.938647983247948))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=1.0083754783289836, σ=4.66252404120157))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=1.5377819379583102, σ=4.428074427700477))
 Sto

In [22]:
@time s = kalman(data)

  0.013000 seconds (155.49 k allocations: 5.398 MiB)


5000-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-5.8090765603313, σ=12.909944487358057))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-11.78623001892783, σ=9.534625892455924))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-10.829235558908996, σ=7.905694150420949))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-7.190578666903145, σ=6.900655593423543))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-4.205553014335523, σ=6.201736729460423))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-1.697258853940365, σ=5.679618342470648))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-2.448883333212301, σ=5.2704627669473))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-1.557294814603282, σ=4.938647983247948))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=1.0083754783289836, σ=4.66252404120157))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=1.5377819379583102, σ=4.428074427700477))
 Sto