In [1]:
using ReactiveMP, Rocket, GraphPPL, Random, LinearAlgebra, Plots, Flux, Zygote, ForwardDiff, DataFrames, DelimitedFiles

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1423
[33m[1m│ [22m[39m- If you have ReactiveMP checked out for development and have
[33m[1m│ [22m[39m  added Flux as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with ReactiveMP


In [2]:
sensors = readdlm("sensors.txt")
sensor1, sensor2, sensor3 = sensors[1,:], sensors[2,:], sensors[3,:]
observation = readdlm("observation.txt")
position = readdlm("position.txt")
# T = size(observation)[1]
T = 15
observation_list = [observation[t,:] for t=1:T];

In [12]:
A = [1.0 0.0 1.0 0.0; 0.0 1.0 0.0 1.0; 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0]
B = [1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0]
# Nonlinear function that maps states to observations
function f(z)       
    pos = B*z
    o1 = sqrt(sum((pos-sensor1).^2))
    o2 = sqrt(sum((pos-sensor2).^2))
    o3 = sqrt(sum((pos-sensor3).^2))
    o = [o1,o2,o3]
end

@model function sensor_fusion(T)
    W ~ Wishart(4, diagm(0=>ones(4)))
    R ~ Wishart(3, diagm(0=>ones(3)))

    z = randomvar(T)
    x = randomvar(T)
    y = datavar(Vector{Float64}, T)

    z[1] ~ MvNormalMeanPrecision(zeros(4), W)
    x[1] ~ f(z[1]) where {meta = CVIApproximation(100, 100, ADAM(), 100, 20)}
    y[1] ~ MvNormalMeanPrecision(x[1], R)

    for t in 2:T
        z[t] ~ MvNormalMeanPrecision(A * z[t-1], W)
        x[t] ~ f(z[t]) where {meta = CVIApproximation(100, 100, ADAM(), 100, 20)}
        y[t] ~ MvNormalMeanPrecision(x[t], R)
    end

    return z, x, y
end

constraints = @constraints begin
    q(z, x) = q(z)q(x)
end;

In [54]:

# function ruleVBGaussianMeanPrecisionW(  dist_out::ProbabilityDistribution{Multivariate},
#                                         dist_mean::ProbabilityDistribution{Multivariate},
#                                         dist_prec::Any)

#     (m_mean, v_mean) = unsafeMeanCov(dist_mean)
#     (m_out, v_out) = unsafeMeanCov(dist_out)

#     Message(MatrixVariate, Wishart, v=cholinv( v_mean + v_out + (m_mean - m_out)*(m_mean - m_out)' ), nu=dims(dist_out) + 2.0)
# end

@rule MvNormalMeanPrecision(:Λ, Marginalisation) (m_out::MvNormalMeanPrecision, q_μ::PointMass, ) = begin 
    μ_mean = mean(q_μ)
    m_out, v_out = mean_cov(m_out)
    return Wishart(length(μ_mean) + 1, cholinv(v_out + (μ_mean - m_out)*(μ_mean - m_out)'))
end

@rule MvNormalMeanPrecision(:Λ, Marginalisation) (m_out::MvNormalMeanPrecision, m_μ::MvNormalMeanCovariance, ) = begin 
    (m_mean, v_mean) = mean_cov(m_μ)
    (m_out, v_out) = mean_cov(m_out)
    return Wishart(length(m_mean) + 2, cholinv(v_mean + v_out + (m_mean - m_out)*(m_mean - m_out)'))
end

@rule MvNormalMeanPrecision(:out, Marginalisation) (m_μ::Any, m_Λ::Wishart, ) = begin 
    MvNormalMeanPrecision(mean(m_μ), mean(m_Λ))
end

@rule MvNormalMeanPrecision(:out, Marginalisation) (q_μ::PointMass, m_Λ::Wishart, ) = begin 
    return MvNormalMeanPrecision(mean(q_μ), mean(m_Λ))
end

@rule MvNormalMeanPrecision(:μ, Marginalisation) (m_out::MvNormalMeanPrecision, m_Λ::Wishart, ) = begin 
    m_out, v_out = mean_cov(m_out)
    m_μ, v_μ = mean_precision(MvNormalMeanCovariance(m_out, v_out + cholinv(mean(m_Λ))))
    return MvNormalMeanPrecision(m_μ, v_μ)
end

@rule MvNormalMeanPrecision(:μ, Marginalisation) (q_out::PointMass, m_Λ::Wishart, ) = begin 
    return MvNormalMeanPrecision(mean(q_out), mean(m_Λ))
end

In [55]:
res = inference(
    model = Model(sensor_fusion, T),
    data = (y = observation_list,),
    iterations = 1000,
    free_energy = false,
    returnvars = (z = KeepEach(),),
    constraints = constraints,
    initmessages = (
        z = MvNormalMeanPrecision(zeros(4), 0.01*diagm(0=>ones(4))), 
        W = Wishart(4, diagm(0=>ones(4))), 
        R= Wishart(3, diagm(0=>ones(3)))),
    initmarginals = (R = Wishart(3, diagm(0=>ones(3))), W = Wishart(4, diagm(0=>ones(4))))
)

LoadError: DimensionMismatch("dimensions must match: a has dims (Base.OneTo(2), Base.OneTo(4)), must have singleton at dim 2")

In [42]:
length(mean_precision(MvNormalMeanPrecision(zeros(4), diagm(0=>ones(4)))))

2

In [20]:
p

4×4 Matrix{Float64}:
  1.0   0.0   0.0  -0.0
  0.0   1.0   0.0  -0.0
  0.0   0.0   1.0  -0.0
 -0.0  -0.0  -0.0   1.0