<a href="https://colab.research.google.com/github/DavideScassola/PML2024/blob/main/Notebooks/04_exact_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exact inference with Belief Propagation

This notebook is inspired from [Jessica Stringham's work](https://jessicastringham.net)

We are going to perform inference through the sum-product message passing, or belief propagation, on tree-like factor graphs (without any loop). We work only with discrete distributions and without using ad-hoc libraries, to better understand the algorithm.

In [1]:
import numpy as np

### Probability distributions

First of all, we need to represent a discrete probability distribution and check that it is normalized.
For example, we can represent a discrete conditional distribution $p(v_1 | h_1)$ with a 2D array, as:

|   | $h_1=a$ | $h_1=b$ | $h_1=c$|
|---|-----|-----|----|
| $v_1=0$ | 0.4  | 0.8  | 0.9|
| $v_1=1$ | 0.6 | 0.2  | 0.1|

We can build a class for the distributions containing the arrays and the labels of the axes


In [2]:
class Distribution():
    """"
    Discrete probability distributions, expressed using labeled arrays
    probs: array of probability values
    axes_labels: list of axes names
    """
    def __init__(self, probs, axes_labels):
        self.probs = probs
        self.axes_labels = axes_labels

    def get_axes(self):
        #returns a dictionary with axes names and the corresponding coordinates
        return {name: axis for axis, name in enumerate(self.axes_labels)}
    
    def get_other_axes_from(self, axis_label):
        #returns a tuple containing all the axes except from axis_label
        return tuple(axis for axis, name in enumerate(self.axes_labels) if name != axis_label)
    
    def is_valid_conditional(self, variable_name):
        #variable_name is the name of the variable for which we are computing the distribution, e.g. in p(y|x) it is 'y'
        return np.all(np.isclose(np.sum(self.probs, axis=self.get_axes()[variable_name]), 1.0))
    
    def is_valid_joint(self):
        return np.all(np.isclose(np.sum(self.probs), 1.0))

In [3]:
#Let's see the previous distribution:

p_v1_given_h1 = Distribution(np.array([[0.4, 0.8, 0.9], [0.6, 0.2, 0.1]]), ['v1', 'h1'])

print('Is p(v1|h1) a valid conditional distribution? ', p_v1_given_h1.is_valid_conditional('v1'))
print('Is p(v1|h1) a valid joint distribution? ', p_v1_given_h1.is_valid_joint())

#Consider also a joint distribution and a conditional distribution with more than one 'given' variables

p_h1 = Distribution(np.array([0.6, 0.3, 0.1]), ['h1'])

print('Is p(h1) a valid conditional distribution? ', p_h1.is_valid_conditional('h1'))
print('Is p(h1) a valid joint distribution? ', p_h1.is_valid_joint())

p_v1_given_h0_h1 = Distribution(np.array([[[0.9, 0.2, 0.7], [0.3, 0.2, 0.5]],[[0.1, 0.8, 0.3], [0.7, 0.8, 0.5]]]), ['v1', 'h0', 'h1'])
print('Is p(v1|h1, h2) a valid conditional distribution? ', p_v1_given_h0_h1.is_valid_conditional('v1'))
print('Is p(v1|h1, h2) a valid joint distribution? ', p_v1_given_h0_h1.is_valid_joint())

Is p(v1|h1) a valid conditional distribution?  True
Is p(v1|h1) a valid joint distribution?  False
Is p(h1) a valid conditional distribution?  True
Is p(h1) a valid joint distribution?  True
Is p(v1|h1, h2) a valid conditional distribution?  True
Is p(v1|h1, h2) a valid joint distribution?  False


We need to allow multiplications between distributions like $p(v_1|h_1,...,h_n) p(h_i)$, where p(hi) is a 1D array.
To do it, we can exploit broadcasting. But first, we need to reshape $p(h_i)$ accordingly to the dimension $h_i$ of the distribution $p(v_1|h_1,...,h_n)$

In [4]:
def multiply(p_v_given_h: Distribution, p_hi: Distribution) -> Distribution:
    ''' 
    Compute the product of the distributions p(v|h1,..,hn)p(hi) where p(hi) is a 1D array
    '''
    #Get the axis corresponding to hi in the conditional distribution
    axis=p_v_given_h.get_axes()[next(iter(p_hi.get_axes()))]
    
    # Reshape p(hi) in order to exploit broadcasting. Consider also the case in which p(hi) is a scalar.
    dims=np.ones_like(p_v_given_h.probs.shape)
    dims[axis] = p_v_given_h.probs.shape[axis]

    if (p_hi.probs.shape != () ):
        reshaped_p_hi = p_hi.probs.reshape(dims)
    else:
        reshaped_p_hi = p_hi.probs

        
    return Distribution(p_v_given_h.probs*reshaped_p_hi, p_v_given_h.axes_labels)

In [5]:
p_v1_h1 = multiply(p_v1_given_h1, p_h1)
print(p_v1_h1.probs)
print(p_v1_h1.is_valid_joint())

p_v1_h1_given_h0 = multiply(p_v1_given_h0_h1, p_h1)
print(p_v1_h1_given_h0.probs)

[[0.24 0.24 0.09]
 [0.36 0.06 0.01]]
True
[[[0.54 0.06 0.07]
  [0.18 0.06 0.05]]

 [[0.06 0.24 0.03]
  [0.42 0.24 0.05]]]


### Factor graphs

Factor graphs are bipartite graphs, with variable nodes and factor nodes. Edges can only connect nodes of different type. Consider for example:

![factor_ex](imgs/factor_example.png)



In [6]:
class Node(object):
    def __init__(self, name):
        self.name = name
        self.neighbors = []

    def is_valid_neighbor(self, neighbor):
        raise NotImplemented()

    def add_neighbor(self, neighbor):
        assert self.is_valid_neighbor(neighbor)
        self.neighbors.append(neighbor)


class Variable(Node):
    def is_valid_neighbor(self, factor):
        return isinstance(factor, Factor)  # Variables can only neighbor Factors


class Factor(Node):
    def is_valid_neighbor(self, variable):
        return isinstance(variable, Variable)  # Factors can only neighbor Variables

    def __init__(self, name):
        super(Factor, self).__init__(name)
        self.distribution = None

We can build some parsing methods in order to create a factor graph from a string representing the factorization of the joint probability distribution

In [7]:
from collections import namedtuple
        
ParsedTerm = namedtuple('ParsedTerm', [
    'term',
    'var_name',
    'given',
])


def _parse_term(term):
    # Given a term like (a|b,c), returns a list of variables
    # and conditioned-on variables
    assert term[0] == '(' and term[-1] == ')'
    term_variables = term[1:-1]

    # Handle conditionals
    if '|' in term_variables:
        var, given = term_variables.split('|')
        given = given.split(',')
    else:
        var = term_variables
        given = []

    return var, given


def _parse_model_string_into_terms(model_string):
    return [
        ParsedTerm('p' + term, *_parse_term(term))
        for term in model_string.split('p')
        if term
    ]

def parse_model_into_variables_and_factors(model_string):
    # Takes in a model_string such as p(h1)p(h2∣h1)p(v1∣h1)p(v2∣h2) and returns a
    # dictionary of variable names to variables and a list of factors.
    
    # Split model_string into ParsedTerms
    parsed_terms = _parse_model_string_into_terms(model_string)
    
    # First, extract all of the variables from the model_string (h1, h2, v1, v2). 
    # These each will be a new Variable that are referenced from Factors below.
    variables = {}
    for parsed_term in parsed_terms:
        # if the variable name wasn't seen yet, add it to the variables dict
        if parsed_term.var_name not in variables:
            variables[parsed_term.var_name] = Variable(parsed_term.var_name)

    # Now extract factors from the model. Each term (e.g. "p(v1|h1)") corresponds to 
    # a factor. 
    # Then find all variables in this term ("v1", "h1") and add the corresponding Variables
    # as neighbors to the new Factor, and this Factor to the Variables' neighbors.
    factors = []
    for parsed_term in parsed_terms:
        # This factor will be neighbors with all "variables" (left-hand side variables) and given variables
        new_factor = Factor(parsed_term.term)
        all_var_names = [parsed_term.var_name] + parsed_term.given
        for var_name in all_var_names:
            new_factor.add_neighbor(variables[var_name])
            variables[var_name].add_neighbor(new_factor)
        factors.append(new_factor)

    return factors, variables

We can combine factor nodes and variable nodes to create a factor graph and add a distribution to each factor node.

In [8]:
class PGM(object):
    def __init__(self, factors, variables):
        self._factors = factors
        self._variables = variables

    @classmethod
    def from_string(cls, model_string):
        factors, variables = parse_model_into_variables_and_factors(model_string)
        return PGM(factors, variables)

    def set_distributions(self, distributions):
        var_dims = {}
        for factor in self._factors:
            factor_data = distributions[factor.name]

            if set(factor_data.axes_labels) != set(v.name for v in factor.neighbors):
                missing_axes = set(v.name for v in factor.neighbors) - set(distributions[factor.name].axes_labels)
                raise ValueError("data[{}] is missing axes: {}".format(factor.name, missing_axes))
                
            for var_name, dim in zip(factor_data.axes_labels, factor_data.probs.shape):
                if var_name not in var_dims:
                    var_dims[var_name] = dim
    
                if var_dims[var_name] != dim:
                    raise ValueError("data[{}] axes is wrong size, {}. Expected {}".format(factor.name, dim, var_dims[var_name]))            
                    
            factor.distribution = distributions[factor.name]
            
    def variable_from_name(self, var_name):
        return self._variables[var_name]

We can notice that, in the previous example, we can write the marginal as a combination of sums and products:

$$p(x_5) = \sum_{x_1, x_2, x_3, x_4}p(x_1, x_2, x_3, x_4, x_5) =\\ = \sum_{x_3, x_4}f_3(x_3,x_4,x_5)\bigg[\sum_{x_1}f_1(x_1, x_3)\bigg]\bigg[\sum_{x_2}f_2(x_2, x_3)\bigg]$$

and interpret them as messages flowing from factors to variables (including a summation) or from variables to factors (via multiplication).

In [12]:
class Messages(object):
    def __init__(self):
        self.messages = {}
        
    def _variable_to_factor_messages(self, variable: Variable, factor: Factor):
        # Take the product over all incoming factors into this variable except the variable
        incoming_messages = [
            self.factor_to_variable_message(neighbor_factor, variable)
            for neighbor_factor in variable.neighbors
            if neighbor_factor.name != factor.name
        ]

        # If there are no incoming messages, this is 1 (BASE CASE)
        return np.prod(incoming_messages, axis=0)
    
    def _factor_to_variable_messages(self, factor, variable: Variable):
        #reinstantiate to obtain a deep copy
        factor_dist = Distribution(factor.distribution.probs, factor.distribution.axes_labels)

        for neighbor_variable in factor.neighbors:
            if neighbor_variable.name == variable.name:
                continue
            #Retrieve the incoming message and multiply the conditional distribution of the factor with the message
            incoming_message = self.variable_to_factor_messages(neighbor_variable, factor)
            factor_dist = multiply(factor_dist, Distribution(incoming_message, [neighbor_variable.name]))

        # Sum over the axes that aren't `variable`
        factor_dist = factor_dist.probs
        other_axes = factor.distribution.get_other_axes_from(variable.name)
        return np.squeeze(np.sum(factor_dist, axis=other_axes))
    
    
    def marginal(self, variable):
        # p(variable) is proportional to the product of incoming messages to variable.
        unnorm_p = np.prod([
            self.factor_to_variable_message(neighbor_factor, variable)
            for neighbor_factor in variable.neighbors
        ], axis=0)

        # At this point, we can normalize this distribution
        return unnorm_p/np.sum(unnorm_p)
    
    def variable_to_factor_messages(self, variable, factor):
        message_name = (variable.name, factor.name)
        if message_name not in self.messages:
            self.messages[message_name] = self._variable_to_factor_messages(variable, factor)
        return self.messages[message_name]
        
    def factor_to_variable_message(self, factor, variable):
        message_name = (factor.name, variable.name)
        if message_name not in self.messages:
            self.messages[message_name] = self._factor_to_variable_messages(factor, variable)
        return self.messages[message_name]  

We can try to build the following factor graph:

![factor1](imgs/factor2.png)

In [10]:
p_h1 = Distribution(np.array([[0.2], [0.8]]), ['h1'])
p_h2_given_h1 = Distribution(np.array([[0.5, 0.2], [0.5, 0.8]]), ['h2', 'h1'])
p_v1_given_h1 = Distribution(np.array([[0.6, 0.1], [0.4, 0.9]]), ['v1', 'h1'])
p_v2_given_h2 = Distribution(p_v1_given_h1.probs, ['v2', 'h2'])

pgm = PGM.from_string("p(h1)p(h2|h1)p(v1|h1)p(v2|h2)")

pgm.set_distributions({
    "p(h1)": p_h1,
    "p(h2|h1)": p_h2_given_h1,
    "p(v1|h1)": p_v1_given_h1,
    "p(v2|h2)": p_v2_given_h2,
})

And compute the marginal distribution $p(v_2)$

In [13]:
pgm = PGM.from_string("p(h1)p(h2|h1)p(v1|h1)p(v2|h2)")

pgm.set_distributions({
    "p(h1)": p_h1,
    "p(h2|h1)": p_h2_given_h1,
    "p(v1|h1)": p_v1_given_h1,
    "p(v2|h2)": p_v2_given_h2,
})

m = Messages()
m.marginal(pgm.variable_from_name('v2'))

array([0.23, 0.77])

In [14]:
m.messages

{('p(h1)', 'h1'): array([0.2, 0.8]),
 ('v1', 'p(v1|h1)'): 1.0,
 ('p(v1|h1)', 'h1'): array([1., 1.]),
 ('h1', 'p(h2|h1)'): array([0.2, 0.8]),
 ('p(h2|h1)', 'h2'): array([0.26, 0.74]),
 ('h2', 'p(v2|h2)'): array([0.26, 0.74]),
 ('p(v2|h2)', 'v2'): array([0.23, 0.77])}

In [15]:
m.marginal(pgm.variable_from_name('v1'))

array([0.2, 0.8])

### Exercise 1

(From Bayesian Reasoning and Machine Learning, David Barber) You live in a house with three rooms, labelled 1, 2, 3. There is a door between rooms 1 and 2 and another between rooms 2 and 3. One cannot directly pass between rooms 1 and 3 in one time-step. An annoying fly is buzzing from one room to another and there is some smelly cheese in room 1 which seems to attract the fly more. Using $x_t$ for which room the fly is in at time t, with $dom(x_t) = {1,2,3}$, the movement of the fly can be described by a transition:
$p(x_{t+1} = i|x_t = j) = M_{ij}$

where M is a transition matrix:

$$
\begin{bmatrix}
0.7 & 0.5 & 0 \\
0.3 & 0.3 & 0.5 \\
0 & 0.2 & 0.5 \\
\end{bmatrix}
$$

Given that the fly is in room 1 at time 1, what is the probability of room occupancy at time t = 5? Assume a Markov chain which is defined by the joint distribution

$p(x_1, . . . , x_T ) = p(x_1) \prod p(x_{t+1}|x_t)$

We are asked to compute $p(x_5|x_1 = 1)$ which is given by
$\sum p(x_5|x_4)p(x_4|x_3)p(x_3|x_2)p(x_2|x_1 = 1)$

In [17]:
pgm = PGM.from_string("p(x5|x4)p(x4|x3)p(x3|x2)p(x2|x1)p(x1)")

p_x5_given_x4 = Distribution(np.array([[0.7, 0.5, 0], [0.3, 0.3, 0.5], [0, 0.2, 0.5]]), ['x5', 'x4'])
p_x4_given_x3 = Distribution(p_x5_given_x4.probs, ['x4', 'x3'])
p_x3_given_x2 = Distribution(p_x5_given_x4.probs, ['x3', 'x2'])
p_x2_given_x1 = Distribution(p_x5_given_x4.probs, ['x2', 'x1'])
p_x1 = Distribution(np.array([1, 0, 0]), ['x1'])

pgm.set_distributions({
    "p(x5|x4)": p_x5_given_x4,
    "p(x4|x3)": p_x4_given_x3,
    "p(x3|x2)": p_x3_given_x2,
    "p(x2|x1)": p_x2_given_x1,
    "p(x1)": p_x1,
})

m2 = Messages()
m2.marginal(pgm.variable_from_name('x5'))

m2.messages

{('p(x1)', 'x1'): array([1, 0, 0]),
 ('x1', 'p(x2|x1)'): array([1, 0, 0]),
 ('p(x2|x1)', 'x2'): array([0.7, 0.3, 0. ]),
 ('x2', 'p(x3|x2)'): array([0.7, 0.3, 0. ]),
 ('p(x3|x2)', 'x3'): array([0.64, 0.3 , 0.06]),
 ('x3', 'p(x4|x3)'): array([0.64, 0.3 , 0.06]),
 ('p(x4|x3)', 'x4'): array([0.598, 0.312, 0.09 ]),
 ('x4', 'p(x5|x4)'): array([0.598, 0.312, 0.09 ]),
 ('p(x5|x4)', 'x5'): array([0.5746, 0.318 , 0.1074])}

### Exercise 2: Hidden Markov Models

Imagine you're trying to guess someone's mood without directly asking them or using brain electrodes. Instead, you observe their facial expressions, whether they're smiling or frowning, to make an educated guess.

We assume moods can be categorized into two states: good and bad. When you meet someone for the first time, there's a 70% chance they're in a good mood and a 30% chance they're in a bad mood.

If someone is in a good mood, there's an 80% chance they'll stay in a good mood and a 20% chance they'll switch to a bad mood over time. The same probabilities of switching the mood apply if they start in a bad mood.

Lastly, when someone is in a good mood, they're 90% likely to smile and 10% likely to frown. Conversely, if they're in a bad mood, they have a 10% chance of smiling and a 90% chance of frowning.

The transitions are summarized in the following graph.

Your task is to use these probabilities to figure out the first and second hidden mood states (the probability that the first mood is good/bad and the probability that the second mood is good/bad) based on the observable facial expressions you see (imagine you see the sequence [smiling, frowning]).

![factor1](imgs/mood.png)
(image by Y. Natsume)