In [1]:
import numpy as np

def make_two_cliques_graph(prior_sigma=10.0, odom_sigma=10.0, rng=None):
    """
    构建一个确定性的小图：
      节点 0..3 组成左 clique，全连接 + prior
      节点 4..6 组成右 clique，全连接 + prior
      跨边 (3,4)
    返回: nodes, edges, factor_graph
    """
    if rng is None:
        rng = np.random.default_rng(0)

    nodes, edges = [], []

    # 固定 node 坐标，便于可视化
    positions = {
        0: (0, 0),
        1: (0, 10),
        2: (0, 20),
        3: (0, 30),
        4: (50, 0),
        5: (50, 10),
        6: (50, 20),
    }

    # 添加节点
    for i in range(7):
        px, py = positions[i]
        nodes.append({
            "data": {"id": f"{i}", "layer": 0, "dim": 2},
            "position": {"x": float(px), "y": float(py)}
        })

    # clique: 0-3
    for i in range(4):
        for j in range(i+1, 4):
            edges.append({"data": {"source": f"{i}", "target": f"{j}"}})
    # clique: 4-6
    for i in range(4, 7):
        for j in range(i+1, 7):
            edges.append({"data": {"source": f"{i}", "target": f"{j}"}})
    # 跨边 (3,4)
    edges.append({"data": {"source": "3", "target": "4"}})

    # 每个节点加 prior
    for i in range(7):
        edges.append({"data": {"source": f"{i}", "target": "prior"}})

    # ----------------- 构建 FactorGraph -----------------
    fg = FactorGraph(nonlinear_factors=False, eta_damping=0)
    var_nodes = []
    I2 = np.eye(2)
    prior_noises, odom_noises = {}, {}

    # 为边生成噪声
    for e in edges:
        s, t = e["data"]["source"], e["data"]["target"]
        if t == "prior":
            prior_noises[int(s)] = rng.normal(0.0, prior_sigma, size=2)
        else:
            odom_noises[(int(s), int(t))] = rng.normal(0.0, odom_sigma, size=2)

    # variable nodes
    for i in range(7):
        v = VariableNode(i, dofs=2)
        v.GT = np.array(positions[i], dtype=float)
        v.prior.lam = 1e-10 * I2
        v.prior.eta = np.zeros(2)
        var_nodes.append(v)
    fg.var_nodes = var_nodes
    fg.n_var_nodes = len(var_nodes)

    # 定义测量函数
    def meas_fn_unary(x, *a): return x
    def jac_fn_unary(x, *a): return np.eye(2)
    def meas_fn(xy, *a): return xy[2:] - xy[:2]
    def jac_fn(xy, *a): return np.array([[-1,0,1,0],[0,-1,0,1]], dtype=float)

    factors = []
    fid = 0
    for e in edges:
        s, t = e["data"]["source"], e["data"]["target"]
        if t == "prior":
            i = int(s)
            vi = var_nodes[i]
            z = vi.GT + prior_noises[i]
            z_lambda = np.eye(len(z))/ (prior_sigma**2)
            f = Factor(fid, [vi], z, z_lambda, meas_fn_unary, jac_fn_unary)
            f.type = "prior"
            f.compute_factor(linpoint=z, update_self=True)
            factors.append(f)
            vi.adj_factors.append(f)
            fid += 1
        else:
            i, j = int(s), int(t)
            vi, vj = var_nodes[i], var_nodes[j]
            z = (vj.GT - vi.GT) + odom_noises[(i, j)]
            z_lambda = np.eye(len(z))/ (odom_sigma**2)
            f = Factor(fid, [vi, vj], z, z_lambda, meas_fn, jac_fn)
            f.type = "odom"
            lin = np.r_[vi.GT, vj.GT]
            f.compute_factor(linpoint=lin, update_self=True)
            factors.append(f)
            vi.adj_factors.append(f)
            vj.adj_factors.append(f)
            fid += 1

    fg.factors = factors
    fg.n_factor_nodes = len(factors)
    return nodes, edges, fg


In [2]:
"""
    Defines classes for variable nodes, factor nodes and edges and factor graph.
"""

import numpy as np
import time
import scipy.linalg

from utils.gaussian import NdimGaussian
from utils.distances import bhattacharyya, mahalanobis

#from amg import classes as amg_cls
#from amg import functions as amg_fnc

class FactorGraph:
    def __init__(self,
                 nonlinear_factors=True,
                 eta_damping=0.0,
                 beta=None,
                 num_undamped_iters=None,
                 min_linear_iters=None,
                 wild_thresh=0):

        self.var_nodes = []
        self.factors = []

        self.n_var_nodes = 0
        self.n_factor_nodes = 0
        self.n_edges = 0
        self.n_msgs = 0

        self.nonlinear_factors = nonlinear_factors

        self.eta_damping = eta_damping

        self.Q = []
        self.b_wild = False
        self.wild_thresh = wild_thresh
        self.multigrid_vars = [[]]
        self.multigrid_factors = [[]]
        self.multigrid = False
        self.conv_width = 1
        self.conv_stride = 1

        self.energy_history = []
        self.error_history = []
        self.nmsgs_history = []
        self.mus = []

        if nonlinear_factors:
            # For linearising nonlinear measurement factors.
            self.beta = beta  # Threshold change in mean of adjacent beliefs for relinearisation.
            self.num_undamped_iters = num_undamped_iters  # Number of undamped iterations after relinearisation before damping is set to 0.4
            self.min_linear_iters = min_linear_iters  # Minimum number of linear iterations before a factor is allowed to realinearise.

    def energy(self, vars=None):
        """
            Computes the sum of all of the squared errors in the graph using the appropriate local loss function.
        """
        # if slice_e is None:
        #     slice_e = slice(len(self.factors))
        # energy = 0
        # for factor in self.factors[slice_e]:
        #     # Variance of Gaussian noise at each factor is weighting of each term in squared loss.
        #     energy += 0.5 * np.linalg.norm(factor.compute_residual()) ** 2
        # return energy
        if vars is None:
            vars = self.var_nodes
        energy = 0
        for var in vars:
            if var.type != "multigrid":
                # Variance of Gaussian noise at each factor is weighting of each term in squared loss.
                energy += 0.5 * np.linalg.norm(var.residual) ** 2
        return energy
    
    def energy_map(self, include_priors: bool = True, include_factors: bool = True) -> float:
        """
        实际上是距离平方和
        """
        total = 0.0

        for v in self.var_nodes[:self.n_var_nodes]:
            gt = np.asarray(v.GT, dtype=float)
            r = np.asarray(v.mu, dtype=float) - gt
            total += 0.5 * float(r.T @ r)

        return total

    def compute_all_messages(self, factors=None, level=None, local_relin=True):
        if factors is None:
            factors = self.factors[:self.n_factor_nodes]
        if level is not None:
            factors = self.multigrid_factors[level]
        for count, factor in enumerate(factors):
            if factor.active:
                # If relinearisation is local then damping is also set locally per factor.
                if self.nonlinear_factors and local_relin:
                    if factor.iters_since_relin == self.num_undamped_iters:
                        factor.eta_damping = self.eta_damping
                    factor.compute_messages(factor.eta_damping)
                else:
                    factor.compute_messages(self.eta_damping)
                    self.n_msgs += 2

    def compute_all_smoothing_messages(self, factors=None, level=None, local_relin=True):
        if factors is None:
            factors = self.factors[:self.n_factor_nodes]
        if level is not None:
            factors = self.multigrid_factors[level]
        for count, factor in enumerate(factors):
            factor.smoothing_compute_messages(self.eta_damping)

    def update_all_beliefs(self, vars=None, level=None, smoothing=False):
        if vars is None:
            vars = self.var_nodes[:self.n_var_nodes]
        if level is not None:
            vars = self.multigrid_vars[level]

        for var in vars:
            if var.active:
                if smoothing:
                    var.update_smooth_belief()
                else:
                    var.update_belief()


    def update_all_residuals(self, vars=None, level=None, smoothing=False):
        if vars is None:
            vars = self.var_nodes[:self.n_var_nodes]
        if level is not None:
            vars = self.multigrid_vars[level]

        for var in vars:
            if var.active:
                res = var.compute_residual()

    def restrict_all_residuals(self, vars=None, level=None, smoothing=False):
        if vars is None:
            vars = self.var_nodes[:self.n_var_nodes]
        if level is not None:
            vars = self.multigrid_vars[level]

        for var in vars:
            if var.active:
                var.multigrid.send_restricted_residual()

    def update_all_residual_etas(self, vars=None, level=None, smoothing=False):
        if vars is None:
            vars = self.var_nodes[:self.n_var_nodes]
        if level is not None:
            vars = self.multigrid_vars[level]

        for var in vars:
            if var.active:
                if var.type[0:5] == "multi":
                    for i_var in var.multigrid.interpolation_vars:
                        i_var.compute_residual()
                        i_var.multigrid.send_restricted_residual()
                    var.multigrid.update_eta()
                else:
                    print("You just tried to update the eta on a base variable... you should probably \
                        check something because this ain't it!")
                
    def prolongate_corrections(self, vars=None, level=None, smoothing=False):
        if vars is None:
            vars = self.var_nodes[:self.n_var_nodes]
        if level is not None:
            vars = self.multigrid_vars[level]

        for var in vars:
            if var.active:
                if var.type[0:5] == "multi":
                    var.multigrid.send_corrections()
                else:
                    print("You just tried to prolongate using a base variable... you should probably \
                        check something because this ain't it!")

    def compute_all_factors(self, factors=None, level=None):
        if factors is None:
            factors = self.factors[:self.n_factor_nodes]
        if level is not None:
            factors = self.multigrid_factors[level]
        for count, factor in enumerate(factors):
            factor.compute_factor()

    def relinearise_factors(self):
        """
            Compute the factor distribution for all factors for which the local belief mean has deviated a distance
            greater than beta from the current linearisation point.
            Relinearisation is only allowed at a maximum frequency of once every min_linear_iters iterations.
        """
        if self.nonlinear_factors:
            for factor in self.factors:
                adj_belief_means = np.array([])
                for belief in factor.adj_beliefs:
                    adj_belief_means = np.concatenate((adj_belief_means, 1/np.diagonal(belief.lam) * belief.eta))
                if np.linalg.norm(factor.linpoint - adj_belief_means) > self.beta and factor.iters_since_relin >= self.min_linear_iters:
                    factor.compute_factor(linpoint=adj_belief_means)
                    factor.iters_since_relin = 0
                    factor.eta_damping = 0.0
                else:
                    factor.iters_since_relin += 1

    def robustify_all_factors(self):
        for factor in self.factors[:self.n_factor_nodes]:
            factor.robustify_loss()

    def synchronous_iteration(self, factors=None, level=None, local_relin=True, robustify=False):
        if level is not None:
            vars = self.multigrid_vars[level]
            factors = self.multigrid_factors[level]
        else:
            vars = self.var_nodes[:self.n_var_nodes]
            factors = self.factors[:self.n_factor_nodes]

        if robustify:
            self.robustify_all_factors(factors)
        if self.nonlinear_factors and local_relin:
            self.relinearise_factors(factors)

        self.compute_all_messages(factors, local_relin=local_relin)
        time.sleep(1e-9)
        self.update_all_beliefs(vars)

    def synchronous_smooth(self, level=None, local_relin=True, robustify=False):
        if level is not None:
            vars = self.multigrid_vars[level]
            factors = self.multigrid_factors[level]
        else:
            vars = self.var_nodes
            factors = self.factors
        if robustify:
            self.robustify_all_factors()
        if self.nonlinear_factors and local_relin:
            self.relinearise_factors()
        self.compute_all_smoothing_messages(local_relin=local_relin)
        self.update_all_beliefs(smoothing=True)
            

    def synchronous_loop(self, vis):
        i=0
        # self.get_means()
        while not vis.reset_event.isSet(): #i<1000 and 
            while vis.pause_event.isSet() and not vis.reset_event.isSet():
                time.sleep(0.5)

            self.visualisation_sync(vis)

            self.synchronous_iteration()
            self.update_all_residuals()

            i+=1
            av_dist = np.mean(np.linalg.norm(np.array([var.mu - var.GT for var in self.var_nodes if var.type != "multigrid"]),axis=1))
            self.energy_history.append(self.energy())
            self.error_history.append(av_dist)
            self.nmsgs_history.append(self.n_msgs)
            print(f'Iteration {i}  // Energy {self.energy_history[-1]:.6f} // ' 
                  f'Average error {av_dist:.4f} // msgs sent {self.n_msgs/1e6:.3f}x10^6')
            
            self.get_multigrid_stats()
            for level in range(len(self.n_active)):
                if self.n_active[level] > 0:
                    print(f'Multigrid stats // level {level} // {(self.n_coarse[level]/(len(self.multigrid_vars[level])))*100:.2f}% coarse ' \
                        f'// {(self.n_active[level]/(len(self.multigrid_vars[level])))*100:.2f}% active ' \
                        f'// {len(self.multigrid_vars[level])} total ')
            
            print('')        
            if vis.skip_event.isSet():
                vis.pause_event.set()
                vis.skip_event.clear()

    def vcycle_loop(self, vis):
        i=0
        # self.get_means()

        while  not vis.reset_event.isSet():
            while vis.pause_event.isSet() and not vis.reset_event.isSet():
                time.sleep(0.5)

            self.visualisation_sync(vis)

            # if i == 10:  # Number of damped iterations before applying undamping
            #     self.eta_damping = 0.0

            for _ in range(1):
                self.synchronous_iteration(level=0)
                time.sleep(1e-9)
            # self.update_all_residuals(level=0)
            # self.restrict_all_residuals(level=0)
            
            for level in range(1,len(self.multigrid_vars)):  #range(1, 5):
                self.update_all_residual_etas(level=level)
                self.update_all_beliefs(level=level) 
                for _ in range(1):
                    self.synchronous_iteration(level=level)
                    time.sleep(1e-9)
                self.update_all_residuals(level=level)
                #self.restrict_all_residuals(level=level)                
                
            for level in range(len(self.multigrid_vars)-1,0,-1):  #range(4,0,-1):
                for _ in range(1):
                    self.synchronous_iteration(level=level)
                    time.sleep(1e-9)
                self.prolongate_corrections(level=level)

            i+=1
            av_dist = np.mean(np.linalg.norm(np.array([var.mu - var.GT for var in self.var_nodes if var.type != "multigrid"]),axis=1))
            self.energy_history.append(self.energy())
            self.error_history.append(av_dist)
            self.nmsgs_history.append(self.n_msgs)
            print(f'Iteration {i}  // Energy {self.energy_history[-1]:.6f} // ' 
                  f'Average error {av_dist:.4f} // msgs sent {self.n_msgs/1e6:.3f}x10^6')
            
            self.get_multigrid_stats()
            for level in range(len(self.n_active)):
                if self.n_active[level] > 0:
                    print(f'Multigrid stats // level {level} // {(self.n_coarse[level]/(len(self.multigrid_vars[level])))*100:.2f}% coarse ' \
                        f'// {(self.n_active[level]/(len(self.multigrid_vars[level])))*100:.2f}% active ' \
                        f'// {len(self.multigrid_vars[level])} total ')
            
            print('')
            if vis.skip_event.isSet():
                vis.pause_event.set()
                vis.skip_event.clear()

    def wildfire_iteration(self, vis, local_relin=True, robustify=False):
        breakout_count = 0
        i = 0
        while not vis.reset_event.isSet():
            if vis.pause_event.isSet() and not vis.reset_event.isSet() or not self.Q:
                time.sleep(0.1)
                _ , new_factors = self.visualisation_sync(vis)
                if new_factors:
                    self.Q = new_factors
            else:
                _ , new_factors = self.visualisation_sync(vis)
                if new_factors:
                    self.Q[0:0] = new_factors

                self.Q[0].compute_messages(self.eta_damping)
                self.n_msgs += 2

                for count, var in enumerate(self.Q[0].adj_var_nodes):
                    var.update_belief()
                    var.compute_residual()
                    breakout_count += 1
                    if any(self.Q[0].messages_dist[count] > self.wild_thresh):
                        for f in var.adj_factors:
                            if f not in self.Q:
                                self.Q.append(f)

                self.Q.pop(0)

                if (self.n_msgs / 2) % len(self.factors) == 0:
                    i += 1
                    vis.read_event.clear()
                    av_dist = np.mean(np.linalg.norm(np.array([var.mu - var.GT for var in self.var_nodes if var.type != "multigrid"]),axis=1))
                    self.energy_history.append(self.energy())
                    self.error_history.append(av_dist)
                    self.nmsgs_history.append(self.n_msgs)
                    print(f'Iteration {i}  // Energy {self.energy_history[-1]:.6f} // ' 
                        f'Average error {av_dist:.4f} // msgs sent {self.n_msgs/1e6:.3f}x10^6')
                    
                    self.get_multigrid_stats()
                    for level in range(len(self.n_active)):
                        if self.n_active[level] > 0:
                            print(f'Multigrid stats // level {level} // {(self.n_coarse[level]/(len(self.multigrid_vars[level])))*100:.2f}% coarse ' \
                                f'// {(self.n_active[level]/(len(self.multigrid_vars[level])))*100:.2f}% active ' \
                                f'// {len(self.multigrid_vars[level])} total ')
                    
                    print('')

                    breakout_count = 0

                    if vis.skip_event.isSet():
                        vis.pause_event.set()
                        vis.skip_event.clear()

    def visualisation_sync(self, vis):
        if vis.n_factors > self.n_factor_nodes:
            while vis.write_event.is_set():
                time.sleep(0.001)
            vis.read_event.set()
            new_n_factors = vis.n_factors - self.n_factor_nodes
            new_n_vars = vis.n_vars - self.n_var_nodes
            
            new_vars = self.var_nodes[slice(int(len(self.var_nodes) - new_n_vars), int(len(self.var_nodes)))]
            self.multigrid_vars[0].extend(new_vars)
            new_factors = self.factors[slice(int(len(self.factors) - new_n_factors), int(len(self.factors)))]
            self.multigrid_factors[0].extend(new_factors)

            vars_to_update = []

            for factor in new_factors:
                if vis.b_wild:
                    factor.b_calc_mess_dist = True
                for adj_var in factor.adj_var_nodes:
                    adj_var.adj_factors.append(factor)
                    if adj_var not in vars_to_update:
                        vars_to_update.append(adj_var)
                
            for var in vars_to_update:
                var.update_belief()
            
            for factor in new_factors:
                factor.compute_factor()

            self.n_var_nodes = int(vis.n_vars)
            self.n_factor_nodes = int(vis.n_factors)

            #if new_vars and vis.b_multi: # i.e. if there are new vars
            #    amg_fnc.coarsen_graph(self, vars_to_update)


            vis.n_vars = int(self.n_var_nodes)
            vis.n_factors = int(self.n_factor_nodes)
            
            self.n_vars_active = int(len(self.var_nodes))
            nodes_removed = 0

            for var_id in range(self.n_vars_active):
                if vis.var_nodes[var_id - nodes_removed].type == 'dead':
                    self.multigrid_vars[vis.var_nodes[var_id - nodes_removed].multigrid.level].remove(vis.var_nodes[var_id - nodes_removed])
                    vis.var_nodes.pop(var_id - nodes_removed)
                    nodes_removed += 1

            self.n_vars_active = int(len(self.var_nodes))
            self.n_factors_active = int(len(self.factors))

            factors_removed = 0
            for factor_id in range(self.n_factors_active):
                if vis.factors[factor_id - factors_removed].type == 'dead':
                    self.multigrid_factors[int(vis.factors[factor_id - factors_removed].level)].remove(vis.factors[factor_id - factors_removed])
                    vis.factors.pop(factor_id - factors_removed)
                    factors_removed += 1

            self.n_factors_active = int(len(self.factors))

            # print("{:} node(s) removed : {:} factor(s) removed".format(nodes_removed ,factors_removed))
            # self.n_var_nodes = int(vis.n_vars)
            # self.n_factor_nodes = int(vis.n_factors)

            vis.read_event.clear()

            return new_vars, new_factors
                
        else:

            return None, None
    

    def joint_distribution_inf(self):
        """
            Get the joint distribution over all variables in the information form
            If nonlinear factors, it is taken at the current linearisation point.
        """

        eta = np.array([])
        lam = np.array([])
        var_ix = np.zeros(len(self.var_nodes)).astype(int)
        tot_n_vars = 0
        for var_node in self.var_nodes:
            var_ix[var_node.variableID] = int(tot_n_vars)
            tot_n_vars += var_node.dofs
            eta = np.concatenate((eta, var_node.prior.eta))
            if var_node.variableID == 0:
                lam = var_node.prior.lam
            else:
                lam = scipy.linalg.block_diag(lam, var_node.prior.lam)

        for count, factor in enumerate(self.factors):
            factor_ix = 0
            for adj_var_node in factor.adj_var_nodes:
                vID = adj_var_node.variableID
                # Diagonal contribution of factor
                eta[var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \
                    factor.factor.eta[factor_ix:factor_ix + adj_var_node.dofs]
                lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \
                    factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs]
                other_factor_ix = 0
                for other_adj_var_node in factor.adj_var_nodes:
                    if other_adj_var_node.variableID > adj_var_node.variableID:
                        other_vID = other_adj_var_node.variableID
                        # Off diagonal contributions of factor
                        lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs] += \
                            factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, other_factor_ix:other_factor_ix + other_adj_var_node.dofs]
                        lam[var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \
                            factor.factor.lam[other_factor_ix:other_factor_ix + other_adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs]
                    other_factor_ix += other_adj_var_node.dofs
                factor_ix += adj_var_node.dofs

        return eta, lam


    def joint_distribution_cov(self):
        """
            Get the joint distribution over all variables in the covariance.
            If nonlinear factors, it is taken at the current linearisation point.
        """
        eta, lam = self.joint_distribution_inf()
        sigma = np.linalg.inv(lam)
        mu = sigma @ eta
        return mu, sigma
    

    def get_multigrid_stats(self):
        self.n_coarse = [0 for _ in self.multigrid_vars]
        self.n_active = [0 for _ in self.multigrid_vars]
        self.n_fine = [0 for _ in self.multigrid_vars]
        for var in self.var_nodes:
            if var.active:
                self.n_active[var.multigrid.level] += 1
            if var.multigrid.classification == "coarse":
                self.n_coarse[var.multigrid.level] += 1
            elif var.multigrid.classification == "fine":
                self.n_fine[var.multigrid.level] += 1



    # def get_means(self, slice_m=None):
    #     """
    #         Get an array containing all current estimates of belief means.
    #     """
    #     if slice_m is None:
    #         slice_m = slice(0,len(self.var_nodes),1)
    #     if len(self.mus) != len(self.var_nodes):
    #         self.mus = [None]*(len(self.var_nodes))

    #     for index, var_node in enumerate(self.var_nodes[slice_m]):
    #         self.mus[index] =  [var_node.mu[0],var_node.mu[1]]


    #     return self.mus
    
    def get_sigmas(self):
        """
            Get an array containing all current estimates of belief sigmas.
        """
        sigmas = np.array([])
        for var_node in self.var_nodes:
            sigmas = np.concatenate((sigmas, var_node.Sigma[0]))
        return sigmas
    
    def get_residuals(self, level=None):
        """
            Get an array containing all current estimates of belief means.
        """
        if level is not None:
            slice_v = slice(int(self.multigrid_vars[level].var_ids[0]), int(self.multigrid_vars[level].var_ids[-1]+1))
        else:
            slice_v = slice(int(len(self.var_nodes)))

        for var in self.var_nodes[slice_v]:
            var.residual = np.zeros(var.dofs)
            for factor in var.adj_factors:
                residual = factor.compute_residual() * (1 - 2 * int(factor.adj_vIDs[1] == var.variableID))
                var.residual += residual[:2]

    def get_var_residuals(self, level=None):
        if level is not None:
            slice_v = slice(int(self.multigrid_vars[level].var_ids[0]), int(self.multigrid_vars[level].var_ids[-1]+1))
        else:
            slice_v = slice(int(len(self.var_nodes)))

        for var in self.var_nodes[slice_v]:
            var.residual = var.compute_residual()


class VariableNode:
    def __init__(self,
                 variable_id,
                 dofs, level=0):

        self.variableID = variable_id
        self.adj_factors = []
        self.InfoMat = []  # Row vector of prior Information vector in factor order
        self.EtaVec = []  # Vector of prior eta values
        #self.multigrid = amg_cls.mutligrid_var_info(self)
        self.type = "None specified"
        self.active = True

        # Node variables are position of landmark in world frame. Initialize variable nodes at origin
        self.mu = np.zeros(dofs)
        self.Sigma = np.zeros([dofs, dofs])
        self.residual = np.zeros(dofs)

        self.belief = NdimGaussian(dofs)

        self.prior = NdimGaussian(dofs)
        self.prior_lambda_end = -1  # -1 flag if the sigma of self.prior is prior_sigma_end
        self.prior_lambda_logdiff = -1

        self.dofs = dofs

    def update_belief(self):
        """
            Update local belief estimate by taking product of all incoming messages along all edges.
            Then send belief to adjacent factor nodes.
        """
        # Update local belief
        eta = self.prior.eta.copy()
        lam = self.prior.lam.copy()
        for factor in self.adj_factors:
            message_ix = factor.adj_var_nodes.index(self)
            eta_inward, lam_inward = factor.messages[message_ix].eta, factor.messages[message_ix].lam
            eta += eta_inward
            lam += lam_inward

        self.belief.eta = eta 
        self.belief.lam = lam

        self.Sigma = np.linalg.inv(self.belief.lam)
        self.mu = self.Sigma @ self.belief.eta
        
        # Send belief to adjacent factors
        for factor in self.adj_factors:
            belief_ix = factor.adj_var_nodes.index(self)
            factor.adj_beliefs[belief_ix].eta, factor.adj_beliefs[belief_ix].lam = self.belief.eta, self.belief.lam


    def update_smooth_belief(self):
        """
            Update local belief estimate by taking product of all incoming messages along all edges.
            Then send belief to adjacent factor nodes.
        """
        # Update local belief
        eta = self.prior.eta.copy()
        lam = self.prior.lam.copy()
        for factor in self.adj_factors:
            message_ix = factor.adj_vIDs.index(self.variableID)
            mu_inward, lam_inward = factor.messages[message_ix].eta, factor.messages[message_ix].lam
            factor.messages_prior[message_ix].eta = mu_inward  # Update messages for belief calculation now sync itr is complete
            factor.messages_prior[message_ix].lam = lam_inward  # If don't have a prior messages then its not truely parallel
            eta += np.diag(lam_inward) * mu_inward
            lam += lam_inward

        self.belief.eta = eta 
        self.belief.lam = lam
        self.Sigma = 1/np.diagonal(self.belief.lam)
        self.mu = self.Sigma * self.belief.eta
        
        # Send belief to adjacent factors
        for factor in self.adj_factors:
            belief_ix = factor.adj_vIDs.index(self.variableID)
            factor.adj_beliefs[belief_ix].eta, factor.adj_beliefs[belief_ix].lam = self.belief.eta, self.belief.lam


    def compute_residual(self):
        
        # Ax = np.zeros(2)

        # Ax += self.prior.lam @ self.mu

        # for factor in self.adj_factors:
        #     for var in factor.adj_var_nodes:
        #         if var.variableID != self.variableID:
        #             Ax += var.mu @ factor.factor.lam[0:2, 2:4]

        # d = self.prior.eta - Ax

        # res = self.prior.eta - self.prior.lam @ self.mu

        # for factor in self.adj_factors:
        #     # get factor residual. Second part flips the sign if the var is second var in the factor
        #     res += factor.compute_residual() * (1 - 2 * int(factor.adj_vIDs[1] == self.variableID))

        # self.residual = res

        res = self.prior.eta - self.prior.lam @ self.mu
        for factor in self.adj_factors:
            if factor.adj_vIDs.index(self.variableID) == 0:
                 res += factor.factor.eta[:self.dofs] - (factor.factor.lam[:self.dofs, :self.dofs] @ self.mu \
                        + factor.factor.lam[:self.dofs, self.dofs:] @ factor.adj_var_nodes[1].mu)
            else:
                res += factor.factor.eta[-self.dofs:] - (factor.factor.lam[-self.dofs:, -self.dofs:] @ self.mu  \
                        + factor.factor.lam[-self.dofs:, :-self.dofs] @ factor.adj_var_nodes[0].mu)
        
        self.residual = res

        return res


class Factor:
    def __init__(self,
                 factor_id,
                 adj_var_nodes,
                 measurement,
                 measurement_lambda,
                 meas_fn,
                 jac_fn,
                 loss=None,
                 mahalanobis_threshold=2,
                 wildfire=False,
                 *args):
        """
            n_stds: number of standard deviations from mean at which loss transitions to robust loss function.
        """

        self.factorID = factor_id

        self.dofs_conditional_vars = 0
        self.adj_var_nodes = adj_var_nodes
        self.adj_vIDs = []
        self.adj_beliefs = []
        self.messages = []
        self.messages_prior = []
        self.messages_dist = []

        self.active = True

        self.level = 0

        self.type = "factor"

        for adj_var_node in self.adj_var_nodes:
            self.dofs_conditional_vars += adj_var_node.dofs
            self.adj_vIDs.append(adj_var_node.variableID)
            self.adj_beliefs.append(NdimGaussian(adj_var_node.dofs))
            self.messages.append(NdimGaussian(adj_var_node.dofs))#, eta=adj_var_node.prior.eta, lam=adj_var_node.prior.lam))
            self.messages_prior.append(NdimGaussian(adj_var_node.dofs))
            self.messages_dist.append(np.zeros(adj_var_node.dofs))

        self.factor = NdimGaussian(self.dofs_conditional_vars)
        self.linpoint = np.zeros(self.dofs_conditional_vars)  # linearisation point

        self.residual = None
        self.b_calc_mess_dist = wildfire

        # Measurement model
        self.measurement = measurement
        self.measurement_lambda = measurement_lambda
        self.meas_fn = meas_fn
        self.jac_fn = jac_fn
        self.args = args

        # Robust loss function
        self.loss = loss
        self.mahalanobis_threshold = mahalanobis_threshold
        self.robust_flag = False

        # Local relinearisation
        self.eta_damping = 0.
        self.iters_since_relin = 1

    def compute_residual(self):
        """
            Calculate the reprojection error vector.
        """
        adj_belief_means = []
        for belief in self.adj_beliefs:
            #adj_belief_means = np.concatenate((adj_belief_means, np.linalg.inv(belief.lam) @ belief.eta))
            adj_belief_means = np.concatenate((adj_belief_means, 1/np.diagonal(belief.lam) * belief.eta))
        
        # d = (self.meas_fn(adj_belief_means, *self.args) - self.measurement) / self.adaptive_gauss_noise_var
        # d = (np.array([[-1,0,1,0],[0,-1,0,1]]) @ adj_belief_means - self.measurement) / self.adaptive_gauss_noise_var
        d = np.array(self.measurement) @ self.factor.lam[:2,2:] - self.factor.lam[:2,:] @ adj_belief_means
        # d = np.array(self.measurement) * self.gauss_noise_var - (adj_belief_means[2:] - adj_belief_means[:2]) * self.adaptive_gauss_noise_var
        # ^^^ This is equivalent to the equations below which are explicitly r = b - Ax
        # J = self.jac_fn(self.linpoint, *self.args)
        # meas_model_lambda = np.eye(len(self.measurement)) / self.adaptive_gauss_noise_var
        # d = J.T @ meas_model_lambda @ self.measurement - self.factor.lam @ adj_belief_means

        self.residual = d
        
        return d

    def energy(self):
        """
            Computes the squared error using the appropriate loss function.
        """
        return 0.5 * np.linalg.norm(self.residual) ** 2

    def compute_factor(self, linpoint=None, update_self=True):
        """
            Compute the factor given the linearisation point.
            If not given then linearisation point is mean of belief of adjacent nodes.
            If measurement model is linear then factor will always be the same regardless of linearisation point.
        """
        if linpoint is None:
            self.linpoint = []
            for belief in self.adj_beliefs:
                self.linpoint += list(1/np.diagonal(belief.lam) * belief.eta)
        else:
            self.linpoint = linpoint

        if isinstance(self.jac_fn, list):
            J = np.array(self.jac_fn)
            pred_measurement = J @ self.linpoint
        else:
            J = self.jac_fn(self.linpoint, *self.args)
            pred_measurement = self.meas_fn(self.linpoint, *self.args)

        if isinstance(self.measurement, float):
            lambda_factor = self.measurement_lambda * np.outer(J, J)
            eta_factor = self.measurement_lambda * J.T * (J @ self.linpoint + self.measurement - pred_measurement)
        else:
            lambda_factor = J.T @ self.measurement_lambda @ J
            eta_factor = (J.T @ self.measurement_lambda) @ (J @ self.linpoint + self.measurement - pred_measurement)

        if update_self:
            self.factor.eta, self.factor.lam = eta_factor, lambda_factor

        return eta_factor, lambda_factor

    def robustify_loss(self):
        """
            Rescale the variance of the noise in the Gaussian measurement model if necessary and update the factor
            correspondingly.
        """
        old_adaptive_gauss_noise_var = self.adaptive_gauss_noise_var
        if self.loss is None:
            self.adaptive_gauss_noise_var = self.gauss_noise_var

        else:
            adj_belief_means = np.array([])
            for belief in self.adj_beliefs:
                adj_belief_means = np.concatenate((adj_belief_means, 1/np.diagonal(belief.lam) * belief.eta))
            pred_measurement = self.meas_fn(self.linpoint, *self.args)

            if self.loss == 'huber':  # Loss is linear after Nstds from mean of measurement model
                mahalanobis_dist = np.linalg.norm(self.measurement - pred_measurement) / np.sqrt(self.gauss_noise_var)
                if mahalanobis_dist > self.mahalanobis_threshold:
                    self.adaptive_gauss_noise_var = self.gauss_noise_var * mahalanobis_dist**2 / \
                            (2*(self.mahalanobis_threshold * mahalanobis_dist - 0.5 * self.mahalanobis_threshold**2))
                    self.robust_flag = True
                else:
                    self.robust_flag = False
                    self.adaptive_gauss_noise_var = self.gauss_noise_var

            elif self.loss == 'constant':  # Loss is constant after Nstds from mean of measurement model
                mahalanobis_dist = np.linalg.norm(self.measurement - pred_measurement) / np.sqrt(self.gauss_noise_var)
                if mahalanobis_dist > self.mahalanobis_threshold:
                    self.adaptive_gauss_noise_var = mahalanobis_dist**2
                    self.robust_flag = True
                else:
                    self.robust_flag = False
                    self.adaptive_gauss_noise_var = self.gauss_noise_var

        # Update factor using existing linearisation point (we are not relinearising).
        self.factor.eta *= old_adaptive_gauss_noise_var / self.adaptive_gauss_noise_var
        self.factor.lam *= old_adaptive_gauss_noise_var / self.adaptive_gauss_noise_var

    def relinearise(self, min_linear_iters, beta):
        adj_belief_means = np.array([])
        for belief in self.adj_beliefs:
            adj_belief_means = np.concatenate((adj_belief_means, 1/np.diagonal(belief.lam) * belief.eta))
        if np.linalg.norm(self.linpoint - adj_belief_means) > beta and self.iters_since_relin >= min_linear_iters:
            self.compute_factor(linpoint=adj_belief_means)
            self.iters_since_relin = 0
            self.eta_damping = 0.0
        else:
            self.iters_since_relin += 1

    #@profile
    def compute_messages(self, eta_damping):
        """
            Compute all outgoing messages from the factor.
            This is specialised for one and two variable factors.
        """

        if len(self.adj_vIDs) == 1:
            v = 0
            if self.b_calc_mess_dist:
                self.messages_dist[v] = mahalanobis(self.messages[v], NdimGaussian(len(messages_eta[v]), eta=messages_eta[v], lam=messages_lam[v]))
            self.messages[v].eta = self.factor.eta.copy()
            self.messages[v].lam = self.factor.lam.copy()
            return
        
        
        if self.type[0:5] == "multi":
            eta_damping = eta_damping
        messages_eta, messages_lam = [], []


        for v in range(len(self.adj_vIDs)):
            eta_factor, lam_factor = self.factor.eta.copy(), self.factor.lam.copy()

            # Take product of factor with incoming messages
            mess_start_dim = 0
            for var in range(len(self.adj_vIDs)):
                if var != v:
                    var_dofs = self.adj_var_nodes[var].dofs
                    eta_factor[mess_start_dim:mess_start_dim + var_dofs] += self.adj_beliefs[var].eta - self.messages[var].eta
                    lam_factor[mess_start_dim:mess_start_dim + var_dofs, mess_start_dim:mess_start_dim + var_dofs] += self.adj_beliefs[var].lam - self.messages[var].lam
                mess_start_dim += self.adj_var_nodes[var].dofs

            # Divide up parameters of distribution
            divide =  self.adj_var_nodes[0].dofs
            if v == 0:
                eo = eta_factor[:divide]
                eno = eta_factor[divide:]

                loo = lam_factor[:divide, :divide]
                lono = lam_factor[:divide, divide:]
                lnoo = lam_factor[divide:, :divide]
                lnono = lam_factor[divide:, divide:]
            elif v == 1:
                eo = eta_factor[divide:]
                eno = eta_factor[:divide]

                loo = lam_factor[divide:, divide:]
                lono = lam_factor[divide:, :divide]
                lnoo = lam_factor[:divide, divide:]
                lnono = lam_factor[:divide, :divide]

            lnono += 1e-12 * np.eye(lnono.shape[0])
            lnono_inv = np.linalg.inv(lnono)

            new_message_lam = loo - lono @ lnono_inv @ lnoo
            messages_lam.append((1 - eta_damping) * new_message_lam + eta_damping * self.messages[v].lam)
            new_message_eta = eo - lono @ lnono_inv @ eno
            messages_eta.append((1 - eta_damping) * new_message_eta + eta_damping * self.messages[v].eta)


            
        for v in range(len(self.adj_vIDs)):
            #self.messages_dist[v] = bhattacharyya(self.messages[v], NdimGaussian(len(messages_eta[v]), eta=messages_eta[v], lam=messages_lam[v]))
            if self.b_calc_mess_dist:
                self.messages_dist[v] = mahalanobis(self.messages[v], NdimGaussian(len(messages_eta[v]), eta=messages_eta[v], lam=messages_lam[v]))
            self.messages[v].lam = messages_lam[v]
            self.messages[v].eta = messages_eta[v]


        #time.sleep(0.00000001)

        

    def smoothing_compute_messages(self, eta_damping):


        for v in range(len(self.adj_vIDs)):
            # Pii = np.array([self.adj_var_nodes[v].prior.lam[0,0], self.adj_var_nodes[v].prior.lam[1,1]])
            # uii = self.adj_var_nodes[v].prior.eta / Pii

            # var_dofs = self.adj_var_nodes[v].dofs
             
            # Pki_sum = 0
            # uki_sum = 0

            # for factor in self.adj_var_nodes[v].adj_factors:
            #     if factor.factorID is not self.factorID:
            #         if factor.adj_vIDs[0] is self.adj_vIDs[v]:
            #             v_ix = 0
            #         else:
            #             v_ix = 1
                    
            #         Pki = np.array([factor.messages[v_ix].lam[0,0], factor.messages[v_ix].lam[1,1]])
            #         uki = factor.messages[v_ix].eta
                 
            #         Pki_sum += Pki
            #         uki_sum += uki * Pki
                     
            # Aij = np.array([self.factor.lam[0,2],self.factor.lam[1,3]])
            # Pij = (-Aij**2) / (Pii + Pki_sum)
            # uij = ((Pii * uii) + uki_sum) / Aij

            # self.messages[1-v].lam[0,0] = Pij[0]
            # self.messages[1-v].lam[1,1] = Pij[1]
            # self.messages[1-v].eta = uij 

            Aij = np.array([self.factor.lam[0,2],self.factor.lam[1,3]])
            Pij = (-Aij**2) / (np.diagonal(self.adj_beliefs[v].lam - self.messages_prior[v].lam))
            uij = (self.adj_beliefs[v].eta - self.messages_prior[v].eta * np.diag(self.messages_prior[v].lam)) / Aij

            self.messages[1-v].lam[0,0] = (1 - eta_damping) * Pij[0] + eta_damping * self.messages[1-v].lam[0,0]
            self.messages[1-v].lam[1,1] = (1 - eta_damping) * Pij[1] + eta_damping * self.messages[1-v].lam[1,1]
            self.messages[1-v].eta = (1 - eta_damping) * uij + eta_damping * self.messages[1-v].eta

            #self.adj_var_nodes[1-v].update_smooth_belief()

        #time.sleep(0.00000001)


In [3]:
def fuse_to_super_kmeans(prev_nodes, prev_edges, k, layer_idx, max_iters=20, tol=1e-6, seed=0):
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)
    n = positions.shape[0]
    if k <= 0: 
        k = 1
    k = min(k, n)
    rng = np.random.default_rng(seed)

    # -------- 改进版初始化 --------
    # 随机无放回抽 k 个点，保证一开始每簇有独立的点
    init_idx = rng.choice(n, size=k, replace=False)
    centers = positions[init_idx]

    # Lloyd 迭代
    for _ in range(max_iters):
        d2 = ((positions[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
        assign = np.argmin(d2, axis=1)

        # -------- 空簇修补 --------
        counts = np.bincount(assign, minlength=k)
        empty_clusters = np.where(counts == 0)[0]
        for ci in empty_clusters:
            # 找到最大簇
            big_cluster = np.argmax(counts)
            big_idxs = np.where(assign == big_cluster)[0]
            # 偷一个点过来
            steal_idx = big_idxs[0]
            assign[steal_idx] = ci
            counts[big_cluster] -= 1
            counts[ci] += 1

        moved = 0.0
        for ci in range(k):
            idxs = np.where(assign == ci)[0]
            new_c = positions[idxs].mean(axis=0)
            moved = max(moved, float(np.linalg.norm(new_c - centers[ci])))
            centers[ci] = new_c
        if moved < tol:
            break

    # final assign (再做一次保证)
    d2 = ((positions[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
    assign = np.argmin(d2, axis=1)

    counts = np.bincount(assign, minlength=k)
    empty_clusters = np.where(counts == 0)[0]
    for ci in empty_clusters:
        big_cluster = np.argmax(counts)
        big_idxs = np.where(assign == big_cluster)[0]
        steal_idx = big_idxs[0]
        assign[steal_idx] = ci
        counts[big_cluster] -= 1
        counts[ci] += 1

    # ---------- 构造 super graph ----------
    super_nodes, node_map = [], {}
    for ci in range(k):
        idxs = np.where(assign == ci)[0]
        pts = positions[idxs]
        mean_x, mean_y = pts.mean(axis=0)
        child_dims = [prev_nodes[i]["data"]["dim"] for i in idxs]
        dim_val = int(max(1, sum(child_dims)))
        nid = f"{ci}"
        super_nodes.append({
            "data": {"id": nid, "layer": layer_idx, "dim": dim_val},
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        for i in idxs:
            node_map[prev_nodes[i]["data"]["id"]] = nid

    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]
        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            else:
                eid = (su, "prior")
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)
        else:
            su = node_map[u]
            eid = (su, "prior")
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map


In [4]:
def fuse_to_super_grid(prev_nodes, prev_edges, gx, gy, layer_idx):
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)
    xmin, ymin = positions.min(axis=0); xmax, ymax = positions.max(axis=0)
    cell_w = (xmax - xmin) / gx if gx > 0 else 1.0
    cell_h = (ymax - ymin) / gy if gy > 0 else 1.0
    if cell_w == 0: cell_w = 1.0
    if cell_h == 0: cell_h = 1.0
    cell_map = {}
    for idx, n in enumerate(prev_nodes):
        x, y = n["position"]["x"], n["position"]["y"]
        cx = min(int((x - xmin) / cell_w), gx - 1)
        cy = min(int((y - ymin) / cell_h), gy - 1)
        cid = cx + cy * gx
        cell_map.setdefault(cid, []).append(idx)
    super_nodes, node_map = [], {}
    for cid, indices in cell_map.items():
        pts = positions[indices]
        mean_x, mean_y = pts.mean(axis=0)
        child_dims = [prev_nodes[i]["data"]["dim"] for i in indices]
        dim_val = int(max(1, sum(child_dims)))
        nid = str(len(super_nodes))
        super_nodes.append({
            "data": {"id": nid, "layer": layer_idx, "dim": dim_val},
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        for i in indices:
            node_map[prev_nodes[i]["data"]["id"]] = nid
    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            elif su == sv:
                eid = tuple(sorted((su, "prior")))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)

        elif v == "prior":
            su = node_map[u]
            eid = tuple(sorted((su, v)))
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map

In [5]:
def build_super_graph(layers):
    """
    基于 layers[-2] 的 base graph, 和 layers[-1] 的 super 分组，构造 super graph。
    要求: layers[-2]["graph"] 已经是构建好的基图（含 unary/binary 因子）。
    layers[-1]["node_map"]: { base_node_id(str, 如 'b12') -> super_node_id(str) }
    """
    from scipy.linalg import block_diag
    # ---------- 取出 base & super ----------
    base_graph = layers[-2]["graph"]
    super_nodes = layers[-1]["nodes"]
    super_edges = layers[-1]["edges"]
    node_map    = layers[-1]["node_map"]   # 'bN' -> 'sX_...'

    # base: id(int)->VariableNode，方便查 dofs 和 mu
    id2var = {vn.variableID: vn for vn in base_graph.var_nodes}

    # ---------- super_id -> [base_id(int)] ----------
    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)


    # ---------- 为每个 super 组建立 (start, dofs) 表 ----------
    # local_idx[sid][bid] = (start, dofs), total_dofs[sid] = sum(dofs)
    local_idx   = {}
    total_dofs  = {}
    for sid, group in super_groups.items():
        off = 0
        local_idx[sid] = {}
        for bid in group:
            d = id2var[bid].dofs
            local_idx[sid][bid] = (off, d)
            off += d
        total_dofs[sid] = off


    # ---------- 创建 super VariableNodes ----------
    fg = FactorGraph(nonlinear_factors=False, eta_damping=0)

    super_var_nodes = {}
    for i, sn in enumerate(super_nodes):
        sid = sn["data"]["id"]
        dofs = total_dofs.get(sid, 0)

        v = VariableNode(i, dofs=dofs)

        # === 叠加 base GT ===
        gt_vec = np.zeros(dofs)
        for bid, (st, d) in local_idx[sid].items():
            gt_base = getattr(id2var[bid], "GT", None)
            if gt_base is None or len(gt_base) != d:
                gt_base = np.zeros(d)
            gt_vec[st:st+d] = gt_base
        v.GT = gt_vec
        v.prior.lam = 1e-10 * np.eye(dofs, dtype=float)
        v.prior.eta = np.zeros(dofs, dtype=float)

        super_var_nodes[sid] = v
        fg.var_nodes.append(v)


    fg.n_var_nodes = len(fg.var_nodes)

    # ---------- 工具：拼接某组的 linpoint（用 base belief 均值） ----------
    def make_linpoint_for_group(sid):
        x = np.zeros(total_dofs[sid])
        for bid, (st, d) in local_idx[sid].items():
            mu = getattr(id2var[bid], "mu", None)
            if mu is None or len(mu) != d:
                mu = np.zeros(d)
            x[st:st+d] = mu
        return x

    # ---------- 3) super prior（in_group unary + in_group binary） ----------
    def make_super_prior_factor(sid, base_factors):
        group = super_groups[sid]
        idx_map = local_idx[sid]
        ncols = total_dofs[sid]

        # 选出：所有变量都在组内的因子（unary 或 binary）
        in_group = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if all(vid in group for vid in vids):
                in_group.append(f)

        def meas_fn_super_prior(x_super, *args):
            meas_fn = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # 拼本因子的局部 x
                x_loc_list = []
                for vid in vids:
                    st, d = idx_map[vid]
                    x_loc_list.append(x_super[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)
                meas_fn.append(f.meas_fn(x_loc))
            return np.concatenate(meas_fn) if meas_fn else np.zeros(0)

        def jac_fn_super_prior(x_super, *args):
            Jrows = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # 构造本因子的局部 x，用于（潜在）非线性雅可比
                x_loc_list = []
                dims = []
                for vid in vids:
                    st, d = idx_map[vid]
                    dims.append(d)
                    x_loc_list.append(x_super[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)

                Jloc = f.jac_fn(x_loc)
                # 将 Jloc 列块映射回 super 变量的列
                row = np.zeros((Jloc.shape[0], ncols))
                c0 = 0
                for vid, d in zip(vids, dims):
                    st, _ = idx_map[vid]
                    row[:, st:st+d] = Jloc[:, c0:c0+d]
                    c0 += d

                Jrows.append(row)
            return np.vstack(Jrows) if Jrows else np.zeros((0, ncols))

        # z_super：拼各 base 因子的 z
        z_list = [f.measurement for f in in_group]
        z_lambda_list = [f.measurement_lambda for f in in_group]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_prior, jac_fn_super_prior, z_super, z_super_lambda 

    # ---------- 4) super between（cross_group binary） ----------
    def make_super_between_factor(sidA, sidB, base_factors):
        groupA, groupB = super_groups[sidA], super_groups[sidB]
        idxA, idxB = local_idx[sidA], local_idx[sidB]
        nA, nB = total_dofs[sidA], total_dofs[sidB]

        cross = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if len(vids) != 2:
                continue
            i, j = vids
            # on side in A，the other side in B
            if (i in groupA and j in groupB) or (i in groupB and j in groupA):
                cross.append(f)


        def meas_fn_super_between(xAB, *args):
            xA, xB = xAB[:nA], xAB[nA:]
            meas_fn = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                x_loc = np.concatenate([xi, xj])
                meas_fn.append(f.meas_fn(x_loc))
            return np.concatenate(meas_fn) 

        def jac_fn_super_between(xAB, *args):
            xA, xB = xAB[:nA], xAB[nA:]
            Jrows = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                    left_start, right_start = si, nA + sj
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                    left_start, right_start = nA + si, sj
                x_loc = np.concatenate([xi, xj])
                Jloc = f.jac_fn(x_loc)

                row = np.zeros((Jloc.shape[0], nA + nB))
                row[:, left_start:left_start+di]   = Jloc[:, :di] 
                row[:, right_start:right_start+dj] = Jloc[:, di:di+dj] 

                Jrows.append(row)
            return np.vstack(Jrows) 

        z_list = [f.measurement for f in cross]
        z_lambda_list = [f.measurement_lambda for f in cross]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_between, jac_fn_super_between, z_super, z_super_lambda


    for e in super_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v == "prior":
            meas_fn, jac_fn, z, z_lambda = make_super_prior_factor(u, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u]], z, z_lambda, meas_fn, jac_fn)
            f.type = "super_prior"
            lin0 = make_linpoint_for_group(u)
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            
        else:
            meas_fn, jac_fn, z, z_lambda = make_super_between_factor(u, v, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u], super_var_nodes[v]], z, z_lambda, meas_fn, jac_fn)
            f.type = "super_between"
            lin0 = np.concatenate([make_linpoint_for_group(u), make_linpoint_for_group(v)])
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            super_var_nodes[v].adj_factors.append(f)


    fg.n_factor_nodes = len(fg.factors)
    return fg


In [6]:
def copy_to_abs(super_nodes, super_edges, layer_idx):
    abs_nodes = []
    for n in super_nodes:
        nid = n["data"]["id"].replace("s", "a", 1)
        abs_nodes.append({
            "data": {"id": nid, "layer": layer_idx, "dim": n["data"]["dim"]},
            "position": {"x": n["position"]["x"], "y": n["position"]["y"]}
        })
    abs_edges = []
    for e in super_edges:
        abs_edges.append({"data": {
            "source": e["data"]["source"].replace("s", "a", 1),
            "target": e["data"]["target"].replace("s", "a", 1)
        }})
    return abs_nodes, abs_edges

In [7]:
def build_abs_graph(
    layers,
    r_reduced = 2):

    abs_var_nodes = {}
    Bs = {}
    ks = {}
    r = 2

    # === 1. Build Abstraction Variables ===
    abs_fg = FactorGraph(nonlinear_factors=False, eta_damping=0)
    sup_fg = layers[-2]["graph"]

    for sn in sup_fg.var_nodes:#
        if sn.dofs <= r_reduced:
            r = sn.dofs  # No reduction if dofs already <= r
        else:
            r = r_reduced

        sid = sn.variableID
        varis_sup_mu = sn.mu
        varis_sup_sigma = sn.Sigma
        
        # Step 1: Eigen decomposition of the covariance matrix
        eigvals, eigvecs = np.linalg.eigh(varis_sup_sigma)

        # Step 2: Sort eigenvalues and eigenvectors in descending order of eigenvalues
        idx = np.argsort(eigvals)[::-1]      # Get indices of sorted eigenvalues (largest first)
        eigvals = eigvals[idx]               # Reorder eigenvalues
        eigvecs = eigvecs[:, idx]            # Reorder corresponding eigenvectors

        # Step 3: Select the top-k eigenvectors to form the projection matrix (principal subspace)
        B_reduced = eigvecs[:, :r]                 # B_reduced: shape (sup_dof, r), projects r to sup_dof
        Bs[sid] = B_reduced                        # Store the projection matrix for this variable

        # Step 4: Project eta and Lam onto the reduced 2D subspace
        varis_abs_mu = B_reduced.T @ varis_sup_mu          # Projected natural mean: shape (2,)
        varis_abs_sigma = B_reduced.T @ varis_sup_sigma @ B_reduced  # Projected covariance: shape (2, 2)
        ks[sid] = varis_sup_mu - B_reduced @ varis_abs_mu  # Store the offset for this variable

        varis_abs_lam = np.linalg.inv(varis_abs_sigma)  # Inverse covariance (precision matrix): shape (2, 2)
        varis_abs_eta = varis_abs_lam @ varis_abs_mu  # Natural parameters: shape (2,)

        v = VariableNode(sid, dofs=r)
        v.GT = sn.GT
        v.prior.lam = 1e-10 * np.eye(r, dtype=float)
        v.prior.eta = np.zeros(r, dtype=float)
        v.mu = varis_abs_mu
        v.Sigma = varis_abs_sigma
        v.belief = NdimGaussian(r, varis_abs_eta, varis_abs_lam)

        abs_var_nodes[sid] = v
        abs_fg.var_nodes.append(v)
    abs_fg.n_var_nodes = len(abs_fg.var_nodes)


    # === 2. Abstract Prior ===
    def make_abs_prior_factor(sup_factor):
        abs_id = sup_factor.adj_var_nodes[0].variableID
        B = Bs[abs_id]
        k = ks[abs_id]

        def meas_fn_abs_prior(x_abs, *args):
            return sup_factor.meas_fn(B @ x_abs + k)
        
        def jac_fn_abs_prior(x_abs, *args):
            return sup_factor.jac_fn(B @ x_abs + k) @ B


        return meas_fn_abs_prior, jac_fn_abs_prior, sup_factor.measurement, sup_factor.measurement_lambda
    


    # === 3. Abstract Between ===
    def make_abs_between_factor(sup_factor):
        vids = [v.variableID for v in sup_factor.adj_var_nodes]
        i, j = vids # two variable IDs
        ni = abs_var_nodes[i].dofs
        Bi, Bj = Bs[i], Bs[j]
        ki, kj = ks[i], ks[j]                       


        def meas_fn_super_between(xij, *args):
            xi, xj = xij[:ni], xij[ni:]
            return sup_factor.meas_fn(np.concatenate([Bi @ xi + ki, Bj @ xj + kj]))

        def jac_fn_super_between(xij, *args):
            xi, xj = xij[:ni], xij[ni:]
            J_sup = sup_factor.jac_fn(np.concatenate([Bi @ xi + ki, Bj @ xj + kj]))
            J_abs = np.zeros((J_sup.shape[0], ni + xj.shape[0]))
            J_abs[:, :ni] = J_sup[:, :Bi.shape[0]] @ Bi
            J_abs[:, ni:] = J_sup[:, Bi.shape[0]:] @ Bj
            return J_abs
        
        return meas_fn_super_between, jac_fn_super_between, sup_factor.measurement, sup_factor.measurement_lambda
    
    for f in sup_fg.factors:
        if len(f.adj_var_nodes) == 1:
            meas_fn, jac_fn, z, z_lambda = make_abs_prior_factor(f)
            v = abs_var_nodes[f.adj_var_nodes[0].variableID]
            abs_f = Factor(f.factorID, [v], z, z_lambda, meas_fn, jac_fn)
            abs_f.type = "abs_prior"
            abs_f.adj_beliefs = [v.belief]

            lin0 = v.mu
            abs_f.compute_factor(linpoint=lin0, update_self=True)
            abs_fg.factors.append(abs_f)
            v.adj_factors.append(abs_f)

        elif len(f.adj_var_nodes) == 2:
            meas_fn, jac_fn, z, z_lambda = make_abs_between_factor(f)
            i, j = [v.variableID for v in f.adj_var_nodes]
            vi, vj = abs_var_nodes[i], abs_var_nodes[j]
            abs_f = Factor(f.factorID, [vi, vj], z, z_lambda, meas_fn, jac_fn)
            abs_f.type = "abs_between"
            abs_f.adj_beliefs = [vi.belief, vj.belief]

            lin0 = np.concatenate([vi.mu, vj.mu])
            abs_f.compute_factor(linpoint=lin0, update_self=True)
            abs_fg.factors.append(abs_f)
            vi.adj_factors.append(abs_f)
            vj.adj_factors.append(abs_f)
    abs_fg.n_factor_nodes = len(abs_fg.factors)


    return abs_fg, Bs, ks


In [8]:
layers = []


# -----------------------
# 构建 GBP 图
# -----------------------

base_nodes, base_edges, gbp_graph = make_two_cliques_graph(prior_sigma=10.0, odom_sigma=10.0, rng=None)
layers = [{"name": "base", "nodes": base_nodes, "edges": base_edges}]
layers[0]["graph"] = gbp_graph
pair_idx = 0
opts=[{"label":"base","value":"base"}]


last = layers[-1]
super_layer_idx = 1


super_nodes, super_edges, node_map = fuse_to_super_grid(last["nodes"], last["edges"], 2, 1, super_layer_idx)
super_nodes, super_edges, node_map = fuse_to_super_kmeans(last["nodes"], last["edges"], 2, super_layer_idx)
layers.append({"name":f"super{1}", "nodes":super_nodes, "edges":super_edges, "node_map":node_map})
layers[-1]["graph"] = build_super_graph(layers)

for i in range(100):
    layers[-1]["graph"].synchronous_iteration()


abs_nodes, abs_edges   = copy_to_abs(super_nodes, super_edges, super_layer_idx + 1)
layers.append({"name":f"abs{1}", "nodes":abs_nodes, "edges":abs_edges})
layers[-1]["graph"], Bs, ks = build_abs_graph(layers, 2)

In [17]:
for edge in base_edges:
    print(edge)

{'data': {'source': '0', 'target': '1'}}
{'data': {'source': '0', 'target': '2'}}
{'data': {'source': '0', 'target': '3'}}
{'data': {'source': '1', 'target': '2'}}
{'data': {'source': '1', 'target': '3'}}
{'data': {'source': '2', 'target': '3'}}
{'data': {'source': '4', 'target': '5'}}
{'data': {'source': '4', 'target': '6'}}
{'data': {'source': '5', 'target': '6'}}
{'data': {'source': '3', 'target': '4'}}
{'data': {'source': '0', 'target': 'prior'}}
{'data': {'source': '1', 'target': 'prior'}}
{'data': {'source': '2', 'target': 'prior'}}
{'data': {'source': '3', 'target': 'prior'}}
{'data': {'source': '4', 'target': 'prior'}}
{'data': {'source': '5', 'target': 'prior'}}
{'data': {'source': '6', 'target': 'prior'}}


In [14]:
print(super_nodes)
print(super_edges)
print(node_map)

[{'data': {'id': '0', 'layer': 1, 'dim': 8}, 'position': {'x': 0.0, 'y': 15.0}}, {'data': {'id': '1', 'layer': 1, 'dim': 6}, 'position': {'x': 50.0, 'y': 10.0}}]
[{'data': {'source': '0', 'target': 'prior'}}, {'data': {'source': '1', 'target': 'prior'}}, {'data': {'source': '0', 'target': '1'}}]
{'0': '0', '1': '0', '2': '0', '3': '0', '4': '1', '5': '1', '6': '1'}


In [9]:
len(layers[:2])

2

In [10]:
basegraph = layers[-3]["graph"]
#supergraph = layers[-1]["graph"]

for it in range(100):
    basegraph.synchronous_iteration()
    energy = basegraph.energy_map(include_priors=True, include_factors=True)
    print(f"Iter {it+1:03d} | Energy = {energy:.6f}")



Iter 001 | Energy = 315.810054
Iter 002 | Energy = 182.343369
Iter 003 | Energy = 190.018171
Iter 004 | Energy = 191.794979
Iter 005 | Energy = 194.556536
Iter 006 | Energy = 193.906010
Iter 007 | Energy = 194.076984
Iter 008 | Energy = 193.931422
Iter 009 | Energy = 193.556378
Iter 010 | Energy = 193.490921
Iter 011 | Energy = 193.451201
Iter 012 | Energy = 193.397052
Iter 013 | Energy = 193.380109
Iter 014 | Energy = 193.371536
Iter 015 | Energy = 193.364074
Iter 016 | Energy = 193.360586
Iter 017 | Energy = 193.358605
Iter 018 | Energy = 193.357341
Iter 019 | Energy = 193.356659
Iter 020 | Energy = 193.356249
Iter 021 | Energy = 193.356011
Iter 022 | Energy = 193.355878
Iter 023 | Energy = 193.355799
Iter 024 | Energy = 193.355753
Iter 025 | Energy = 193.355727
Iter 026 | Energy = 193.355712
Iter 027 | Energy = 193.355703
Iter 028 | Energy = 193.355698
Iter 029 | Energy = 193.355695
Iter 030 | Energy = 193.355693
Iter 031 | Energy = 193.355692
Iter 032 | Energy = 193.355691
Iter 033

In [11]:
"""
# 找到所有 super_prior factors
prior_factors = [f for f in supergraph.factors if getattr(f, "type", "") == "super_prior"]
# 对 super_prior factors 的邻居变量
prior_vars = []
for f in prior_factors:
    for v in f.adj_var_nodes:
        if v not in prior_vars:
            prior_vars.append(v)
supergraph.compute_all_messages(prior_factors, local_relin=True)
supergraph.update_all_beliefs(prior_vars)

energy = supergraph.energy_map(include_priors=True, include_factors=True)
print(f"Iter {it+1:03d} | Energy = {energy:.6f}")
"""
supergraph = layers[-2]["graph"]
for it in range(100):
    supergraph.synchronous_iteration()
    energy = supergraph.energy_map(include_priors=True, include_factors=True)
    print(f"Iter {it+1:03d} | Energy = {energy:.6f}")

Iter 001 | Energy = 193.355691
Iter 002 | Energy = 193.355691
Iter 003 | Energy = 193.355691
Iter 004 | Energy = 193.355691
Iter 005 | Energy = 193.355691
Iter 006 | Energy = 193.355691
Iter 007 | Energy = 193.355691
Iter 008 | Energy = 193.355691
Iter 009 | Energy = 193.355691
Iter 010 | Energy = 193.355691
Iter 011 | Energy = 193.355691
Iter 012 | Energy = 193.355691
Iter 013 | Energy = 193.355691
Iter 014 | Energy = 193.355691
Iter 015 | Energy = 193.355691
Iter 016 | Energy = 193.355691
Iter 017 | Energy = 193.355691
Iter 018 | Energy = 193.355691
Iter 019 | Energy = 193.355691
Iter 020 | Energy = 193.355691
Iter 021 | Energy = 193.355691
Iter 022 | Energy = 193.355691
Iter 023 | Energy = 193.355691
Iter 024 | Energy = 193.355691
Iter 025 | Energy = 193.355691
Iter 026 | Energy = 193.355691
Iter 027 | Energy = 193.355691
Iter 028 | Energy = 193.355691
Iter 029 | Energy = 193.355691
Iter 030 | Energy = 193.355691
Iter 031 | Energy = 193.355691
Iter 032 | Energy = 193.355691
Iter 033

In [12]:
absgraph = layers[-1]["graph"]

#print(Bs[0] @ absgraph.var_nodes[0].mu + ks[0])
print(layers[-2]["graph"].var_nodes[0].mu)

#print(absgraph.var_nodes[0].belief.eta)
#print(absgraph.var_nodes[0].belief.lam)

#absgraph.update_all_beliefs(absgraph.var_nodes)
for i in range(100):
    absgraph.compute_all_messages(absgraph.factors, local_relin=True)
    absgraph.update_all_beliefs(absgraph.var_nodes)


print(Bs[0] @ absgraph.var_nodes[0].mu + ks[0])

[-1.34553367  3.38894029 -2.90695587 12.40028538  6.31483732 23.53412873
 -5.19983453 27.30061548]
[-1.34553367  3.38894029 -2.90695587 12.40028538  6.31483732 23.53412873
 -5.19983453 27.30061548]


In [54]:
print(supergraph.var_nodes[1].belief.eta)
print(supergraph.var_nodes[1].belief.lam)

[ 1.15158182 -0.12118028  0.220961    0.08883329  0.30506051  0.44922781]
[[ 0.03714286  0.         -0.01        0.         -0.01        0.        ]
 [ 0.          0.03714286  0.         -0.01        0.         -0.01      ]
 [-0.01        0.          0.03        0.         -0.01        0.        ]
 [ 0.         -0.01        0.          0.03        0.         -0.01      ]
 [-0.01        0.         -0.01        0.          0.03        0.        ]
 [ 0.         -0.01        0.         -0.01        0.          0.03      ]]


In [58]:
basegraph.factors[6].adj_beliefs[0].lam

array([[0.02997088, 0.        ],
       [0.        , 0.02997088]])

In [57]:
basegraph.factors[7].adj_beliefs[0].lam

array([[0.02997088, 0.        ],
       [0.        , 0.02997088]])

In [44]:
np.linalg.solve(basegraph.factors[6].adj_beliefs[0].lam, basegraph.factors[6].adj_beliefs[0].eta)

array([52.1165681 ,  5.44711511])

In [46]:
np.linalg.solve(supergraph.var_nodes[1].belief.lam, supergraph.var_nodes[1].belief.eta)

array([52.1165681 ,  5.44711511, 38.15757778, 11.67015352, 40.26006536,
       20.68001639])

In [49]:
np.linalg.inv(supergraph.var_nodes[1].belief.lam)

array([[36.84210504,  0.        , 18.42105243,  0.        , 18.42105243,
         0.        ],
       [ 0.        , 36.84210504,  0.        , 18.42105243,  0.        ,
        18.42105243],
       [18.42105243,  0.        , 46.71052601,  0.        , 21.71052607,
         0.        ],
       [ 0.        , 18.42105243,  0.        , 46.71052601,  0.        ,
        21.71052607],
       [18.42105243,  0.        , 21.71052607,  0.        , 46.71052601,
         0.        ],
       [ 0.        , 18.42105243,  0.        , 21.71052607,  0.        ,
        46.71052601]])

In [50]:
np.linalg.inv(basegraph.factors[6].adj_beliefs[0].lam)

array([[33.36571711,  0.        ],
       [ 0.        , 33.36571711]])

In [280]:
np.linalg.norm(layers[-1]["graph"].var_nodes[0].GT - (Bs[layers[-1]["graph"].var_nodes[0].variableID] @ layers[-1]["graph"].var_nodes[0].mu + ks[layers[-1]["graph"].var_nodes[0].variableID]))

10.686408732643704

In [281]:
np.linalg.norm(layers[-1]["graph"].var_nodes[1].GT - (Bs[layers[-1]["graph"].var_nodes[1].variableID] @ layers[-1]["graph"].var_nodes[1].mu + ks[layers[-1]["graph"].var_nodes[1].variableID]))

16.507938987468407

In [282]:
(16.507938987308457**2 + 10.68640873270132**2)/2

193.35569060552677

In [283]:
a = supergraph.factors[2].jac_fn(np.concatenate([np.array([1,2,1,2,1,2,3,4]), np.array([0,1,2,3,4,5])])
)
a

array([[ 0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  1.,  0.,  0.,  0.,  0.,
         0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  1.,  0.,  0.,  0.,
         0.]])

In [284]:
supergraph.factors[2].factor.lam - a.T@a

array([[ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  , -0.99,  0.  ,  0.99,
         0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  , -0.99,  0.  ,
         0.99,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.99,  0.  , -0.99,
         

In [285]:
layers[-2]["graph"].var_nodes[0].mu

array([-1.34553367,  3.38894029, -2.90695587, 12.40028538,  6.31483732,
       23.53412873, -5.19983453, 27.30061548])

In [286]:
supergraph.var_nodes[0].mu

array([-1.34553367,  3.38894029, -2.90695587, 12.40028538,  6.31483732,
       23.53412873, -5.19983453, 27.30061548])