In [1]:
# Reactive programming layer for Julia
using Rocket 
# Core package for Constrained Bethe Free Energy minimsation with Factor graphs and message passing
using ReactiveMP 
# High-level user friendly probabilistic model and constraints specification language for ReactiveMP
using GraphPPL
# Optionally include Distributions.jl and Random from Base
using Distributions, Random

┌ Info: Precompiling ReactiveMP [a194aa59-28ba-4574-a09c-4a745416d6e3]
└ @ Base loading.jl:1342


This notebook covers fundamentals of ReactiveMP.jl, for advanced usage we refer to the documentation.

## General syntax for model creation

We use `@model` macro from `GraphPPL.jl` package to create a probabilistic model $p(s, y)$ and also specify extra constraints on variational family of distributions $\mathcal{Q}$.
Below there is a simple example of general syntax for model creation. In this tutorial we do not cover all possible way to create models or extra features of `GraphPPL.jl` and we refer a reader to the documentation for more rigorous explanations and examples.

In [2]:
# `@model` macro accepts a regular Julia function
@model function test_model1(s_mean, s_precision)
    
    # We use `randomvar` function to create 
    # random variables in our model
    s = randomvar()
    
    # `tilde` expression creates a functional dependencies
    # between variables in our model and can be read as 
    # `sampled from`
    s ~ GaussianMeanPrecision(s_mean, s_precision)
    
    # We use `datavar` function to create 
    # observed data variables in our models
    # We also need to specify the type of our data 
    # In this example it is `Float64`
    y = datavar(Float64)
    
    y ~ GaussianMeanPrecision(s, 1.0)
    
    return s, y
end

test_model1 (generic function with 1 method)

`@model` macro creates a function with the same name and with the same set of input arguments as the original function (`test_model1(s_mean, s_precision)` in this example). However, the return value is modified in such a way to contain a reference to the model object as the first value and user specified variables in a form of a tuple as the second value.

In [3]:
model, (s, y) = test_model1(0.0, 1.0);

Later on we can examine our model structure with the help of some utility functions such as: 
- `getnodes()`: returns an array of factor nodes in a correposning factor graph
- `getrandom()`: returns an array of random variable in the model
- `getdata()`: returns an array of data inputs in the model
- `getconstant()`: return an array of constant values in the model

In [4]:
getnodes(model)

2-element Vector{ReactiveMP.AbstractFactorNode}:
 FactorNode:
 form            : NormalMeanPrecision
 sdtype          : Stochastic()
 interfaces      : (Interface(out, Marginalisation()), Interface(μ, Marginalisation()), Interface(τ, Marginalisation()))
 factorisation   : ((1, 2, 3),)
 local marginals : (:out_μ_τ,)
 metadata        : nothing
 pipeline        : FactorNodePipeline(functional_dependencies = DefaultFunctionalDependencies(), extra_stages = EmptyPipelineStage()

 FactorNode:
 form            : NormalMeanPrecision
 sdtype          : Stochastic()
 interfaces      : (Interface(out, Marginalisation()), Interface(μ, Marginalisation()), Interface(τ, Marginalisation()))
 factorisation   : ((1, 2, 3),)
 local marginals : (:out_μ_τ,)
 metadata        : nothing
 pipeline        : FactorNodePipeline(functional_dependencies = DefaultFunctionalDependencies(), extra_stages = EmptyPipelineStage()


In [5]:
getrandom(model) .|> name

1-element Vector{Symbol}:
 :s

In [6]:
getdata(model) .|> name

1-element Vector{Symbol}:
 :y

In [7]:
getconstant(model) .|> getconst

3-element Vector{Float64}:
 0.0
 1.0
 1.0

It is also possible to use control flow statements in model specification such as `if` or `for` blocks. In principal, any valid Julia code can be used inside `@model` block.

In [8]:
@model function test_model2(n)
    
    if n <= 1
        error("`n` argument must be greater than one.")
    end
    
    # `randomvar(n)` creates a dense sequence of 
    # random variables
    s = randomvar(n)
    
    # `datavar(Float64, n)` creates a dense sequence of 
    # observed data variables of type `Float64`
    y = datavar(Float64, n)
    
    s[1] ~ GaussianMeanPrecision(0.0, 0.1)
    y[1] ~ GaussianMeanPrecision(s[1], 1.0)
    
    for i in 2:n
        s[i] ~ GaussianMeanPrecision(s[i - 1], 1.0)
        y[i] ~ GaussianMeanPrecision(s[i], 1.0)
    end
    
    return s, y
end

test_model2 (generic function with 1 method)

In [9]:
model, (s, y) = test_model2(10);

In [10]:
# An amount of factor nodes in generated Factor Graph
getnodes(model) |> length

20

In [11]:
# An amount of random variables
getrandom(model) |> length

10

In [12]:
# An amount of data inputs
getdata(model) |> length

10

In [13]:
# An amount of constant values
getconstant(model) |> length

21

It is also possible to use complex expression inside functional dependencies expressions

```julia
y ~ NormalMeanPrecision(2.0 * (s + 1.0), 1.0)
```

`~` operator automatically creates a random variable if none was created before with the same name and errors if this name already exists

```julia
# s = randomvar() here is optional
# `~` creates random variables automatically
s ~ NormalMeanPrecision(0.0, 1.0)
```

An error example:

In [14]:
@model function error_model1()
    s = 1.0
    s ~ NormalMeanPrecision(0.0, 1.0)
end

LoadError: LoadError: Invalid name 's' for new random variable. 's' was already initialized with '=' operator before.
in expression starting at /Users/bvdmitri/.julia/dev/GraphPPL/src/GraphPPL.jl:161

By default `GraphPPL.jl` creates new references for constants (literals like `0.0` or `1.0`) in a model. In some situtations it may not be efficient especially if this constants represent some matrices. `GraphPPL.jl` will create a new copy of some constant matrix in a model every time it uses it. However it is possible to use `constvar()` function to create and reuse constant in the model specification syntax

```julia
# Creates constant reference in a model with a prespecified value
c = constvar(0.0)
```

An example:

In [15]:
@model function test_model5(dim::Int, n::Int, A::Matrix, P::Matrix, Q::Matrix)
    
    s = randomvar(n)
    
    y = datavar(Vector{Float64}, n)
    
    # Here we create constant references
    # for constant matrices in our model 
    # to make inference a little bit more efficient
    cA = constvar(A)
    cP = constvar(P)
    cQ = constvar(Q)
    
    s[1] ~ MvGaussianMeanCovariance(zeros(dim), cP)
    y[1] ~ MvGaussianMeanCovariance(s[1], cQ)
    
    for i in 2:n
        s[i] ~ MvGaussianMeanCovariance(cA * s[i - 1], cP)
        y[i] ~ MvGaussianMeanCovariance(s[i], cQ)
    end
    
    return s, y
end

test_model5 (generic function with 1 method)

`~` expression also may return a reference to a newly created node in a corresponding factor graph for better convenience or later usage:

```julia
@model function test_model()

    # In this example `ynode` refers to the corresponding 
    # `GaussianMeanVariance` node created in the factor graph
    ynode, y ~ GaussianMeanVariance(0.0, 1.0)
    
    return ynode, y
end
```

## Inference in ReactiveMP.jl

ReactiveMP.jl uses `Rocket.jl` library API for inference routines. `Rocket.jl` is a reactive programming extenstions library for Julia that is higly inspired by `RxJS` and similar libraries from `Rx` ecosystem. It consists of **observables**, **actors**, **subscriptions** and **operators**. For more infromation and rigorous examples see [Rocket.jl github page](https://github.com/biaslab/Rocket.jl).

### Observables
Observables are lazy push-based collections and deliver their values over time.

In [16]:
# Timer that emits a new value every second and has an initial one second delay 
observable = timer(1000, 1000)

TimerObservable(1000, 1000)

Subscription allows to subscribe on future values in observable and actors specify what to do with new values:

In [17]:
actor = (value) -> println(value)
subscription1 = subscribe!(observable, actor)

TimerSubscription()

0
1
2
3
4


In [18]:
# We always need to unsubscribe from some observables
unsubscribe!(subscription1)

In [19]:
# We can modify our observables
modified = observable |> filter(d -> rem(d, 2) === 1) |> map(Int, d -> d ^ 2)

ProxyObservable(Int64, MapProxy(Int64))

In [20]:
subscription2 = subscribe!(modified, (value) -> println(value))

TimerSubscription()

1
9
25
49
81


In [21]:
unsubscribe!(subscription2)

`ReactiveMP.jl` library returns posterior marginals in a form of an observable. It is possible to subscribe on its future updates, but for convenience `ReactiveMP.jl` caches last obtained value of all marginals in a model. To get a reference for the posterior marginal of some random variable in a model `ReactiveMP.jl` exports two functions: 
- `getmarginal(x)`: for a single random variable `x`
- `getmarginals(xs)`: for a dense sequence of random variables `sx`

Lets see how it works in practice. Here we create a simple coin toss model. We assume that observations are governed by the `Bernoulli` distribtuion with unknown bias parameter `θ`. To have a fully Bayesian treatment of this problem we endow `θ` with the `Beta` prior.

In [22]:
@model function coin_toss_model(n)

    # `datavar` creates data 'inputs' in our model
    # We will pass data later on to these inputs
    # In this example we create a sequence of inputs that accepts Float64
    y = datavar(Float64, n)
    
    # We endow θ parameter of our model with some prior
    θ ~ Beta(2.0, 7.0)
    
    # We assume that outcome of each coin flip 
    # is governed by the Bernoulli distribution
    for i in 1:n
        y[i] ~ Bernoulli(θ)
    end
    
    # We return references to our data inputs and θ parameter
    # We will use these references later on during inference step
    return y, θ
end

coin_toss_model (generic function with 1 method)

In [46]:
_, (y, θ) = coin_toss_model(500);

In [47]:
# As soon as we have a new value for marginal posterior over `θ` variable
# we simply print first two statistics of it
θ_subscription = subscribe!(getmarginal(θ), (marginal) -> println("New update: mean(θ) = ", mean(marginal), ", std(θ) = ", std(marginal)));

Next, lets define our dataset:

In [48]:
p = 0.75 # Bias of a coin

dataset = float.(rand(Bernoulli(p), 500));

To pass data to our model we use `update!` function

In [49]:
update!(y, dataset)

New update: mean(θ) = 0.7269155206286837, std(θ) = 0.01972901448985252


In [50]:
# It is necessary to always unsubscribe from running observables
unsubscribe!(θ_subscription)

In [51]:
# ReactiveMP.jl inference backedn is lazy and do not compute posterior marginals if no-one is listening for them
# At this moment we already unsubscribed from new posterior updates so this `update!` does nothing
update!(y, dataset)

Rocket.jl provides some useful built-in actors for obtaining posterior marginals especially with static datasets.

In [70]:
# `keep` actor simply keeps all incoming updates in an internal storage, ordered
θvalues = keep(Marginal)

KeepActor{Marginal}(Marginal[])

In [71]:
# `getmarginal` always emits last cached value as its first value
subscribe!(getmarginal(θ) |> take(1), θvalues);

In [72]:
getvalues(θvalues)

1-element Vector{Marginal}:
 Marginal(Beta{Float64}(α=370.0, β=139.0))

In [73]:
subscribe!(getmarginal(θ) |> take(1), θvalues);

In [74]:
getvalues(θvalues)

2-element Vector{Marginal}:
 Marginal(Beta{Float64}(α=370.0, β=139.0))
 Marginal(Beta{Float64}(α=370.0, β=139.0))

In [75]:
# `buffer` actor keeps very last incoming update in an internal storage and can also store 
# an array of updates for a sequence of random variables
θbuffer = buffer(Marginal, 1)

BufferActor{Marginal, Vector{Marginal}}(Marginal[#undef])

In [77]:
subscribe!(getmarginals([ θ ]) |> take(1), θbuffer);

In [78]:
getvalues(θbuffer)

1-element Vector{Marginal}:
 Marginal(Beta{Float64}(α=370.0, β=139.0))

In [79]:
subscribe!(getmarginals([ θ ]) |> take(1), θbuffer);

In [80]:
getvalues(θbuffer)

1-element Vector{Marginal}:
 Marginal(Beta{Float64}(α=370.0, β=139.0))

That was an example of exact Bayesian inference with Sum-Product (or Belief Propagation) algorithm. However, ReactiveMP.jl is not limited to only sum-product algoritm but also supports variational message passing with [Constrained Bethe Free Energy Minimisation](https://www.mdpi.com/1099-4300/23/7/807).

## Variational inference

On a very high-level, ReactiveMP.jl is aimed to solve the Constrained Bethe Free Energy minimisation problem. For this task we often need to specify extra factorisation on variatonal family of distributions $q \in \mathcal{Q}$. For this purpose `@model` macro supports optional `where { ... }` clauses for every `~` expression in a model specification.

In [194]:
@model function test_model6(n)
    τ ~ GammaShapeRate(1.0, 1.0) 
    μ ~ NormalMeanVariance(0.0, 100.0)
    
    y = datavar(Float64, n)
    
    for i in 1:n
        # Here we assume a mean-field assumption on our 
        # variational family of distributions locally for the current node
        y[i] ~ NormalMeanPrecision(μ, τ) where { q = q(y[i])q(μ)q(τ) }
    end
    
    return μ, τ, y
end

test_model6 (generic function with 1 method)

There are several options to specify the mean-field factorisation constraint. 

```julia
y[i] ~ NormalMeanPrecision(μ, τ) where { q = q(y[i])q(μ)q(τ) } # With names from model specification
y[i] ~ NormalMeanPrecision(μ, τ) where { q = q(out)q(mean)q(precision) } # With names from node specification
y[i] ~ NormalMeanPrecision(μ, τ) where { q = MeanField() } # With alias name
```

It is also possible to use local structured factorisation:

```julia
y[i] ~ NormalMeanPrecision(μ, τ) where { q = q(y[i], μ)q(τ) } # With names from model specification
y[i] ~ NormalMeanPrecision(μ, τ) where { q = q(out, mean)q(precision) } # With names from node specification
```

An an option `@model` macro accepts optional arguments for model specification, one of which is `default_factorisation` that accepts `MeanField()` as its argument for better convenience

```julia
@model [ default_factorisation = MeanField() ] function test_model(...)
    ...
end
```

To run inference on this model we again need to create a synthetic dataset:

In [208]:
dataset = rand(Normal(-3.0, inv(sqrt(5.0))), 1000);

In [209]:
model, (μ, τ, y) = test_model6(length(dataset));

For variational inference we also usually need to set initial marginals for our inference procedure. For that purpose `ReactiveMP.jl` export `setmarginal!` function:

In [210]:
setmarginal!(μ, vague(NormalMeanPrecision))
setmarginal!(τ, vague(GammaShapeRate))

In [198]:
μ_values = keep(Marginal)
τ_values = keep(Marginal)

μ_subscription = subscribe!(getmarginal(μ), μ_values)
τ_subscription = subscribe!(getmarginal(τ), τ_values)

for i in 1:10
    update!(y, dataset)
end

In [199]:
getvalues(μ_values)

10-element Vector{Marginal}:
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-3.0067821818264343e-9, w=0.010000001002000566))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-27.53188940920315, w=9.184909095347356))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-9542.198405846082, w=3179.91536872275))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-14558.342535507441, w=4851.528446734716))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-14565.984264326395, w=4854.075027036634))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-14565.991894780827, w=4854.077569859672))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-14565.991902396012, w=4854.077572397454))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-14565.991902403606, w=4854.077572400089))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-14565.99190240359, w=4854.077572400089))
 Marginal(NormalWeightedMeanPrecision{Float64}(xi=-14565.99190240359, w=4854.077572400089))

In [200]:
getvalues(τ_values)

10-element Vector{Marginal}:
 Marginal(GammaShapeRate{Float64}(a=501.0, b=5.000000000046053e14))
 Marginal(GammaShapeRate{Float64}(a=501.0, b=54605.44565548494))
 Marginal(GammaShapeRate{Float64}(a=501.0, b=157.55185828100187))
 Marginal(GammaShapeRate{Float64}(a=501.0, b=103.26663816710823))
 Marginal(GammaShapeRate{Float64}(a=501.0, b=103.2124615573716))
 Marginal(GammaShapeRate{Float64}(a=501.0, b=103.21240748910154))
 Marginal(GammaShapeRate{Float64}(a=501.0, b=103.21240743514142))
 Marginal(GammaShapeRate{Float64}(a=501.0, b=103.21240743508743))
 Marginal(GammaShapeRate{Float64}(a=501.0, b=103.21240743508754))
 Marginal(GammaShapeRate{Float64}(a=501.0, b=103.21240743508754))

In [201]:
println("μ: mean = ", mean(last(μ_values)), ", std = ", std(last(μ_values)))

μ: mean = -3.0007744386337576, std = 0.014353130838941412


In [202]:
println("τ: mean = ", mean(last(τ_values)), ", std = ", std(last(τ_values)))

τ: mean = 4.854067572400048, std = 0.21686374576309103


### Form constraints

In order to support form constraint `randomvar()` function also supports `where { ... }` clause with some optional arguments. One of these arguments is `form_constraint` that allows to specify additional extra form constraint to random variables in our model. Another one is `prod_constraint` that allows to specify an additional constraints during computation of product of two colliding messages. For example we can perform EM algorithm if we assign a point mass contraint on some variables in our model.

<img style="display: block;
  margin-left: auto;
  margin-right: auto;
  width: 50%;" src="./pics/posterior.png" />

In [223]:
@model function test_model7(n)
    τ ~ GammaShapeRate(1.0, 1.0) 
    
    # In case of form constraints `randomvar()` call is necessary
    μ = randomvar() where { form_constraint = PointMassFormConstraint() }
    μ ~ NormalMeanVariance(0.0, 100.0)
    
    y = datavar(Float64, n)
    
    for i in 1:n
        y[i] ~ NormalMeanPrecision(μ, τ) where { q = q(y[i])q(μ)q(τ) }
    end
    
    return μ, τ, y
end

test_model7 (generic function with 1 method)

In [184]:
model, (μ, τ, y) = test_model7(length(dataset));

In [185]:
setmarginal!(μ, vague(NormalMeanPrecision))
setmarginal!(τ, PointMass(1.0))

μ_values = keep(Marginal)
τ_values = keep(Marginal)

μ_subscription = subscribe!(getmarginal(μ), μ_values)
τ_subscription = subscribe!(getmarginal(τ), τ_values)

for i in 1:10
    update!(y, dataset)
end

In [186]:
getvalues(μ_values) |> last

Marginal(PointMass{Float64}(-2.9913047925590215))

In [187]:
getvalues(τ_values) |> last 

Marginal(GammaShapeRate{Float64}(a=501.0, b=106.01537627984227))

By default `ReactiveMP.jl` tries to compute an analytical product of two colliding messages and throws an error if no analytical solution is known. However, it is possible to fallback to a generic product that does not require an analytical solution to be known. In this case inference backend will simply propagate product of two message in a form of a tuple. It is not possible to use such a tuple-product during an inference and in this case it is mandatory to use some form constraint to approximate this product.

```julia
μ = randomvar() where { 
    prod_constraint = ProdGeneric(),
    form_constraint = SampleListFormConstraint() 
}
```

Sometimes it is usefull to preserve a specific parametrisation of the resulting product later on in an inference procedure. `ReactiveMP.jl` exports special `prod_constraint` called `ProdPreserveType` especially for that purpose:

```julia
μ = randomvar() where { prod_constraint = ProdPreserveType(NormalWeightedMeanPrecision) }
```

### Free Energy

During variational inference `ReactiveMP.jl` optimises a special functional called Bethe Free Energy functional. It is possible to obtain its values over VMP iterations with `score` function.

In [215]:
model, (μ, τ, y) = test_model6(length(dataset));

In [216]:
bfe_observable = score(BetheFreeEnergy(), model)

ProxyObservable(Real, MapProxy(Tuple{ReactiveMP.InfCountingReal, ReactiveMP.InfCountingReal}))

In [217]:
bfe_subscription = subscribe!(bfe_observable, (fe) -> println("Current BFE value: ", fe));

In [218]:
# Reset the model with vague marginals
setmarginal!(μ, vague(NormalMeanPrecision))
setmarginal!(τ, vague(GammaShapeRate))

for i in 1:10
    update!(y, dataset)
end

Current BFE value: 14763.268311193242
Current BFE value: 3275.486553095013
Current BFE value: 676.8537773787698
Current BFE value: 637.9744430974015
Current BFE value: 637.974374505051
Current BFE value: 637.9743745049896
Current BFE value: 637.9743745049882
Current BFE value: 637.974374504985
Current BFE value: 637.9743745049841
Current BFE value: 637.9743745049841


In [219]:
# It always necessary to unsubscribe and release computer resources
unsubscribe!([ μ_subscription, τ_subscription, bfe_subscription ])

### Meta specification

During model specification some functional dependencies may accept an optional `meta` object in `where { ... }` clause. The purpose of the `meta` object is to adjust, modify or supply some extra information to the inference backend during messages computations. `meta` object for example may contain an approximation method that needs to be used during various approximations or it may specify the tradeoff between accuracy and performance:

```julia
# In this example `meta` object for autoregressive `AR` node specifes the variate type of 
# the autoregressive process and its order. In addition it specifies that messages computation rules 
# respect accuracy over speed with `ARsafe()` strategy. In contrast, `ARunsafe()` strategy tries to speedup computations
# by cost of possible numerical instabilities during an inference procedure
s[i] ~ AR(s[i - 1], θ, γ) where { q = q(s[i - 1], s[i])q(θ)q(γ), meta = ARMeta(Multivariate, order, ARsafe()) }
...
s[i] ~ AR(s[i - 1], θ, γ) where { q = q(s[i - 1], s[i])q(θ)q(γ), meta = ARMeta(Univariate, order, ARunsafe()) }
```

Another example with `GaussianControlledVariance`, or simply `GCV` [see Hierarchical Gaussian Filter], node:

```julia
# In this example we specify structured factorisation and flag meta with `GaussHermiteCubature` 
# method with `21` sigma points for approximation of non-lineariety between hierarchy layers
xt ~ GCV(xt_min, zt, real_k, real_w) where { q = q(xt, xt_min)q(zt)q(κ)q(ω), meta = GCVMetadata(GaussHermiteCubature(21)) }
```

Meta object is usefull to pass any extra information to a node that is not a random variable or constant model variable. It may include extra approximation methods, differentiation methods, optional non-linear functions, extra inference parameters etc.

## Creating custom nodes and message computation rules

### Custom nodes

To create a custom functional form and to make it available during model specification `ReactiveMP.jl` exports the `@node` macro:

```julia
# `@node` macro accepts a name of the functional form, its type, either `Stochastic` or `Deterministic` and an array of interfaces:
@node NormalMeanVariance Stochastic [ out, μ, v ]

# Interfaces may have aliases for their names that might be convenient for factorisation constraints specification
@node NormalMeanVariance Stochastic [ out, (μ, aliases = [ mean ]), (v, aliases = [ var ]) ]

# `NormalMeanVariance` structure declaration must exist, otherwise `@node` macro will throw an error
struct NormalMeanVariance end 

@node NormalMeanVariance Stochastic [ out, μ, v ]

# It is also possible to use function objects as a node functional form
function dot end

# Syntax for functions is a bit differet, as it is necesssary to use `typeof(...)` function for them 
# out = dot(x, a)
@node typeof(dot) Deterministic [ out, x, a ]
```

**Note**: Deterministic nodes do not support factorisation constraints with `where { q = ... }` clause.

After that it is possible to use newly during model specification:

```julia
@model function test_model()
    ...
    y ~ dot(x, a)
    ...
end
```

### Custom messages computation rules

`ReactiveMP.jl` exports `@rule` macro to create custom messages computation rules. For example let us create a simple `+` node to be available for usage in the model specification usage. We refer to *A Factor Graph Approach to Signal Modelling , System Identification and Filtering* [ Sascha Korl, 2005, page 32 ] for a rigorous explanation `+` node in factor graphs. According to Korl, assuming that inputs are Gaussian Sum-Product message computation rule for `+` node is the following:

$$
\mu_z = \mu_x + \mu_y \\
V_z = V_x + V_y
$$

To specify this in `ReactiveMP.jl` we use `@node` and `@rule` macros:
 
```julia
@node typeof(+) Deterministic  [ z, x, y ]

@rule typeof(+)(:z, Marginalisation) (m_x::UnivariateNormalDistributionsFamily, m_y::UnivariateNormalDistributionsFamily) = begin
    x_mean, x_var = mean_var(m_x)
    y_mean, y_var = mean_var(m_y)
    return NormalMeanVariance(x_mean + y_mean, x_var + y_var)
end
```

In this example, for the `@rule` macro, we specify a type of our functional form: `typeof(+)`. Next, we specify an edge we are going to compute an outbound message for. `Marginalisation` indicates that the corresponding message respects the marginalisation constraint for posterior over corresponding edge:

$$
q(z) = \int q(z, x, y) \mathrm{d}x\mathrm{d}y
$$

If we look on difference between sum-product rules and variational rules with mean-field assumption we notice that they require different local information to compute an outgoing message:

<img style="display: block;
  margin-left: auto;
  margin-right: auto;
  width: 30%;" src="./pics/sp.png" width="20%" />

$$
\mu(z) = \int f(x, y, z)\mu(x)\mu(y)\mathrm{d}x\mathrm{d}y
$$

<img style="display: block;
  margin-left: auto;
  margin-right: auto;
  width: 30%;" src="./pics/vmp.png" width="20%" />

$$
\nu(z) = \exp{ \int \log f(x, y, z)q(x)q(y)\mathrm{d}x\mathrm{d}y }
$$

`@rule` macro support both cases with special prefixes during rule specification:
- `m_` prefix corresponds to the incoming message on a specific edge
- `q_` prefix corresponds to the posterior marginal of a specific edge

Example of a Sum-Product rule with `m_` messages used:

```julia
@rule NormalMeanPrecision(:μ, Marginalisation) (m_out::UnivariateNormalDistributionsFamily, m_τ::PointMass) = begin 
    m_out_mean, m_out_cov = mean_cov(m_out)
    return NormalMeanPrecision(m_out_mean, inv(m_out_cov + inv(mean(m_τ))))
end
```

Example of a Variational rule with Mean-Field assumption with `q_` posteriors used:

```julia
@rule NormalMeanPrecision(:μ, Marginalisation) (q_out::Any, q_τ::Any) = begin 
    return NormalMeanPrecision(mean(q_out), mean(q_τ))
end
```

`ReactiveMP.jl` also supports structured rules. It is possible to obtain joint marginal over a set of edges:

```julia
@rule NormalMeanPrecision(:τ, Marginalisation) (q_out_μ::Any, ) = begin
    m, V = mean_cov(q_out_μ)
    θ = 2 / (V[1,1] - V[1,2] - V[2,1] + V[2,2] + abs2(m[1] - m[2]))
    α = convert(typeof(θ), 1.5)
    return Gamma(α, θ)
end
```

**NOTE**: In `@rule`specification messages or marginals arguments **must** be in order with interfaces specification from `@node` macro:

```julia
# Inference backend expects arguments in `@rule` macro to be in the same order
@node NormalMeanPrecision Stochastic [ out, μ, τ ]
```

Any rule always has an access to meta information with hidden `meta::Any` variable:

```julia
@rule MyCustomNode(:out, Marginalisation) (m_in1::Any, m_in2::Any) = begin 
    ...
    println(meta)
    ...
end
```

It is also possible to dispatch on a specific type of a meta object:

```julia
@rule MyCustomNode(:out, Marginalisation) (m_in1::Any, m_in2::Any, meta::LaplaceApproximation) = begin 
    ...
end
```

or

```julia
@rule MyCustomNode(:out, Marginalisation) (m_in1::Any, m_in2::Any, meta::GaussHermiteCubature) = begin 
    ...
end
```

### Customizing messages computational pipeline

In certain situation it might be convenient to customize default messages computational pipeline. `GrahpPPL.jl` supports `pipeline` keyword in `where { ... }` clause to add some extra steps after a message has been computed. A use case might be an extra approximation method to preserve conjugacy in the model, debugging or simple printing.

<img style="display: block;
  margin-left: auto;
  margin-right: auto;
  width: 30%;" src="./pics/pipeline.png" width="20%" />

```julia
y[i] ~ NormalMeanPrecision(x[i], 1.0) where { pipeline = LoggerPipelineStage() }
y[i] ~ NormalMeanPrecision(x[i], 1.0) where { pipeline = LaplaceApproximation() }
```

Let us return to the coin toss model, but this time we want to print flowing messages:

In [149]:
@model function coin_toss_model_log(n)

    y = datavar(Float64, n)

    θ ~ Beta(2.0, 7.0) where { pipeline = LoggerPipelineStage("θ") }

    for i in 1:n
        y[i] ~ Bernoulli(θ)  where { pipeline = LoggerPipelineStage("y[$i]") }
    end
    
    return y, θ
end

coin_toss_model_log (generic function with 1 method)

In [150]:
_, (y, θ) = coin_toss_model_log(5);

In [151]:
θ_subscription = subscribe!(getmarginal(θ), (value) -> println("New posterior marginal for θ: ", value));

[θ][Beta][out]: Message(Beta{Float64}(α=2.0, β=7.0))


In [152]:
coinflips = float.(rand(Bernoulli(0.5), 5));

In [153]:
update!(y, coinflips)

[y[1]][Bernoulli][p]: Message(Beta{Float64}(α=1.0, β=2.0))
[y[2]][Bernoulli][p]: Message(Beta{Float64}(α=1.0, β=2.0))
[y[3]][Bernoulli][p]: Message(Beta{Float64}(α=1.0, β=2.0))
[y[4]][Bernoulli][p]: Message(Beta{Float64}(α=1.0, β=2.0))
[y[5]][Bernoulli][p]: Message(Beta{Float64}(α=2.0, β=1.0))
New posterior marginal for θ: Marginal(Beta{Float64}(α=3.0, β=11.0))


In [154]:
unsubscribe!(θ_subscription)

In [155]:
# Inference is lazy and does not send messages if no one is listening for them
update!(y, coinflips)

### Customizing posterior computational pipeline