# Conjugate-NonConjugate Varialtional Message Passsing: a tutorial


*Table of contents*
1. [Introduction](#Introduction)
2. [Model specification](#Model-specification)
3. [Limitations](#Limitations)
4. [Inference](#Inference)
5. [Extension](#Extension)

## Introduction

In this tutorial, we show how to excute VMP in models with delta factors of the form $\delta(f(x_{1}, \dots, x_{n}) - y)$ (`f` arbitrary differentiable function) inside `ReactiveMP`.
The `ReactiveMP` implemented the Conjugate-NonConjugate Variational Inference (CVI) inside factor graphs carefully following the paper [Probabilistic programming with stochastic variational message passing](https://reader.elsevier.com/reader/sd/pii/S0888613X22000950?token=EFB22E01793BD0BF73EECC9702C315644969403BD44B13FA850E9F66C8A49E88C0D5C68A9AD03301C609DA443DB33F80&originRegion=eu-west-1&originCreation=20221027115856) (see it for implementation details).

More specifically the tutorial shows:
1. [How specify](#Model-specification) a `Delta factor` of the form $\delta(f(x_{1}, \dots, x_{n}) - y)$ inside the `@model` macro with CVI inference procedure
2. [What limitation](#Limitations) the current implementation has
3. [Show several inference examples](#Inference)
4. [How to extend](#Extension) it from the user perspective

## Model specification

Suppose we have a function `f`:

```
    f(x, y, ..., z) = ...
```

And we have a model where we want define a `Delta factor` with this `f`:

```
@model function model_name(...)
    ... some where here inputs (x, y, ..., z) defiened
    out ~ f(x, y, ..., z) where {meta = CVIApproximation(rng, n_iterations, n_samples, Descent(learning_rate))}
    ...    
end
```

If you want to see a detailed example at this point, go to [Inference](#Inference) section. 

The magic happens inside the `where` block: where we specify through `meta` parameter that `ReactiveMP` should run `CVI` for the messages sent through the `out ~ f(x, y, ..., z)` node.

So to specify it for `out ~ f(x, y, ..., z)` node you need set `meta` to `CVIApproximation(...)` inside `where` block: `meta=CVIApproximation(rng, n_samples, n_iterations, Descent(learning_rate))` in the above example.

The `CVIApproximation` structure serves for two reasons:
1. Marker that the `CVI` rules need to be called
2. Container of the `CVI` hyperparameters.

`CVI` procedure has 4 hyperparameters:
1. random number generator, which will be called inside the `CVI` procedure (`rng`)
2. number of samples to use for the out message approximation (`n_samples`)
3. number of iterations of the CVI procedure (`n_iterations`)
4. optimizer, which will be used to perform the CVI step (`Descent(learning_rate)`)

## Limitations

There are several main limitations for the `CVI` procedure that you need to satisfy:
1. The `CVI` procedure supposes that there is mean-field assumption on the interfaces connected to the node (`out, x, y, ..., z`)
2. The connected interface is factorized out in other nodes to which it connected
3. The messages on input interfaces (`x, y, ..., z`) are exponential family distributions

In `ReactiveMP`, you can obtain the first and second assumption through `@constraints` macro:

```
@model function model_name(...)
 ... some where here, inputs (x, y, ..., z) defined
 ... ~ Node1(x, q1, q2, ..., qn) # some node that is using x interface
 out ~ f(x, y, ..., z) where {meta = CVIApproximation(rng, n_iterations, n_samples, Descent(learning_rate))}
 ...
 ... ~ Node2(p1,..., out, pn) # some node that is using out the interface
 ... 
end

constraints = @constraints begin
 q(out, p1, ..., pn) = q(y)q(p1,...,pn)
 q(out, x, y, ..., z) = q(out)q(x)...q(z)
 q(x, q1, ..., qn) = q(x)q(q1,...,qn)
end;
```

Note that only some exponential family distributions are implemented. If you want to add one not implemented inside `ReactiveMP` you need to implement (see [example](###Adding-a-custom-implementation-for-an-instance-from-exponential-family-distribution)):
1. `naturalparams(dist)` is a function that returns a natural parameter instance for this distribution.
2. `lognormalizer(natparams)` is the log normalizer for the natural parameters of this distribution


## Inference

## Extension

### Adding a custom implementation for an instance from exponential family distribution

You might need a distribution from the exponential distribution family, which is not implemented in `ReactiveMP` yet.

This section shows how you can implement it yourself.

We will take `Beta distribution` as an example.

In [1]:
using ReactiveMP, Rocket, GraphPPL, Random, LinearAlgebra, Plots, Flux, ForwardDiff, Plots, SpecialFunctions, Distributions, Base, StableRNGs

First, you need to specify a structure for beta distribution natural parameters

In [2]:
struct BetaNaturalParameters{T <: Real} <: NaturalParameters
    α::T
    β::T
end

Second, you need to implement `isproper` function

In [21]:
ReactiveMP.isproper(params::BetaNaturalParameters) = ((params.α  - 1)> zero(params.α)) & ((params.β - 1)> zero(params.β))

You need to implement `naturalparams`

In [4]:
ReactiveMP.naturalparams(dist::Beta) = BetaNaturalParameters(dist.α, dist.β)

Also you need to implement convert from `BetaNaturalParameters` to `Distribution`

In [5]:
function Base.convert(::Type{Distribution}, η::BetaNaturalParameters)
    return Beta(η.α, η.β)
end

Implementing `lognormalizer`, `logpdf`

In [7]:
ReactiveMP.lognormalizer(params::BetaNaturalParameters) = loggamma(params.α) + loggamma(params.β) - loggamma(params.α + params.β) 
ReactiveMP.logpdf(params::BetaNaturalParameters, x) = x * (params.α - 1) + (1-x) * (params.β - 1) - lognormalizer(params) 

Also you need to specify how to substract natural parameters

In [8]:
function Base.:-(left::BetaNaturalParameters, right::BetaNaturalParameters)
    return BetaNaturalParameters(
        left.α - right.α,
        left.β - right.β
    )
end

Also several axulary function need to be implemented

In [9]:
function Base.vec(p::BetaNaturalParameters)
    return [p.α, p.β]
end

ReactiveMP.as_naturalparams(::Type{T}, args...) where {T <: BetaNaturalParameters} =
    convert(BetaNaturalParameters, args...)

function BetaNaturalParameters(v::AbstractVector{T}) where {T <: Real}
    @assert length(v) === 2 "`BetaNaturalParameters` must accept a vector of length `2`."
    return BetaNaturalParameters(v[1], v[2])
end

Base.convert(::Type{BetaNaturalParameters}, vector::AbstractVector) =
    convert(BetaNaturalParameters{eltype(vector)}, vector)

Base.convert(::Type{BetaNaturalParameters{T}}, vector::AbstractVector) where {T} =
    BetaNaturalParameters(convert(AbstractVector{T}, vector))

Now, we can specify a model with `Beta distribution` inside `Delta` factor.

Generate some syntatic data

In [18]:
rng = StableRNG(123)
cvi_rng = StableRNG(42)

golden_beta = Beta(3, 4)

num_samples = 1000

means = rand(rng, golden_beta, num_samples)
observations = map(mean -> rand(rng, NormalMeanVariance(mean, 0.01)), means);

Creating a model specification

In [19]:
@model function normal_with_beta_mean(num_observations, cvi_rng, cvi_num_samples, cvi_num_iterations, learning_rate)
    observations = datavar(Float64, num_observations) 
    beta ~ Beta(1, 1)
    mean ~ identity(beta) where {meta = CVIApproximation(cvi_rng, cvi_num_samples, cvi_num_iterations, Descent(learning_rate))}
    observations .~ NormalMeanVariance(mean, 0.01)
end

Do inference

In [22]:
res = inference(
    model = Model(normal_with_beta_mean, num_samples, cvi_rng, 100, 2000, 0.1),
    data = (observations = observations,),
    iterations = 100,
    free_energy = false,
    returnvars = (beta = KeepLast(),),
    # constraints = constraints,
    initmessages = (beta = Beta(1, 1),),
)

Inference results:
-----------------------------------------
beta = Beta{Float64}(α=10921.612503258872, β=11205.411979087337)


(0.4935870393225512, 0.42857142857142855)

### Using custom optimizer

MethodError: MethodError: Cannot `convert` an object of type Vector{ForwardDiff.Dual{ForwardDiff.Tag{ReactiveMP.var"#735#739"{BetaNaturalParameters{Float64}, Float64}, Float64}, Float64, 2}} to an object of type BetaNaturalParameters{ForwardDiff.Dual{ForwardDiff.Tag{ReactiveMP.var"#735#739"{BetaNaturalParameters{Float64}, Float64}, Float64}, Float64, 2}}
Closest candidates are:
  convert(::Type{T}, !Matched::T) where T at Base.jl:61
  BetaNaturalParameters{T}(::Any, !Matched::Any) where T<:Real at ~/repos/ReactiveMP.jl/demo/Conjugate-NonConjugate Varialtional Message passsing.ipynb:2