# Exact Inference Algorithms



## CSCI E-83
## Stephen Elston

We have seen several approaches to representation of probabilistic graphical models. Now, we will turn our attention to **inference algorithms**. The goal of inference is to compute the **posterior distribution** of one or more variables in the model given **evidence**. Alternatively, we can say that inference is used to return results to a **query** on the model. 

In this lesson we will examine three efficient classes of algorithms for inference on graphical models:

1. **Variable elimination:**
2. **Message passing, or sum-product or belief propagation algorithms:**
3. **Junction tree algorithm:**

## Complexity of inference for graphical models

To understand the need for efficient inference algorithms it helps to understand the computational complexity of inference of a graphical model. If we use a **tabular** solution algorithm, the complexity is NP; meaning the number of operations required grow as order $= O(n^k)$. 

On the face of things, it might seem that performing inference on graphical models of any scale is hopeless. While it is true, there are no general algorithms for solving the inference problem, there are many practical and widely applicable cases for which efficient inference algorithms exist. 

The key to reducing the computational complexity of graphical model inference algorithms is use of independencies. The naive approach is to simple the full table of marginal distributions of the graph variables. This approach has combinatorial or NP complexity. By combining conditional probabilities with evidence, the complexity of marginal influence can be significantly reduced. 

The algorithms we will explore take advantage of special structures commonly found in model graphs. Part of the tick is to rearrange the graph to create the desired structure. These algorithms combined with conditional probabilities and evidence result in even large scale models becoming tractable.    

## Elimination algorithms

To understand the **elimination algorithm**, let's use an example based on a **chain graph**. Chain graphs occur in a wide range of applications including protein activation models. An example is shown in the figure below.

<img src="img/Chain1.JPG" alt="Drawing" style="width:400px; height:75px"/>
<center> **Chain graph** </center>

Our goal is to to compute the marginal distribution of $Z$, $P(Z)$:

$$P(Z) = \sum_V \sum_W \sum_X \sum_Y P(V,W,X,Y,Z)$$

We can decompose this distribution as follows:

$$P(Z) = \sum_V \sum_W \sum_X \sum_Y P(V)\ P(W \ |\ V)\ P(X\ |\ W)\ P(Y\ |\ X)\ P(Z\ |\ Y)$$

We can rearrange these terms as follows:

$$P(Z) =  \sum_W \sum_X \sum_Y P(X\ |\ W)\ P(Y\ |\ X)\ P(Z\ |\ Y) \sum_V P(V)\ P(W \ |\ V)$$

Now:

$$p(W) = \sum_V P(V)\ P(W \ |\ V)$$

So we can rewrite the marginal distribution as:

$$P(Z) =  \sum_W \sum_X \sum_Y P(X\ |\ W)\ P(Y\ |\ X)\ P(Z\ |\ Y) p(W)$$

We have **eliminated** $V$ from the graph as shown in the figure below. 

<img src="img/Eliminate1.JPG" alt="Drawing" style="width:400px; height:75px"/>
<center> **Eliminate V from the chain graph** </center>

Only a **local cost** has been paid in this elimination. My local cost we mean that the summation was only over the variable $V$. 

We can continue the process by eliminating $W$ using local summation:

$$P(Z) =  \sum_X \sum_Y P(Y\ |\ X)\ P(Z\ |\ Y) \sum_W p(W)\ P(X\ |\ W) \\
= \sum_X \sum_Y P(Y\ |\ X)\ P(Z\ |\ Y)\ p(X)$$

<img src="img/Eliminate2.JPG" alt="Drawing" style="width:400px; height:75px"/>
<center> **Eliminate W from the chain graph** </center>

Continuing the process $X$ is eliminated using local summation:

$$P(Z) =  \sum_Y P(Z\ |\ Y) \sum_X p(X)\ P(Y\ |\ X) \\
= \sum_Y P(Z\ |\ Y)\ p(Y)$$


<img src="img/Eliminate3.JPG" alt="Drawing" style="width:400px; height:75px"/>
<center> **Eliminate X from the chain graph** </center>

Finally we can eliminate $Y$ using local summation to finally compute the marginal distribution of $Z$:

$$P(Z) =  \sum_Y p(Z)\ P(Z\ |\ Y)$$

<img src="img/Eliminate4.JPG" alt="Drawing" style="width:400px; height:75px"/>
<center> **Eliminate Y to compute the marginal distribution $P(Z)$** </center>  

The complexity of this elimination process is $O(kn^2)$. This compares rather favorably with the NP problem of complexity of $O(n^k)$.

## Elimination on undirected chains  

We can also apply elimination to **undirected chain graphs**. An example is shown in the figure below. 

<img src="img/Undirected1.JPG" alt="Drawing" style="width:400px; height:75px"/>
<center> **Undirected chain graph** </center>

Our goal is to to compute the marginal distribution of $Z$, $P(Z)$. We can decompose this distribution as follows:

$$P(Z) = \sum_V \sum_W \sum_X \sum_Y \frac{1}{Z} \phi(V,W)\ \phi(W,X)\ phi(X,Y)\ \phi(Y,Z) \\
= \frac{1}{Z} \sum_W \sum_X \sum_Y  \phi(W,X)\ phi(X,Y)\ \phi(Y,Z) \sum_V \phi(V,W) $$


## Sum-product algorithm

The generalization of the elimination method is known as the **sum-product algorithm**. The sum-product algorithm is applied to cliques of Markov networks or undirected graphs. The general expression can be written in a form of a sum over products:

$$\sum_Z \prod_{\phi \in \mathscr{F}} \phi\\
\text{where}\\
\mathscr{F} = \text{set of all factors}
$$

Let's make this a bit more concrete. We want to make a **query** of given evidence $e$:

......................

$$P(x_1, e) = \sum_{x_n} \cdots \sum_{x_3} \sum_{x_2} \prod_i $$

..............................

To understand this formulation a bit better 


```
Choose an ordering \mathscr{Z} in which the query node, $x_1$ is the last.
Place potentials on the active list
Eliminate the ith node by taking the sum-product over all potentials containing $x_i$.
Placed the resulting factor on the active list
```

## Message passing algorithms

In this section we will see how the elimination algorithm can be generalized to **tree graphs** to create the **message passing** algorithm, also known as the **belief propagation** algorithm. 

Let's start with a simple example of a tree graph. 

> **Definition:** A tree graph is a undirected acyclic graph, in which any pair of nodes are connected by exactly one path. 

An example of a tree graph is shown in the figure below. Notice there is a flow of information from the bottom to the top. 

<img src="img/EliminationTree.JPG" alt="Drawing" style="width:350px; height:400px"/>
<center> **Elimination applied to a tree for query on node 1** </center>

In general, we can compute the factor which results from eliminating variables below as follows:

$$m_{ji}(x_i) = \sum_{x_j} \Big(\ \psi(x_j)\ \psi(x_i,x_j)\ \prod_{f \in N(j) \backslash i} m_{fj}(x_j) \Big)$$

> **Note:** The notation $f \in N(j) \backslash i$ indicates all factors $f$ in the set of parents $N(j)$ except $f = i$.

We can also think of this formulation as **passing a message** from node $j$ to node $i$. In other words, we can stay that $m_{ji}(x_i)$ represents **propagating a "belief"** from node $j$ to node $i$. Notice that **elimination on a tree is equivalent to message passing along the branches of the tree**. 

We can summarize the message passing or belief propagation algorithm:

```
The query node becomes the root of a tree. 
The tree is a directed graph with edges pointing towards leaves from the query node. 
Belief propagation or message passing is ordered depth first.   
Elimination is performed by message-passing, or Belief Propagation, along the branches from the leaves to the query node. 
```

> **Note:** Belief propagation uses the tree graph itself as the representation data structure. 

### Computing marginals with message passing

Now that we have looked at the basics of the the message passing or belief propagation algorithm we will generalize the method to efficiently compute marginal distributions of the nodes of a tree graph. The naive approach would be to perform a query on each of the nodes in the graph. While this approach would work in principle, it is computationally inefficient. We will explore an algorithm that will efficiently compute the marginal distributions on the graph, including message reuse. 

A key fact to notice that a node only **sends a message to a neighbor only once it has received messages from all other neighbors**. For example, in the figure above the  messages must be passed in  the following order:

1. $m_{53}x(x_3)$ and $m_{43}x(x_3)$,
2. $m_{32}x(x_2)$, and
3. $m_{21}x(x_1)$.

If were where to continue with simple message passing to compute all the marginal distributions of the variables on the graph we would see that we would be recomputing the same messages several times. In fact, the computational complexity of this approach is $NC$ where $N$ is the number of nodes and $C$ is the complexity of the branching of the nodes. 

Keeping in mind that the ordering requirement guides the construction of efficient algorithms, we explore a method to extend the message passing algorithm. 

Developing an efficient method for computing the marginal distributions on a tree leads us to the **two pass algorithm**. The two pass algorithm proceeds by two steps:
1. The **conditional probability distribution** (**CDP**) are updated using **Evidence**.  
2. The nodes **collects** messages from their neighbors, which **emit** messages to the node. Leaf nodes emit messages at the start of the collection step. 
3. One a node has collected messages from neighbors it **distributes** messages to its neighbors. In the distribution phase, a node may only emit a messages once it has collected messages from all of its neighbors. 
4. The marginal distributions are computed. 

A schematic view of this algorithm is illustrated in the figure below. 

<img src="img/SumProductTree.JPG" alt="Drawing" style="width:450px; height:400px"/>
<center> **Two pass algorithm on a tree** </center>

In the evidence step, the potential of nodes of evidence are updated using the following relationship:

$$\psi^E(x_i) = \psi(x_i)\ \delta(x_i, \bar{x}_j)\\
\text{where}\\
\delta(x_i, \bar{x}_j) = 1\ if\ i = j\\
\delta(x_i, \bar{x}_j) = 0\ \text{otherwise}$$

Messages are computed for both the collection and distribution steps are computed using the aforementioned relationship:

$$m_{ji}(x_i) = \sum_{x_j} \Big(\ \psi(x_j)\ \psi(x_i,x_j)\ \prod_{f \in N(j) \backslash i} m_{fj}(x_j) \Big)$$

The collection phase of the algorithm is illustrated in the figure below. 

<img src="img/Collect.JPG" alt="Drawing" style="width:450px; height:400px"/>
<center> **Collect phase of algorithm on a tree** </center>

The distribute phase of the algorithm is shown in the figure below.

<img src="img/Distribute.JPG" alt="Drawing" style="width:450px; height:400px"/>
<center> **Distribute phase of the algorithm on a tree** </center>

Finally, the marginal distribution of the nodes are computed using the following relationship:

$$p(x_i) = \psi^E(x_i)\ \prod_{j \in N(i)}m_{ji}(x_i)$$

## Introduction to factors and the junction tree algorithm

In the previous section we examined the message passing or belief propagation algorithm. This algorithm can efficiently compute the marginal distributions of variables in a tree graph. 

There is another way to create an efficient message passing algorithm. This method is known as the **junction tree algorithm**. The junction tree algorithm can factorize graphs that do not have a tree structure initially. Thus, the junction tree algorithm is not only efficient, but quite flexible. 

The general steps to of the junction tree algorithm are:
1. Moralize the graph, following the process we have already applied.
2. Triangulate the graph to transform multiply connected graphs to trees. 
3. Build a clique tree from the transformed graph. The clique tree is composed of clique variables and factors. 
4. Propagate the probabilities by local message passing from either factors to variables or variables and factors. 



<img src="img/VarToFactor.JPG" alt="Drawing" style="width:350px; height:200px"/>
<center> **Distribute phase of the algorithm on a tree** </center>


<img src="img/FactorToVar.JPG" alt="Drawing" style="width:350px; height:200px"/>
<center> **Distribute phase of the algorithm on a tree** </center>



## Junction tree for multiply connected graph

Let's work though a more complex example for a multiply connected graph. A example is shown in the figure below.

<img src="img/MultiConnected.JPG" alt="Drawing" style="width:350px; height:275px"/>
<center> **A multiply connected graph** </center>

This graph is clearly not a tree. There are multiple paths between many nodes. The question is, how can we transform the graph so that we can perform message passing?

### Moralization

The graph illustrated above has several immoralities. However, the moralizatoin process for this graph is straight-forward. The result is shown in the figure below.  

<img src="img/MultiConnectedMoralized.JPG" alt="Drawing" style="width:350px; height:275px"/>
<center> **Moralized multiply connected graph** </center>

There is a problem with this graph. There are several **cycles** with four variables. 

### Triangularization 

With the graph moralized, you can see that there are cycles with four or more variables. The question at this point is if the cliques created in these cycles lead to consistent results. To purse this question consider the graph and the resulting clique tree shown in the figure below.

<img src="img/FourCycle.JPG" alt="Drawing" style="width:350px; height:150px"/>
<center> **A four cycle graph with the corresponding clique tree** </center>

Notice that there is no way to ensure that the probability associated with variable 3 is consistent between the two branches of the tree. There is no guarantee of **global consistency** with cycles of four or more variables. 

To solve the above problem we need to perform a procedure known as **triangularization**. The triangularization procedure adds an edge to the four cycle as shown in the figure below.  

<img src="img/Triangle1.JPG" alt="Drawing" style="width:350px; height:300px"/>
<center> **A triangleized four cycle graph with the corresponding clique trees** </center>

The triangularized graph leads to a globally consistent clique graph. Notice that there are two possible triangularizations, which lead to two different clique trees. While the clique trees are globally consistent there they are not unique. 

<img src="img/Triangle2.JPG" alt="Drawing" style="width:350px; height:275px"/>
<center> **A triangleized multiply connected graph** </center>


#### Copyright 2018, Stephen F Elston. All rights reserved.