# Test 3 on Belief Propagation

## Imports and MyBeliefPropagation class

In [1]:
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.models import FactorGraph

from functools import reduce
import operator
from collections import defaultdict
from copy import deepcopy

In [2]:
def prod(iterable):
    '''Helper function to obtain the product of all the items in the iterable
    given as input'''
    return reduce(operator.mul, iterable, 1)


class MyBeliefPropagation:
    def __init__(self, factor_graph):
        assert factor_graph.check_model()
        self.original_graph = factor_graph
        self.variables = factor_graph.get_variable_nodes()

        self.state_names = dict()
        for f in self.original_graph.factors:
            self.state_names.update(f.state_names)

    def get_evidence_factors(self, evidence):
        '''
        For each evidence variable v, create a factor with p(v=e)=1. Receives a dict of
        evidences, where keys are variables and values are variable states. Returns a list of
        DiscreteFactor.
        '''
        # For each factor that involves variable v, add another factor with p(v=value)=1.
        # Returns a list of evidence factors.
        evidence_factors = []

        for variable, value in evidence.items():
            i = self.state_names[variable].index(value)
            values = [0] * len(self.state_names[variable])
            values[i] = 1.0
            ef = DiscreteFactor(
                variables=[variable],
                cardinality=[len(values)],
                values=values,
                state_names=self.state_names,
            )
            evidence_factors.append(ef)
        return evidence_factors

    def set_evidence(self, evidence):
        '''
        Generates a new graph with the evidence factors
        evidence (keys: variables, values: states)
        '''
        evidence_factors = self.get_evidence_factors(evidence)
        self.working_graph = self.original_graph.copy()
        for f in evidence_factors:
            self.working_graph.add_factors(f)
            for v in f.variables:
                self.working_graph.add_edge(v, f)
        self.bp_done = False

    def factor_ones(self, v):
        '''
        Returns a DiscreteFactor for variable v with all ones.
        '''
        card = len(self.state_names[v])
        return DiscreteFactor(
            variables=[v],
            cardinality=[card],
            values=[1] * card,
            state_names=self.state_names,
        )

    def initialize_messages(self):
        '''
        This function creates, for each edge factor-variable, two messages: m(f->v) and
        m(v->f). It initiliazies each message as a DiscreteFactor with all ones. It stores all
        the messages in a dict of dict. Keys of both dicts are either factors or variables.
        Messages are indexed as messages[to][from]. For example, m(x->y) is in messages[y][x].
        It's done this way because it will be useful to get all messages that go to a variable
        or a factor.
        '''
        self.messages = defaultdict(dict)
        for f in self.working_graph.get_factors():
            for v in f.variables:
                self.messages[v][f] = self.factor_ones(v)
                self.messages[f][v] = self.factor_ones(v)

    def factor_to_variable(self, f, v):
        '''
        Computes message m from factor to variable.
        It computes it from all messages from all
        other variables to the factor (i.e. all variables connected the factor except v).
        Returns message m.
        '''
        assert v in self.variables and f in self.working_graph.factors
        messages_to_f = list(self.messages[f].values())
        messages_to_f.remove(self.messages[f][v])  # all except the one from variable v

        m = f * prod(messages_to_f)
        other_vars = set(m.variables) - set([v])
        m.marginalize(other_vars)
        return m

    def variable_to_factor(self, v, f):
        '''
        Computes message m from variable to factor.
        It computes it from all messages from all
        other factors to the variable (i.e. all factors connected the variable except f).
        Returns message m.
        '''
        assert v in self.variables and f in self.working_graph.factors
        messages_to_v = list(self.messages[v].values())
        messages_to_v.remove(self.messages[v][f])  # all except the one from factor f
        if (
            len(messages_to_v) == 0
        ):  # No neighbors, return 1 (or return None and do not update)
            return self.factor_ones(v)
        m = prod(messages_to_v)
        return m

    def update(self, m_to, m_from):
        '''
        Performs an update of a message depending on whether it is variable-to-factor or
        factor-to-variable.
        '''
        if m_from in self.variables:
            assert m_to in self.working_graph.factors, f'm_from: {m_from}\nm_to: {m_to}'
            self.messages[m_to][m_from] = self.variable_to_factor(m_from, m_to)
        else:
            assert (
                m_from in self.working_graph.factors and m_to in self.variables
            ), f'm_from: {m_from}\nm_to: {m_to}'
            self.messages[m_to][m_from] = self.factor_to_variable(m_from, m_to)

    def collect_evidence(self, node, parent=None):
        '''
        Passes messages from the leaves to the root of the tree.
        The parent argument is used to avoid an infinite recursion.
        '''
        for child in self.working_graph.neighbors(node):
            if child != parent:
                self.update(node, self.collect_evidence(child, parent=node))
        return node

    def distribute_evidence(self, node, parent=None):
        '''
        Passes messages from the root to the leaves of the tree.
        The parent argument is used to avoid an infinite recursion.
        '''
        for child in self.working_graph.neighbors(node):
            if child != parent:
                self.update(child, node)
                self.distribute_evidence(child, parent=node)

    def run_bp(self, root):
        '''
        After initializing the messages, this function performs Belief Propagation
        using collect_evidence and distribute_evidence from the given root node.
        '''
        assert root in self.variables, 'Variable not in the model'
        self.initialize_messages()
        print('Working graph', self.working_graph.check_model())
        self.collect_evidence(root)
        self.distribute_evidence(root)
        self.bp_done = True

    def get_marginal(self, variable):
        '''
        To be used after run_bp. Returns p(variable | evidence) unnormalized.
        '''
        assert self.bp_done, 'First run BP!'
        return prod(self.messages[variable].values())

    def get_marginal_subset(self, variables):
        '''
        Returns p(variables | evidence) unnormalized.
        '''
        assert self.bp_done, 'First run BP!'
        # IMPLEMENT
        product = 1
        factor = None
        for f in self.working_graph.factors:
            # print(f.variables, variables)
            if set(variables).issubset(f.variables):
                factor = f
                break

        if factor is None:
            raise ValueError('Not valid set of variables')

        res = factor
        mssgs = [self.messages[factor][v] for v in f.variables]

        res = factor * prod(mssgs)

        return res.marginalize(
            [v for v in f.variables if v not in variables], inplace=False
        )

## Declare DiscreteFactors

In [3]:
variables = ['BT', 'UT', 'Ho', 'Pr']
state_names = dict([(var, [True, False]) for var in variables])

p = dict()
p['Pr'] = DiscreteFactor(
    variables=['Pr'], cardinality=[2], values=[0.87, 0.13], state_names=state_names
)

p['Ho|Pr'] = DiscreteFactor(
    variables=['Pr', 'Ho'],
    cardinality=[2, 2],
    values=[0.99, 0.01, 0.1, 0.9],
    state_names=state_names,
)

p['BT|Ho'] = DiscreteFactor(
    variables=['Ho', 'BT'],
    cardinality=[2, 2],
    values=[0.9, 0.1, 0.3, 0.7],
    state_names=state_names,
)

p['UT|Ho'] = DiscreteFactor(
    variables=['Ho', 'UT'],
    cardinality=[2, 2],
    values=[0.9, 0.1, 0.2, 0.8],
    state_names=state_names,
)

## Create FactorGraph

In [4]:
G = FactorGraph()
assert set(variables) == set([v for f in p.values() for v in f.variables])

G.add_nodes_from(variables)
for f in p.values():
    G.add_factors(f)
    for v in f.variables:
        G.add_edge(v, f)

print('Model is ok: ', G.check_model())

Model is ok:  True


In [5]:
len(G.factors)

4

## Run BeliefPropagation

In [14]:
bp = MyBeliefPropagation(G)
bp.set_evidence({'BT': True, 'UT': True})
bp.run_bp(root='BT')

res = bp.get_marginal('Pr')
res_norm = res.normalize(inplace=False)
print(res)
print(res_norm)

Working graph True
+-----------+-----------+
| Pr        |   phi(Pr) |
| Pr(True)  |    0.6982 |
+-----------+-----------+
| Pr(False) |    0.0176 |
+-----------+-----------+
+-----------+-----------+
| Pr        |   phi(Pr) |
| Pr(True)  |    0.9755 |
+-----------+-----------+
| Pr(False) |    0.0245 |
+-----------+-----------+


In [7]:
res.values.sum()

np.float64(0.7157250000000001)

In [8]:
len(G.nodes), len(G.edges())

(8, 7)

## Get normalization factors

In [16]:
print(bp.get_marginal_subset(['Pr', 'Ho']).values.sum())

print(bp.get_marginal('Pr').values.sum())

print(bp.get_marginal('Ho').values.sum())

0.7157250000000001
0.7157250000000001
0.7157250000000001


Which is the probability of evidence.