In [275]:
from pgmpy.models import FactorGraph
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.inference import BeliefPropagation
import numpy as np

In [276]:
T, K = 4, 2  # 4 variables in a chain, 2 states {0,1}
Xs = [f"X{t}" for t in range(1, T + 1)]  # ["X1","X2","X3","X4"]

In [277]:

phi = [
    np.array([1.0, 1.0]),  # φ_1
    np.array([1.0, 1.0]),  # φ_2
    np.array([0.3, 3.0]),  # φ_3 (evidence-like bump for state 1)
    np.array([1.0, 1.0]),  # φ_4
]

psi = [
    np.array([[3.0, 0.5], [0.5, 3.0]]),  # ψ_1 between (X1, X2)
    np.array([[1.0, 1.0], [1.0, 1.0]]),  # ψ_2 between (X2, X3)
    np.array([[0.25, 2.0], [2.0, 0.25]]),  # ψ_3 between (X3, X4)
]

In [278]:
G = FactorGraph()
G.add_nodes_from(Xs)  # variable nodes

# Add unary factors and connect them
unary_factors = []
for t in range(T):
    f = DiscreteFactor([Xs[t]], [K], phi[t])
    G.add_factors(f)
    G.add_edge(Xs[t], f)  # connect variable to its unary factor
    unary_factors.append(f)

In [279]:
pair_factors = []
for t in range(T - 1):
    g = DiscreteFactor([Xs[t], Xs[t + 1]], [K, K], psi[t].ravel(order="C"))
    G.add_factors(g)
    G.add_edge(Xs[t], g)
    G.add_edge(Xs[t + 1], g)
    pair_factors.append(g)

In [280]:
def get_first_last_node(G:FactorGraph):
    var_nodes = [n for n in G.nodes() if not isinstance(n, DiscreteFactor)]
    endpoints = [v for v in var_nodes if G.degree(v) == 2]

    return sorted(endpoints, key=lambda x: int(x[1:]))
first, last = get_first_last_node(G)

In [281]:
def is_pairwise_factor(f):
    return isinstance(f, DiscreteFactor) and len(f.scope()) == 2

In [None]:
def marginal (G, node1, node2):
    bp = BeliefPropagation(G)
    q = bp.query(variables=[node1, node2])
    table = q.values.reshape(q.cardinality)
    print(table)
    return table.sum(axis=0)
print(marginal(G, "X3", "X4"))

[[0.01010101 0.08080808]
 [0.80808081 0.1010101 ]]
[0.81818182 0.18181818]


In [283]:
def forward_masseage(G:FactorGraph, start, last):
    forward_dict = dict()
    curr_node = start
    next_node = None
    prev_node = None
    while curr_node != last:
        for f in G.neighbors(curr_node):
            if not (isinstance(f, DiscreteFactor) and len(f.scope()) == 2):
                continue
            other_vars = [
                v for v in G.neighbors(f) if not isinstance(v, DiscreteFactor)
            ]
            candidate = other_vars[0] if other_vars[1] == curr_node else other_vars[1]
            if candidate == prev_node:
                continue
            next_node = candidate
            break
        forward_dict[(curr_node, next_node)] = marginal(G, curr_node, next_node)
        if next_node is None:
            break
        prev_node, curr_node = curr_node, next_node
        print(curr_node)
    return forward_dict


forward_masseage(G, first, last)

X2
X3
X4


{('X1', 'X2'): array([0.5, 0.5]),
 ('X2', 'X3'): array([0.5, 0.5]),
 ('X3', 'X4'): array([0.09090909, 0.90909091])}

In [284]:
def backward_masseage(G: FactorGraph, start, last):
    backward_dict = dict()
    curr_node = last
    next_node = None
    prev_node = None
    while curr_node != start:
        for f in G.neighbors(curr_node):
            if not (isinstance(f, DiscreteFactor) and len(f.scope()) == 2):
                continue
            other_vars = [
                v for v in G.neighbors(f) if not isinstance(v, DiscreteFactor)
            ]
            candidate = other_vars[0] if other_vars[1] == curr_node else other_vars[1]
            if candidate == prev_node:
                continue
            next_node = candidate
            break
        backward_dict[(next_node, curr_node)] = marginal(G, next_node, curr_node)
        if next_node is None:
            break
        prev_node, curr_node = curr_node, next_node
        print(curr_node)
    return backward_dict

backward_masseage(G, first, last)

X3
X2
X1


{('X3', 'X4'): array([0.09090909, 0.90909091]),
 ('X2', 'X3'): array([0.5, 0.5]),
 ('X1', 'X2'): array([0.5, 0.5])}

In [285]:
def inference(G, node1, node2=None):
    first, last = get_first_last_node(G)
    forward_massages = forward_masseage(G, first, last)
    backward_massages = backward_masseage(G, first, last)
    forward_value = np.ones_like(forward_massages[next(iter(forward_massages))])
    backward_value = np.ones_like(backward_massages[next(iter(backward_massages))])
    if node2 is None:
        for key, values in forward_massages.items():
            if key[0] == node1:
                break
            else:
                forward_value = values * forward_value
        for key, values in backward_massages.items():
            if key[0] == node1:
                break
            else:
                backward_value = values * forward_value
        return forward_value * backward_value
    else:
        for key, values in forward_massages.items():
            if key[0] == node1 or key[1] == node1:
                break
            else:
                backward_value = values * forward_value
        for key, values in backward_massages.items():
            if key[0] == node1:
                break
            else:
                backward_value = values * forward_value
        return forward_value * backward_value
    
inference(G, "X4")

X2
X3
X4
X3
X2
X1


array([0.00025826, 0.02582645])