In [65]:
from typing import Callable, Dict, List, Tuple
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from functools import partial
import numpy as np
import time
import cv2
from pathlib import Path
import imgui
import moderngl
from pyrr import Matrix44
import moderngl_window as mglw
from moderngl_window import geometry
from moderngl_window.integrations.imgui import ModernglWindowRenderer
import PIL
from scipy.spatial import Delaunay
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

@register_pytree_node_class
class Gaussian:
    def __init__(self, eta, Lam):
        self.eta = eta
        self.Lam = Lam

    def tree_flatten(self):
        return (self.eta, self.Lam), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

    
    def mu(self):
        return jnp.where(
            jnp.allclose(self.Lam, 0), self.eta, jnp.linalg.solve(self.Lam, self.eta)
        )
    

    
    def sigma(self):
        return jnp.linalg.inv(self.Lam)

    def zero_like(self):
        return Gaussian(jnp.zeros_like(self.eta), jnp.zeros_like(self.Lam))

    def __repr__(self) -> str:
        return f"Gaussian(eta={self.eta}, lam={self.Lam})"

    def __mul__(self, other):
        return Gaussian(self.eta + other.eta, self.Lam + other.Lam)

    def __truediv__(self, other):
        return Gaussian(self.eta - other.eta, self.Lam - other.Lam)

    def copy(self):
        return Gaussian(self.eta.copy(), self.Lam.copy())


@register_pytree_node_class
class Variable:
    var_id: int
    belief: Gaussian
    msgs: Gaussian
    adj_factor_idx: jnp.array

    def __init__(self, var_id, belief, msgs, adj_factor_idx):
        self.var_id = var_id
        self.belief = belief
        self.msgs = msgs
        self.adj_factor_idx = adj_factor_idx

    def tree_flatten(self):
        return (self.var_id, self.belief, self.msgs, self.adj_factor_idx), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)


@register_pytree_node_class
class Factor:
    factor_id: jnp.array
    z: jnp.ndarray
    z_Lam: jnp.ndarray
    threshold: jnp.ndarray
    potential: Gaussian
    adj_var_id: jnp.array
    adj_var_idx: jnp.array

    def __init__(
        self, factor_id, z, z_Lam, threshold, potential, adj_var_id, adj_var_idx
    ):
        self.factor_id = factor_id
        self.z = z
        self.z_Lam = z_Lam
        self.threshold = threshold
        self.potential = potential
        self.adj_var_id = adj_var_id
        self.adj_var_idx = adj_var_idx

    def tree_flatten(self):
        return (
            self.factor_id,
            self.z,
            self.z_Lam,
            self.threshold,
            self.potential,
            self.adj_var_id,
            self.adj_var_idx,
        ), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)


@partial(jax.jit, static_argnames=["i", "j"])
def marginalize(gaussians: Gaussian, i, j): # Equ. (46), (47); Compute msg to i:j Variables from connected factors
    eta = gaussians.eta
    Lam = gaussians.Lam
    k = eta.size
    idx = jnp.arange(0, k)
    aa = idx[i:j] # index from i to j-1
    bb = jnp.concatenate([idx[:i], idx[j:]]) # rest
    aa_eta = eta[aa]
    bb_eta = eta[bb]
    aa_Lam = Lam[aa[:, None], aa]
    ab_Lam = Lam[aa[:, None], bb]
    bb_Lam = Lam[bb][:, bb]
    if bb_Lam.size == 0:
        return Gaussian(aa_eta, aa_Lam)
    # print("How large? ", bb_Lam.shape)

    bb_Cov = jnp.linalg.inv(bb_Lam)
    eta = aa_eta - ab_Lam @ bb_Cov @ bb_eta
    Lam = aa_Lam - ab_Lam @ bb_Cov @ ab_Lam.T
    return Gaussian(eta, Lam)


"""
def tree_stack(tree, axis=0, use_np=True):
    if use_np:
        return jax.tree.map(lambda *v: jnp.array(np.stack(v, axis=axis)), *tree)
    return jax.tree.map(lambda *v: jnp.stack(v, axis=axis), *tree)
"""
def tree_stack(tree, axis=0, use_np=True):
    if use_np:
        return jax.tree_util.tree_map(lambda *v: jnp.array(np.stack(v, axis=axis)), *tree)
    return jax.tree_util.tree_map(lambda *v: jnp.stack(v, axis=axis), *tree)

def h_fn(x):
    return x


def h2_fn(xs):
    """
    xs: shape (2, D), where:
        - xs[0] is x1
        - xs[1] is x2
    """
    x1 = xs[0]
    x2 = xs[1]

    #jax.debug.print("Shape of x2 - x1: {}", (x2 - x1).shape)

    return x2 - x1


# @jax.jit
def update_belief(var: Variable, ftov_msgs): # Calculate Eq. (7)
    belief = var.belief.zero_like()
    for i in range(ftov_msgs.eta.shape[0]):
        belief = belief * Gaussian(ftov_msgs.eta[i], ftov_msgs.Lam[i])


    return belief

"""
# @jax.jit
def update_belief(varis: Variable, ftov_msgs):
    belief = varis.belief.zero_like()

    mask = (varis.adj_factor_idx >= 0)[..., None]

    # Debug: print mask
    jax.debug.print("mask = {}", mask)

    varis.msgs.eta = varis.msgs.eta * mask
    varis.msgs.Lam = varis.msgs.Lam * mask[..., None]

    # Debug: print masked eta
    jax.debug.print("masked msgs.eta = {}", varis.msgs.eta)

    for i in range(ftov_msgs.eta.shape[0]):
        msg = Gaussian(ftov_msgs.eta[i], ftov_msgs.Lam[i])
        belief = belief * msg

    return belief
"""


# @jax.jit
def compute_vtof_msgs(var: Variable, ftov_msgs): # Eq.(19); do for each variable (x_m)
    vtof_msgs = []
    for i, idx in enumerate(var.adj_factor_idx): # for each f_si connected to x_m...
        msg = var.belief / Gaussian(ftov_msgs.eta[i], ftov_msgs.Lam[i]) # Eq.(19) LHS subscript of SUM
        eta = jnp.where(idx < 0, msg.zero_like().eta, msg.eta) # Those not connected should not affect the calculation (idx < 0)
        Lam = jnp.where(idx < 0, msg.zero_like().Lam, msg.Lam) # The reason to not using "if" (while it's per-element) is to optimize better
        vtof_msgs.append(Gaussian(eta, Lam)) # append (x_m -> f_si)
    return tree_stack(vtof_msgs, use_np=False) # [(x_m -> f_s1), (x_m -> f_s2), ... ] # The length is Ni_v


@partial(jax.jit, static_argnames=["h_fn"])
def factor_energy(factor, xs, h_fn):
    h = h_fn(xs)
    z = factor.z
    z_Lam = factor.z_Lam
    r = z - h
    return 0.5 * r @ z_Lam @ r.T



# @partial(jax.jit, static_argnames=["h_fn", "w"])
def factor_update(factor, xs, h_fn, w):
    h = h_fn(xs)
    J = jax.jacrev(h_fn)(xs).reshape(h.size, xs.size) # Jacobian auto-diff (J_s)


    z = factor.z # I think this is a vector
    z_Lam = factor.z_Lam
    
    r = z - h.reshape(-1) # TODO: reshape can be problematic
    s = w(r.T @ z_Lam @ r, factor.threshold) # Scale to consider Robust Loss
    Lam = s * J.T @ z_Lam @ J # Eq. (36)
    eta = s * J.T @ z_Lam @ (J @ xs.reshape(-1) + r) # TODO: reshape can be problematic; Eq. (36); xs should be a vector
    return Gaussian(eta, Lam) # Factor; represented w.r.t. neighboring variables xs


# @jax.jit
def compute_ftov_msg(factor, vtof_msgs): # Ch 3.5 Message Passing at a Factor Node
    # vtof_msgs: Variable to Factor Messages; shape: (N_adj, dim) 
    # where N_adj is the number of adjacent variables and dim is the dimension of each variable

    N_adj, dim = vtof_msgs.eta.shape


    
    pot = factor.potential.copy() # log(f_s), but for only a specific variable a factor is connected to.
    i = 0
    for n in range(N_adj): # Add all! (Produce all)
        j = i + dim
        pot.eta = pot.eta.at[i:j].add(vtof_msgs.eta[n])
        pot.Lam = pot.Lam.at[i:j, i:j].add(vtof_msgs.Lam[n])
        i = j


    ftov_msgs = []
    i = 0
    for n in range(N_adj):
        j = i + dim
        pot_m_1 = pot.copy()
        pot_m_1.eta = pot_m_1.eta.at[i:j].add(-vtof_msgs.eta[n]) # Subtract direction of going out! (42)
        pot_m_1.Lam = pot_m_1.Lam.at[i:j, i:j].add(-vtof_msgs.Lam[n]) # (43)
        msg = marginalize(pot_m_1, i, j) # (46), (47)
        ftov_msgs.append(msg)
        i = j

    
    return tree_stack(ftov_msgs, use_np=False)


@jax.jit
def update_variable(varis): # Update belief with receiving msgs and calculate msg to factors; varis.msgs are up-to-date and varis.belief are not
    varis.belief = jax.vmap(update_belief)(varis, varis.msgs) # Eq. (7); varis.msgs is receiving msgs (ftov)
    vtof_msgs = jax.vmap(compute_vtof_msgs)(varis, varis.msgs) # Variable -> Factor Msg; Eq. (19)
    linpoints = jax.vmap(lambda x: x.mu())(varis.belief) # Current avg of belief! Belief is posterior

    return varis, vtof_msgs, linpoints # vtof msgs: # Var * # Var-direction (factor, Ni_v) msgs


@partial(jax.jit, static_argnames=["f", "w"])
def update_factor(facs, varis, vtof_msgs, linpoints, f, w): # f is factor function, w is robustifier
    vtof_msgs_reordered = jax.tree_util.tree_map( # Variable to factor messages to specific (variable, factor; or variable-direction) pair
        lambda x: x[facs.adj_var_id, facs.adj_var_idx], vtof_msgs # id: Variable id (one end), idx: direction (another end)
    )
    linpoints_reordered = jax.tree_util.tree_map(
        lambda x: x[facs.adj_var_id], linpoints # Reorder linpoints by adj_var_id: variables' mean for factors' one ends
    )
    
    facs.potential = jax.vmap(factor_update, in_axes=(0, 0, None, None))( # Calculate each factor potential (f_s(x, x_1, ..., x_M) of Eq. (15))
        facs, linpoints_reordered, f, w # Each factor contribution of variable-direction pair (factor: variable-direction pair)
    ) # 1 or 2-dimensional!! (gradient / prior factor or smoothness factor)
    ftov_msgs = jax.vmap(compute_ftov_msg)(facs, vtof_msgs_reordered) # ftov calculation by Eq. (15), with potential f_s, and msg vtof

    jax.debug.print("ftov_msgs = {}", ftov_msgs)
    varis.msgs.eta = varis.msgs.eta.at[facs.adj_var_id, facs.adj_var_idx].set( # Setting varis' receiving messages
        ftov_msgs.eta
    )
    varis.msgs.Lam = varis.msgs.Lam.at[facs.adj_var_id, facs.adj_var_idx].set(
        ftov_msgs.Lam
    )

    mask = (varis.adj_factor_idx >= 0)[..., None]
    varis.msgs.eta = varis.msgs.eta * mask
    varis.msgs.Lam = varis.msgs.Lam * mask[..., None]

    return facs, varis


@jax.jit
def huber(e, t):
    x = jnp.sqrt(e)
    return jnp.where(x <= t, 1.0, t / x)


@jax.jit
def l2(e, _):
    return 1.0

In [66]:
def build_pose_slam_graph(N, prior_meas, between_meas, prior_std=0.05, odom_std=0.05, Ni_v=10, D=2):
    """
    Build a 2D pose-SLAM factor graph with:
    - N variable nodes (each 2D position)
    - Prior measurements (with strong precision)
    - Between measurements (with moderate precision from noise_std)
    
    Parameters:
    - N: number of variables
    - prior_meas: list of (i, z) where z is the prior measurement at node i
    - between_meas: list of (i, j, z) where z is relative measurement from i to j
    - noise_std: standard deviation for between measurements (list or array of length D)
    - Ni_v: number of factor connections per variable (default 5)
    - D: dimension of each variable (default 2)
    
    Returns:
    - varis: Variable object
    - prior_facs: Factor object for priors
    - between_facs: Factor object for between factors
    """

    # === Step 1: Initialize Variable nodes ===
    var_ids = jnp.arange(N, dtype=jnp.int32)
    belief = Gaussian(jnp.zeros((N, D)), jnp.tile(jnp.eye(D), (N, 1, 1)))  # initial mean 0, covariance I
    msgs = Gaussian(jnp.zeros((N, Ni_v, D)), jnp.zeros((N, Ni_v, D, D)))  # messages (eta, Lambda) to each factor port
    adj_factor_idx = -jnp.ones((N, Ni_v), dtype=jnp.int32)  # -1 indicates no connected factor at this port

    varis = Variable(var_ids, belief, msgs, adj_factor_idx)

    # === Step 2: Build Prior Factors (strong precision for anchoring the graph) ===
    prior_factor_id = []
    prior_z = []
    prior_z_Lam = []
    prior_threshold = []
    prior_adj_var_id = []
    prior_adj_var_idx = []

    fac_counter = 0  # global factor ID counter

    for (i, z) in prior_meas:
        prior_factor_id.append(fac_counter)
        prior_z.append(jnp.array(z))

        # Very weak prior: large noise variance -> small precision
        prior_z_Lam.append(jnp.eye(D) / prior_std)  # shape (D, D)

        prior_threshold.append(1.0)
        prior_adj_var_id.append([i])     # only connected to variable i
        prior_adj_var_idx.append([0])     # use port 0 for prior

        varis.adj_factor_idx = varis.adj_factor_idx.at[i, 0].set(fac_counter)
        fac_counter += 1

    prior_facs = Factor(
        factor_id=jnp.array(prior_factor_id),
        z=jnp.stack(prior_z),
        z_Lam=jnp.stack(prior_z_Lam),
        threshold=jnp.array(prior_threshold),
        potential=None,
        adj_var_id=jnp.array(prior_adj_var_id),
        adj_var_idx=jnp.array(prior_adj_var_idx),
    )

    # === Step 3: Build Between Factors (relative pose measurements) ===
    between_factor_id = []
    between_z = []
    between_z_Lam = []
    between_threshold = []
    between_adj_var_id = []
    between_adj_var_idx = []

    for (i, j, z) in between_meas:
        between_factor_id.append(fac_counter)
        between_z.append(jnp.array(z))

        # Between-factor noise: use provided noise_std to compute precision
        between_z_Lam.append(jnp.diag(1.0 / (jnp.ones(D)*odom_std)   ))  # shape (D, D)

        between_threshold.append(1.0)

        # Assign first empty port >=1 to variable i and j
        port_i = int(jnp.argmax(varis.adj_factor_idx[i, 1:] == -1)) + 1
        port_j = int(jnp.argmax(varis.adj_factor_idx[j, 1:] == -1)) + 1
        varis.adj_factor_idx = varis.adj_factor_idx.at[i, port_i].set(fac_counter)
        varis.adj_factor_idx = varis.adj_factor_idx.at[j, port_j].set(fac_counter)

        between_adj_var_id.append([i, j])
        between_adj_var_idx.append([port_i, port_j])
        fac_counter += 1

    between_facs = Factor(
        factor_id=jnp.array(between_factor_id),
        z=jnp.stack(between_z),
        z_Lam=jnp.stack(between_z_Lam),
        threshold=jnp.array(between_threshold),
        potential=None,
        adj_var_id=jnp.array(between_adj_var_id),
        adj_var_idx=jnp.array(between_adj_var_idx),
    )

    return varis, prior_facs, between_facs


In [67]:
def generate_grid_slam_data(H=16, W=16, dx=1.0, dy=1.0, prior_noise_std=0.05, odom_noise_std=0.05, seed=0):
    """
    Generate 2D SLAM data over a regular H x W grid.

    Each variable is a node located at position (j*dx, i*dy), where i is row index and j is column index.
    Relative pose measurements (between factors) are added between horizontal and vertical neighbors.
    Each variable also receives a weak but accurate prior to ensure the global graph is well-constrained.

    Args:
        H: number of rows in the grid
        W: number of columns in the grid
        dx: horizontal spacing between grid points
        dy: vertical spacing between grid points
        odom_noise_std: standard deviation of noise added to relative measurements (between factors)
        prior_std: standard deviation for prior factors (should be much larger than odom_noise_std to make priors weak)
        seed: random seed for reproducibility

    Returns:
        positions: (N, 2) array of ground-truth positions (N = H * W)
        prior_meas: list of (i, z) where i is variable index and z is its true position
        between_meas: list of (i, j, z) where (i, j) is a measurement edge and z is the noisy relative pose from i to j
    """
    np.random.seed(seed)
    N = H * W  # total number of variables

    # Step 1: Generate ground-truth positions on the grid
    positions = []
    for i in range(H):
        for j in range(W):
            x = j * dx
            y = i * dy
            positions.append([x, y])
    positions = np.array(positions)  # shape (N, 2)

    # Step 2: Add weak but accurate prior for each variable
    prior_meas = []
    for idx, pos in enumerate(positions):
        noise = np.random.randn(2) * prior_noise_std
        z = (pos + noise).tolist()
        prior_meas.append((idx, z))  # accurate measurement with weak information will be set in z_Lam later

    # Step 3: Add noisy relative pose measurements (between factors)
    between_meas = []
    for i in range(H):
        for j in range(W):
            idx = i * W + j  # flat index of (i, j)

            # Horizontal neighbor (i, j) -> (i, j+1)
            if j < W - 1:
                nbr = i * W + (j + 1)
                rel = positions[nbr] - positions[idx]  # ideal relative translation
                noise = np.random.randn(2) * odom_noise_std
                z = (rel + noise).tolist()
                between_meas.append((idx, nbr, z))

            # Vertical neighbor (i, j) -> (i+1, j)
            if i < H - 1:
                nbr = (i + 1) * W + j
                rel = positions[nbr] - positions[idx]
                noise = np.random.randn(2) * odom_noise_std
                #z = (rel + noise).tolist()
                z = noise.tolist()
                between_meas.append((idx, nbr, z))

    return positions, prior_meas, between_meas


In [68]:
def gbp_solve(varis, prior_facs, between_facs, num_iters=50, visualize=False, prior_h=h_fn, between_h=h2_fn):
    energy_log = []
    
    for i in range(num_iters):
        # Step 1: Variable update
        varis, vtof_msgs, linpoints = update_variable(varis)
        print(vtof_msgs.eta.shape)

        # Step 2: Factor update
        prior_facs, varis = update_factor(prior_facs, varis, vtof_msgs, linpoints, prior_h, l2)
        between_facs, varis = update_factor(between_facs, varis, vtof_msgs, linpoints, between_h, l2)


        if visualize:
            # Step 3: Energy computation
            prior_energy = jnp.sum(jax.vmap(factor_energy, in_axes=(0, 0, None))(
                prior_facs, linpoints[prior_facs.adj_var_id[:, 0]], prior_h
            ))
    
            between_energy = jnp.sum(jax.vmap(factor_energy, in_axes=(0, 0, None))(
                between_facs, linpoints[between_facs.adj_var_id], between_h
            ))
    
            energy = prior_energy + between_energy
            energy_log.append(energy)
        

    # Step 4: Keep linpoints
    linpoints = jax.vmap(lambda x: x.mu())(varis.belief)  
        
    return varis, prior_facs, between_facs, np.array(energy_log), linpoints


In [69]:
positions, prior_meas, between_meas = generate_grid_slam_data(H=16, W=16, prior_noise_std=0.5, odom_noise_std=0.5)
varis, prior_facs, between_facs = build_pose_slam_graph(N=256, prior_meas=prior_meas, between_meas=between_meas, 
                                                        prior_std=1, odom_std=0.1,
                                                        Ni_v=10, D=2)

cpu_device = jax.devices("cpu")[0]
varis= jax.device_put(varis, cpu_device)
prior_facs = jax.device_put(prior_facs, cpu_device)
between_facs = jax.device_put(between_facs, cpu_device)

"""
print("var_id:", varis.var_id.shape)            # should be (N,)
print("belief.eta:", varis.belief.eta.shape)    # should be (N, D)
print("belief.Lam:", varis.belief.Lam.shape)    # should be (N, D, D)
print("msgs.eta:", varis.msgs.eta.shape)        # should be (N, Ni_v, D)
print("msgs.Lam:", varis.msgs.Lam.shape)        # should be (N, Ni_v, D, D)
print("adj_factor_idx:", varis.adj_factor_idx.shape)  # should be (N, Ni_v)
"""

varis, prior_facs, between_facs, energy_log, linpoints = gbp_solve(
    varis, prior_facs, between_facs, num_iters=5, visualize=True
)
energy_log

(256, 10, 2)
ftov_msgs = Gaussian(eta=[[[ 8.82026196e-01  2.00078607e-01]]

 [[ 1.48936903e+00  1.12044656e+00]]

 [[ 2.93377900e+00 -4.88638937e-01]]

 [[ 3.47504425e+00 -7.56786019e-02]]

 [[ 3.94839048e+00  2.05299258e-01]]

 [[ 5.07202196e+00  7.27136731e-01]]

 [[ 6.38051891e+00  6.08375072e-02]]

 [[ 7.22193146e+00  1.66837171e-01]]

 [[ 8.74703979e+00 -1.02579132e-01]]

 [[ 9.15653419e+00 -4.27047879e-01]]

 [[ 8.72350502e+00  3.26809287e-01]]

 [[ 1.14322186e+01 -3.71082515e-01]]

 [[ 1.31348772e+01 -7.27182865e-01]]

 [[ 1.30228796e+01 -9.35919285e-02]]

 [[ 1.47663898e+01  7.34679401e-01]]

 [[ 1.50774736e+01  1.89081267e-01]]

 [[-4.43892866e-01  9.60176624e-03]]

 [[ 8.26043904e-01  1.07817447e+00]]

 [[ 2.61514544e+00  1.60118997e+00]]

 [[ 2.80633664e+00  8.48848641e-01]]

 [[ 3.47572351e+00  2.89991021e-01]]

 [[ 4.14686489e+00  1.97538769e+00]]

 [[ 5.74517393e+00  7.80962825e-01]]

 [[ 6.37360239e+00  1.38874519e+00]]

 [[ 7.19305086e+00  8.93629849e-01]]

 [[ 8.552267

array([22141.648 ,  4695.3857,  2141.432 ,  1835.2019,  1699.0979],
      dtype=float32)

In [70]:
energy_log

array([22141.648 ,  4695.3857,  2141.432 ,  1835.2019,  1699.0979],
      dtype=float32)

In [71]:
def h3_fn(x):
    """
    Predicts measurement h(x) for coarse prior.

    Input:
        x: (8,) → 4 stacked fine-level variables (each 2D)
    Output:
        z_hat: (16,) = [x0, x1, x2, x3, x1-x0, x2-x0, x3-x1, x3-x2]
    """
    x = x.reshape(-1)
    
    x0 = x[0:2]
    x1 = x[2:4]
    x2 = x[4:6]
    x3 = x[6:8]

    # 4 priors (just xi)
    z_hat_0 = x0
    z_hat_1 = x1
    z_hat_2 = x2
    z_hat_3 = x3

    # 4 internal between
    z_hat_4 = x1 - x0
    z_hat_5 = x2 - x0
    z_hat_6 = x3 - x1
    z_hat_7 = x3 - x2

    return jnp.concatenate([
        z_hat_0, z_hat_1, z_hat_2, z_hat_3,
        z_hat_4, z_hat_5, z_hat_6, z_hat_7
    ])


def h4_fn(xs):
    """
    Predicts coarse between measurement h(xs) where:
      - xs[0] is coarse variable i (8D)
      - xs[1] is coarse variable j (8D)
    Fixed: uses v01→v02 and v11→v12 edges
    
    Returns:
        z_hat: shape (4,) = two 2D relative positions
    """
    jax.debug.print("xs = {}", xs)
    xi = xs[0]
    xj = xs[1]

    # First residual: xj[0:2] - xi[2:4]  (v02 - v01)
    r1 = xj[0:2] - xi[2:4]

    # Second residual: xj[4:6] - xi[6:8] (v12 - v11)
    r2 = xj[4:6] - xi[6:8]

    jax.debug.print("xs = {}", xs)
    return jnp.concatenate([r1, r2])  # shape (4,)


def h5_fn(xs):
    """
    Predicts coarse between measurement h(xs) where:
      - xs[0] is coarse variable i (8D)
      - xs[1] is coarse variable j (8D)
    Fixed: uses v10→v20 and v11→v21 edges

    Returns:
        z_hat: shape (4,) = two 2D relative positions
    """
    xi = xs[0]
    xj = xs[1]

    # First residual: xj[0:2] - xi[4:6]  (v20 - v10)
    r1 = xj[0:2] - xi[4:6]

    # Second residual: xj[2:4] - xi[6:8] (v21 - v11)
    r2 = xj[2:4] - xi[6:8]

    return jnp.concatenate([r1, r2])  # shape (4,)


In [None]:
def build_coarse_slam_graph(
    varis_fine: Variable,
    prior_facs_fine: Factor,
    between_facs_fine: Factor,
    H: int, W: int,
    stride: int = 2,
    prior_std: float = 1.0,
    between_std: float = 0.1,
) -> Tuple[Variable, Factor, Factor]:
    D = 2
    patch_map: Dict[int, List[int]] = {}
    fine_to_patch: Dict[int, Tuple[int, int]] = {}
    coarse_var_id = 0

    coarse_beliefs = []
    coarse_msgs_eta = []
    coarse_msgs_Lam = []
    coarse_adj_factor_idx = []


    # === 1. Build Coarse Variables ===
    for i in range(0, H - 1, stride):
        for j in range(0, W - 1, stride):
            v00 = i * W + j
            v01 = v00 + 1
            v10 = v00 + W
            v11 = v10 + 1
            patch = [v00, v01, v10, v11]
            patch_map[coarse_var_id] = patch
            for k, vid in enumerate(patch):
                fine_to_patch[vid] = (coarse_var_id, k)

            Ni_v = 10
            eta = jnp.zeros((8))  # 8D for coarse variable
            Lam = jnp.zeros((8, 8))
            coarse_beliefs.append(Gaussian(eta, Lam))


            coarse_msgs_eta.append(jnp.zeros((Ni_v, 8)))
            coarse_msgs_Lam.append(jnp.zeros((Ni_v, 8, 8)))
            coarse_adj_factor_idx.append(-jnp.ones(Ni_v, dtype=jnp.int32))
            coarse_var_id += 1

    varis_coarse = Variable(
        var_id=jnp.arange(len(patch_map)),
        belief=tree_stack(coarse_beliefs, axis=0),
        msgs=Gaussian(jnp.stack(coarse_msgs_eta), jnp.stack(coarse_msgs_Lam)),
        adj_factor_idx=jnp.stack(coarse_adj_factor_idx),
    )

    # === 2. Build Coarse Priors ===
    fine_between_dict = {
        (int(i), int(j)): k for k, (i, j) in enumerate(between_facs_fine.adj_var_id)
    }

    fine_between_dict.update({(int(j), int(i)): k for k, (i, j) in enumerate(between_facs_fine.adj_var_id)})

    prior_ids, prior_zs, prior_zLams = [], [], []
    adj_var_ids, adj_var_idxs = [], []
    factor_id_counter = 0

    for patch_id, patch in patch_map.items():
        residuals = []
        precisions = []

        for v in patch:
            mask = (prior_facs_fine.adj_var_id[:, 0] == v)
            if jnp.any(mask):
                i = jnp.argmax(mask)
                z_i = prior_facs_fine.z[i]
                z_Lam_i = prior_facs_fine.z_Lam[i]
            else:
                z_i = jnp.zeros((D,))
                z_Lam_i = (1. / (prior_std ** 2)) * jnp.eye(D)
            residuals.append(z_i)
            precisions.append(z_Lam_i)

        edge_indices = [(0, 1), (0, 2), (1, 3), (2, 3)]
        for i, j in edge_indices:
            a, b = patch[i], patch[j]
            key = (a, b)
            if key in fine_between_dict:
                k = fine_between_dict[key]
                a_k, b_k = between_facs_fine.adj_var_id[k]
                z = between_facs_fine.z[k]
                z_Lam = between_facs_fine.z_Lam[k]
                if a_k == b:
                    z = -z
                
                residuals.append(z)
                precisions.append(z_Lam)
            else:
                residuals.append(jnp.zeros((D,)))
                precisions.append((1. / (between_std ** 2)) * jnp.eye(D))

        z = jnp.concatenate(residuals)
        z_Lam = jax.scipy.linalg.block_diag(*precisions)


        prior_ids.append(factor_id_counter)
        prior_zs.append(z)
        prior_zLams.append(z_Lam)
        adj_var_ids.append(jnp.array([patch_id]))
        adj_var_idxs.append(jnp.array([0]))

        varis_coarse.adj_factor_idx = varis_coarse.adj_factor_idx.at[patch_id, 0].set(factor_id_counter)
        factor_id_counter += 1


    prior_facs_coarse = Factor(
        factor_id=jnp.array(prior_ids, dtype=jnp.int32),
        z=jnp.stack(prior_zs),
        z_Lam=jnp.stack(prior_zLams),
        threshold=jnp.ones((len(prior_ids),)),
        potential=None,
        adj_var_id=jnp.stack(adj_var_ids),
        adj_var_idx=jnp.stack(adj_var_idxs),
    )

    # === Build Horizontal & Vertical Between Factors Separately ===
    horizontal_ids, horizontal_zs, horizontal_zLams = [], [], []
    horizontal_adj_ids, horizontal_adj_idxs = [], []

    vertical_ids, vertical_zs, vertical_zLams = [], [], []
    vertical_adj_ids, vertical_adj_idxs = [], []

    height = H // stride
    width = W // stride

    for row in range(height):
        for col in range(width):
            patch_i = row * width + col

            # Horizontal neighbor
            if col < width - 1:
                patch_j = patch_i + 1
                pi_patch, pj_patch = patch_map[patch_i], patch_map[patch_j]

                fine_pairs = [(pi_patch[1], pj_patch[0]), (pi_patch[3], pj_patch[2])]

                residuals = []
                precisions = []
                for a, b in fine_pairs:
                    key = (a, b)
                    if key in fine_between_dict:
                        k = fine_between_dict[key]
                        a_k, b_k = between_facs_fine.adj_var_id[k]
                        z = between_facs_fine.z[k]
                        z_Lam = between_facs_fine.z_Lam[k]
                        if a_k == b:
                            z = -z
                    else:
                        z = jnp.zeros((D,))
                        z_Lam = (1. / (between_std ** 2)) * jnp.eye(D)
                    residuals.append(z)
                    precisions.append(z_Lam)

                z = jnp.concatenate(residuals)
                z_Lam = jax.scipy.linalg.block_diag(*precisions)

                adj_id = jnp.array([patch_i, patch_j])
                port_i = int(jnp.argmax(varis_coarse.adj_factor_idx[patch_i] == -1))
                port_j = int(jnp.argmax(varis_coarse.adj_factor_idx[patch_j] == -1))
                adj_idx = jnp.array([port_i, port_j])

                horizontal_ids.append(factor_id_counter)
                horizontal_zs.append(z)
                horizontal_zLams.append(z_Lam)
                horizontal_adj_ids.append(adj_id)
                horizontal_adj_idxs.append(adj_idx)

                
                varis_coarse.adj_factor_idx = varis_coarse.adj_factor_idx.at[patch_i, port_i].set(factor_id_counter)
                varis_coarse.adj_factor_idx = varis_coarse.adj_factor_idx.at[patch_j, port_j].set(factor_id_counter)

                factor_id_counter += 1

            # Vertical neighbor
            if row < height - 1:
                patch_j = patch_i + width
                pi_patch, pj_patch = patch_map[patch_i], patch_map[patch_j]
                fine_pairs = [(pi_patch[2], pj_patch[0]), (pi_patch[3], pj_patch[1])]

                residuals = []
                precisions = []
                for a, b in fine_pairs:
                    key = (a, b)
                    if key in fine_between_dict:
                        k = fine_between_dict[key]
                        a_k, b_k = between_facs_fine.adj_var_id[k]
                        z = between_facs_fine.z[k]
                        z_Lam = between_facs_fine.z_Lam[k]
                        if a_k == b:
                            z = -z
                    else:
                        z = jnp.zeros((D,))
                        z_Lam = (1. / (between_std ** 2)) * jnp.eye(D)
                    residuals.append(z)
                    precisions.append(z_Lam)

                z = jnp.concatenate(residuals)
                z_Lam = jax.scipy.linalg.block_diag(*precisions)

                adj_id = jnp.array([patch_i, patch_j])
                port_i = int(jnp.argmax(varis_coarse.adj_factor_idx[patch_i] == -1))
                port_j = int(jnp.argmax(varis_coarse.adj_factor_idx[patch_j] == -1))
                adj_idx = jnp.array([port_i, port_j])

                vertical_ids.append(factor_id_counter)
                vertical_zs.append(z)
                vertical_zLams.append(z_Lam)
                vertical_adj_ids.append(adj_id)
                vertical_adj_idxs.append(adj_idx)

                varis_coarse.adj_factor_idx = varis_coarse.adj_factor_idx.at[patch_i, port_i].set(factor_id_counter)
                varis_coarse.adj_factor_idx = varis_coarse.adj_factor_idx.at[patch_j, port_j].set(factor_id_counter)


                factor_id_counter += 1

    horizontal_between_facs = Factor(
        factor_id=jnp.array(horizontal_ids, dtype=jnp.int32),
        z=jnp.stack(horizontal_zs),
        z_Lam=jnp.stack(horizontal_zLams),
        threshold=jnp.ones((len(horizontal_ids),)),
        potential=None,
        adj_var_id=jnp.stack(horizontal_adj_ids),
        adj_var_idx=jnp.stack(horizontal_adj_idxs),
    )

    vertical_between_facs = Factor(
        factor_id=jnp.array(vertical_ids, dtype=jnp.int32),
        z=jnp.stack(vertical_zs),
        z_Lam=jnp.stack(vertical_zLams),
        threshold=jnp.ones((len(vertical_ids),)),
        potential=None,
        adj_var_id=jnp.stack(vertical_adj_ids),
        adj_var_idx=jnp.stack(vertical_adj_idxs),
    )

    return varis_coarse, prior_facs_coarse, horizontal_between_facs, vertical_between_facs

In [73]:
# === Step 2: construct coarse-level graph ===
varis_coarse, prior_facs_coarse, horizontal_between_facs, vertical_between_facs  = build_coarse_slam_graph(
    varis_fine=varis,
    prior_facs_fine=prior_facs,
    between_facs_fine=between_facs,
    H=16, W=16,
    stride = 2,
    prior_std = 1.0,
    between_std = 0.1,
)

# === Step 3: print coarse-level factor numbers to verify===
print("Coarse Variables:", len(varis_coarse.var_id))
print("Coarse Prior Factors:", len(prior_facs_coarse.factor_id))
print("Coarse Horizontal Between Factors:", len(horizontal_between_facs.factor_id))
print("Coarse vertical Between Factors:", len(vertical_between_facs.factor_id))

k = 0  # factor index，e.g. 0~63

print("=== Prior Factor #{} ===".format(k))
print("factor_id:", prior_facs_coarse.factor_id[k])
print("adj_var_id:", prior_facs_coarse.adj_var_id[k])     # coarse variable index
print("adj_var_idx:", prior_facs_coarse.adj_var_idx[k])   # always [0] since 1 variable
print("z (residual target):\n", prior_facs_coarse.z[k])   # shape: (16,)
print("z_Lam (precision matrix):\n", prior_facs_coarse.z_Lam[k])  # shape: (16, 16)
print("threshold:", prior_facs_coarse.threshold[k])

k = 0

print(f"\n=== Coarse Between Factor #{k} ===")
print("factor_id:", horizontal_between_facs.factor_id[k])

print("adj_var_id:", horizontal_between_facs.adj_var_id[k])     # coarse variable IDs
print("adj_var_idx:", horizontal_between_facs.adj_var_idx[k])   # slot idxs on both vars

print("z (residual):", horizontal_between_facs.z[k])            # shape: (4,)
print("z_Lam (precision):\n", horizontal_between_facs.z_Lam[k])  # shape: (4,4)

print("threshold:", horizontal_between_facs.threshold[k])


[0, 1, 16, 17] [2, 3, 18, 19]
[(1, 2), (17, 18)]
[2, 3, 18, 19] [4, 5, 20, 21]
[(3, 4), (19, 20)]
[4, 5, 20, 21] [6, 7, 22, 23]
[(5, 6), (21, 22)]
[6, 7, 22, 23] [8, 9, 24, 25]
[(7, 8), (23, 24)]
[8, 9, 24, 25] [10, 11, 26, 27]
[(9, 10), (25, 26)]
[10, 11, 26, 27] [12, 13, 28, 29]
[(11, 12), (27, 28)]
[12, 13, 28, 29] [14, 15, 30, 31]
[(13, 14), (29, 30)]
[32, 33, 48, 49] [34, 35, 50, 51]
[(33, 34), (49, 50)]
[34, 35, 50, 51] [36, 37, 52, 53]
[(35, 36), (51, 52)]
[36, 37, 52, 53] [38, 39, 54, 55]
[(37, 38), (53, 54)]
[38, 39, 54, 55] [40, 41, 56, 57]
[(39, 40), (55, 56)]
[40, 41, 56, 57] [42, 43, 58, 59]
[(41, 42), (57, 58)]
[42, 43, 58, 59] [44, 45, 60, 61]
[(43, 44), (59, 60)]
[44, 45, 60, 61] [46, 47, 62, 63]
[(45, 46), (61, 62)]
[64, 65, 80, 81] [66, 67, 82, 83]
[(65, 66), (81, 82)]
[66, 67, 82, 83] [68, 69, 84, 85]
[(67, 68), (83, 84)]
[68, 69, 84, 85] [70, 71, 86, 87]
[(69, 70), (85, 86)]
[70, 71, 86, 87] [72, 73, 88, 89]
[(71, 72), (87, 88)]
[72, 73, 88, 89] [74, 75, 90, 91]
[(7

In [50]:
print(varis.msgs.eta.shape, varis.msgs.Lam.shape, varis.adj_factor_idx.shape)
print(varis_coarse.msgs.eta.shape, varis_coarse.msgs.Lam.shape, varis_coarse.adj_factor_idx.shape)

(256, 10, 2) (256, 10, 2, 2) (256, 10)
(64, 10, 8) (64, 10, 8, 8) (64, 10)


In [51]:
cpu_device = jax.devices("cpu")[0]
varis_coarse= jax.device_put(varis_coarse, cpu_device)
prior_facs_coarse = jax.device_put(prior_facs_coarse, cpu_device)
horizontal_between_facs = jax.device_put(horizontal_between_facs, cpu_device)



print("var_id:", varis_coarse.var_id.shape)            # should be (N,)
print("belief.eta:", varis_coarse.belief.eta.shape)    # should be (N, D)
print("belief.Lam:", varis_coarse.belief.Lam.shape)    # should be (N, D, D)
print("msgs.eta:", varis_coarse.msgs.eta.shape)        # should be (N, Ni_v, D)
print("msgs.Lam:", varis_coarse.msgs.Lam.shape)        # should be (N, Ni_v, D, D)
print("adj_factor_idx:", varis_coarse.adj_factor_idx.shape)  # should be (N, Ni_v)



"""
varis_coarse, prior_facs_coarse, between_facs_coarse, energy_log, linpoints = gbp_solve(
    varis_coarse, prior_facs_coarse, between_facs_coarse, num_iters=50, visualize=False, prior_h=h3_fn, between_h=h4_fn
)
"""

varis_coarse, vtof_msgs, linpoints = update_variable(varis_coarse)


# Step 2: Factor update
prior_facs_coarse, varis_coarse = update_factor(prior_facs_coarse, varis_coarse, vtof_msgs, linpoints, h3_fn, l2)
horizontal_between_facs, varis_coarse = update_factor(horizontal_between_facs, varis_coarse, vtof_msgs, linpoints, h4_fn, l2)
#vertical_between_facs, varis_coarse = update_factor(vertical_between_facs, varis_coarse, vtof_msgs, linpoints, h5_fn, l2)


varis_coarse, vtof_msgs, linpoints = update_variable(varis_coarse)


# Step 2: Factor update
prior_facs_coarse, varis_coarse = update_factor(prior_facs_coarse, varis_coarse, vtof_msgs, linpoints, h3_fn, l2)

var_id: (64,)
belief.eta: (64, 8)
belief.Lam: (64, 8, 8)
msgs.eta: (64, 10, 8)
msgs.Lam: (64, 10, 8, 8)
adj_factor_idx: (64, 10)
ftov_msgs = Gaussian(eta=[[[ 8.43366146e-01  9.16410732e+00  3.40691209e+00 -1.74360199e+01
   -1.66935730e+01  2.70799065e+00  1.51968412e+01  7.97222424e+00]]

 [[ 2.45289898e+00  8.52161407e+00  7.67830324e+00 -1.17027130e+01
   -1.00919342e+01 -6.76845407e+00  1.17910366e+01  1.18352747e+01]]

 [[ 1.69943447e+01 -6.12245750e+00  9.62253380e+00  7.97455406e+00
   -2.04388561e+01  1.02271676e-01  1.04649782e+01  1.24344635e+00]]

 [[-1.29676986e+00  1.52147532e+00  1.71238022e+01 -8.33569241e+00
   -6.18527651e-01  1.12449431e+00  1.05127211e+01  8.08710575e+00]]

 [[-4.72071981e+00  1.49125242e+01  1.64136639e+01 -3.12783527e+00
    3.13648415e+00 -3.59879398e+00  1.88194637e+01 -6.62844086e+00]]

 [[-2.59870696e+00 -4.82444620e+00  3.28841858e+01  2.09881043e+00
   -2.98598838e+00  4.87510395e+00  1.35867376e+01 -5.69891691e-01]]

 [[-7.89675474e+00  8.40

In [12]:
prior_energy = jnp.sum(jax.vmap(factor_energy, in_axes=(0, 0, None))(
    prior_facs_coarse, linpoints[prior_facs_coarse.adj_var_id[:, 0]], h3_fn
))

prior_energy

Array(nan, dtype=float32)

In [13]:
varis_coarse.belief.eta

Array([[nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, na