In [2]:
using Distributions
using Rocket
using ReactiveMP
using BenchmarkTools

using Profile
using PProf
using ProfileSVG

import Base: show

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1260
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP


In [7]:
function kalman_filter_graph()
    x_prev_add = AdditionNode()
    
    x_prior = datavar(:x_prior, Normal{Float64})
    add_1   = constvar(:add_1, 1.0)

    connect!(x_prev_add, :in1, x_prior, 1)
    connect!(x_prev_add, :in2, add_1, 1)
    
    noise = constvar(:noise, Normal(0.0, sqrt(200.0)))
    
    add_x_and_noise = AdditionNode()
    
    x = simplerandomvar(:x)
    
    connect!(x_prev_add, :out, x, 1)
    connect!(add_x_and_noise, :in1, x, 2)
    connect!(add_x_and_noise, :in2, noise, 1)
    
    y = datavar(:y, Float64)
    
    connect!(add_x_and_noise, :out, y, 1)
    
    activate!(x_prev_add)
    activate!(add_x_and_noise)
    
    return x_prior, x, y
end

function kalman(data)
    N = length(data)
    
    x_prior, x, y = kalman_filter_graph()
    
    beliefs = Vector{Normal{Float64}}(undef, N)
    updates = Channel{Tuple{Normal{Float64}, Float64}}(Inf) do ch
        while true
            update_data = take!(ch)
            update!(x_prior, update_data[1])
            update!(y, update_data[2])
        end
    end
    
    actor = sync(lambda(Tuple{Int, AbstractBelief},
        on_next = (d) -> begin
        
            index  = d[1]    
            belief = getdata(d[2])
                    
            if index < N
                push!(updates, (belief, data[index + 1]))
            else
                finish!(x_prior)
                finish!(y)
                close(updates)
            end
                
            beliefs[index] = belief
        end,
        on_error = (e) -> println(e)
    ), timeout = 10000)
    
    @async begin
        try
            subscribe!(getmarginal(x) |> enumerate(), actor)
            push!(updates, (Normal(0.0, 1000.0), data[1]))
        catch err 
            println(err)
        end
    end
    
    wait(actor)
    
    return beliefs
end

kalman (generic function with 1 method)

In [8]:
N = 600
data = collect(1:N) + sqrt(200.0) * randn(N);

In [11]:
@btime kalman($data)

  1.185 ms (15637 allocations: 683.77 KiB)


600-element Array{Normal{Float64},1}:
 Normal{Float64}(μ=10.31065874694313, σ=14.140721622265264)
 Normal{Float64}(μ=10.157996227510043, σ=9.999500037496876)
 Normal{Float64}(μ=7.808732845833988, σ=8.164693657357805)
 Normal{Float64}(μ=7.120005927907771, σ=7.070891041799029)
 Normal{Float64}(μ=6.265283887660218, σ=6.3244288330249585)
 Normal{Float64}(μ=7.664453113948766, σ=5.773406469256952)
 Normal{Float64}(μ=8.508833760619225, σ=5.34514847952991)
 Normal{Float64}(μ=10.887970393349898, σ=4.999937501171851)
 Normal{Float64}(μ=11.048981040347211, σ=4.713992830503184)
 Normal{Float64}(μ=14.868724807949096, σ=4.472091234310838)
 Normal{Float64}(μ=14.939385349128038, σ=4.263975563874187)
 Normal{Float64}(μ=17.654292472050876, σ=4.082448884373011)
 Normal{Float64}(μ=17.7324041587127, σ=3.922292531398713)
 ⋮
 Normal{Float64}(μ=589.6967170713511, σ=0.5827164478148905)
 Normal{Float64}(μ=590.6392762518345, σ=0.5822224110578259)
 Normal{Float64}(μ=591.656685816653, σ=0.5817296287317493)
 Normal

In [10]:
@time kalman(data)

  0.001932 seconds (15.64 k allocations: 683.781 KiB)


600-element Array{Normal{Float64},1}:
 Normal{Float64}(μ=10.31065874694313, σ=14.140721622265264)
 Normal{Float64}(μ=10.157996227510043, σ=9.999500037496876)
 Normal{Float64}(μ=7.808732845833988, σ=8.164693657357805)
 Normal{Float64}(μ=7.120005927907771, σ=7.070891041799029)
 Normal{Float64}(μ=6.265283887660218, σ=6.3244288330249585)
 Normal{Float64}(μ=7.664453113948766, σ=5.773406469256952)
 Normal{Float64}(μ=8.508833760619225, σ=5.34514847952991)
 Normal{Float64}(μ=10.887970393349898, σ=4.999937501171851)
 Normal{Float64}(μ=11.048981040347211, σ=4.713992830503184)
 Normal{Float64}(μ=14.868724807949096, σ=4.472091234310838)
 Normal{Float64}(μ=14.939385349128038, σ=4.263975563874187)
 Normal{Float64}(μ=17.654292472050876, σ=4.082448884373011)
 Normal{Float64}(μ=17.7324041587127, σ=3.922292531398713)
 ⋮
 Normal{Float64}(μ=589.6967170713511, σ=0.5827164478148905)
 Normal{Float64}(μ=590.6392762518345, σ=0.5822224110578259)
 Normal{Float64}(μ=591.656685816653, σ=0.5817296287317493)
 Normal

In [9]:
@time kalman(data)

  2.937399 seconds (7.95 M allocations: 420.222 MiB, 2.73% gc time)


600-element Array{Normal{Float64},1}:
 Normal{Float64}(μ=10.31065874694313, σ=14.140721622265264)
 Normal{Float64}(μ=10.157996227510043, σ=9.999500037496876)
 Normal{Float64}(μ=7.808732845833988, σ=8.164693657357805)
 Normal{Float64}(μ=7.120005927907771, σ=7.070891041799029)
 Normal{Float64}(μ=6.265283887660218, σ=6.3244288330249585)
 Normal{Float64}(μ=7.664453113948766, σ=5.773406469256952)
 Normal{Float64}(μ=8.508833760619225, σ=5.34514847952991)
 Normal{Float64}(μ=10.887970393349898, σ=4.999937501171851)
 Normal{Float64}(μ=11.048981040347211, σ=4.713992830503184)
 Normal{Float64}(μ=14.868724807949096, σ=4.472091234310838)
 Normal{Float64}(μ=14.939385349128038, σ=4.263975563874187)
 Normal{Float64}(μ=17.654292472050876, σ=4.082448884373011)
 Normal{Float64}(μ=17.7324041587127, σ=3.922292531398713)
 ⋮
 Normal{Float64}(μ=589.6967170713511, σ=0.5827164478148905)
 Normal{Float64}(μ=590.6392762518345, σ=0.5822224110578259)
 Normal{Float64}(μ=591.656685816653, σ=0.5817296287317493)
 Normal