In [2]:
using ReactiveMP, Rocket, GraphPPL, Random, LinearAlgebra, Plots, Flux, Zygote, ForwardDiff, DataFrames, DelimitedFiles

[33m[1m│ [22m[39m- If you have ReactiveMP checked out for development and have
[33m[1m│ [22m[39m  added Flux as a dependency but haven't updated your primary
[33m[1m│ [22m[39m  environment's manifest file, try `Pkg.resolve()`.
[33m[1m│ [22m[39m- Otherwise you may need to report an issue with ReactiveMP


In [3]:
function f(x, y, z)
    return x + 2 * y + z
end

@model function multiple_interfaces(N)
    
    μ ~ NormalMeanVariance(0, 1)
    z = randomvar(N)
    x = randomvar(N)
    y = randomvar(N)
    t = randomvar(N)
    out = datavar(Float64, N)

    for i in 1:N
        z[i] ~ NormalMeanVariance(μ, 1)
        x[i] ~ NormalMeanVariance(μ, 1)
        y[i] ~ NormalMeanVariance(μ, 1)
        t[i] ~ f(x[i], y[i], z[i])
        # t[i] ~ ((x[i] + (2 * y[i])) + z[i])
        out[i] ~ NormalMeanVariance(t[i], 1)
    end

    return μ, z, x, y, out
end
constraints = @constraints begin
    q(z, x, y, μ, t) = q(μ)q(z)q(x)q(y)q(t)
end;

meta = @meta begin
   f(x, y, z) -> CVIApproximation(100, 100, ADAM(), 100, 20)
end

Meta specification:
  f(x, y, z) -> CVIApproximation(100, 100, nothing, Flux.Optimise.Adam(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}()), 100, 20)
Options:
  warn = true

In [4]:
N = 1000
data = rand(NormalMeanVariance(), N);

res = inference(
    model = Model(multiple_interfaces, N),
    data = (out = data,),
    iterations = 2,
    free_energy = false,
    returnvars = (μ = KeepEach(),),
    constraints = constraints,
    meta = meta,
    options = (limit_stack_depth = 1000,),
    #initmessages = (z=NormalMeanVariance(), x=NormalMeanVariance(), y=NormalMeanVariance(),),
    initmarginals = (z=NormalMeanVariance(), x=NormalMeanVariance(), y=NormalMeanVariance(), μ=NormalMeanVariance()),
)

└ @ ReactiveMP /home/mykola/repos/ReactiveMP.jl/src/constraints/meta/meta.jl:102


Inference results:
-----------------------------------------
μ = NormalWeightedMeanPrecision{Float64}[NormalWeightedMeanPrecision{Float64}(xi=0.5...
