In [5]:
using Distributions
using Rocket
using ReactiveMP
using BenchmarkTools
using StaticArrays

using Profile
using PProf
using ProfileSVG

import Base: show

In [6]:
meanv = datavar(:mean, NormalMeanPrecision{Float64})
precv = datavar(:precision, GammaAB{Float64})

gmpnode = GaussianMeanPrecisionNode(factorisation = SA[ SA[1], SA[2], SA[3] ])

connect!(gmpnode, :mean, meanv, 1)
connect!(gmpnode, :precision, precv, 1)

y = datavar(:y, Float64)

connect!(gmpnode, :value, y, 1)

activate!(gmpnode)

setbelief!(meanv, NormalMeanPrecision(0.0, 0.001))
setbelief!(precv, GammaAB(0.01, 0.01))

In [7]:
realprecision = 0.001
realmean = -100.0

data = rand(Normal(realmean, sqrt(1.0 / realprecision)), (1000, )) .+ 10;

In [8]:
averageE = combineLatest(
    (getbelief(meanv), getbelief(precv), getbelief(y)), 
    true, 
    (Float64, (beliefs) -> averageEnergy(NormalMeanPrecision, beliefs))
)

differentialE = combineLatest(
    (getbelief(meanv), getbelief(precv)),
    true,
    (Float64, (beliefs) -> reduce(+, map(differentialEntropy, beliefs)))
)

n = 100
iters = 5

fe = Matrix{Float64}(undef, n, iters)

# subscribe!(getbelief(tmp), logger())

for i in 1:n
    
    meanprior = nothing
    precprior = nothing
    
    subscribe!(getbelief(meanv) |> take(1) |> discontinue() |> map(AbstractMessage, ReactiveMP.as_message), (d) -> meanprior = d)
    subscribe!(getbelief(precv) |> take(1) |> discontinue() |> map(AbstractMessage, ReactiveMP.as_message), (d) -> precprior = d)
    
    # update!(meanv, meanprior)
    # update!(precv, precprior)
    
    averageGamma = combineLatest(
        (of(as_belief(getdata(precprior).a)), of(as_belief(getdata(precprior).b)), getbelief(precv)),
        true,
        (Float64, (d) -> averageEnergy(GammaAB, d))
    )
    
    averageNormal = combineLatest(
        (of(as_belief(getdata(meanprior).mean)), of(as_belief(getdata(meanprior).precision)), getbelief(meanv)),
        true,
        (Float64, (d) -> averageEnergy(NormalMeanPrecision, d))
    ) 
    
    freeEnergy = combineLatest(
        (averageE, averageGamma, averageNormal, differentialE),
        true,
        (Float64, (d) -> d[1] + d[2] + d[3] - d[4])
    )
    
    for j in 1:iters
        subscription = subscribe!(freeEnergy, (d) -> fe[i, j] = d)
        update!(meanv, meanprior)
        update!(precv, precprior)
        update!(y, data[i])
        unsubscribe!(subscription)
    end
end

subscribe!(getbelief(meanv) |> map(Any, mean) |> take(1) |> skip_complete(), logger("meanv"))
subscribe!(getbelief(precv) |> map(Any, mean) |> take(1) |> skip_complete(), logger("precv"))

[meanv] Data: -88.05273150121674
[precv] Data: 0.0010335619177398205


SwitchMapSubscription()

In [9]:
sum(fe, dims = 1)

1×5 Array{Float64,2}:
 452.288  447.884  447.819  447.815  447.814

In [10]:
fe

100×5 Array{Float64,2}:
 9.42832  9.384    9.37861  9.37493  9.37399
 5.80164  5.76841  5.76625  5.76616  5.76616
 5.25593  5.25072  5.25063  5.25063  5.25063
 6.004    5.78352  5.76451  5.76426  5.76425
 4.82185  4.68966  4.6892   4.6892   4.6892
 4.53264  4.53224  4.53224  4.53224  4.53224
 5.49632  5.2859   5.28134  5.28133  5.28133
 4.5551   4.48794  4.48778  4.48778  4.48778
 4.35966  4.35736  4.35736  4.35736  4.35736
 6.35152  5.92413  5.91188  5.91186  5.91186
 5.19169  4.90798  4.90759  4.90759  4.90759
 4.60457  4.51108  4.51105  4.51105  4.51105
 4.53626  4.53496  4.53496  4.53496  4.53496
 ⋮                                   
 3.90087  3.90031  3.90031  3.90031  3.90031
 4.19274  4.18639  4.18639  4.18639  4.18639
 3.90305  3.8997   3.8997   3.8997   3.8997
 4.14849  4.14239  4.14239  4.14239  4.14239
 3.92971  3.92224  3.92224  3.92224  3.92224
 4.07275  4.06674  4.06674  4.06674  4.06674
 4.09206  4.08113  4.08113  4.08113  4.08113
 3.98452  3.97609  3.97609  3.97609  3.9

In [9]:
ReactiveMP.averageEnergy(GammaAB, (Belief(1.0), Belief(1.0), Belief(GammaAB(1.0, 1.0))))

1.0