In [None]:
from sumproduct import Variable, Factor, FactorGraph
import numpy as np


In [1]:
# factors: an array of 2d factor matrices, for x12, x23, ..., where dimensions are x1*x2, x2*x3, ...

def create_2chain_graph(factors):
    g = FactorGraph(silent=True)  # init the graph without message printouts
    num_vars = len(factors)+1
    vars = []
    vnames = []
    gvars = []
    for i in range(len(factors)-1):
        assert factors[i].shape[1] == factors[i+1].shape[0]
        vars.append(factors[i].shape[0])
    vars.append(factors[-1].shape[0])
    vars.append(factors[-1].shape[1])
    for i, v_size in enumerate(vars):
        vname = 'x'+str(i+1)
        v = Variable(vname, v_size)
        vnames.append(vname)
        gvars.append(v)

    for i in range(len(gvars)-1):
        fname = 'f{}{}'.format(i+1, i+2)
        # factors are transposed, from x2 to x1, etc'
        fact = Factor(fname, factors[i].transpose())
        g.add(fact)
        g.append(fname, gvars[i+1])
        g.append(fname, gvars[i])

    return g, vnames


def compute_2chain_marginals(factors):
    g, vnames = create_2chain_graph(factors)
    g.compute_marginals(max_iter=15500, tolerance=1e-8)
    rc = []
    for vname in vnames:
        rc.append(g.nodes[vname].marginal())
    return rc
