In [7]:
using Distributions
using ReactiveMP
using Rx
using BenchmarkTools

import Base: show

In [8]:
mutable struct InferenceActor <: Actor{AbstractMessage}
    index       :: Int
    size        :: Int
    data        :: Vector{Float64}
    messages    :: Vector{AbstractMessage}
    communicate :: Channel{Tuple{Float64, Float64}}
    
    y          :: ObservedVariable
    e_mean     :: EstimatedVariable
    e_variance :: EstimatedVariable
    
    InferenceActor(data::Vector{Float64}, y::ObservedVariable, e_mean::EstimatedVariable, e_variance::EstimatedVariable) = begin
        size  = length(data)
        messages = Vector{AbstractMessage}(undef, size)
        
        actor = new(1, size, data, messages, Channel{Tuple{Float64, Float64}}(Inf), y, e_mean, e_variance)
        
        task = @async begin
            while true
                u = take!(actor.communicate)
                update!(actor, u[1], u[2])
            end
        end
        
        bind(actor.communicate, task)
        
        return actor
    end
end

function update!(actor::InferenceActor, mean::Float64, variance::Float64)
    next!(actor.y.values, actor.data[actor.index])
    next!(actor.e_mean.values, mean)
    next!(actor.e_variance.values, variance)
end

function stop!(actor::InferenceActor)
    complete!(actor.y.values)
    complete!(actor.e_mean.values)
    complete!(actor.e_variance.values)
end

function Rx.on_next!(actor::InferenceActor, data::AbstractMessage)
    m = mean(data.distribution)
    v = var(data.distribution)
    
    actor.messages[actor.index] = data
    
    actor.index += 1
    
    if actor.index < actor.size + 1
        put!(actor.communicate, (m, v))
    else
        stop!(actor)
    end
end

Rx.on_error!(actor::InferenceActor, err) = error(err)

Rx.on_complete!(actor::InferenceActor)   = begin 
    close(actor.communicate)
end

Base.show(io::IO, actor::InferenceActor) = print(io, "InferenceActor")

In [9]:
function kalman(N, data)
    x_prev_add   = addition_node("x_prev_add", StochasticMessage{Normal{Float64}}, DeterministicMessage, StochasticMessage{Normal{Float64}});
    add_1        = constant_variable("add_1", 1.0, x_prev_add.in2);

    x_prev_prior = gaussian_mean_variance("x_prev_prior");
    x_prev_m     = estimated_variable("x_prev_m", x_prev_prior.mean);
    x_prev_v     = estimated_variable("x_prev_v", x_prev_prior.variance);

    x_prev = random_variable("x_prev", x_prev_prior.value, x_prev_add.in1);

    noise_node     = gaussian_mean_variance("noise_node");
    noise_mean     = constant_variable("noise_mean", 0.0, noise_node.mean);
    noise_variance = constant_variable("noise_variance", 200.0, noise_node.variance);

    add_x_and_noise = addition_node("add_x_and_noise", StochasticMessage{Normal{Float64}}, StochasticMessage{Normal{Float64}}, DeterministicMessage);

    x = random_variable("x", x_prev_add.out, add_x_and_noise.in1);
    n = random_variable("n", noise_node.value, add_x_and_noise.in2);
    y = observed_variable("y", add_x_and_noise.out);
    
    actor  = InferenceActor(data, y, x_prev_m, x_prev_v);
    synced = sync(actor)
    
    @async begin
        subscribe!(inference(x), synced)   
        update!(actor, 0.0, 1000.0) 
    end
    
    wait(synced)
    
    return actor.messages
end

kalman (generic function with 1 method)

In [10]:
N = 5000
data = collect(1:N) + sqrt(200.0) * randn(N);

In [11]:
@btime kalman($N, $data);

  11.438 ms (250348 allocations: 8.83 MiB)


In [12]:
kalman(N, data)[1:100]

100-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-14.54379923153286, σ=12.909944487358057))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-9.29585817645424, σ=9.534625892455924))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-4.385580231916261, σ=7.905694150420949)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=0.5780523266070468, σ=6.900655593423543)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=5.197944568681382, σ=6.201736729460423))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=3.0246440368555483, σ=5.679618342470648)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=10.35345256173157, σ=5.2704627669473))    
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=12.715184274941802, σ=4.938647983247948)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=11.97201296142503, σ=4.66252404120157))   
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=11.99376437255369, σ=4.428074427700