In [4]:
using ReactiveMP, Rocket, GraphPPL, Random, LinearAlgebra, Plots, Flux, Zygote, ForwardDiff, DataFrames, DelimitedFiles

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1423
[33m[1m│ [22m[39m- If you have ReactiveMP checked out for development and have
[33m[1m│ [22m[39m  added Flux as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with ReactiveMP


In [5]:
sensors = readdlm("sensors.txt")
sensor1, sensor2, sensor3 = sensors[1,:], sensors[2,:], sensors[3,:]
observation = readdlm("observation.txt")
position = readdlm("position.txt")
# T = size(observation)[1]
T = 15
observation_list = [observation[t,:] for t=1:T];

In [6]:
A = [1.0 0.0 1.0 0.0; 0.0 1.0 0.0 1.0; 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0]
B = [1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0]
# Nonlinear function that maps states to observations
function f(z)
    pos = B*z
    o1 = sqrt(sum((pos-sensor1).^2))
    o2 = sqrt(sum((pos-sensor2).^2))
    o3 = sqrt(sum((pos-sensor3).^2))
    o = [o1,o2,o3]
end

@model function sensor_fusion(T)
    W ~ Wishart(4, diagm(0=>ones(4)))
    R ~ Wishart(3, diagm(0=>ones(3)))

    z = randomvar(T)
    x = randomvar(T)
    y = datavar(Vector{Float64}, T)

    z[1] ~ MvNormalMeanPrecision(zeros(4), W)
    x[1] ~ f(z[1]) where {meta = CVIApproximation(100, 100, ADAM(), 100, 20)}
    y[1] ~ MvNormalMeanPrecision(x[1], R)

    for t in 2:T
        z[t] ~ MvNormalMeanPrecision(A * z[t-1], W)
        x[t] ~ f(z[t]) where {meta = CVIApproximation(100, 100, ADAM(), 100, 20)}
        y[t] ~ MvNormalMeanPrecision(x[t], R)
    end

    return z, x, y
end

constraints = @constraints begin
    q(z, W) = q(z)q(W)
    q(x, R) = q(x)q(R)
    # q(z) = q(z[begin])..q(z[end]) 
end;

In [10]:
res = inference(
    model = Model(sensor_fusion, T),
    data = (y = observation_list,),
    iterations = 2,
    free_energy = false,
    returnvars = (z = KeepEach(),),
    constraints = constraints,
    initmessages = (z = MvNormalMeanPrecision(zeros(4), 0.01*diagm(0=>ones(4))),),
    initmarginals = (R = Wishart(3, diagm(0=>ones(3))), W = Wishart(4, diagm(0=>ones(4))))
)

Inference results:
-----------------------------------------
z = Vector{MvNormalWeightedMeanPrecision{Float64, Vector{Float64}, Matrix{Float64}}}...
