In [1]:
using LinearAlgebra, Flux.Optimise, Plots, ForneyLab

In [2]:
graph = FactorGraph()

@RV x ~ GaussianMeanVariance(2.,1.)
@RV w ~ Gamma(2.,3.5)
@RV y ~ GaussianMeanPrecision(x,w)

placeholder(y, :y)
;

In [3]:
pfz = PosteriorFactorization(x, w, ids=[:X, :W])
algo = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algo, free_energy=true)
eval(Meta.parse(source_code));
println(source_code)

begin

function stepW!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 2))

messages[1] = ruleVBGammaOut(nothing, ProbabilityDistribution(Univariate, PointMass, m=2.0), ProbabilityDistribution(Univariate, PointMass, m=3.5))
messages[2] = ruleVBGaussianMeanPrecisionW(ProbabilityDistribution(Univariate, PointMass, m=data[:y]), marginals[:x], nothing)

marginals[:w] = messages[1].dist * messages[2].dist

return marginals

end

function stepX!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 2))

messages[1] = ruleVBGaussianMeanVarianceOut(nothing, ProbabilityDistribution(Univariate, PointMass, m=2.0), ProbabilityDistribution(Univariate, PointMass, m=1.0))
messages[2] = ruleVBGaussianMeanPrecisionM(ProbabilityDistribution(Univariate, PointMass, m=data[:y]), nothing, marginals[:w])

marginals[:x] = messages[1].dist * messages[2].dist

return marginals

end

function freeEnergy(data::Dict, marginals::Dict)

F = 0.0

F += 

In [4]:
# Execute algorithm
n_its = 5
marginals = Dict()
F = zeros(n_its)
data = Dict(:y => 11.4)

marginals[:x] = vague(GaussianMeanVariance)
marginals[:w] = vague(Gamma)

for i = 1:n_its
    stepX!(data, marginals)
    stepW!(data, marginals)
    
    F[i] = freeEnergy(data, marginals)
end

In [5]:
F

5-element Array{Float64,1}:
 58.75614770485671
 13.198016410548277
  7.857830010678137
  7.692763736169844
  7.690828089185899

In [6]:
marginals

Dict{Any,Any} with 2 entries:
  :w => Gam(a=2.50, b=43.42)…
  :x => 𝒩(xi=2.66, w=1.06)…

In [7]:
mean(marginals[:x])

2.517509781970646

In [8]:
var(marginals[:x])

0.9449457678754632

In [9]:
# CVI
graph = FactorGraph()

f(x) = x

@RV x ~ GaussianMeanVariance(2.,1.)
@RV x_ ~ Cvi(x,g=f,opt=Descent(0.1),num_samples=1000,num_iterations=100)
@RV w ~ Gamma(2.,3.5)
@RV y ~ GaussianMeanPrecision(x_,w)

placeholder(y, :y)
;

In [10]:
pfz = PosteriorFactorization(x, w, ids=[:X, :W])
algo = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algo, free_energy=true)
eval(Meta.parse(source_code));
println(source_code)

begin

function stepW!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 2))

messages[1] = ruleVBGammaOut(nothing, ProbabilityDistribution(Univariate, PointMass, m=2.0), ProbabilityDistribution(Univariate, PointMass, m=3.5))
messages[2] = ruleVBGaussianMeanPrecisionW(ProbabilityDistribution(Univariate, PointMass, m=data[:y]), marginals[:x_], nothing)

marginals[:w] = messages[1].dist * messages[2].dist

return marginals

end

function stepX!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 4))

messages[1] = ruleVBGaussianMeanVarianceOut(nothing, ProbabilityDistribution(Univariate, PointMass, m=2.0), ProbabilityDistribution(Univariate, PointMass, m=1.0))
messages[2] = ruleSPCVIOutVD(:cvi_1, nothing, messages[1])
messages[3] = ruleVBGaussianMeanPrecisionM(ProbabilityDistribution(Univariate, PointMass, m=data[:y]), nothing, marginals[:w])
messages[4] = ruleSPCVIIn1MV(:cvi_1, messages[3], messages[1])

marginals[:x] = m

In [11]:
# Execute algorithm
n_its = 5
marginals = Dict()
F = zeros(n_its)
data = Dict(:y => 11.4)

marginals[:x] = vague(GaussianMeanVariance)
marginals[:w] = vague(Gamma)

for i = 1:n_its
    stepX!(data, marginals)
    stepW!(data, marginals)
    
    F[i] = freeEnergy(data, marginals)
end

In [12]:
F

5-element Array{Float64,1}:
 58.75616991383632
 13.159539423818055
  7.830488814981333
  7.686974526271021
  7.690629922850485

In [13]:
marginals

Dict{Any,Any} with 3 entries:
  :w  => Gam(a=2.50, b=43.41)…
  :x_ => SampleList(s=[1.85, 3.62, 1.25, 3.12, 1.53, 5.59, 1.61, 2.17, 0.56, 2.…
  :x  => 𝒩(xi=2.67, w=1.06)…

In [14]:
# CVI
graph = FactorGraph()

f(x) = x^2

@RV x ~ GaussianMeanVariance(2.,1.)
@RV x_ ~ Cvi(x,g=f,opt=Descent(0.1),num_samples=1000,num_iterations=100)
@RV w ~ Gamma(2.,3.5)
@RV y ~ GaussianMeanPrecision(x_,w)

placeholder(y, :y)
;

In [15]:
pfz = PosteriorFactorization(x, w, ids=[:X, :W])
algo = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algo, free_energy=true)
eval(Meta.parse(source_code));
println(source_code)

begin

function stepW!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 2))

messages[1] = ruleVBGammaOut(nothing, ProbabilityDistribution(Univariate, PointMass, m=2.0), ProbabilityDistribution(Univariate, PointMass, m=3.5))
messages[2] = ruleVBGaussianMeanPrecisionW(ProbabilityDistribution(Univariate, PointMass, m=data[:y]), marginals[:x_], nothing)

marginals[:w] = messages[1].dist * messages[2].dist

return marginals

end

function stepX!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 4))

messages[1] = ruleVBGaussianMeanVarianceOut(nothing, ProbabilityDistribution(Univariate, PointMass, m=2.0), ProbabilityDistribution(Univariate, PointMass, m=1.0))
messages[2] = ruleSPCVIOutVD(:cvi_1, nothing, messages[1])
messages[3] = ruleVBGaussianMeanPrecisionM(ProbabilityDistribution(Univariate, PointMass, m=data[:y]), nothing, marginals[:w])
messages[4] = ruleSPCVIIn1MV(:cvi_1, messages[3], messages[1])

marginals[:x] = m

In [16]:
# Execute algorithm
n_its = 5
marginals = Dict()
F = zeros(n_its)
data = Dict(:y => 11.4)

marginals[:x] = vague(GaussianMeanVariance)
marginals[:w] = vague(Gamma)

for i = 1:n_its
    stepX!(data, marginals)
    stepW!(data, marginals)
    
    F[i] = freeEnergy(data, marginals)
end

In [17]:
F

5-element Array{Float64,1}:
 17.433302878439797
  3.8589459295245243
  3.9028404679782627
  3.951863299950009
  3.876962222435946

In [18]:
marginals

Dict{Any,Any} with 3 entries:
  :w  => Gam(a=2.50, b=4.43)…
  :x_ => SampleList(s=[9.69, 12.73, 11.51, 12.70, 11.09, 10.08, 10.96, 10.92, 7…
  :x  => 𝒩(xi=85.45, w=25.66)…