In [1]:
using Rocket
using Distributions
using ReactiveMP
using BenchmarkTools

import Base: show

┌ Info: Precompiling Rocket [df971d30-c9d6-4b37-b8ff-e965b2cb3a40]
└ @ Base loading.jl:1273
┌ Info: Precompiling Distributions [31c24e10-a181-5473-b8eb-7969acd0382f]
└ @ Base loading.jl:1273
┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1273
┌ Info: Precompiling BenchmarkTools [6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf]
└ @ Base loading.jl:1273


In [2]:
function createSubgraph(index::Int)
    noise_add      = AdditionNode("[$index] noise_add");
    noise_node     = GaussianMeanVarianceNode("[$index] noise_node");
    noise_mean     = ConstantVariable("[$index] noise_mean", 0.0, noise_node.mean);
    noise_variance = ConstantVariable("[$index] noise_variance", 1.0, noise_node.variance);
    
    tmp = RandomVariable("[$index] tmp", noise_node.value, noise_add.in2)
    yn  = ObservedVariable("[$index] yn", noise_add.out)
    
    return (yn, noise_add)
end

function createGraph(size::Int)
    gmv_0   = GaussianMeanVarianceNode("[0] gmv");
    gmv_0_m = PriorVariable("[0] gmv_m", gmv_0.mean);
    gmv_0_v = PriorVariable("[0] gmv_v", gmv_0.variance);
    
    add_c_0 = AdditionNode("[0] add_c");
    x0      = RandomVariable("[0] x", gmv_0.value, add_c_0.in1)
    c0      = ConstantVariable("[0] c", 1.0, add_c_0.in2)
    
    index = 1
    
    prev_add_c = add_c_0
    
    xs = Vector{RandomVariable}(undef, size)
    ys = Vector{ObservedVariable}(undef, size)
    
    while index < size
        equality_n    = EqualityIOONode("[$index] equality")
        yn, noise_add = createSubgraph(index)
        
        xn  = RandomVariable("[$index] xn", prev_add_c.out, equality_n.in1)
        xn_ = RandomVariable("[$index] xn_", equality_n.out1, noise_add.in1)
        
        add_c = AdditionNode("[$index] add_c");
        
        xn__ = RandomVariable("[$index] xn__", equality_n.out2, add_c.in1)
        cn   = ConstantVariable("[$index] cn", 1.0, add_c.in2)
        
        xs[index] = xn
        ys[index] = yn
        
        prev_add_c = add_c
        
        index += 1
    end
    
    last_noise_add = AdditionNode("[last] noise_add");
    x_last         = RandomVariable("[last] x", prev_add_c.out, last_noise_add.in1)
    
    last_noise_node     = GaussianMeanVarianceNode("[last] noise_node");
    last_noise_mean     = ConstantVariable("[last] noise_mean", 0.0, last_noise_node.mean);
    last_noise_variance = ConstantVariable("[last] noise_variance", 200.0, last_noise_node.variance);
    
    z      = RandomVariable("[last] z", last_noise_node.value, last_noise_add.in2)
    y_last = ObservedVariable("[last] y", last_noise_add.out)
    
    xs[size] = x_last
    ys[size] = y_last
    
    return (xs, ys, gmv_0_m, gmv_0_v)
end

createGraph (generic function with 1 method)

In [11]:
function smoothing_replay(data)
    N = length(data)
    
    xs, ys, prior_m, prior_v = createGraph(N);
    
    messages = Vector{AbstractMessage}(undef, N)
    
    subscribe!(inference(xs[1]) |> take(1), lambda(on_next = (d) -> messages[1] = d))
    subscribe!(inference(xs[N]) |> take(1), lambda(on_next = (d) -> messages[N] = d))
    
    update!(prior_m, 0.0)
    update!(prior_v, 1000.0)
    for i in 1:N
       update!(ys[i], data[i])
    end
    
    for i in 2:N-1
        subscribe!(inference(xs[i]) |> take(1), lambda(on_next = (d) -> messages[i] = d))
    end
    
    return messages
end

smoothing_replay (generic function with 2 methods)

In [12]:
N = 256
data = collect(1:N) + sqrt(200.0) * randn(N);

In [13]:
@btime smoothing_replay($data);

  575.817 ms (4329810 allocations: 135.82 MiB)


In [18]:
@btime smoothing_replay(collect(1:64) + sqrt(200.0) * randn(64))

  37.529 ms (301942 allocations: 9.67 MiB)


64-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=-0.6343240398921595, σ=0.1259821586621601))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=0.36567596010784054, σ=0.1259821586621601))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=1.3656759601078408, σ=0.12598215866216006))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=2.3656759601078403, σ=0.1259821586621601)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=3.3656759601078403, σ=0.12598215866216006))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=4.365675960107839, σ=0.1259821586621601))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=5.36567596010784, σ=0.1259821586621601))   
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=6.365675960107838, σ=0.12598215866216006)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.365675960107838, σ=0.12598215866216006)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=8.36567596010784, σ=0.12598

In [19]:
smoothing_replay(data)

256-element Array{AbstractMessage,1}:
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=0.17790546543467523, σ=0.06262169238705584))
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=1.177905465434675, σ=0.06262169238705584))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=2.1779054654346752, σ=0.06262169238705585)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=3.1779054654346757, σ=0.06262169238705584)) 
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=4.177905465434676, σ=0.06262169238705584))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=5.177905465434676, σ=0.06262169238705582))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=6.177905465434675, σ=0.06262169238705582))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=7.177905465434675, σ=0.06262169238705582))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=8.177905465434677, σ=0.06262169238705584))  
 StochasticMessage{Normal{Float64}}(Normal{Float64}(μ=9.177905465434677