PS 3 - Question 2 <br>
Inference and  Representation<br>
NYU Center for Data Science<br>
October 3, 2017

It is a Python adaptation of the Matlab code provided in Brown University CS242 Homework 1:
http://cs.brown.edu/courses/cs242/assignments/
The factor graph library (fglib) is a Python 3 package to simulate message passing on factor graphs: https://github.com/danbar/fglib

In [44]:
import numpy as np 
import networkx as nx
from fglib import graphs, nodes, rv, inference

def make_debug_graph():

    # Create factor graph
    fg = graphs.FactorGraph()

    # Create variable nodes
    x1 = nodes.VNode("x1")
    x2 = nodes.VNode("x2")
    x3 = nodes.VNode("x3")
    x4 = nodes.VNode("x4")

    # Create factor nodes
    f12 = nodes.FNode("f12")
    f234 = nodes.FNode("f234")
    f3 = nodes.FNode("f3")
    f4 = nodes.FNode("f4")

    # Add nodes to factor graph
    fg.set_nodes([x1, x2, x3, x4])
    fg.set_nodes([f12, f234, f3,f4 ])

    # Add edges to factor graph
    fg.set_edge(x1, f12)
    fg.set_edge(f12, x2)
    fg.set_edge(x2, f234)
    fg.set_edge(f234, x3)
    fg.set_edge(f234, x4)
    fg.set_edge(x3, f3)
    fg.set_edge(x4, f4)

    #add potential for f_3: p(x3)
    dist_f3 = [0.5, 0.5]
    f3.factor = rv.Discrete(dist_f3,x3)
    
    #add potential for f_4: p(x4)
    dist_f4 = [0.4,0.6]
    f4.factor = rv.Discrete(dist_f4,x4)
    
    # add potential for f_{234}: p(x2, x3, x4) = p(x2|x3,x4) p(x3,x4)
    px3x4=np.outer(dist_f3,dist_f4)
    px3x4=np.reshape(px3x4, np.shape(px3x4)+(1,))
    px2_conditioned_x3x4=[[[0.2,0.8],
                         [0.25,0.75],],
                         [[0.7,0.3],
                         [0.3,0.7]]]
    
    dist_f234 =px3x4*px2_conditioned_x3x4
    f234.factor = rv.Discrete(dist_f234,x3,x4,x2)
   
    # add potential for f_{12}:  p (x1,x2) = p(x1 | x2) p(x2)
    px1_conditioned_x2 = [[0.5,0.5],
                         [0.7,0.3]]
    px2= np.sum(dist_f234, axis=(0,1))
    dist_f12 =px2[:,np.newaxis]*px1_conditioned_x2
    f12.factor = rv.Discrete(dist_f12,x2,x1)
    # Perform sum-product algorithm on factor graph
    # and request belief of variable node x1
    belief = inference.sum_product(fg, x3)
    return (fg)


In [49]:
# Update belief given a edge visiting schedule
def schedule_propagation(schedule, graph):
    '''
    schedule: list of edges (in tuple form)
    '''
    for node_origin, node_destination in schedule:
        # Get fglib edge object
        edge = graph.get_edge_data(node_origin, node_destination)['object']
        # get message using sum-product algorithm
        message = node_origin.spa(node_destination)
        # set message
        edge.set_message(node_origin,node_destination,message)
    return 

def get_beliefs(fg, n_iteration=10, parallel_update):
    # determine whether the graph is acyclic
    cycles = nx.cycle_basis(fg)
    is_acyclic = len(cycles)==0
    # If acyclic use depth first search to generate a efficient schedule
    if is_acyclic:
        root_node = list(fg.get_vnodes())[0]
        root2leaf = list(nx.depth_first_search.dfs_edges(fg, root_node))
        leaf2root = [(v,u) for u,v in reversed(root2leaf)]
        schedule_propagation(leaf2root, fg)
        schedule_propagation(root2leaf, fg)
        
    # Otherwise, use iterative updating (Loopy propagation)
    else:
        fnodes = fg.get_fnodes()
        vnodes = fg.get_vnodes()
        nodes_sequence = fnodes + vnodes
        schedule = [(node, neighbor) for node in nodes_sequence for neighbor in node.neighbors()]
        for i in range(n_iteration):
            schedule_propagation(schedule, fg)
            
    # Generate the belief of every vnodes
    output_items = []
    for vnode in fg.get_vnodes():
        output_items.append((str(vnode), vnode.belief().pmf))
    output_dict = dict(output_items)
    return output_dict

In [35]:
beliefs = get_beliefs(fg)
# Print belief of variable nodes
print("Belief of variable nodes ")
print(beliefs)

Belief of variable nodes 
{'x1': array([ 0.65897284,  0.34102716]), 'x2': array([ 0.20513578,  0.79486422]), 'x3': array([ 0.52640912,  0.47359088]), 'x4': array([ 0.28679718,  0.71320282])}
