In [3]:
using Distributions
using Rocket
using ReactiveMP
using BenchmarkTools

import Base: show

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1260
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP


In [77]:
function kalman_filter_graph()
    x_prev_add = AdditionNode()
    add_1      = constvar(:add_1, 1.0)
    
    connect!(x_prev_add, :in2, add_1, 1) # TODO index = 1
    
    x_prev_prior = GaussianMeanVariance()
    x_prev_m     = datavar(:x_prev_m, Float64)
    x_prev_v     = datavar(:x_prev_v, Float64)
    
    connect!(x_prev_prior, :mean, x_prev_m, 1)
    connect!(x_prev_prior, :variance, x_prev_v, 1)

    x_prev = randomvar(:x_prev, 2)
    
    connect!(x_prev_prior, :value, x_prev, 1)
    connect!(x_prev_add, :in1, x_prev, 2)

    noise_node     = GaussianMeanVariance()
    noise_mean     = constvar(:noise_mean, 0.0)
    noise_variance = constvar(:noise_mean, 200.0)
    
    connect!(noise_node, :mean, noise_mean, 1)
    connect!(noise_node, :variance, noise_variance, 1)
    
    add_x_and_noise = AdditionNode()
    
    x = randomvar(:x, 2)
    
    connect!(x_prev_add, :out, x, 1)
    connect!(add_x_and_noise, :in1, x, 2)
    
    n = randomvar(:n, 2)
    
    connect!(noise_node, :value, n, 1)
    connect!(add_x_and_noise, :in2, n, 2)
    
    y = datavar(:y, Float64)
    
    connect!(add_x_and_noise, :out, y, 1)
    
    activate!(x_prev_add)
    activate!(x_prev_prior)
    activate!(noise_node)
    activate!(add_x_and_noise)
    
    return x_prev_m, x_prev_v, x, y
end

function kalman(data)
    N = length(data)
    
    x_prev_m, x_prev_v, x, y = kalman_filter_graph()
    
    beliefs = Vector{Normal{Float64}}(undef, N)
    updates = Channel{Tuple{Float64, Float64, Float64}}(Inf) do ch
        while true
            update_data = take!(ch)
            update!(x_prev_m, update_data[1])
            update!(x_prev_v, update_data[2])
            update!(y, update_data[3])
        end
    end
    
    actor = sync(lambda(Tuple{Int, AbstractBelief},
        on_next = (d) -> begin
        
            index  = d[1]    
            belief = getdata(d[2])
                    
            if index < N
                push!(updates, (mean(belief), var(belief), data[index + 1]))
            else
                finish!(x_prev_m)
                finish!(x_prev_v)
                finish!(y)
                close(updates)
            end
                
            beliefs[index] = belief
        end,
        on_error = (e) -> println(e)
    ), timeout = 10000)
    
    @async begin
        subscribe!(belief(x) |> enumerate(), actor)
        push!(updates, (0.0, 1000.0, data[1]))
    end
    
    wait(actor)
    
    return beliefs
end

kalman (generic function with 1 method)

In [78]:
N = 600
data = collect(1:N) + sqrt(200.0) * randn(N);

In [76]:
@time kalman($data)

  12.351 ms (21609 allocations: 747.42 KiB)


600-element Array{Any,1}:
 Normal{Float64}(μ=3.364316403367523, σ=12.909944487358057)
 Normal{Float64}(μ=5.2562992747353015, σ=9.534625892455924)
 Normal{Float64}(μ=1.20823697043715, σ=7.905694150420949)
 Normal{Float64}(μ=4.653439830725727, σ=6.900655593423543)
 Normal{Float64}(μ=10.336071622290161, σ=6.201736729460423)
 Normal{Float64}(μ=7.795284753380579, σ=5.679618342470648)
 Normal{Float64}(μ=8.646727835260892, σ=5.2704627669473)
 Normal{Float64}(μ=7.5216986544770625, σ=4.938647983247948)
 Normal{Float64}(μ=8.636975510234665, σ=4.66252404120157)
 Normal{Float64}(μ=9.77466023593911, σ=4.428074427700477)
 Normal{Float64}(μ=11.378428009426054, σ=4.225771273642584)
 Normal{Float64}(μ=11.519592203071745, σ=4.0488816508945815)
 Normal{Float64}(μ=12.224734259265313, σ=3.8924947208076164)
 ⋮
 Normal{Float64}(μ=589.3058833170217, σ=0.5826176387363623)
 Normal{Float64}(μ=590.3454832720731, σ=0.5821238530397734)
 Normal{Float64}(μ=591.3547336371323, σ=0.5816313207127862)
 Normal{Float64}(μ=5

In [11]:
@benchmark kalman($data)

BenchmarkTools.Trial: 
  memory estimate:  562.45 KiB
  allocs estimate:  14328
  --------------
  minimum time:     922.926 μs (0.00% GC)
  median time:      1.067 ms (0.00% GC)
  mean time:        1.482 ms (22.59% GC)
  maximum time:     69.094 ms (97.69% GC)
  --------------
  samples:          3390
  evals/sample:     1

In [10]:
@time kalman(data)

  0.001548 seconds (14.33 k allocations: 562.438 KiB)


600-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=23.091443540047656, σ=12.909944487358057))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=6.3334533059203935, σ=9.534625892455924))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.503212814087148, σ=7.905694150420949))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=9.15959418479361, σ=6.900655593423543))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.057642860305852, σ=6.201736729460423))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=4.737479085167591, σ=5.679618342470648))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=3.9836137790425887, σ=5.2704627669473))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=3.9612225483105625, σ=4.938647983247948))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=4.877957659443098, σ=4.66252404120157))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.722577243196825, σ=4.428074427700477))
 Stochastic