# Composite nodes

One of the convenient properties of Forney-style factor graphs (as opposed to regular factor graphs) is that they naturally allow for composability: one can draw a box around part of a FFG and treat this box as a new type of factor node. For example, one can combine a gain (multiplication) node with the addition node in a so-called *composite node*, as depicted in the following graph (Fig. 4.2 from [1]):

<img src="./figures/information_filter.png" width="350"/>

Composite nodes are useful for two reasons:

1. Building large graphs becomes more convenient by 'packaging' repetitive parts of the graph as composite nodes.
2. One can define 'shortcut rules' for message updates, which might be more efficient and/or numerically stable than performing vanilla message passing on the internals of the composite node. The shortcut rule might for example exploit the matrix inversion lemma, or involve some optimization algorithm. For example, in the schedule shown above, message (4) is calculated directly from messages (2) and (3), and the message out of the gain node does not have to be represented explicitly.


Composite nodes in ForneyLab
--------------------------------

ForneyLab makes it easy to define composite nodes as well as shortcut rules. If no suitable shortcut rule is available, ForneyLab automatically reverts to explicit message passing on the internals of the composite node.

To demonstrate the use of composite nodes, we'll build a factor graph involving a gain-addition combination like in the "information filter" factor graph shown above.

### 1. 'Flat' factor graph without composite node

In [1]:
using ForneyLab

# Define factor graph for x1 = x0 + b*u1, where x0 and u1 have Gaussian priors, and b is a constant.
# This is a part of the information filter graph from the introduction.
g = FactorGraph()

b = [1.0; 0.5]

@RV x_0 ~ GaussianMeanVariance(ones(2), eye(2))
@RV u_1 ~ GaussianMeanVariance(1.0, 1.0)
@RV x_1 = x_0 + b*u_1;

Let's create a sum-product message passing schedule to calculate $p(x_1)$:

In [2]:
flat_schedule = sumProductSchedule(x_1)

draw(g, schedule=flat_schedule)

println(flat_schedule)

	SPClamp{Multivariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Multivariate} clamp_1
	SPClamp{MatrixVariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.MatrixVariate} clamp_2
1.	SPGaussianMeanVarianceOutVPP on Interface 1 (out) of ForneyLab.GaussianMeanVariance gaussian_1
	SPClamp{Multivariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Multivariate} clamp_5
	SPClamp{Univariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Univariate} clamp_3
	SPClamp{Univariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Univariate} clamp_4
2.	SPGaussianMeanVarianceOutVPP on Interface 1 (out) of ForneyLab.GaussianMeanVariance gaussian_2
3.	SPMultiplicationOutVGP on Interface 1 (out) of ForneyLab.Multiplication multiplication_1
4.	SPAdditionOutVGG on Interface 1 (out) of ForneyLab.Addition addition_1



As expected, we get a schedule with 4 messages. Let's now 'draw a box' around the multiplication and addition nodes, and create a composite node that implements the relationship $f(x_0, x_1, u_1) = \delta(x_1 - (x_0 + b*u_1))$.

### 2. Define composite node and use it in a factor graph

We can easily define a composite node using ForneyLab's `@composite` macro. It works like this:

In [3]:
# Define a composite node for z = x + b*y
@composite GainAddition (z, x, y) begin
    # Specify the 'internal factor graph' of the GainAddion composite node.
    # z, x, and y can be used as if they are existing Variables in this block.
    b = [1.0; 0.5]
    
    @RV z = x + b*y
end

In the above code, `GainAddition` is the name of the composite node that we're defining. The `(z, x, y)` part defines the names of the interfaces of this node. Every interface corresponds to a `Variable` in the internal graph of the composite node, so we can use `z`, `x`, and `y` as if they are existing variables when building the internal graph.

Now that our custom `GainAddition` composite node is defined, we can use it like any other factor node. Let's build the same factor graph as before, but now with our composite node:

In [4]:
g2 = FactorGraph()

@RV x_0 ~ GaussianMeanVariance(ones(2), eye(2))
@RV u_1 ~ GaussianMeanVariance(1.0, 1.0)
@RV x_1

gain_addition_node = GainAddition(x_1, x_0, u_1) # This syntax returns the constructed node

composite_schedule = sumProductSchedule(x_1)

draw(g2, schedule=composite_schedule)

println(composite_schedule)

	SPClamp{Multivariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Multivariate} clamp_1
	SPClamp{MatrixVariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.MatrixVariate} clamp_2
1.	SPGaussianMeanVarianceOutVPP on Interface 1 (out) of ForneyLab.GaussianMeanVariance gaussian_1
	SPClamp{Univariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Univariate} clamp_3
	SPClamp{Univariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Univariate} clamp_4
2.	SPGaussianMeanVarianceOutVPP on Interface 1 (out) of ForneyLab.GaussianMeanVariance gaussian_2
3.	(INTERNAL SCHEDULE) SPAdditionOutVGG on Interface 1 (z) of GainAddition gainaddition_1



As you can see, the factor graph just got simpler, and the message passing schedule only has 3 entries. However, message (3) is special, since it is produced by executing a message passing schedule on the internal graph of the composite node `gainaddition_1`. We can inspect the internal graph of the composite node and the internal message passing schedule that produces message (3) of the main schedule:

In [5]:
draw(gain_addition_node.inner_graph, schedule=composite_schedule[end].internal_schedule)
show(composite_schedule[end].internal_schedule)

	SPClamp{Multivariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Multivariate} clamp_1
1.	SPMultiplicationOutVGP on Interface 1 (out) of ForneyLab.Multiplication multiplication_1
2.	SPAdditionOutVGG on Interface 1 (out) of ForneyLab.Addition addition_1


As you can see, each interface of the `GainAddition` node corresponds to a `Terminal` node in its internal graph. The schedule that is executed on the internal graph consists of 2 messages. 

So now we have a message passing schedule that is hierarchical: the last message in `composite_schedule` is produced by another message passing algorithm (executed on the internals of the composite node). However, when we actually generate code for this message passing schedule, the hierarchical schedule will be flattened to get a simple, linear schedule:

In [6]:
# Print message passing algorithm for the factor graph with the GainAddition composite node
println(messagePassingAlgorithm(composite_schedule))

function step!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(3))

messages[1] = ruleSPGaussianMeanVarianceOutVPP(nothing, Message(Multivariate, PointMass, m=[1.0, 1.0]), Message(MatrixVariate, PointMass, m=[1.0 0.0; 0.0 1.0]))
messages[2] = ruleSPGaussianMeanVarianceOutVPP(nothing, Message(Univariate, PointMass, m=1.0), Message(Univariate, PointMass, m=1.0))
messages[3] = ruleSPMultiplicationOutVGP(nothing, messages[2], Message(Multivariate, PointMass, m=[1.0, 0.5]))
messages[4] = ruleSPAdditionOutVGG(nothing, messages[1], messages[3])


return marginals

end


In fact, the flattened algorithm is **IDENTICAL** (up to unimportant reordering of the messages) to the algorithm that we get if we don't use the composite node! Check for yourself...

In [7]:
# Print message passing algorithm for the 'flat' factor graph WITHOUT the GainAddition composite node
println(messagePassingAlgorithm(flat_schedule))

function step!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(4))

messages[1] = ruleSPGaussianMeanVarianceOutVPP(nothing, Message(Multivariate, PointMass, m=[1.0, 1.0]), Message(MatrixVariate, PointMass, m=[1.0 0.0; 0.0 1.0]))
messages[2] = ruleSPGaussianMeanVarianceOutVPP(nothing, Message(Univariate, PointMass, m=1.0), Message(Univariate, PointMass, m=1.0))
messages[3] = ruleSPMultiplicationOutVGP(nothing, messages[2], Message(Multivariate, PointMass, m=[1.0, 0.5]))
messages[4] = ruleSPAdditionOutVGG(nothing, messages[1], messages[3])


return marginals

end


This is nice: we can define and use composite nodes without having to implement any additional message update rules.
You can manually flatten schedules containing internal message passing schedules using the `flatten` function:

In [8]:
show(ForneyLab.flatten(composite_schedule))

	SPClamp{Multivariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Multivariate} clamp_1
	SPClamp{MatrixVariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.MatrixVariate} clamp_2
1.	SPGaussianMeanVarianceOutVPP on Interface 1 (out) of ForneyLab.GaussianMeanVariance gaussian_1
	SPClamp{Univariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Univariate} clamp_3
	SPClamp{Univariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Univariate} clamp_4
2.	SPGaussianMeanVarianceOutVPP on Interface 1 (out) of ForneyLab.GaussianMeanVariance gaussian_2
	SPClamp{Multivariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Multivariate} clamp_1
3.	SPMultiplicationOutVGP on Interface 1 (out) of ForneyLab.Multiplication multiplication_1
4.	SPAdditionOutVGG on Interface 1 (out) of ForneyLab.Addition addition_1


### 3. Defining shortcut rules

If we actually want to use composite nodes to get a different (i.e. more efficient) algorithm, we'll have to specify message update rules that apply specifically to the composite nodes at hand. These rules are sometimes referred to as *shortcut rules*, since they provide a way to shortcut the calculation of internal messages.

As an example, we'll specify a custom shortcut rule to calculate the (Gaussian) sum-product message towards the `z` interface. There is nothing special about the rule definitions for composite nodes, they are identical to the rule definitions for regular nodes.

In [9]:
@sumProductRule(:node_type     => GainAddition,                                 # our custom composite node
                :outbound_type => Message{Gaussian},                            # this rule produces a Gaussian msg
                :inbound_types => (Void, Message{Gaussian}, Message{Gaussian}), # msg towards first interface, incoming types
                :name          => SPGainAdditionOutVGG)                         # name of the update rule;

If we now build a new sum-product algorithm, the custom shortcut rule will automatically be picked up:

In [10]:
shortcut_schedule = sumProductSchedule(x_1)
show(shortcut_schedule)

	SPClamp{Multivariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Multivariate} clamp_1
	SPClamp{MatrixVariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.MatrixVariate} clamp_2
1.	SPGaussianMeanVarianceOutVPP on Interface 1 (out) of ForneyLab.GaussianMeanVariance gaussian_1
	SPClamp{Univariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Univariate} clamp_3
	SPClamp{Univariate} on Interface 1 (out) of ForneyLab.Clamp{ForneyLab.Univariate} clamp_4
2.	SPGaussianMeanVarianceOutVPP on Interface 1 (out) of ForneyLab.GaussianMeanVariance gaussian_2
3.	SPGainAdditionOutVGG on Interface 1 (z) of GainAddition gainaddition_1


Yes, the new schedule uses the shortcut rule, allowing it to calculate the result in just 3 message updates instead of 4!

References
------------

[1] S. Korl, "A Factor Graph Approach to Signal Modelling, System Identification and Filtering", 2005.