In [1]:
using Distributions
using Rocket
using ReactiveMP
using BenchmarkTools

import Base: show

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1260
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP


In [7]:
function kalman_filter_graph()
    x_prev_add   = AdditionNode("x_prev_add");
    add_1        = ConstantVariable("add_1", 1.0, x_prev_add.in2);

    x_prev_prior = GaussianMeanVarianceNode("x_prev_prior");
    x_prev_m     = PriorVariable("x_prev_m", x_prev_prior.mean);
    x_prev_v     = PriorVariable("x_prev_v", x_prev_prior.variance);

    x_prev = RandomVariable("x_prev", x_prev_prior.value, x_prev_add.in1);

    noise_node     = GaussianMeanVarianceNode("noise_node");
    noise_mean     = ConstantVariable("noise_mean", 0.0, noise_node.mean);
    noise_variance = ConstantVariable("noise_variance", 200.0, noise_node.variance);

    add_x_and_noise = AdditionNode("add_x_and_noise");

    x = RandomVariable("x", x_prev_add.out, add_x_and_noise.in1);
    n = RandomVariable("n", noise_node.value, add_x_and_noise.in2);
    y = ObservedVariable("y", add_x_and_noise.out);
    
    return x_prev_m, x_prev_v, x, y
end

function kalman(data)
    N = length(data)
    
    x_prev_m, x_prev_v, x, y = kalman_filter_graph()
    
    messages = Vector{AbstractMessage}(undef, N)
    updates  = Channel{Tuple{Float64, Float64, Float64}}(Inf) do ch
        while true
            update_data = take!(ch)
            update!(x_prev_m, update_data[1])
            update!(x_prev_v, update_data[2])
            update!(y, update_data[3])
        end
    end
    
    actor = sync(lambda(Tuple{Int, AbstractMessage},
        on_next = (d) -> begin

            index   = d[1]    
            message = d[2]

                    
            if index < N
                push!(updates, (mean(message.distribution), var(message.distribution), data[index + 1]))
            else
                complete!(x_prev_m.values)
                complete!(x_prev_v.values)
                complete!(y.values)
                close(updates)
            end
                
            messages[index] = message
        end,
        on_error = (e) -> println(e)
    ), timeout = 10000)
    
    @async begin
        subscribe!(inference(x) |> enumerate(), actor)
        push!(updates, (0.0, 1000.0, data[1]))
    end
    
    wait(actor)
    
    return messages
end

kalman (generic function with 1 method)

In [8]:
N = 600
data = collect(1:N) + sqrt(200.0) * randn(N);

In [11]:
@benchmark kalman($data)

BenchmarkTools.Trial: 
  memory estimate:  562.45 KiB
  allocs estimate:  14328
  --------------
  minimum time:     922.926 μs (0.00% GC)
  median time:      1.067 ms (0.00% GC)
  mean time:        1.482 ms (22.59% GC)
  maximum time:     69.094 ms (97.69% GC)
  --------------
  samples:          3390
  evals/sample:     1

In [10]:
@time kalman(data)

  0.001548 seconds (14.33 k allocations: 562.438 KiB)


600-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=23.091443540047656, σ=12.909944487358057))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=6.3334533059203935, σ=9.534625892455924))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.503212814087148, σ=7.905694150420949))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=9.15959418479361, σ=6.900655593423543))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.057642860305852, σ=6.201736729460423))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=4.737479085167591, σ=5.679618342470648))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=3.9836137790425887, σ=5.2704627669473))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=3.9612225483105625, σ=4.938647983247948))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=4.877957659443098, σ=4.66252404120157))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.722577243196825, σ=4.428074427700477))
 Stochastic