# Conjugate-NonConjugate Varialtional Message Passsing: a tutorial


*Table of contents*
1. [Introduction](#Introduction)
2. [Model specification](#Model-specification)
3. [Limitations](#Limitations)
4. [Inference](#Inference)
5. [Extension](#Extension)

## Introduction

In this tutorial, we show how to excute VMP in models with delta factors of the form $\delta(f(x_{1}, \dots, x_{n}) - y)$ (`f` arbitrary differentiable function) inside `ReactiveMP`.
The `ReactiveMP` implemented the Conjugate-NonConjugate Variational Inference (CVI) inside factor graphs carefully following the paper [Probabilistic programming with stochastic variational message passing](https://reader.elsevier.com/reader/sd/pii/S0888613X22000950?token=EFB22E01793BD0BF73EECC9702C315644969403BD44B13FA850E9F66C8A49E88C0D5C68A9AD03301C609DA443DB33F80&originRegion=eu-west-1&originCreation=20221027115856) (see it for implementation details).

More specifically the tutorial shows:
1. [How specify](#Model-specification) a `Delta factor` of the form $\delta(f(x_{1}, \dots, x_{n}) - y)$ inside the `@model` macro with CVI inference procedure
2. [What limitation](#Limitations) the current implementation has
3. [Show several inference examples](#Inference)
4. [How to extend](#Extension) it from the user perspective

## Model specification

Suppose we have a function `f`:

```
    f(x, y, ..., z) = ...
```

And we have a model where we want define a `Delta factor` with this `f`:

```
@model function model_name(...)
    ... some where here inputs (x, y, ..., z) defiened
    out ~ f(x, y, ..., z) where {meta = CVIApproximation(rng, n_iterations, n_samples, Descent(learning_rate))}
    ...    
end
```

If you want to see a detailed example at this point, go to [Inference](#Inference) section. 

The magic happens inside the `where` block: where we specify through `meta` parameter that `ReactiveMP` should run `CVI` for the messages sent through the `out ~ f(x, y, ..., z)` node.

So to specify it for `out ~ f(x, y, ..., z)` node you need set `meta` to `CVIApproximation(...)` inside `where` block: `meta=CVIApproximation(rng, n_samples, n_iterations, Descent(learning_rate))` in the above example.

The `CVIApproximation` structure serves for two reasons:
1. Marker that the `CVI` rules need to be called
2. Container of the `CVI` hyperparameters.

`CVI` procedure has 4 hyperparameters:
1. random number generator, which will be called inside the `CVI` procedure (`rng`)
2. number of samples to use for the out message approximation (`n_samples`)
3. number of iterations of the CVI procedure (`n_iterations`)
4. optimizer, which will be used to perform the CVI step (`Descent(learning_rate)`)

## Limitations

There are several main limitations for the `CVI` procedure that you need to satisfy:
1. The `CVI` procedure supposes that there is mean-field assumption on the function inputs (`x, y, ..., z`) connected to the node 
2. The connected interface is factorized out in other nodes to which they are connected
3. The messages on input interfaces (`x, y, ..., z`) are from the exponential family distributions

In `ReactiveMP`, you can obtain the first and second assumption through `@constraints` macro:

```
@model function model_name(...)
 ... some where here, inputs (x, y, ..., z) defined
 ... ~ Node1(x, q1, q2, ..., qn) # some node that is using x interface
 out ~ f(x, y, ..., z) where {meta = CVIApproximation(rng, n_iterations, n_samples, Descent(learning_rate))}
 ...
 ... ~ Node2(p1,..., out, pn) # some node that is using out the interface
 ... 
end

constraints = @constraints begin
 q(out, p1, ..., pn) = q(y)q(p1,...,pn)
 q(x, y, ..., z) = q(x)...q(z) # mean-field on the inputs into delta factor
 q(x, q1, ..., qn) = q(x)q(q1,...,qn) # mean-field in the node where x also used
 ...
end;
```

Note that only some exponential family distributions are implemented. If you want to add one not implemented inside `ReactiveMP` read this [example](###Adding-a-custom-implementation-for-an-instance-from-exponential-family-distribution).


## Inference

## Extension

### Adding a custom implementation for an instance from exponential family distribution

You might need a distribution from the exponential distribution family, which has not yet been implemented in `ReactiveMP.`

This section shows how you can implement it yourself.

This example implements `Beta distribution`.

To add a new distribution, we need to implement the following:

1. auxiliary structure for storing `Beta` natural parameters: it is `BetaNaturalParameters` in our example
2. `isproper`: the domain check for `BetaNaturalParameters`
3. auxiliary functions that convert `BetaNaturalParameters` to vector, distribution and vice-versa
4. `lognormalizer` and `logpdf` for `BetaNaturalParameters`
5. subtraction for `BetaNaturalParameters` 

In [4]:
using ReactiveMP, Rocket, GraphPPL, Random, LinearAlgebra, Plots, Flux, ForwardDiff, Plots, SpecialFunctions, Distributions, Base, StableRNGs

`BetaNaturalParameters` structure

In [2]:
struct BetaNaturalParameters{T <: Real} <: NaturalParameters
    α::T
    β::T
end

`isproper` functions

In [3]:
ReactiveMP.isproper(params::BetaNaturalParameters) = ((params.α  - 1)> zero(params.α)) & ((params.β - 1)> zero(params.β))

Needed axillary "convert" functions

In [6]:
ReactiveMP.naturalparams(dist::Beta) = BetaNaturalParameters(dist.α, dist.β)

function Base.convert(::Type{Distribution}, η::BetaNaturalParameters)
    return Beta(η.α, η.β)
end

function Base.vec(p::BetaNaturalParameters)
    return [p.α, p.β]
end

ReactiveMP.as_naturalparams(::Type{T}, args...) where {T <: BetaNaturalParameters} =
    convert(BetaNaturalParameters, args...)

function BetaNaturalParameters(v::AbstractVector{T}) where {T <: Real}
    @assert length(v) === 2 "`BetaNaturalParameters` must accept a vector of length `2`."
    return BetaNaturalParameters(v[1], v[2])
end

Base.convert(::Type{BetaNaturalParameters}, vector::AbstractVector) =
    convert(BetaNaturalParameters{eltype(vector)}, vector)

Base.convert(::Type{BetaNaturalParameters{T}}, vector::AbstractVector) where {T} =
    BetaNaturalParameters(convert(AbstractVector{T}, vector))

`lognormalizer` and `logpdf`

In [7]:
ReactiveMP.lognormalizer(params::BetaNaturalParameters) = loggamma(params.α) + loggamma(params.β) - loggamma(params.α + params.β) 
ReactiveMP.logpdf(params::BetaNaturalParameters, x) = x * (params.α - 1) + (1-x) * (params.β - 1) - lognormalizer(params) 

subtraction

In [8]:
function Base.:-(left::BetaNaturalParameters, right::BetaNaturalParameters)
    return BetaNaturalParameters(
        left.α - right.α,
        left.β - right.β
    )
end

Now, we can specify a model with `Beta distribution` inside `Delta` factor.

Generate some syntatic data

In [13]:
rng = StableRNG(123)
cvi_rng = StableRNG(42)

golden_beta = Beta(3, 4)

num_samples = 1000

means = rand(rng, golden_beta, num_samples)
observations = map(mean -> rand(rng, NormalMeanVariance(mean, 0.01)), means);

Creating a model specification with `@model` macro.

Note, we don't need `q(beta, mean) = q(beta)q(mean)` constraint.

In [14]:
@model function normal_with_beta_mean(num_observations, cvi_rng, cvi_num_samples, cvi_num_iterations, learning_rate)
    observations = datavar(Float64, num_observations) 
    beta ~ Beta(1, 1)
    mean ~ identity(beta) where {meta = CVI(cvi_rng, cvi_num_samples, cvi_num_iterations, Descent(learning_rate))}
    observations .~ NormalMeanVariance(mean, 0.01)
end

┌ Error: Failed to revise /Users/mykola/repos/ReactiveMP.jl/src/nodes/delta/approximations/cvi.jl
│   exception = Revise.ReviseEvalException("/Users/mykola/repos/ReactiveMP.jl/demo/none:0", ErrorException("invalid redefinition of constant CVIApproximation"), Any[(top-level scope at none:0, 1)])
└ @ Revise /Users/mykola/.julia/packages/Revise/do2nH/src/packagedef.jl:715
│ 
│   /Users/mykola/repos/ReactiveMP.jl/src/nodes/delta/approximations/cvi.jl
│ 
│ If the error was due to evaluation order, it can sometimes be resolved by calling `Revise.retry()`.
│ Use Revise.errors() to report errors again. Only the first error in each file is shown.
│ Your prompt color may be yellow until the errors are resolved.
└ @ Revise /Users/mykola/.julia/packages/Revise/do2nH/src/packagedef.jl:825


To run inference we will use `inference` function from `ReactiveMP`.
Note, that for running inference for this model we need to init message for `beta`.

In [11]:
res = inference(
    model = Model(normal_with_beta_mean, num_samples, cvi_rng, 100, 1000, 0.1),
    data = (observations = observations,),
    iterations = 100,
    free_energy = false,
    returnvars = (beta = KeepLast(),),
    initmessages = (beta = Beta(1, 1),),
)

Inference results:
-----------------------------------------
beta = Beta{Float64}(α=464.35142801780785, β=475.24111495855345)


In [12]:
mean(res.posteriors[:beta]), mean(golden_beta)

(0.49420510144415886, 0.42857142857142855)

As we can see by mean comparison: the inference is working correctly.

### Using a custom optimizer

Currently, `CVI` supports only `Flux` optimizers to perform the step optimization step.

Here it is shown how to extend it for a some structure.

Suppose we have `CustomDescent` structure which we want to use for `CVI` optimization step.

To make it possible we need to implement `ReactiveMP.cvi_update!(opt::CustomDescent, λ, ∇)`

In [143]:
struct CustomDescent 
    learning_rate::Float64
end

function ReactiveMP.cvi_update!(opt::CustomDescent, λ, ∇)
    return vec(λ) - (opt.learning_rate .* vec(∇))
end

Let's try to apply it to a model:
$$
\begin{aligned}
 p(x) & = \mathcal{N}(0, 1),\\
 p(y_{i}) & = \mathcal{N}(x^2, 1),\\
\end{aligned}
$$

This model is not particularly interesting, but it show how to use `CVI` approximation with a custom structure.

Let's generate some syntatic data for the model

In [144]:
rng = StableRNG(123)
cvi_rng = StableRNG(123)

golden = NormalMeanVariance(20, 10)

num_samples = 5000

means = rand(rng, golden, num_samples)
observations = map(mean -> rand(rng, NormalMeanVariance(mean^2, 1)), means);

To create a model we use `GraphPPL` package and `@model` macro:

In [145]:
f(x) = x ^ 2
@model function normal_square_mean(num_observations, cvi_rng, cvi_num_samples, cvi_num_iterations, optimizer)
    observations = datavar(Float64, num_observations)
    latent ~ NormalMeanPrecision(0, 10)
    mean ~ f(latent) where {meta = CVIApproximation(cvi_rng, cvi_num_samples, cvi_num_iterations, optimizer)}
    observations .~ NormalMeanVariance(mean, 1)
end

Note we are using optimizer, as a model parameter: `normal_square_mean(..., optimizer)`.

We will use the `inference` function from `ReactiveMP` to run inference, where we provide an instance of the `CustomDescent` structure:
```
... = inference(
model = Model(..., CustomDescent(0.1)),
...)
```

**Side notes**: for running inference for this model, we need to init `latent` message.
 
We do not need to init marginal for `latent,` but with this initialization, the inference procedure is more stable.

In [146]:
res = inference(
    model = Model(normal_square_mean, num_samples, cvi_rng, 100, 5000, CustomDescent(0.1)),
    data = (observations = observations,),
    iterations = 20,
    free_energy = false,
    initmessages = (latent = NormalMeanVariance(0, 10),),
    initmarginals = (latent = NormalMeanVariance(0, 10),)
)

mean(res.posteriors[:latent][end]), var(res.posteriors[:latent][end])

(-20.259631575342336, 1.2181547169391146e-7)

The inference working!