In [168]:
using ReactiveMP, Rocket, GraphPPL, Random, LinearAlgebra, Plots, Flux, Zygote, ForwardDiff, DataFrames, DelimitedFiles

In [169]:
sensors = readdlm("sensors.txt")
sensor1, sensor2, sensor3 = sensors[1,:], sensors[2,:], sensors[3,:]
observation = readdlm("observation.txt")
position = readdlm("position.txt")
# T = size(observation)[1]
T = 15
observation_list = [observation[t,:] for t=1:T];

In [170]:
A = [1.0 0.0 1.0 0.0; 0.0 1.0 0.0 1.0; 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0]
B = [1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0]
# Nonlinear function that maps states to observations
function f(z)
    pos = B*z
    o1 = sqrt(sum((pos-sensor1).^2))
    o2 = sqrt(sum((pos-sensor2).^2))
    o3 = sqrt(sum((pos-sensor3).^2))
    o = [o1,o2,o3]
end

@model function sensor_fusion(T)
    W ~ Wishart(4, diagm(0=>ones(4)))
    R ~ Wishart(3, diagm(0=>ones(3)))

    z = randomvar(T)
    x = randomvar(T)
    y = datavar(Vector{Float64}, T)

    z[1] ~ MvNormalMeanPrecision(zeros(4), W)
    x[1] ~ f(z[1]) where {meta = CVIApproximation(100, 100, ADAM(), 100, 20)}
    y[1] ~ MvNormalMeanPrecision(x[1], R)

    for t in 2:T
        z[t] ~ MvNormalMeanPrecision(A * z[t-1], W)
        x[t] ~ f(z[t]) where {meta = CVIApproximation(100, 100, ADAM(), 100, 20)}
        y[t] ~ MvNormalMeanPrecision(x[t], R)
    end

    return z, x, y
end

constraints = @constraints begin
    q(z, x) = q(z)q(x)
end;

In [171]:
@rule MvNormalMeanPrecision(:Λ, Marginalisation) (m_out::MultivariateGaussianDistributionsFamily, q_μ::PointMass, ) = begin 
    μ_mean = mean(q_μ)
    m_out, v_out = mean_cov(m_out)
    return Wishart(length(m_out) + 1, cholinv(v_out + (μ_mean - m_out)*(μ_mean - m_out)'))
end

@rule MvNormalMeanPrecision(:Λ, Marginalisation) (m_out::MultivariateGaussianDistributionsFamily, m_μ::MvNormalMeanCovariance, ) = begin 
    (m_mean, v_mean) = mean_cov(m_μ)
    (m_out, v_out) = mean_cov(m_out)
    return Wishart(length(m_out) + 1, cholinv(v_mean + v_out + (m_mean - m_out)*(m_mean - m_out)'))
end

@rule MvNormalMeanPrecision(:out, Marginalisation) (m_μ::Any, m_Λ::Wishart, ) = begin 
    return MvNormalMeanPrecision(mean(m_μ), mean(m_Λ))
end

@rule MvNormalMeanPrecision(:out, Marginalisation) (q_μ::PointMass, m_Λ::Wishart, ) = begin 
    return MvNormalMeanPrecision(mean(q_μ), mean(m_Λ))
end

@rule MvNormalMeanPrecision(:μ, Marginalisation) (m_out::MultivariateGaussianDistributionsFamily, m_Λ::Wishart, ) = begin 
    m_out, v_out = mean_cov(m_out)
    m_μ, v_μ = mean_precision(MvNormalMeanCovariance(m_out, v_out + cholinv(mean(m_Λ))))
    return MvNormalMeanPrecision(m_μ, v_μ)
end

@rule MvNormalMeanPrecision(:μ, Marginalisation) (q_out::PointMass, m_Λ::Wishart, ) = begin 
    return MvNormalMeanPrecision(mean(q_out), mean(m_Λ))
end

In [172]:
res = inference(
    model = Model(sensor_fusion, T),
    data = (y = observation_list,),
    iterations = 2,
    free_energy = false,
    returnvars = (z = KeepEach(),),
    constraints = constraints,
    initmessages = (
        z = MvNormalMeanPrecision(zeros(4), 0.01*diagm(0=>ones(4))), 
        W = Wishart(4, diagm(0=>ones(4))), 
        R= Wishart(3, diagm(0=>ones(3)))),
    initmarginals = (R = Wishart(3, diagm(0=>ones(3))), W = Wishart(4, diagm(0=>ones(4))))
)

LoadError: RuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@rule MvNormalMeanPrecision(:Λ, Marginalisation) (q_out::PointMass, m_μ::ProdFinal, ) = begin 
    return ...
end



In [42]:
length(mean_precision(MvNormalMeanPrecision(zeros(4), diagm(0=>ones(4)))))

2

In [20]:
p

4×4 Matrix{Float64}:
  1.0   0.0   0.0  -0.0
  0.0   1.0   0.0  -0.0
  0.0   0.0   1.0  -0.0
 -0.0  -0.0  -0.0   1.0

In [87]:
mean(Wishart(1, diagm(0=>ones(4))))

4×4 Matrix{Float64}:
 1.0  0.0  0.0  0.0
 0.0  1.0  0.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  0.0  1.0

In [132]:
rng = something(nothing, Random.GLOBAL_RNG)
samples = rand(rng, MvNormalMeanPrecision(zeros(4), diagm(0=>ones(4))), 10)
samples

4×10 Matrix{Float64}:
  0.825888   -0.948144  1.31037   …  -0.988617    0.105391  -1.28206
 -1.24359    -0.5526    0.371746      0.416852   -1.15302    0.549141
  1.35556    -1.81755   1.05541       1.29406     0.1835    -0.0721076
  0.0345196  -0.293503  1.69378      -0.0409659   0.573095   0.357154

In [150]:
for col in eachcol(samples)
    println(typeof(convert(Array, col)))
end

Vector{Float64}
Vector{Float64}
Vector{Float64}
Vector{Float64}
Vector{Float64}
Vector{Float64}
Vector{Float64}
Vector{Float64}
Vector{Float64}
Vector{Float64}


In [148]:
Array

Array

In [158]:
[rand(rng, MvNormalMeanPrecision(zeros(4), diagm(0=>ones(4)))) for i=1:10]

10-element Vector{Vector{Float64}}:
 [1.7302414687022891, -0.9453722711387934, -1.2406293563393882, -0.98154225239336]
 [-0.3434040949807968, -0.21939477559552267, -2.267802617045663, 0.5979133468261842]
 [-0.07944486178621736, 0.19849749458741203, -0.0601064856398875, 1.2132532138542484]
 [-0.7034826533529761, 0.8729854781401021, 0.004504004445781497, 3.202064953219846]
 [0.899147655354539, 0.31012388570745714, -0.7340087882260947, -0.8893638349472888]
 [-1.6000919529227657, 0.2510578654586799, -0.2585899202643073, -1.9386142674160158]
 [0.8830293610357492, -0.5882730188414279, 1.4431209265546403, -0.5224377999112231]
 [-0.8838580074059497, 0.10038503924771103, -0.3708521310083794, -0.2350022584903327]
 [-1.1852581987270783, 0.4255864948201069, -0.17182254035193475, -1.033856942713576]
 [0.4996689457804617, 1.1746641015954244, 1.1477327398908073, 1.1557731759117749]