In [1]:
using Distributions
using ReactiveMP
using Rocket
using BenchmarkTools

import Base: show

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1273
│ - If you have ReactiveMP checked out for development and have
│   added Rocket as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with ReactiveMP


In [8]:
mutable struct InferenceActor <: Actor{AbstractMessage}
    index       :: Int
    size        :: Int
    data        :: Vector{Float64}
    messages    :: Vector{AbstractMessage}
    communicate :: Channel{Tuple{Float64, Float64}}
    
    y          :: ObservedVariable
    e_mean     :: PriorVariable
    e_variance :: PriorVariable
    
    InferenceActor(data::Vector{Float64}, y::ObservedVariable, e_mean::PriorVariable, e_variance::PriorVariable) = begin
        size  = length(data)
        messages = Vector{AbstractMessage}(undef, size)
        
        actor = new(1, size, data, messages, Channel{Tuple{Float64, Float64}}(Inf), y, e_mean, e_variance)
        
        task = @async begin
            while true
                u = take!(actor.communicate)
                update!(actor, u[1], u[2])
            end
        end
        
        bind(actor.communicate, task)
        
        return actor
    end
end

function update!(actor::InferenceActor, mean::Float64, variance::Float64)
    next!(actor.y.values, actor.data[actor.index])
    next!(actor.e_mean.values, mean)
    next!(actor.e_variance.values, variance)
end

function stop!(actor::InferenceActor)
    complete!(actor.y.values)
    complete!(actor.e_mean.values)
    complete!(actor.e_variance.values)
end

function Rocket.on_next!(actor::InferenceActor, data::AbstractMessage)
    m = mean(data.distribution)
    v = var(data.distribution)
    
    actor.messages[actor.index] = data
    
    actor.index += 1
    
    if actor.index < actor.size + 1
        put!(actor.communicate, (m, v))
    else
        stop!(actor)
    end
end

Rocket.on_error!(actor::InferenceActor, err) = error(err)
Rocket.on_complete!(actor::InferenceActor)   = close(actor.communicate)

Base.show(io::IO, actor::InferenceActor) = print(io, "InferenceActor")

In [9]:
function kalman()
    N = 5000
    data = collect(1:N) + sqrt(200.0) * randn(N);
    
    x_prev_add   = AdditionNode("x_prev_add");
    add_1        = ConstantVariable("add_1", 1.0, x_prev_add.in2);

    x_prev_prior = GaussianMeanVarianceNode("x_prev_prior");
    x_prev_m     = PriorVariable("x_prev_m", x_prev_prior.mean);
    x_prev_v     = PriorVariable("x_prev_v", x_prev_prior.variance);

    x_prev = RandomVariable("x_prev", x_prev_prior.value, x_prev_add.in1);

    noise_node     = GaussianMeanVarianceNode("noise_node");
    noise_mean     = ConstantVariable("noise_mean", 0.0, noise_node.mean);
    noise_variance = ConstantVariable("noise_variance", 200.0, noise_node.variance);

    add_x_and_noise = AdditionNode("add_x_and_noise");

    x = RandomVariable("x", x_prev_add.out, add_x_and_noise.in1);
    n = RandomVariable("n", noise_node.value, add_x_and_noise.in2);
    y = ObservedVariable("y", add_x_and_noise.out);
    
    actor  = InferenceActor(data, y, x_prev_m, x_prev_v);
    synced = sync(actor)
    
    @async begin
        try
            subscribe!(inference(x), synced)   
            update!(actor, 0.0, 1000.0) 
        catch e 
            println(e)
        end
    end
    
    wait(synced)
    
    return actor.messages
end

kalman (generic function with 1 method)

In [12]:
@btime kalman();

  9.245 ms (175429 allocations: 7.08 MiB)


In [11]:
@time kalman()

  0.013627 seconds (175.43 k allocations: 7.076 MiB)


5000-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=1.0653003241119448, σ=12.909944487358057)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=3.2051518819658047, σ=9.534625892455924))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=9.852077909000885, σ=7.905694150420949))   
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=9.324988777146203, σ=6.900655593423543))   
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=5.809616274608046, σ=6.201736729460423))   
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=8.123870025194876, σ=5.679618342470648))   
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=4.126954209155379, σ=5.2704627669473))     
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=6.772969665860738, σ=4.938647983247948))   
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.700508168166821, σ=4.66252404120157))    
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.341625422505872, σ=4.42