In [45]:
using GraphPPL, ReactiveMP, Distributions
using Plots

### ET

In [46]:
function f(x)
    return sqrt.(x)
end

function f_inv(x)
    return x .^ 2
end


@model function NKS(meta)
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    x ~ MvNormal(μ=ones(2), Λ=diageye(2))
    z ~ f(x) where {meta=meta}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [50]:
# imarginals = (x = MvNormal(zeros(2), diageye(2)), )


result = inference(model = Model(NKS, UT(inverse=f_inv)), data=(y2=4.0,), free_energy=true)

"hi2" = "hi2"
"hi2" = "hi2"


Inference results:
-----------------------------------------
Free Energy: Real[3.59144]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=8.682926672530977, w=2.780487695426886)
x  = MvNormalWeightedMeanPrecision(
xi: [1.1666666666712022, 1.00000000000025]
Λ: [1....


In [9]:
mean(result.posteriors[:x])

2-element Vector{Float64}:
 1.1546391752579255
 0.4999999999999999

In [10]:
function f(x)
    return sqrt.(x)
end


@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    x ~ MvNormal(μ=ones(2), Λ=diageye(2))
    z ~ f(x) where {meta=ET()}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [11]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=true)

Inference results:
-----------------------------------------
Free Energy: Real[1.19875]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.8, w=2.8)
x  = MvNormalWeightedMeanPrecision(
xi: [1.1666666666664167, 0.99999999999975]
Λ: [1....


In [12]:
mean(result.posteriors[:x])

2-element Vector{Float64}:
 0.9999999999995713
 0.9999999999995

In [13]:
function f(x, θ)
    return x .+ θ
end

function f_x(θ, z)
    return z .- θ
end

function f_θ(x, z)
    return z .- x
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ MvNormal(μ=ones(2), Λ=diageye(2))
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x, θ) where {meta=ET(inverse=(f_x, f_θ))}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [14]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=true)

Inference results:
-----------------------------------------
Free Energy: Real[1.70859]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.3333333333333335, w=2.3333333333333335...
θ  = MvNormalWeightedMeanPrecision(
xi: [0.60000000000024, 1.0]
Λ: [1.40000000000036 ...
x  = MvNormalWeightedMeanPrecision(
xi: [5.999645225079745e-13, 9.99999999999e-13]
Λ:...


In [15]:
function f(x, θ, ζ)
    return x .+ θ .+ ζ
end


@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ MvNormal(μ=ones(2), Λ=diageye(2))
    ζ ~ MvNormal(μ=0.5ones(2), Λ=diageye(2))
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x, θ, ζ) where {meta=ET()}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [16]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=true)

Inference results:
-----------------------------------------
Free Energy: Real[1.69876]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.375, w=2.25)
ζ  = MvNormalWeightedMeanPrecision(
xi: [0.4999999999995714, 0.49999999999900013]
Λ: ...
θ  = MvNormalWeightedMeanPrecision(
xi: [1.1428571428568057, 0.9999999999995002]
Λ: [...
x  = MvNormalWeightedMeanPrecision(
xi: [-0.14285714285766313, -1.499911306270087e-12...


In [28]:
for res in ([result, result]..., [result, result]...)
    @show res
end

res = Inference results:
-----------------------------------------
Free Energy: Real[3.59144]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=8.682926672530977, w=2.780487695426886)
x  = MvNormalWeightedMeanPrecision(
xi: [1.1666666666712022, 1.00000000000025]
Λ: [1....

res = Inference results:
-----------------------------------------
Free Energy: Real[3.59144]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=8.682926672530977, w=2.780487695426886)
x  = MvNormalWeightedMeanPrecision(
xi: [1.1666666666712022, 1.00000000000025]
Λ: [1....

res = Inference results:
-----------------------------------------
Free Energy: Real[3.59144]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=8.682926672530977, w=2.780487695426886)
x  = MvNormalWeightedMeanPrecision(
xi: [1.1666666666712022, 1.00000000000025]
Λ: [1....

res = Inference results:
-----------------------------------------


In [43]:
function f₁(x)
    return sqrt.(x)
end

function f₁_inv(x)
    return x .^ 2
end


@model function delta_1input(meta)
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    x ~ MvNormal(μ=ones(2), Λ=diageye(2))
    z ~ f₁(x) where {meta=meta}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

function f₂(x, θ)
    return x .+ θ
end

function f₂_x(θ, z)
    return z .- θ
end

function f₂_θ(x, z)
    return z .- x
end

@model function delta_2inputs(meta)
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ MvNormal(μ=ones(2), Λ=diageye(2))
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f₂(x, θ) where {meta=meta}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end


function f₃(x, θ, ζ)
    return x .+ θ .+ ζ
end

@model function delta_3inputs(meta)
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ MvNormal(μ=ones(2), Λ=diageye(2))
    ζ ~ MvNormal(μ=0.5ones(2), Λ=diageye(2))
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f₃(x, θ, ζ) where {meta=meta}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

function f₄(x, θ)
    return θ.*x
end

@model function delta_2input_1d2d(meta)
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ Normal(μ=0.5, γ=1.0)
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f₄(x, θ) where {meta=meta}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end


## -------------------------------------------- ##
## Inference definition
## -------------------------------------------- ##
function inference_1input(data)
    res = []
    for meta in (ET(inverse=f₁_inv), UT(inverse=f₁_inv), ET(), UT())
        push!(res, inference(model = Model(delta_1input, meta), data=(y2=data,), free_energy=true))
    end
    res
end

function inference_2inputs(data)
    res = []
    for meta in (ET(inverse=(f₂_x, f₂_θ)), UT(inverse=(f₂_x, f₂_θ)), ET(), UT())
        push!(res, inference(model = Model(delta_2inputs, meta), data=(y2=data,), free_energy=true))
    end
    res
end

function inference_3inputs(data)
    res = []
    for meta in (ET(), UT())
        push!(res, inference(model = Model(delta_3inputs, meta), data=(y2=data,), free_energy=true))
    end
    res
end

function inference_2input_1d2d(data)
    res = []
    for meta in (ET(), UT())
        push!(res, inference(model = Model(delta_2input_1d2d, meta), data=(y2=data,), free_energy=true, free_energy_diagnostics=(BetheFreeEnergyCheckNaNs(), BetheFreeEnergyCheckInfs())))
    end
    res
end

inference_2input_1d2d (generic function with 1 method)

In [44]:
data = 4.0
## -------------------------------------------- ##
## Inference execution
result₁ = inference_1input(data)
result₂ = inference_2inputs(data)
result₃ = inference_3inputs(data)
result₄ = inference_2input_1d2d(data)
## -------------------------------------------- ##
## Test inference results


2-element Vector{Any}:
 Inference results:
-----------------------------------------
Free Energy: Real[5.89517]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=8.0, w=2.8)
θ  = NormalWeightedMeanPrecision{Float64}(xi=0.0, w=1.00000000000025)
x  = MvNormalWeightedMeanPrecision(
xi: [1.3333333333333333, 0.0]
Λ: [1.1666666666669...

 Inference results:
-----------------------------------------
Free Energy: Real[5.89517]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=8.0, w=2.8)
θ  = NormalWeightedMeanPrecision{Float64}(xi=0.0, w=1.00000000000025)
x  = MvNormalWeightedMeanPrecision(
xi: [1.3333333333333333, 0.0]
Λ: [1.1666666666669...


In [38]:
result₁[1]

Inference results:
-----------------------------------------
Free Energy: Real[3.35606]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=8.8, w=2.8)
x  = MvNormalWeightedMeanPrecision(
xi: [1.1666666666669165, 1.0]
Λ: [1.0104166666667...


In [39]:
for res in (result₁..., result₂..., result₃..., result₄)
    res
end

In [42]:
for res in (result₁..., result₂..., result₃..., result₄...)
    @show !isnan(res.free_energy[1]) && !isinf(res.free_energy[1])
end

!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true
!(isnan(res.free_energy[1])) && !(isinf(res.free_energy[1])) = true


In [17]:
function f(x, θ)
    return θ.*x
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ Normal(μ=0.5, γ=1.0)
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x, θ) where {meta=ET()}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [18]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=true)

Inference results:
-----------------------------------------
Free Energy: Real[1.60946]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.0, w=2.8)
θ  = NormalWeightedMeanPrecision{Float64}(xi=0.0, w=1.00000000000025)
x  = MvNormalWeightedMeanPrecision(
xi: [0.3333333333333333, 0.0]
Λ: [1.1666666666669...


### UT

In [19]:
function f(x)
    return x
end

function f_inv(x)
    return x
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x) where {meta=UT(inverse=f_inv)}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [20]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=true)

Inference results:
-----------------------------------------
Free Energy: Real[1.57708]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.0, w=2.5)
x  = MvNormalWeightedMeanPrecision(
xi: [0.6666666666676693, 0.0]
Λ: [1.6666666666676...


In [21]:
function f(x)
    return x.^2
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x) where {meta=UT()}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [22]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=true)

Inference results:
-----------------------------------------
Free Energy: Real[1.40739]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.3333332222149457, w=2.3333332222149457...
x  = MvNormalWeightedMeanPrecision(
xi: [0.0, 0.0]
Λ: [1.0 0.0; 0.0 1.0]
)



In [23]:
function f(x, θ)
    return x .+ θ
end

function f_x(θ, z)
    return z .- θ
end

function f_θ(x, z)
    return z .- x
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ MvNormal(μ=ones(2), Λ=diageye(2))
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x, θ) where {meta=UT(inverse=(f_x, f_θ))}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [24]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=true)

Inference results:
-----------------------------------------
Free Energy: Real[1.70859]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.333333333352748, w=2.3333333333333455)...
θ  = MvNormalWeightedMeanPrecision(
xi: [0.5999999999880001, 1.0]
Λ: [1.4000000000003...
x  = MvNormalWeightedMeanPrecision(
xi: [-1.0511773545936107e-11, 1.0000000167191506e...


In [25]:
function f(x, θ)
    return θ.*x
end

c = randn(2);

@model function NKS()
    y2 = datavar(Float64)
    c = zeros(2); c[1] = 1.0;

    θ ~ Normal(μ=0.5, γ=2.0)
    x ~ MvNormal(μ=zeros(2), Λ=diageye(2))
    z ~ f(x, θ) where {meta=UT()}
    y1 ~ Normal(μ=dot(z, c), σ²=1.0)
    y2 ~ Normal(μ=y1, σ²=0.5)
end

In [26]:
imarginals = (x = MvNormal(zeros(2), diageye(2)), )

result = inference(model = Model(NKS), data=(y2=1.0,), free_energy=true)

Inference results:
-----------------------------------------
Free Energy: Real[2.23446]
-----------------------------------------
y1 = NormalWeightedMeanPrecision{Float64}(xi=2.0, w=2.8)
θ  = NormalWeightedMeanPrecision{Float64}(xi=0.0, w=1.00000000000025)
x  = MvNormalWeightedMeanPrecision(
xi: [0.3333333333333333, 0.0]
Λ: [1.1666666666669...
