In [7]:
using GraphPPL, ReactiveMP, Distributions, Rocket
using BenchmarkTools

In [8]:
# Here I just create a custom function `f`
# it is also possible to define function inside the `@model` macro
f(x, r, l) = exp(x + r + l) 

f (generic function with 1 method)

In [9]:
# Lets pretend we have some approximation method and some internal fields
struct CustomApproximationMethod 
    # some fields here ...
end

In [13]:
@model function hello()
    x ~ NormalMeanPrecision(1.0, 2.0)
    r ~ NormalMeanPrecision(1.0, 2.0)
    l ~ NormalMeanPrecision(1.0, 2.0)
    
    # ReactiveMP.jl overloads all unknown functional nodes to use `DeltaFnNode`
    z ~ f(x, r, l) where { meta = CustomApproximationMethod() }
    
    y = datavar(Float64)
    y ~ NormalMeanPrecision(z, 1.0)
    
    return x, y, z
end

hello (generic function with 1 method)

In [11]:
# Some rules that don't do anything really (yet)
# We define a rule for `DeltaFn{f}` where `f` is a callable reference to our function and can be called as `f(1, 2, 3)` blabla
# `m_ins` is a tuple of input messages
# `meta` handles reference to our meta object
# `N` can be used for dispatch and can handle special cases, e.g `m_ins::NTuple{1, NormalMeanPrecision}` means that `DeltaFn` has only 1 input

@rule DeltaFn{f}(:out, Marginalisation) (m_ins::NTuple{N, Any}, meta::CustomApproximationMethod) where { f, N } = begin
    return NormalMeanPrecision(f(mean.(m_ins)...), 1.0)
end

@rule DeltaFn{f}((:in, k), Marginalisation) (q_ins::Any, m_in::Any, meta::CustomApproximationMethod) where { f } = begin 
    return NormalMeanPrecision(0.0, 1.0)
end

@marginalrule DeltaFn{f}(:ins) (q_out::Any, m_ins::NTuple{N, Any}, meta::CustomApproximationMethod) where { f, N } = begin 
    return MvNormalMeanPrecision(zeros(N), diageye(N))
end

In [14]:
# Here I only test that there are no errors

model, (x, y, z) = hello();

subscription1 = subscribe!(getmarginal(z), println)
subscription2 = subscribe!(getmarginal(x), println)

update!(y, 1.0)

unsubscribe!(subscription1)
unsubscribe!(subscription2)

Marginal(NormalWeightedMeanPrecision{Float64}(xi=21.085536923187668, w=2.0))
Marginal(NormalWeightedMeanPrecision{Float64}(xi=2.0, w=3.0))
