# Composite nodes

One of the convenient properties of Forney-style factor graphs (as opposed to regular factor graphs) is that they naturally allow for composability: one can draw a box around part of an FFG and treat this box as a new type of factor node. For example, one can combine a gain (multiplication) node with the addition node in a so-called *composite node*, as depicted in the following graph (Fig. 4.2 from Korl's [A factor graph approach to signal modelling, system identification and filtering](https://www.research-collection.ethz.ch/handle/20.500.11850/82737)):

<img src="./figures/information_filter.png" width="350"/>

Composite nodes are useful for two reasons:

1. Building large graphs becomes more convenient by 'packaging' repetitive parts of the graph as composite nodes.
2. One can define 'shortcut rules' for message updates, which might be more efficient and/or numerically stable than performing vanilla message passing on the internals of the composite node. For example, in the schedule shown above, message (4) is calculated directly from messages (2) and (3). The shortcut rule might exploit the matrix inversion lemma, or involve some optimization algorithm.

To demonstrate the use of composite nodes, in this demo we will build a gain-addition combination that constrains

\begin{align*}
    x_1 = x_0 + b\cdot u_1\,,
\end{align*}

where `x_0` and `u_1` have Gaussian priors, and b is a constant matrix. We are interested in computing a belief over `x_1`.

We first construct a "flat" graph that represents the gain and addition contraints as two distinct factors, and generate a schedule for inferring a belief for `x_1`. Then, we compare the resulting schedule with a schedule generated on an FFG with a _composite_ gain-addition node. Finally, we show how to register a custom update rule with ForneyLab.

## Schedule generation without composite node

In [11]:
using ForneyLab
using LinearAlgebra

# Define factor graph for x1 = x0 + b*u1, where x0 and u1 have Gaussian priors, and b is a constant.
# This is a part of the information filter graph from the introduction.
g = FactorGraph()

b = [1.0, 0.5]' # 1x2 Matrix b

@RV x_0 ~ GaussianWeightedMeanPrecision(1.0, 1.0)
@RV u_1 ~ GaussianWeightedMeanPrecision(ones(2), eye(2))
@RV x_1 = x_0 + b*u_1;

In [12]:
flat_algorithm = messagePassingAlgorithm(x_1)
flat_schedule = flat_algorithm.posterior_factorization.posterior_factors[Symbol("")].schedule

draw(g, schedule=flat_schedule) # Inspect the resulting schedule

println(flat_schedule)

1.	SPGaussianWeightedMeanPrecisionOutNPP on Interface 1 (out) of GaussianWeightedMeanPrecision gaussianweightedmeanprecision_1
2.	SPGaussianWeightedMeanPrecisionOutNPP on Interface 1 (out) of GaussianWeightedMeanPrecision gaussianweightedmeanprecision_2
3.	SPMultiplicationOutNGP on Interface 1 (out) of Multiplication multiplication_1
4.	SPAdditionOutNGG on Interface 1 (out) of Addition addition_1



## Usage of composite nodes in an FFG

Now we 'draw a box' around the multiplication and addition nodes, and create a composite node. We can easily define a composite node using ForneyLab's `@composite` macro.

In [13]:
# Define a composite node for z = x + b*y
@composite GainAddition (z, x, y) begin
    # Specify the 'internal factor graph' of the GainAddion composite node.
    # z, x, and y can be used as if they are existing Variables in this block.
    b = [1.0, 0.5]'
    
    @RV z = x + b*y
end

Here, `GainAddition` is the name of the composite node that we're defining. The tuple `(z, x, y)` defines the variables that this node constrains. The order of these variables simultaneously fixes the argument order for the update rules. Now that our custom `GainAddition` composite node is defined, we can use it in the graph definition.

In [14]:
g2 = FactorGraph()

@RV x_0 ~ GaussianWeightedMeanPrecision(1.0, 1.0)
@RV u_1 ~ GaussianWeightedMeanPrecision(ones(2), eye(2))
@RV x_1 ~ GainAddition(x_0, u_1);

## Defining a custom shortcut rule

If we actually want to use composite nodes to get a different (i.e. more efficient) algorithm, we'll have to specify message update rules that apply specifically to the composite nodes at hand. These rules are sometimes referred to as _shortcut rules_, since they provide a way to shortcut the calculation of internal messages. This rule definition for the composite nodes is analogous to the rule definitions for regular nodes. If we then build a new sum-product algorithm, the custom shortcut rule `SPGainAdditionOutNGG` will be automatically inserted. The rule declaration below defines the distribution family of the outgoing message from the composite node `GainAddition` as `Gaussian`.

In [15]:
@sumProductRule(:node_type     => GainAddition, # Our custom composite node
                :outbound_type => Message{Gaussian}, # This rule produces a Gaussian message
                :inbound_types => (Nothing, Message{Gaussian}, Message{Gaussian}), # Incoming message types
                :name          => SPGainAdditionOutNGG) # Name of the update rule

SPGainAdditionOutNGG

Here, the resulting schedule contains one less message than before, because message (3) directly computes the belief over `x_1` from the prior beliefs.

In [16]:
composite_algorithm = messagePassingAlgorithm(x_1)
composite_schedule = composite_algorithm.posterior_factorization.posterior_factors[Symbol("")].schedule
draw(g2, schedule=composite_schedule)

println(composite_schedule)

1.	SPGaussianWeightedMeanPrecisionOutNPP on Interface 1 (out) of GaussianWeightedMeanPrecision gaussianweightedmeanprecision_1
2.	SPGaussianWeightedMeanPrecisionOutNPP on Interface 1 (out) of GaussianWeightedMeanPrecision gaussianweightedmeanprecision_2
3.	SPGainAdditionOutNGG on Interface 1 (z) of GainAddition gainaddition_1



## Executing the custom rule
We can directly compile the algorithm to Julia code.

In [17]:
source_code = algorithmSourceCode(composite_algorithm)
eval(Meta.parse(source_code)) # Load algorithm

println(source_code) # Inspect the algorithm

begin

function step!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 3))

messages[1] = ruleSPGaussianWeightedMeanPrecisionOutNPP(nothing, Message(Univariate, PointMass, m=1.0), Message(Univariate, PointMass, m=1.0))
messages[2] = ruleSPGaussianWeightedMeanPrecisionOutNPP(nothing, Message(Multivariate, PointMass, m=[1.0, 1.0]), Message(MatrixVariate, PointMass, m=Diagonal(Bool[1, 1])))
messages[3] = ruleSPGainAdditionOutNGG(nothing, messages[1], messages[2])

marginals[:x_1] = messages[3].dist

return marginals

end

end # block


The above algorithm makes a call to our custom update rule `ruleSPGainAdditionOutNGG`, but we have not yet implemented this rule, which we will do below. We define two functions, the first of which implements a shortcut rule for the specific case when both incoming messages are weighted-mean-precision parameterized. By using the matrix inversion lemma, this update will lead to significant speedup in the case high-dimensional messages. This demo however just concerns the 2-D case, but generalizations can be readily implemented.

In [18]:
# Specific shortcut update with high-performance implementation
function ruleSPGainAdditionOutNGG(
    msg_out::Nothing,
    msg_x::Message{GaussianWeightedMeanPrecision, Univariate},
    msg_y::Message{GaussianWeightedMeanPrecision, Multivariate})

    b = [1.0, 0.5]'
    
    xi_x = msg_x.dist.params[:xi]
    W_x = msg_x.dist.params[:w]
    xi_y = msg_y.dist.params[:xi]
    W_y = msg_y.dist.params[:w]
    
    H = cholinv(W_y + b'*W_x*b)
    
    Message(Univariate, 
            GaussianWeightedMeanPrecision, 
            xi = first(xi_x + W_x*b*H*(xi_y - b'*xi_x)), 
            w  = first(W_x - W_x*b*H*b'*W_x))
end

# Catch-all backup update with lesser efficiency
function ruleSPGainAdditionOutNGG(
    msg_out::Nothing,
    msg_x::Message{F1, Univariate},
    msg_y::Message{F2, Multivariate}) where {F1<:Gaussian, F2<:Gaussian}

    b = [1.0, 0.5]'

    d_x = convert(ProbabilityDistribution{Univariate, GaussianMeanVariance}, msg_x.dist)
    d_y = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, msg_y.dist)
    
    Message(Univariate, 
            GaussianMeanVariance, 
            m = first(d_x.params[:m] + b*d_y.params[:m]), 
            v = first(d_x.params[:v] + b*d_y.params[:v]*b'))
end

ruleSPGainAdditionOutNGG (generic function with 2 methods)

In [19]:
step!(Dict()) # Execute the algorithm

Dict{Any,Any} with 1 entry:
  :x_1 => 𝒩(xi=1.11, w=0.44)…