In [1]:
using GraphPPL, ReactiveMP, Distributions
# using Plots

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1662


In [85]:
# Generate toy dataset
T = 15
x_0_data = 0.6
a_data = 1.2

x_data = Vector{Float64}(undef, T)
y_data = Vector{Float64}(undef, T)

x_t_min_data = x_0_data
for t=1:T
    x_data[t] = a_data*x_t_min_data
    y_data[t] = ceil(x_data[t])
    
    x_t_min_data = x_data[t]
end    
;

In [95]:
# Specify forward transition
f(x_t_min::Float64, a::Float64) = a*x_t_min


@model function NKS(n)
    o = datavar(Float64, n)
    z = randomvar(n)
    a = randomvar(n)

    z_0 ~ Normal(μ=0.0, γ=1.0)
    z_prev = z_0
    for i in 1:n
        a[i] ~ Normal(μ=1.0, γ=2.0)
        z[i] ~ f(z_prev, a[i]) where {meta=UT()}
        o[i] ~ Normal(μ=z[i], σ²=0.2)

        z_prev = z[i]
    end
end

In [98]:
imarginals = (z_0 = MvNormal(zeros(2), diageye(2)), )
n = length(y_data)
result = inference(model = Model(NKS, n), data=(o=y_data,), free_energy=false);

"called marginal" = "called marginal"
(μ_in, Σ_in) = mean_cov(q_ins) = ([8.004174031576499, 1.2481783026629607], [10.0406171356769 39.959262791877606; 39.959262791877606 321.34853658074485])


LoadError: MethodError: no method matching -(::Vector{Float64}, ::Float64)
For element-wise subtraction, use broadcasting with dot syntax: array .- scalar
[0mClosest candidates are:
[0m  -([91m::T[39m, ::T) where T<:Union{Float16, Float32, Float64} at float.jl:384
[0m  -([91m::LinearAlgebra.UniformScaling[39m, ::Number) at ~/.julia/juliaup/julia-1.8.0+0.x64/share/julia/stdlib/v1.8/LinearAlgebra/src/uniformscaling.jl:146
[0m  -([91m::ReactiveMP.InfCountingReal[39m, ::Real) at ~/.julia/dev/ReactiveMP/src/helpers.jl:148
[0m  ...

In [34]:
# using Plots
# plot(first.(mean.(result.posteriors[:z])), ribbon=sqrt.(first.(cov.(result.posteriors[:z]))))

### ET

In [99]:
function f(x)
    return x
end

function f_inv(x)
    return x
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x) where {meta=ET(inverse=f_inv)}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [100]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=false)

Inference results:
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.0, w=2.5)
x  = MvNormalWeightedMeanPrecision(
xi: [0.6666666666666665, 0.0]
Λ: [1.6666666666676...


In [101]:
function f(x)
    return x.^2
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x) where {meta=ET()}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [102]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=false)

m_ins = (MvNormalMeanPrecision(
μ: [0.0, 0.0]
Λ: [1.0 0.0; 0.0 1.0]
)
,)
(A, b) = ([0.0 0.0; 0.0 0.0], [0.0, 0.0])
μ_fw_out = [0.0, 0.0]


Inference results:
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.0, w=3.0)
x  = MvNormalWeightedMeanPrecision(
xi: [0.0, 0.0]
Λ: [1.0 0.0; 0.0 1.0]
)



In [114]:
function f(x, θ)
    return x .+ θ
end

function f_inv_x(θ, z)
    return randn(length(θ))
end

function f_inv_θ(x, z)
    return randn(length(x))
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ MvNormal(μ=ones(2), Λ=diageye(2))
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x, θ) where {meta=ET(inverse=(f_inv_x, f_inv_θ))}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [115]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=false)

LoadError: RuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@rule DeltaFn{f}(:(:in, k), Marginalisation) (m_out::MvNormalWeightedMeanPrecision, m_ins::Tuple, meta::DeltaExtended{Tuple{typeof(f_inv_x), typeof(f_inv_θ)}}) = begin 
    return ...
end



In [103]:
function f(x, θ, ζ)
    return x .+ θ .+ ζ
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ MvNormal(μ=ones(2), Λ=diageye(2))
    ζ ~ MvNormal(μ=0.5ones(2), Λ=diageye(2))
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x, θ, ζ) where {meta=ET()}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [104]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=false)

Inference results:
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.375, w=2.25)
ζ  = MvNormalWeightedMeanPrecision(
xi: [0.23333333333296002, 0.49999999999800004]
Λ:...
θ  = MvNormalWeightedMeanPrecision(
xi: [0.5333333333327801, 0.9999999999975]
Λ: [0.5...
x  = MvNormalWeightedMeanPrecision(
xi: [-0.06666666666685998, -1.4999113062670868e-1...


### UT

In [105]:
function f(x)
    return x
end

function f_inv(x)
    return x
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    # γ ~ Gamma(α=1.0, β=1.0)
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x) where {meta=UT(inverse=f_inv)}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [106]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=false)

(μ_tilde, Σ_tilde, _) = unscentedStatistics(μ_fw_in1, Σ_fw_in1, f; alpha = meta.alpha, beta = meta.beta, kappa = meta.kappa) = ([0.0, 0.0], [1.0 0.0; 0.0 1.0], [1.0 0.0; 0.0 1.0])


Inference results:
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.0, w=2.5)
x  = MvNormalWeightedMeanPrecision(
xi: [0.6666666666676693, 0.0]
Λ: [1.6666666666676...


In [107]:
function f(x)
    return x.^2
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x) where {meta=UT()}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [108]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=false)

(μ_tilde, Σ_tilde, _) = unscentedStatistics(μ_fw_in1, Σ_fw_in1, f; alpha = meta.alpha, beta = meta.beta, kappa = meta.kappa) = ([1.0, 1.0], [2.000001000065822 1.9999990000505932; 1.9999990000505932 2.000001000065822], [0.0 0.0; 0.0 0.0])
(μ_tilde, Σ_tilde, C_tilde) = ([1.0, 1.0], [2.000001000065822 1.9999990000505932; 1.9999990000505932 2.000001000065822], [0.0 0.0; 0.0 0.0])
(μ_in, Σ_in) = ([0.0, 0.0], [1.0 0.0; 0.0 1.0])
(μ_in, Σ_in) = mean_cov(q_ins) = ([0.0, 0.0], [1.0 -0.0; -0.0 1.0])


Inference results:
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.3333332222149457, w=2.3333332222149457...
x  = MvNormalWeightedMeanPrecision(
xi: [0.0, 0.0]
Λ: [1.0 0.0; 0.0 1.0]
)



In [109]:
function f(x, θ)
    return x .+ θ
end

function f_inv_x(x)
    return randn(length(θ))
end

function f_inv_θ(x)
    return randn(length(x))
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ MvNormal(μ=ones(2), Λ=diageye(2))
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x, θ) where {meta=UT(inverse=[f_inv_x, f_inv_θ])}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [110]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=false)

"why am i called" = "why am i called"
"called marginal" = "called marginal"
(meta.inverse[k])(mean(q_ins)) = [-1.7506976494524147, 0.14681859776414038, 0.20173249993936687, 0.6442088876701773]
(ms, Vs) = mean_cov(q_ins) = ([-1.705935392948323e-11, -9.999778782799056e-13, 0.9999999999829406, 0.999999999999], [1.6666666666676055 -1.2503668776828297e-32 0.6666666666675873 5.382188518894395e-22; -1.2503668776828297e-32 1.000000000001 -1.1494886002502958e-32 9.999778782818784e-13; 0.6666666666675873 -1.1494886002502958e-32 1.666666666667569 5.382188518904331e-22; 5.382188518894395e-22 9.999778782818784e-13 5.382188518904331e-22 1.000000000001])


LoadError: DimensionMismatch: dimensions must match: a has dims (Base.OneTo(2),), b has dims (Base.OneTo(4),), mismatch at 1

In [111]:
function f(x, θ)
    return x .+ θ
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ MvNormal(μ=ones(2), Λ=diageye(2))
    # ζ ~ MvNormal(μ=0.5ones(2), Λ=diageye(2))
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x, θ) where {meta=UT()}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [81]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=false)

"called marginal" = "called marginal"
(μ_in, Σ_in) = mean_cov(q_ins) = ([-1.705935392948323e-11, -9.999778782799056e-13, 0.9999999999829406, 0.999999999999], [1.6666666666676055 -1.2503668776828297e-32 0.6666666666675873 5.382188518894395e-22; -1.2503668776828297e-32 1.000000000001 -1.1494886002502958e-32 9.999778782818784e-13; 0.6666666666675873 -1.1494886002502958e-32 1.666666666667569 5.382188518904331e-22; 5.382188518894395e-22 9.999778782818784e-13 5.382188518904331e-22 1.000000000001])
(μ_in, Σ_in) = mean_cov(q_ins) = ([-1.705935392948323e-11, -9.999778782799056e-13, 0.9999999999829406, 0.999999999999], [1.6666666666676055 -1.2503668776828297e-32 0.6666666666675873 5.382188518894395e-22; -1.2503668776828297e-32 1.000000000001 -1.1494886002502958e-32 9.999778782818784e-13; 0.6666666666675873 -1.1494886002502958e-32 1.666666666667569 5.382188518904331e-22; 5.382188518894395e-22 9.999778782818784e-13 5.382188518904331e-22 1.000000000001])


Inference results:
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.333333333352748, w=2.3333333333333455)...
θ  = MvNormalWeightedMeanPrecision(
xi: [0.5999999999894395, 0.9999999999979999]
Λ: [...
x  = MvNormalWeightedMeanPrecision(
xi: [-1.023561235768417e-11, -9.999778782789055e-...
