In [1]:
using GraphPPL, ReactiveMP, Distributions, Rocket
using BenchmarkTools

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1423


In [2]:
# Here I just create a custom function `f`
# it is also possible to define function inside the `@model` macro
f(x, r, l) = exp(x + r + l) 

f (generic function with 1 method)

In [3]:
g(x) = x + 1

g (generic function with 1 method)

In [4]:
# Lets pretend we have some approximation method and some internal fields
struct CustomApproximationMethod 
    # some fields here ...
end

In [14]:
@model function hello()
    
    x ~ NormalMeanPrecision(1.0, 2.0)
    r ~ NormalMeanPrecision(1.0, 2.0)
    l ~ NormalMeanPrecision(1.0, 2.0)
    
    # ReactiveMP.jl overloads all unknown functional nodes to use `DeltaFnNode`
    z ~ f(x, r, l) where { meta = CustomApproximationMethod() }
    # z ~ g(x) where { meta = CustomApproximationMethod() }
    
    y = datavar(Float64)
    y ~ NormalMeanPrecision(z, 1.0)
    
    return x, y, z
end

In [15]:
# Some rules that don't do anything really (yet)
# We define a rule for `DeltaFn{f}` where `f` is a callable reference to our function and can be called as `f(1, 2, 3)` blabla
# `m_out` input message on the `out` edge
# `m_ins` is a tuple of input messages
# `meta` handles reference to our meta object
# `N` can be used for dispatch and can handle special cases, e.g `m_ins::NTuple{1, NormalMeanPrecision}` means that `DeltaFn` has only 1 input

@rule DeltaFn{f}(:out, Marginalisation) (m_out::Any, m_ins::NTuple{N, Any}, meta::CustomApproximationMethod) where { f, N } = begin
    return NormalMeanPrecision(f(mean.(m_ins)...), 1.0)
end

@rule DeltaFn{f}((:in, k), Marginalisation) (q_ins::Any, m_in::Any, meta::CustomApproximationMethod) where { f } = begin 
    return NormalMeanPrecision(0.0, 1.0)
end

@marginalrule DeltaFn{f}(:ins) (m_out::Any, q_out::Any, m_ins::NTuple{N, Any}, meta::CustomApproximationMethod) where { f, N } = begin 
    return MvNormalMeanPrecision(zeros(N), diageye(N))
end

@marginalrule DeltaFn{f}(:ins) (m_out::Any, q_out::Any, m_ins::NTuple{1, Any}, meta::CustomApproximationMethod) where { f } = begin 
    return NormalMeanPrecision(0.0, 1.0)
end

In [16]:
result = inference(
    model = Model(hello),
    data  = (y = 1.0, ),
    returnvars = (
        z = KeepLast(),
        x = KeepLast()
    )
)

Inference results:
-----------------------------------------
z = NormalWeightedMeanPrecision{Float64}(xi=21.085536923187668, w=2.0)
x = NormalWeightedMeanPrecision{Float64}(xi=2.0, w=3.0)
