In [1]:
import re
import dash
from dash import html, dcc, Input, Output, State, no_update
import dash_cytoscape as cyto
import numpy as np
from scipy.linalg import block_diag
from collections import defaultdict

# ==== GBP import ====
from gbp.gbp import *

app = dash.Dash(__name__)
app.title = "Factor Graph SVD Abs&Recovery"



def update_super_graph_linearized(layers, eta_damping=0.2):
    """
    Construct the super graph based on the base graph in layers[-2] and the super-grouping in layers[-1].
    Requirement: layers[-2]["graph"] is an already-built base graph (with unary/binary factors).
    layers[-1]["node_map"]: { base_node_id (str, e.g., 'b12') -> super_node_id (str) }
    """
    # ---------- Extract base & super ----------
    base_graph = layers[-2]["graph"]
    super_nodes = layers[-1]["nodes"]
    super_edges = layers[-1]["edges"]
    node_map    = layers[-1]["node_map"]   # 'bN' -> 'sX_...'

    # base: id(int) -> VariableNode, handy to query dofs and mu
    id2var = {vn.variableID: vn for vn in base_graph.var_nodes}

    # ---------- super_id -> [base_id(int)] ----------
    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)


    # ---------- For each super group, build a (start, dofs) table ----------
    # local_idx[sid][bid] = (start, dofs), total_dofs[sid] = sum(dofs)
    local_idx   = {}
    total_dofs  = {}
    for sid, group in super_groups.items():
        off = 0
        local_idx[sid] = {}
        for bid in group:
            d = id2var[bid].dofs
            local_idx[sid][bid] = (off, d)
            off += d
        total_dofs[sid] = off


    # ---------- Create super VariableNodes ----------
    fg = FactorGraph(nonlinear_factors=False, eta_damping=eta_damping)

    super_var_nodes = {}
    for i, sn in enumerate(super_nodes):
        sid = sn["data"]["id"]
        dofs = total_dofs.get(sid, 0)

        v = VariableNode(i, dofs=dofs)
        gt_vec = np.zeros(dofs)
        mu_blocks = []
        Sigma_blocks = []
        for bid, (st, d) in local_idx[sid].items():
            # === Stack base GT ===
            gt_base = getattr(id2var[bid], "GT", None)
            if gt_base is None or len(gt_base) != d:
                gt_base = np.zeros(d)
            gt_vec[st:st+d] = gt_base

            # === Stack base belief ===
            vb = id2var[bid]
            mu_blocks.append(vb.mu)
            Sigma_blocks.append(vb.Sigma)

        super_var_nodes[sid] = v
        v.GT = gt_vec

        mu_super = np.concatenate(mu_blocks) if mu_blocks else np.zeros(dofs)
        Sigma_super = block_diag(*Sigma_blocks) if Sigma_blocks else np.eye(dofs)
        lam = np.linalg.inv(Sigma_super)
        eta = lam @ mu_super
        v.mu = mu_super
        v.Sigma = Sigma_super
        v.belief = NdimGaussian(dofs, eta, lam)
        v.prior.lam = 1e-10 * lam
        v.prior.eta = 1e-10 * eta

        fg.var_nodes.append(v)


    fg.n_var_nodes = len(fg.var_nodes)

    # ---------- Utility: assemble a group's linpoint (using base belief means) ----------
    def make_linpoint_for_group(sid):
        x = np.zeros(total_dofs[sid])
        for bid, (st, d) in local_idx[sid].items():
            mu = getattr(id2var[bid], "mu", None)
            if mu is None or len(mu) != d:
                mu = np.zeros(d)
            x[st:st+d] = mu
        return x

    # ---------- 3) super prior (in-group unary + in-group binary) ----------
    def make_super_prior_factor(sid, lin0, base_factors):
        group = super_groups[sid]
        idx_map = local_idx[sid]
        ncols = total_dofs[sid]

        # Select factors whose variables are all within the group (unary or binary)
        in_group = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if all(vid in group for vid in vids):
                in_group.append(f)

        def jac_fn_super_prior(x_super, *args):
            Jrows = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # Build this factor's local x for (potentially) nonlinear Jacobian
                x_loc_list = []
                dims = []
                for vid in vids:
                    st, d = idx_map[vid]
                    dims.append(d)
                    x_loc_list.append(lin0[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)

                Jloc = f.jac_fn(x_loc)
                # Map Jloc column blocks back to the super variable columns
                row = np.zeros((Jloc.shape[0], ncols))
                c0 = 0
                for vid, d in zip(vids, dims):
                    st, _ = idx_map[vid]
                    row[:, st:st+d] = Jloc[:, c0:c0+d]
                    c0 += d

                Jrows.append(row)
            return np.vstack(Jrows) if Jrows else np.zeros((0, ncols))

        J = jac_fn_super_prior(x_super=lin0)
        meas_fn = []
        for f in in_group:
            vids = [v.variableID for v in f.adj_var_nodes]
            # Build this factor's local x for (potentially) nonlinear Jacobian
            x_loc_list = []
            dims = []
            for vid in vids:
                st, d = idx_map[vid]
                dims.append(d)
                x_loc_list.append(lin0[st:st+d])
            x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)
            meas_fn.append(f.meas_fn(x_loc))
        meas_fn = np.concatenate(meas_fn) 
        def meas_fn_super_prior(x_super, *args):
            return meas_fn + J@(x_super-lin0)

        # z_super: concatenate each base factor's z
        z_list = [f.measurement for f in in_group]
        z_lambda_list = [f.measurement_lambda for f in in_group]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_prior, jac_fn_super_prior, z_super, z_super_lambda 

    # ---------- 4) super between (cross-group binary) ----------
    def make_super_between_factor(sidA, sidB, lin0, base_factors):
        groupA, groupB = super_groups[sidA], super_groups[sidB]
        idxA, idxB = local_idx[sidA], local_idx[sidB]
        nA, nB = total_dofs[sidA], total_dofs[sidB]

        cross = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if len(vids) != 2:
                continue
            i, j = vids
            # One side in A, the other side in B
            if (i in groupA and j in groupB) or (i in groupB and j in groupA):
                cross.append(f)

        
        def jac_fn_super_between(xAB, *args):
            xA, xB = lin0[:nA], lin0[nA:]
            Jrows = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                    left_start, right_start = si, nA + sj
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                    left_start, right_start = nA + si, sj

                x_loc = np.concatenate([xi, xj])
                Jloc = f.jac_fn(x_loc)
                row = np.zeros((Jloc.shape[0], nA + nB))
                row[:, left_start:left_start+di]   = Jloc[:, :di] 
                row[:, right_start:right_start+dj] = Jloc[:, di:di+dj] 
                Jrows.append(row)

            return np.vstack(Jrows) 
        
        J = jac_fn_super_between(xAB=lin0)
        xA, xB = lin0[:nA], lin0[nA:]
        meas_fn = []
        for f in cross:
            i, j = [v.variableID for v in f.adj_var_nodes]
            if i in groupA:
                si, di = idxA[i]
                sj, dj = idxB[j]
                xi = xA[si:si+di]
                xj = xB[sj:sj+dj]
            else:
                si, di = idxB[i]
                sj, dj = idxA[j]
                xi = xB[si:si+di]
                xj = xA[sj:sj+dj]
            x_loc = np.concatenate([xi, xj])
            meas_fn.append(f.meas_fn(x_loc))
        meas_fn = np.concatenate(meas_fn) 
        def meas_fn_super_between(xAB, *args):
            return meas_fn + J@(xAB-lin0)

        z_list = [f.measurement for f in cross]
        z_lambda_list = [f.measurement_lambda for f in cross]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_between, jac_fn_super_between, z_super, z_super_lambda


    for e in super_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v == "prior":
            lin0 = make_linpoint_for_group(u)
            meas_fn, jac_fn, z, z_lambda = make_super_prior_factor(u, lin0, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u]], z, z_lambda, meas_fn, jac_fn)
            f.adj_beliefs = [vn.belief for vn in f.adj_var_nodes]
            f.type = "super_prior"
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            
        else:
            lin0 = np.concatenate([make_linpoint_for_group(u), make_linpoint_for_group(v)])
            meas_fn, jac_fn, z, z_lambda = make_super_between_factor(u, v, lin0, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u], super_var_nodes[v]], z, z_lambda, meas_fn, jac_fn)
            f.adj_beliefs = [vn.belief for vn in f.adj_var_nodes]
            f.type = "super_between"
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            super_var_nodes[v].adj_factors.append(f)


    fg.n_factor_nodes = len(fg.factors)
    return fg



# -----------------------
# SLAM-like base graph
# -----------------------
def make_slam_like_graph(N=100, step_size=25, loop_prob=0.05, loop_radius=50, prior_prop=0.0, seed=None):
    if seed is None :
        rng = np.random.default_rng()  # ✅ Ensure we have an RNG
    else:
        rng = np.random.default_rng(seed)
    nodes, edges = [], []
    positions = []
    x, y = 0.0, 0.0
    positions.append((x, y))

    # ✅ Deterministic-by-RNG: trajectory generation
    for _ in range(1, int(N)):
        dx, dy = rng.standard_normal(2)  # replace np.random.randn
        norm = np.sqrt(dx**2 + dy**2) + 1e-6
        dx, dy = dx / norm * float(step_size), dy / norm * float(step_size)
        x, y = x + dx, y + dy
        positions.append((x, y))

    # Sequential edges along the path
    for i, (px, py) in enumerate(positions):
        nodes.append({
            "data": {"id": f"{i}", "layer": 0, "dim": 2, "num_base": 1},
            "position": {"x": float(px), "y": float(py)}
        })

    for i in range(int(N) - 1):
        edges.append({"data": {"source": f"{i}", "target": f"{i+1}"}})

    # ✅ Deterministic-by-RNG: loop-closure edges
    for i in range(int(N)):
        for j in range(i + 5, int(N)):
            if rng.random() < float(loop_prob):  # replace np.random.rand
                xi, yi = positions[i]
                xj, yj = positions[j]
                if np.hypot(xi - xj, yi - yj) < float(loop_radius):
                    edges.append({"data": {"source": f"{i}", "target": f"{j}"}})

    # ✅ Sample priors using the same RNG
    if prior_prop <= 0.0:
        strong_ids = {0}
    elif prior_prop >= 1.0:
        strong_ids = set(range(N))
    else:
        k = max(1, int(np.floor(prior_prop * N)))
        strong_ids = set(rng.choice(N, size=k, replace=False).tolist())

    # Add edges for nodes with strong priors
    for i in strong_ids:
        edges.append({"data": {"source": f"{i}", "target": "prior"}})

    return nodes, edges



# -----------------------
# Grid aggregation
# -----------------------
def fuse_to_super_grid(prev_nodes, prev_edges, gx, gy, layer_idx):
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)
    xmin, ymin = positions.min(axis=0); xmax, ymax = positions.max(axis=0)
    cell_w = (xmax - xmin) / gx if gx > 0 else 1.0
    cell_h = (ymax - ymin) / gy if gy > 0 else 1.0
    if cell_w == 0: cell_w = 1.0
    if cell_h == 0: cell_h = 1.0
    cell_map = {}
    for idx, n in enumerate(prev_nodes):
        x, y = n["position"]["x"], n["position"]["y"]
        cx = min(int((x - xmin) / cell_w), gx - 1)
        cy = min(int((y - ymin) / cell_h), gy - 1)
        cid = cx + cy * gx
        cell_map.setdefault(cid, []).append(idx)
    super_nodes, node_map = [], {}
    for cid, indices in cell_map.items():
        pts = positions[indices]
        mean_x, mean_y = pts.mean(axis=0)
        child_dims = [prev_nodes[i]["data"]["dim"] for i in indices]
        child_nums = [prev_nodes[i]["data"].get("num_base", 1) for i in indices]
        dim_val = int(max(1, sum(child_dims)))
        num_val = int(sum(child_nums))
        nid = str(len(super_nodes))
        super_nodes.append({
            "data": {
                "id": nid,
                "layer": layer_idx,
                "dim": dim_val,
                "num_base": num_val   # Inherit the sum
            },
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        for i in indices:
            node_map[prev_nodes[i]["data"]["id"]] = nid
    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            elif su == sv:
                eid = tuple(sorted((su, "prior")))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)

        elif v == "prior":
            su = node_map[u]
            eid = tuple(sorted((su, v)))
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map

# -----------------------
# K-Means aggregation
# -----------------------
def fuse_to_super_kmeans(prev_nodes, prev_edges, k, layer_idx, max_iters=20, tol=1e-6, seed=0):
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)
    n = positions.shape[0]
    if k <= 0: 
        k = 1
    k = min(k, n)
    rng = np.random.default_rng(seed)

    # -------- Improved initialization --------
    # Randomly sample k points without replacement to ensure each cluster starts with a distinct point
    init_idx = rng.choice(n, size=k, replace=False)
    centers = positions[init_idx]

    # Lloyd iterations
    for _ in range(max_iters):
        d2 = ((positions[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
        assign = np.argmin(d2, axis=1)

        # -------- Empty-cluster fix --------
        counts = np.bincount(assign, minlength=k)
        empty_clusters = np.where(counts == 0)[0]
        for ci in empty_clusters:
            # Find the largest cluster
            big_cluster = np.argmax(counts)
            big_idxs = np.where(assign == big_cluster)[0]
            # Steal one point over
            steal_idx = big_idxs[0]
            assign[steal_idx] = ci
            counts[big_cluster] -= 1
            counts[ci] += 1

        moved = 0.0
        for ci in range(k):
            idxs = np.where(assign == ci)[0]
            new_c = positions[idxs].mean(axis=0)
            moved = max(moved, float(np.linalg.norm(new_c - centers[ci])))
            centers[ci] = new_c
        if moved < tol:
            break

    # Final assign (redo once to be safe)
    d2 = ((positions[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
    assign = np.argmin(d2, axis=1)

    counts = np.bincount(assign, minlength=k)
    empty_clusters = np.where(counts == 0)[0]
    for ci in empty_clusters:
        big_cluster = np.argmax(counts)
        big_idxs = np.where(assign == big_cluster)[0]
        steal_idx = big_idxs[0]
        assign[steal_idx] = ci
        counts[big_cluster] -= 1
        counts[ci] += 1

    # ---------- Build the super graph ----------
    super_nodes, node_map = [], {}
    for ci in range(k):
        idxs = np.where(assign == ci)[0]
        pts = positions[idxs]
        mean_x, mean_y = pts.mean(axis=0)
        child_dims = [prev_nodes[i]["data"]["dim"] for i in idxs]
        child_nums = [prev_nodes[i]["data"].get("num_base", 1) for i in idxs]
        dim_val = int(max(1, sum(child_dims)))
        num_val = int(sum(child_nums)) 
        nid = f"{ci}"
        super_nodes.append({
            "data": {
                "id": nid,
                "layer": layer_idx,
                "dim": dim_val,
                "num_base": num_val   # Inherit the sum
            },
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        for i in idxs:
            node_map[prev_nodes[i]["data"]["id"]] = nid

    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]
        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            else:
                eid = (su, "prior")
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)
        else:
            su = node_map[u]
            eid = (su, "prior")
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map


def copy_to_abs(super_nodes, super_edges, layer_idx):
    abs_nodes = []
    for n in super_nodes:
        nid = n["data"]["id"].replace("s", "a", 1)
        abs_nodes.append({
            "data": {
                "id": nid,
                "layer": layer_idx,
                "dim": n["data"]["dim"],
                "num_base": n["data"].get("num_base", 1)  # Inherit
            },
            "position": {"x": n["position"]["x"], "y": n["position"]["y"]}
        })
    abs_edges = []
    for e in super_edges:
        abs_edges.append({"data": {
            "source": e["data"]["source"].replace("s", "a", 1),
            "target": e["data"]["target"].replace("s", "a", 1)
        }})
    return abs_nodes, abs_edges

# -----------------------
# Sequential merge (tail group absorbs remainder)
# -----------------------
def fuse_to_super_order(prev_nodes, prev_edges, k, layer_idx, tail_heavy=True):
    """
    Sequentially split prev_nodes in current order into k groups; the last group absorbs the remainder (tail_heavy=True).
    Reuse existing rules for aggregating dim/num_base, deduplicating edges, and propagating prior.
    """
    n = len(prev_nodes)
    if k <= 0: k = 1
    k = min(k, n)

    # Group sizes
    base = n // k
    rem  = n %  k
    if rem > 0:
        sizes = [k]*(base) + [rem]     # Tail absorbs remainder: ..., last += rem
    else:
        sizes = [k]*(base)

    # Build groups: record indices per group
    groups = []
    start = 0
    for s in sizes:
        groups.append(list(range(start, start+s)))
        start += s

    # ---- Build super_nodes & node_map ----
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)

    super_nodes, node_map = [], {}
    for gi, idxs in enumerate(groups):
        pts = positions[idxs]
        mean_x, mean_y = pts.mean(axis=0)

        child_dims = [prev_nodes[i]["data"]["dim"] for i in idxs]
        child_nums = [prev_nodes[i]["data"].get("num_base", 1) for i in idxs]
        dim_val = int(max(1, sum(child_dims)))
        num_val = int(sum(child_nums))

        nid = f"{gi}"  # Same as kmeans: use group index as id (string)
        super_nodes.append({
            "data": {
                "id": nid,
                "layer": layer_idx,
                "dim": dim_val,
                "num_base": num_val
            },
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        # Build base-id -> super-id mapping (note: ids are strings throughout)
        for i in idxs:
            node_map[prev_nodes[i]["data"]["id"]] = nid

    # ---- Super edges: keep and deduplicate inter-group edges; intra-group edges collapse to prior; prior edges roll up to their owning super ----
    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            else:
                # Intra-group pairwise edge → group prior (consistent with grid/kmeans handling)
                eid = tuple(sorted((su, "prior")))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)
        else:
            su = node_map[u]
            eid = tuple(sorted((su, "prior")))
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map


# -----------------------
# Tools
# -----------------------
def parse_layer_name(name):
    if name == "base": return ("base", 0)
    m = re.match(r"(super|abs)(\d+)$", name)
    return (m.group(1), int(m.group(2))) if m else ("base", 0)

def highest_pair_idx(names):
    hi = 0
    for nm in names:
        kind, k = parse_layer_name(nm)
        if kind in ("super","abs"): hi = max(hi, k)
    return hi


# -----------------------
# Initialization & Boundary
# -----------------------
def init_layers(N=100, step_size=25, loop_prob=0.05, loop_radius=50, prior_prop=0.0, seed=None):
    base_nodes, base_edges = make_slam_like_graph(N, step_size, loop_prob, loop_radius, prior_prop, seed)
    return [{"name": "base", "nodes": base_nodes, "edges": base_edges}]

VIEW_W, VIEW_H = 960, 600
ASPECT = VIEW_W / VIEW_H
AXIS_PAD=20.0
# ==== Blobal Status ====
layers = init_layers()


def adjust_bounds_to_aspect(xmin, xmax, ymin, ymax, aspect):
    cx=(xmin+xmax)/2; cy=(ymin+ymax)/2
    dx=xmax-xmin; dy=ymax-ymin
    if dx<=0: dx=1
    if dy<=0: dy=1
    if dx/dy > aspect:
        dy_new=dx/aspect
        return xmin,xmax,cy-dy_new/2,cy+dy_new/2
    else:
        dx_new=dy*aspect
        return cx-dx_new/2,cx+dx_new/2,ymin,ymax

def reset_global_bounds(base_nodes):
    global GLOBAL_XMIN, GLOBAL_XMAX, GLOBAL_YMIN, GLOBAL_YMAX
    global GLOBAL_XMIN_ADJ, GLOBAL_XMAX_ADJ, GLOBAL_YMIN_ADJ, GLOBAL_YMAX_ADJ
    xs=[n["position"]["x"] for n in base_nodes] or [0.0]
    ys=[n["position"]["y"] for n in base_nodes] or [0.0]
    GLOBAL_XMIN,GLOBAL_XMAX=min(xs),max(xs)
    GLOBAL_YMIN,GLOBAL_YMAX=min(ys),max(ys)
    GLOBAL_XMIN_ADJ,GLOBAL_XMAX_ADJ,GLOBAL_YMIN_ADJ,GLOBAL_YMAX_ADJ=adjust_bounds_to_aspect(
        GLOBAL_XMIN,GLOBAL_XMAX,GLOBAL_YMIN,GLOBAL_YMAX,ASPECT)

# ==== Blobal Status ====
layers = init_layers()
pair_idx = 0
reset_global_bounds(layers[0]["nodes"])
gbp_graph = None

# -----------------------
# GBP Graph Construction
# -----------------------
def build_noisy_pose_graph(
    nodes,
    edges,
    prior_sigma: float = 10,
    odom_sigma: float = 10,
    tiny_prior: float = 1e-12,
    seed=None,
):
    
    """
    Construct a 2D pose-only factor graph (linear, Gaussian) and inject noise.
    Parameters:
      prior_sigma : standard deviation of the strong prior (smaller = stronger)
      odom_sigma  : standard deviation of odometry measurement noise
      prior_prop  : 0.0 = anchor only; (0,1) = randomly select by proportion; >=1.0 = all
      tiny_prior  : a tiny prior added to all nodes to prevent singularity
      seed        : random seed (for reproducibility)
    """

    fg = FactorGraph(nonlinear_factors=False, eta_damping=0)

    var_nodes = []
    I2 = np.eye(2, dtype=float)
    N = len(nodes)

    # ---- Pre-generate noise ----
    prior_noises = {}
    odom_noises = {}

    if seed is None:
        rng = np.random.default_rng()
    else:
        rng = np.random.default_rng(seed)

    # Generate noise for all edges
    for e in edges:
        src = e["data"]["source"]; dst = e["data"]["target"]
        # Binary edge
        if dst != "prior":
            odom_noises[(int(src[:]), int(dst[:]))] = rng.normal(0.0, odom_sigma, size=2)
        # Unary edge (strong prior)
        elif dst == "prior":
            prior_noises[int(src[:])] = rng.normal(0.0, prior_sigma, size=2)


    # ---- variable nodes ----
    for i, n in enumerate(nodes):
        v = VariableNode(i, dofs=2)
        v.GT = np.array([n["position"]["x"], n["position"]["y"]], dtype=float)

        # Tiny prior
        v.prior.lam = tiny_prior * I2
        v.prior.eta = np.zeros(2, dtype=float)

        var_nodes.append(v)

    fg.var_nodes = var_nodes
    fg.n_var_nodes = len(var_nodes)


    # ---- prior factors ----
    def meas_fn_unary(x, *args):
        return x
    def jac_fn_unary(x, *args):
        return np.eye(2)
    # ---- odometry factors ----
    def meas_fn(xy, *args):
        return xy[2:] - xy[:2]
    def jac_fn(xy, *args):
        return np.array([[-1, 0, 1, 0],
                         [ 0,-1, 0, 1]], dtype=float)
    
    factors = []
    fid = 0

    for e in edges:
        src = e["data"]["source"]; dst = e["data"]["target"]
        if dst != "prior":
            i, j = int(src[:]), int(dst[:])
            vi, vj = var_nodes[i], var_nodes[j]

            meas = (vj.GT - vi.GT) + odom_noises[(i, j)]

            meas_lambda = np.eye(len(meas))/ (odom_sigma**2)
            f = Factor(fid, [vi, vj], meas, meas_lambda, meas_fn, jac_fn)
            f.type = "base"
            linpoint = np.r_[vi.GT, vj.GT]
            f.compute_factor(linpoint=linpoint, update_self=True)

            factors.append(f)
            vi.adj_factors.append(f)
            vj.adj_factors.append(f)
            fid += 1

        else:
            i = int(src[:])
            vi = var_nodes[i]
            z = vi.GT + prior_noises[i]

            z_lambda = np.eye(len(meas))/ (prior_sigma**2)
            f = Factor(fid, [vi], z, z_lambda, meas_fn_unary, jac_fn_unary)
            f.type = "prior"
            f.compute_factor(linpoint=z, update_self=True)

            factors.append(f)
            vi.adj_factors.append(f)
            fid += 1

        # anchor for initial position
        v0 = var_nodes[0]
        z = v0.GT

        z_lambda = np.eye(len(meas))/ ((1e-3)**2)
        f = Factor(fid, [v0], z, z_lambda, meas_fn_unary, jac_fn_unary)
        f.type = "prior"
        f.compute_factor(linpoint=z, update_self=True)

        factors.append(f)
        v0.adj_factors.append(f)
        fid += 1

    fg.factors = factors
    fg.n_factor_nodes = len(factors)
    return fg


def build_super_graph(layers, eta_damping=0.2):
    """
    Construct the super graph based on the base graph in layers[-2] and the super-grouping in layers[-1].
    Requirement: layers[-2]["graph"] is an already-built base graph (with unary/binary factors).
    layers[-1]["node_map"]: { base_node_id (str, e.g., 'b12') -> super_node_id (str) }
    """
    # ---------- Extract base & super ----------
    base_graph = layers[-2]["graph"]
    super_nodes = layers[-1]["nodes"]
    super_edges = layers[-1]["edges"]
    node_map    = layers[-1]["node_map"]   # 'bN' -> 'sX_...'

    # base: id(int) -> VariableNode, handy to query dofs and mu
    id2var = {vn.variableID: vn for vn in base_graph.var_nodes}

    # ---------- super_id -> [base_id(int)] ----------
    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)


    # ---------- For each super group, build a (start, dofs) table ----------
    # local_idx[sid][bid] = (start, dofs), total_dofs[sid] = sum(dofs)
    local_idx   = {}
    total_dofs  = {}
    for sid, group in super_groups.items():
        off = 0
        local_idx[sid] = {}
        for bid in group:
            d = id2var[bid].dofs
            local_idx[sid][bid] = (off, d)
            off += d
        total_dofs[sid] = off


    # ---------- Create super VariableNodes ----------
    fg = FactorGraph(nonlinear_factors=False, eta_damping=eta_damping)

    super_var_nodes = {}
    for i, sn in enumerate(super_nodes):
        sid = sn["data"]["id"]
        dofs = total_dofs.get(sid, 0)

        v = VariableNode(i, dofs=dofs)
        gt_vec = np.zeros(dofs)
        mu_blocks = []
        Sigma_blocks = []
        for bid, (st, d) in local_idx[sid].items():
            # === Stack base GT ===
            gt_base = getattr(id2var[bid], "GT", None)
            if gt_base is None or len(gt_base) != d:
                gt_base = np.zeros(d)
            gt_vec[st:st+d] = gt_base

            # === Stack base belief ===
            vb = id2var[bid]
            mu_blocks.append(vb.mu)
            Sigma_blocks.append(vb.Sigma)

        super_var_nodes[sid] = v
        v.GT = gt_vec

        mu_super = np.concatenate(mu_blocks) if mu_blocks else np.zeros(dofs)
        Sigma_super = block_diag(*Sigma_blocks) if Sigma_blocks else np.eye(dofs)
        lam = np.linalg.inv(Sigma_super)
        eta = lam @ mu_super
        v.mu = mu_super
        v.Sigma = Sigma_super
        v.belief = NdimGaussian(dofs, eta, lam)
        v.prior.lam = 1e-12 * lam
        v.prior.eta = 1e-12 * eta

        fg.var_nodes.append(v)

    fg.n_var_nodes = len(fg.var_nodes)

    # ---------- Utility: assemble a group's linpoint (using base belief means) ----------
    def make_linpoint_for_group(sid):
        x = np.zeros(total_dofs[sid])
        for bid, (st, d) in local_idx[sid].items():
            mu = getattr(id2var[bid], "mu", None)
            if mu is None or len(mu) != d:
                mu = np.zeros(d)
            x[st:st+d] = mu
        return x

    # ---------- 3) super prior (in-group unary + in-group binary) ----------
    def make_super_prior_factor(sid, base_factors):
        group = super_groups[sid]
        idx_map = local_idx[sid]
        ncols = total_dofs[sid]

        # Select factors whose variables are all within the group (unary or binary)
        in_group = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if all(vid in group for vid in vids):
                in_group.append(f)

        def meas_fn_super_prior(x_super, *args):
            meas_fn = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # Assemble this factor's local x
                x_loc_list = []
                for vid in vids:
                    st, d = idx_map[vid]
                    x_loc_list.append(x_super[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)
                meas_fn.append(f.meas_fn(x_loc))
            return np.concatenate(meas_fn) if meas_fn else np.zeros(0)

        def jac_fn_super_prior(x_super, *args):
            Jrows = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # Build this factor's local x for (potentially) nonlinear Jacobian
                x_loc_list = []
                dims = []
                for vid in vids:
                    st, d = idx_map[vid]
                    dims.append(d)
                    x_loc_list.append(x_super[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)

                Jloc = f.jac_fn(x_loc)
                
                # Map Jloc column blocks back to the super variable columns
                row = np.zeros((Jloc.shape[0], ncols))
                c0 = 0
                for vid, d in zip(vids, dims):
                    st, _ = idx_map[vid]
                    row[:, st:st+d] = Jloc[:, c0:c0+d]
                    c0 += d

                Jrows.append(row)
            return np.vstack(Jrows) if Jrows else np.zeros((0, ncols))

        # z_super: concatenate each base factor's z
        z_list = [f.measurement for f in in_group]
        z_lambda_list = [f.measurement_lambda for f in in_group]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_prior, jac_fn_super_prior, z_super, z_super_lambda 

    # ---------- 4) super between (cross-group binary) ----------
    def make_super_between_factor(sidA, sidB, base_factors):
        groupA, groupB = super_groups[sidA], super_groups[sidB]
        idxA, idxB = local_idx[sidA], local_idx[sidB]
        nA, nB = total_dofs[sidA], total_dofs[sidB]

        cross = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if len(vids) != 2:
                continue
            i, j = vids
            # One side in A, the other side in B
            if (i in groupA and j in groupB) or (i in groupB and j in groupA):
                cross.append(f)


        def meas_fn_super_between(xAB, *args):
            xA, xB = xAB[:nA], xAB[nA:]
            meas_fn = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                x_loc = np.concatenate([xi, xj])
                meas_fn.append(f.meas_fn(x_loc))
            return np.concatenate(meas_fn) 

        def jac_fn_super_between(xAB, *args):
            xA, xB = xAB[:nA], xAB[nA:]
            Jrows = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                    left_start, right_start = si, nA + sj
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                    left_start, right_start = nA + si, sj
                x_loc = np.concatenate([xi, xj])
                Jloc = f.jac_fn(x_loc)

                row = np.zeros((Jloc.shape[0], nA + nB))
                row[:, left_start:left_start+di]   = Jloc[:, :di] 
                row[:, right_start:right_start+dj] = Jloc[:, di:di+dj] 

                Jrows.append(row)
            
            return np.vstack(Jrows) 

        z_list = [f.measurement for f in cross]
        z_lambda_list = [f.measurement_lambda for f in cross]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_between, jac_fn_super_between, z_super, z_super_lambda


    for e in super_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v == "prior":
            meas_fn, jac_fn, z, z_lambda = make_super_prior_factor(u, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u]], z, z_lambda, meas_fn, jac_fn)
            f.adj_beliefs = [vn.belief for vn in f.adj_var_nodes]
            f.type = "super_prior"
            lin0 = make_linpoint_for_group(u)
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            
        else:
            meas_fn, jac_fn, z, z_lambda = make_super_between_factor(u, v, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u], super_var_nodes[v]], z, z_lambda, meas_fn, jac_fn)
            f.adj_beliefs = [vn.belief for vn in f.adj_var_nodes]
            f.type = "super_between"
            lin0 = np.concatenate([make_linpoint_for_group(u), make_linpoint_for_group(v)])
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            super_var_nodes[v].adj_factors.append(f)


    fg.n_factor_nodes = len(fg.factors)
    return fg


def build_abs_graph(
    layers,
    r_reduced = 2,
    eta_damping=0.2):

    abs_var_nodes = {}
    Bs = {}
    ks = {}
    k2s = {}

    # === 1. Build Abstraction Variables ===
    abs_fg = FactorGraph(nonlinear_factors=False, eta_damping=eta_damping)
    sup_fg = layers[-2]["graph"]

    for sn in sup_fg.var_nodes:
        if sn.dofs <= r_reduced:
            r = sn.dofs  # No reduction if dofs already <= r
        else:
            r = r_reduced

        sid = sn.variableID
        varis_sup_mu = sn.mu
        varis_sup_sigma = sn.Sigma
        
        # Step 1: Eigen decomposition of the covariance matrix
        eigvals, eigvecs = np.linalg.eigh(varis_sup_sigma)

        # Step 2: Sort eigenvalues and eigenvectors in descending order of eigenvalues
        idx = np.argsort(eigvals)[::-1]      # Get indices of sorted eigenvalues (largest first)
        eigvals = eigvals[idx]               # Reorder eigenvalues
        eigvecs = eigvecs[:, idx]            # Reorder corresponding eigenvectors

        # Step 3: Select the top-k eigenvectors to form the projection matrix (principal subspace)
        B_reduced = eigvecs[:, :r]                 # B_reduced: shape (sup_dof, r), projects r to sup_dof
        Bs[sid] = B_reduced                        # Store the projection matrix for this variable

        # Step 4: Project eta and Lam onto the reduced 2D subspace
        varis_abs_mu = B_reduced.T @ varis_sup_mu          # Projected natural mean: shape (2,)
        varis_abs_sigma = B_reduced.T @ varis_sup_sigma @ B_reduced  # Projected covariance: shape (2, 2)
        ks[sid] = varis_sup_mu - B_reduced @ varis_abs_mu  # Store the mean offset for this variable
        #k2s[sid] = varis_sup_sigma - B_reduced @ varis_abs_sigma @ B_reduced.T  # Residual covariance

        varis_abs_lam = np.linalg.inv(varis_abs_sigma)  # Inverse covariance (precision matrix): shape (2, 2)
        varis_abs_eta = varis_abs_lam @ varis_abs_mu  # Natural parameters: shape (2,)

        v = VariableNode(sid, dofs=r)
        v.GT = sn.GT
        v.prior.lam = 1e-10 * varis_abs_lam
        v.prior.eta = 1e-10 * varis_abs_eta
        v.mu = varis_abs_mu
        v.Sigma = varis_abs_sigma
        v.belief = NdimGaussian(r, varis_abs_eta, varis_abs_lam)

        abs_var_nodes[sid] = v
        abs_fg.var_nodes.append(v)
    abs_fg.n_var_nodes = len(abs_fg.var_nodes)


    # === 2. Abstract Prior ===
    def make_abs_prior_factor(sup_factor):
        abs_id = sup_factor.adj_var_nodes[0].variableID
        B = Bs[abs_id]
        k = ks[abs_id]

        def meas_fn_abs_prior(x_abs, *args):
            return sup_factor.meas_fn(B @ x_abs + k)
        
        def jac_fn_abs_prior(x_abs, *args):
            return sup_factor.jac_fn(B @ x_abs + k) @ B

        return meas_fn_abs_prior, jac_fn_abs_prior, sup_factor.measurement, sup_factor.measurement_lambda
    


    # === 3. Abstract Between ===
    def make_abs_between_factor(sup_factor):
        vids = [v.variableID for v in sup_factor.adj_var_nodes]
        i, j = vids # two variable IDs
        ni = abs_var_nodes[i].dofs
        Bi, Bj = Bs[i], Bs[j]
        ki, kj = ks[i], ks[j]                       
    

        def meas_fn_super_between(xij, *args):
            xi, xj = xij[:ni], xij[ni:]
            return sup_factor.meas_fn(np.concatenate([Bi @ xi + ki, Bj @ xj + kj]))

        def jac_fn_super_between(xij, *args):
            xi, xj = xij[:ni], xij[ni:]
            J_sup = sup_factor.jac_fn(np.concatenate([Bi @ xi + ki, Bj @ xj + kj]))
            J_abs = np.zeros((J_sup.shape[0], ni + xj.shape[0]))
            J_abs[:, :ni] = J_sup[:, :Bi.shape[0]] @ Bi
            J_abs[:, ni:] = J_sup[:, Bi.shape[0]:] @ Bj
            return J_abs
        
        return meas_fn_super_between, jac_fn_super_between, sup_factor.measurement, sup_factor.measurement_lambda
    

    for f in sup_fg.factors:
        if len(f.adj_var_nodes) == 1:
            meas_fn, jac_fn, z, z_lambda = make_abs_prior_factor(f)
            v = abs_var_nodes[f.adj_var_nodes[0].variableID]
            abs_f = Factor(f.factorID, [v], z, z_lambda, meas_fn, jac_fn)
            abs_f.type = "abs_prior"
            abs_f.adj_beliefs = [v.belief]

            lin0 = v.mu
            abs_f.compute_factor(linpoint=lin0, update_self=True)

            abs_fg.factors.append(abs_f)
            v.adj_factors.append(abs_f)

        elif len(f.adj_var_nodes) == 2:
            meas_fn, jac_fn, z, z_lambda = make_abs_between_factor(f)
            i, j = [v.variableID for v in f.adj_var_nodes]
            vi, vj = abs_var_nodes[i], abs_var_nodes[j]
            abs_f = Factor(f.factorID, [vi, vj], z, z_lambda, meas_fn, jac_fn)
            abs_f.type = "abs_between"
            abs_f.adj_beliefs = [vi.belief, vj.belief]

            lin0 = np.concatenate([vi.mu, vj.mu])
            abs_f.compute_factor(linpoint=lin0, update_self=True)

            abs_fg.factors.append(abs_f)
            vi.adj_factors.append(abs_f)
            vj.adj_factors.append(abs_f)

    abs_fg.n_factor_nodes = len(abs_fg.factors)


    return abs_fg, Bs, ks, k2s


def bottom_up_modify_super_graph(layers):
    """
    Update super-node means (mu) from base nodes,
    and simultaneously adjust variable beliefs and adjacent messages.
    """

    base_graph = layers[-2]["graph"]
    super_graph = layers[-1]["graph"]
    node_map = layers[-1]["node_map"]

    id2var = {vn.variableID: vn for vn in base_graph.var_nodes}

    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)

    sid2idx = {sn["data"]["id"]: i for i, sn in enumerate(layers[-1]["nodes"])}

    for sid, group in super_groups.items():
        mu_blocks = [id2var[bid].mu for bid in group]
        mu_super = np.concatenate(mu_blocks) if mu_blocks else np.zeros(0)

        if sid in sid2idx:
            idx = sid2idx[sid]
            v = super_graph.var_nodes[idx]

            # Old belief
            old_belief = v.belief

            # 1. Update mu
            #v.mu = mu_super

            # 2. New belief (use old Sigma + new mu)
            #lam = np.linalg.inv(v.Sigma)
            #eta = lam @ v.mu
            #new_belief = NdimGaussian(v.dofs, eta, lam)
            #v.belief = new_belief
            #v.prior = NdimGaussian(v.dofs)


        """
            # 3. update adj_beliefs and messages
            if v.adj_factors:
                n_adj = len(v.adj_factors)
                d_eta = new_belief.eta - old_belief.eta
                d_lam = new_belief.lam - old_belief.lam
                for f in v.adj_factors:
                    if v in f.adj_var_nodes:
                        idx_in_factor = f.adj_var_nodes.index(v)
                        # update factor's adj_belief
                        f.adj_beliefs[idx_in_factor] = new_belief
                        # update corresponding messages
                        msg = f.messages[idx_in_factor]
                        msg.eta += d_eta / n_adj
                        msg.lam += d_lam / n_adj
                        f.messages[idx_in_factor] = msg
        """


def top_down_modify_base_and_abs_graph(layers):
    """
    From the super graph downward, split μ to the base graph,
    and simultaneously update the base variables' beliefs and the adjacent factors'
    adj_beliefs / messages.

    Assume layers[-1] is the super layer and layers[-2] is the base layer.
    """
    super_graph = layers[-1]["graph"]
    base_graph = layers[-2]["graph"]
    node_map   = layers[-1]["node_map"]  # { base_id(str) -> super_id(str) }


    # super_id -> [base_id(int)]
    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)

    # child lookup
    id2var_base = {vn.variableID: vn for vn in base_graph.var_nodes}

    a = 0
    for s_var in super_graph.var_nodes:
        sid = str(s_var.variableID)
        if sid not in super_groups:
            continue
        base_ids = super_groups[sid]

        # === split super.mu to base ===
        mu_super = s_var.mu
        off = 0
        for bid in base_ids:
            v = id2var_base[bid]
            d = v.dofs
            mu_child = mu_super[off:off+d]
            off += d

            old_belief = v.belief

            # 1. update mu
            v.mu = mu_child

            # 2. new belief（keep Σ unchanged，use new mu）
            eta = v.belief.lam @ v.mu
            new_belief = NdimGaussian(v.dofs, eta, v.belief.lam)
            v.belief = new_belief
            v.prior = NdimGaussian(v.dofs, 1e-10*eta, 1e-10*v.belief.lam)

            # 3. Sync to adjacent factors (this step is important)
            if v.adj_factors:
                n_adj = len(v.adj_factors)
                d_eta = new_belief.eta - old_belief.eta
                d_lam = new_belief.lam - old_belief.lam

                for f in v.adj_factors:
                    if v in f.adj_var_nodes:
                        idx = f.adj_var_nodes.index(v)
                        # update adj_beliefs
                        f.adj_beliefs[idx] = new_belief
                        # correct coresponding message
                        msg = f.messages[idx]
                        msg.eta += d_eta / n_adj
                        msg.lam += d_lam / n_adj
                        f.messages[idx] = msg

    return base_graph


def top_down_modify_super_graph(layers):
    """
    From the abs graph downward, project mu / Sigma back to the super graph,
    and simultaneously update the super variables' beliefs and the adjacent
    factors' adj_beliefs / messages.

    Requirements:
      - layers[-1] is abs, layers[-2] is super
      - Factors at the abs level and the super level share the same factorID (one-to-one)
      - The columns of B are orthonormal (from covariance eigenvectors; eigenvectors from np.linalg.eigh are orthogonal)
    """

    abs_graph   = layers[-1]["graph"]
    super_graph = layers[-2]["graph"]
    Bs  = layers[-1]["Bs"]   # { super_id(int) -> B (d_super × r) }
    ks  = layers[-1]["ks"]   # { super_id(int) -> k (d_super,) }
    #k2s = layers[-1]["k2s"]  # { super_id(int) -> residual covariance (d_super × d_super) }

    # Prebuild abs factor index: factorID -> Factor
    #abs_f_by_id = {f.factorID: f for f in getattr(abs_graph, "factors", [])}

    # ---- First project variables' mu / Sigma and update beliefs ----
    for sn in super_graph.var_nodes:
        sid = sn.variableID
        if sid not in Bs or sid not in ks:
            continue
        B  = Bs[sid]    # (d_s × r)
        k  = ks[sid]    # (d_s,)
        #k2 = k2s[sid]   # (d_s × d_s)

        # x_s = B x_a + k; Σ_s = B Σ_a Bᵀ + k2
        mu_a    = abs_graph.var_nodes[sid].mu
        mu_s    = B @ mu_a + k
        sn.mu   = mu_s

        # Refresh super belief (natural parameters) with the new μ and Σ
        eta = sn.belief.lam @ sn.mu
        new_belief = NdimGaussian(sn.dofs, eta, sn.belief.lam)
        sn.belief  = new_belief
        sn.prior = NdimGaussian(sn.dofs, 1e-10*eta, 1e-10*sn.belief.lam)

    # ---- Then push abs messages back to super, preserving the original super messages on the orthogonal complement ----
    # Idea: for the side of the super factor f_sup connected to variable sid:
    #   η_s_new = B η_a + (I - B Bᵀ) η_s_old
    #   Λ_s_new = B Λ_a Bᵀ + (I - B Bᵀ) Λ_s_old (I - B Bᵀ)
    # This ensures the subspace is governed by the abs message, while the orthogonal complement keeps the old super message.
    for sn in super_graph.var_nodes:
        # Iterate over super factors adjacent to this super variable
        for f_sup in sn.adj_factors:
            idx_side = f_sup.adj_var_nodes.index(sn)
            # update the factor's recorded adjacent belief on that side (optional; usually refreshed in the next iteration)
            f_sup.adj_beliefs[idx_side] = sn.belief

    return



def refresh_gbp_results(layers):
    """
    Precompute an affine map to the base plane for each layer:
      base:   A_i = I2, b_i = 0
      super:  A_s = (1/m) [A_c1, A_c2, ..., A_cm], b_s = (1/m) Σ b_cj
      abs:    A_a = A_super(s) @ B_s,             b_a = A_super(s) @ k_s + b_super(s)
    Then refresh gbp_result via pos = A @ mu + b.
    Convention: use string keys everywhere (aligned with Cytoscape ids).
    """
    if not layers:
        return

    # ---------- 1) Bottom-up: compute A, b for each layer ----------
    for li, L in enumerate(layers):
        g = L.get("graph")
        if g is None:
            L.pop("A", None); L.pop("b", None); L.pop("gbp_result", None)
            continue

        name = L["name"]
        # ---- base ----
        if name.startswith("base"):
            L["A"], L["b"] = {}, {}
            for v in g.var_nodes:
                key = str(v.variableID)
                L["A"][key] = np.eye(2)
                L["b"][key] = np.zeros(2, dtype=float)

        # ---- super ----
        elif name.startswith("super"):
            parent = layers[li - 1]
            node_map = L["node_map"]  # { prev_id(str) -> super_id(str) }

            # Grouping (preserve insertion order to match the concatenation order in build_super_graph)
            groups = {}
            for prev_id, s_id in node_map.items():
                prev_id = str(prev_id); s_id = str(s_id)
                groups.setdefault(s_id, []).append(prev_id)

            L["A"], L["b"] = {}, {}
            for s_id, children in groups.items():
                m = len(children)
                # Horizontal concatenation [A_c1, A_c2, ...]
                A_blocks = [parent["A"][cid] for cid in children]  # each block has shape 2×d_c
                A_concat = np.hstack(A_blocks) if A_blocks else np.zeros((2, 0))
                b_sum = sum((parent["b"][cid] for cid in children), start=np.zeros(2, dtype=float))
                L["A"][s_id] = (1.0 / m) * A_concat
                L["b"][s_id] = (1.0 / m) * b_sum

        # ---- abs ----
        elif name.startswith("abs"):
            parent = layers[li - 1]  # the corresponding super layer
            Bs, ks = L["Bs"], L["ks"]  # Note: keys are the super variableIDs (int)

            # Build a mapping between super variableID (int) and the super string id (follow node list order)
            # The order of nodes in the parent (super) and this (abs) layer is consistent (copy_to_abs preserves order)
            int2sid = {i: str(parent["nodes"][i]["data"]["id"]) for i in range(len(parent["nodes"]))}

            L["A"], L["b"] = {}, {}
            for av in g.var_nodes:
                sid_int = av.variableID              # super variableID (int)
                s_id = int2sid.get(sid_int, str(sid_int))  # super string id (also the abs node id)
                B = Bs[sid_int]                       # (sum d_c) × r
                k = ks[sid_int]                       # (sum d_c,)

                A_sup = parent["A"][s_id]             # shape 2 × (sum d_c)
                b_sup = parent["b"][s_id]             # shape (2,)

                L["A"][s_id] = A_sup @ B              # 2 × r
                L["b"][s_id] = A_sup @ k + b_sup      # 2,

        else:
            # Unknown layer type
            L["A"], L["b"] = {}, {}

    # ---------- 2) Compute gbp_result ----------
    for li, L in enumerate(layers):
        g = L.get("graph")
        if g is None:
            L.pop("gbp_result", None)
            continue

        name = L["name"]
        res = {}

        if name.startswith("base"):
            for v in g.var_nodes:
                vid = str(v.variableID)
                res[vid] = v.mu[:2].tolist()

        elif name.startswith("super"):
            # Directly use A_super, b_super mapping
            # nodes order is consistent with var_nodes order
            for i, v in enumerate(g.var_nodes):
                s_id = str(L["nodes"][i]["data"]["id"])
                A, b = L["A"][s_id], L["b"][s_id]   # A: 2×(sum d_c)
                res[s_id] = (A @ v.mu + b).tolist()

        elif name.startswith("abs"):
            parent = layers[li - 1]
            # Also align via string ids
            for i, v in enumerate(g.var_nodes):
                a_id = str(L["nodes"][i]["data"]["id"])  # same text as the super s_id
                A, b = L["A"][a_id], L["b"][a_id]        # A: 2×r
                res[a_id] = (A @ v.mu + b).tolist()

        L["gbp_result"] = res



def vloop(layers):
    """
    Simplified V-cycle:
    1) bottom-up: rebuild and iterate once for base / super / abs in order
    2) top-down: propagate mu from super -> base
    3) refresh gbp_result on each layer for UI use
    """

    # ---- bottom-up ----
    #if layers and "graph" in layers[0]:
    #    layers[0]["graph"].synchronous_iteration()
        
    for i in range(1, len(layers)):
        name = layers[i]["name"]

        if name.startswith("super1"):
            # Update super using the previous layer's graph
            # layers[i]["graph"] = build_super_graph(layers[:i+1])
            #layers[i]["graph"].synchronous_iteration()
            #bottom_up_modify_super_graph(layers[:i+1])
            #build_super_graph(layers[:i+1])
            #update_super_graph_linearized(layers[:i+1])
            pass

        elif name.startswith("super"):
            # Update super using the previous layer's graph
            layers[i]["graph"] = build_super_graph(layers[:i+1])
            #layers[i]["graph"] = update_super_graph_linearized(layers[:i+1])

        elif name.startswith("abs"):
            # Rebuild abs using the previous super
            abs_graph, Bs, ks, k2s = build_abs_graph(layers[:i+1])
            layers[i]["graph"] = abs_graph
            layers[i]["Bs"], layers[i]["ks"], layers[i]["k2s"] = Bs, ks, k2s

        # After build, one iteration per layer
        if "graph" in layers[i]:
            layers[i]["graph"].synchronous_iteration()

    # ---- top-down (pass mu) ----
    for i in range(len(layers) - 1, 0, -1):
        # After one iterations per layer, reproject
        if "graph" in layers[i]:
            layers[i]["graph"].synchronous_iteration()

        # this is very important, but dont know why yet
        # so abs layer need more iterations
        #if name.startswith("abs"):
            #layers[i]["graph"].synchronous_iteration()  

        name = layers[i]["name"]
        if name.startswith("super"):
            # Split super.mu back to base/abs
            top_down_modify_base_and_abs_graph(layers[:i+1])

        elif name.startswith("abs"):
            # Project abs.mu back to super
            top_down_modify_super_graph(layers[:i+1])


    # ---- refresh gbp_result for UI ----
    refresh_gbp_results(layers)



def compute_energy(layers):
    """
    energy = 0.5 * sum_i || mu_i[0:2] - GT_i[0:2] ||^2  over base layer variables
    vectorized version (no per-node np calls)
    """
    try:
        base_graph = layers[0].get("graph", None)
        var_nodes = getattr(base_graph, "var_nodes", None)
        if base_graph is None or not var_nodes:
            return "Energy: -"

        # Stack mu[:2] and GT[:2] for all variables into (N, 2)
        mus_2 = np.stack([np.asarray(v.mu[:2], dtype=float) for v in var_nodes], axis=0)
        gts_2 = np.stack([np.asarray(v.GT[:2], dtype=float) for v in var_nodes], axis=0)

        diff = mus_2 - gts_2                      # shape (N, 2)
        total = 0.5 * np.sum(diff * diff)         # 0.5 * sum of squared norms

        return f"Energy: {float(total):.4f}"
    except Exception:
        return "Energy: -"
    


class VGraph:
    def __init__(self,
                 layers,
                 nonlinear_factors=True,
                 eta_damping=0.4,
                 beta=0.0,
                 iters_since_relinear=0,
                 num_undamped_iters=0,
                 min_linear_iters=100,
                 wild_thresh=0):

        self.layers = layers
        self.iters_since_relinear = iters_since_relinear
        self.min_linear_iters = min_linear_iters
        self.nonlinear_factors = nonlinear_factors
        self.eta_damping = eta_damping
        self.wild_thresh = wild_thresh

        #self.energy_history = []
        #self.error_history = []
        #self.nmsgs_history = []
        #self.mus = []


    def vloop(self):
        """
        Simplified V-cycle:
        1) bottom-up: rebuild and iterate once for base / super / abs in order
        2) top-down: propagate mu from super -> base
        3) refresh gbp_result on each layer for UI use
        """

        layers = self.layers

        # ---- bottom-up ----
        #if layers and "graph" in layers[0]:
        #    layers[0]["graph"].synchronous_iteration()
            
        for i in range(1, len(layers)):
            name = layers[i]["name"]

            if name.startswith("super1"):
                # Update super using the previous base graph's new linearization points
                pass

            elif name.startswith("super"):
                # Update super using the previous layer's graph
                layers[i]["graph"] = build_super_graph(layers[:i+1], eta_damping=self.eta_damping)

            elif name.startswith("abs"):
                # Rebuild abs using the previous super
                abs_graph, Bs, ks, k2s = build_abs_graph(layers[:i+1], eta_damping=self.eta_damping)
                layers[i]["graph"] = abs_graph
                layers[i]["Bs"], layers[i]["ks"], layers[i]["k2s"] = Bs, ks, k2s

            # After build, one iteration per layer
            if "graph" in layers[i]:
                layers[i]["graph"].synchronous_iteration()

        # ---- top-down (pass mu) ----
        for i in range(len(layers) - 1, 0, -1):
            # After one iterations per layer, reproject
            if "graph" in layers[i]:
                layers[i]["graph"].synchronous_iteration()

            # this is very important, but dont know why yet
            # so abs layer need more iterations
            #if name.startswith("abs"):
                #layers[i]["graph"].synchronous_iteration()  

            name = layers[i]["name"]
            if name.startswith("super"):
                # Split super.mu back to base/abs
                top_down_modify_base_and_abs_graph(layers[:i+1])

            elif name.startswith("abs"):
                # Project abs.mu back to super
                top_down_modify_super_graph(layers[:i+1])

        # ---- refresh gbp_result for UI ----
        #refresh_gbp_results(layers)
        return layers
vg = VGraph(layers)


In [3]:
N=512
step=25
prob=0.05
radius=50 
prior_prop=0.02
prior_sigma=1
loop_sigma=0.01
layers = []



layers = init_layers(N=N, step_size=step, loop_prob=prob, loop_radius=radius, prior_prop=prior_prop, seed=2001)
pair_idx = 0


# 构建 GBP 图
gbp_graph = build_noisy_pose_graph(layers[0]["nodes"], layers[0]["edges"],
                                    prior_sigma=prior_sigma,
                                    odom_sigma=odom_sigma,
                                    seed=2001)
layers[0]["graph"] = gbp_graph
gbp_graph.num_undamped_iters = 0
gbp_graph.min_linear_iters = 2000
opts=[{"label":"base","value":"base"}]



In [4]:
def energy_map(graph, include_priors: bool = True, include_factors: bool = True) -> float:
    """
    It is actually the sum of squares of distances.
    """
    total = 0.0

    for v in graph.var_nodes[:graph.n_var_nodes]:
        gt = np.asarray(v.GT[0:2], dtype=float)
        r = np.asarray(v.mu[0:2], dtype=float) - gt
        total += 0.5 * float(r.T @ r)

    return total


In [6]:
basegraph = layers[0]["graph"]
for it in range(2000):
    basegraph.synchronous_iteration()
    energy = energy_map(basegraph, include_priors=True, include_factors=True)
    print(f"Iter {it+1:03d} | Energy = {energy:.6f}")


Iter 001 | Energy = 81216946.922214
Iter 002 | Energy = 60764519.818854
Iter 003 | Energy = 44846122.972340
Iter 004 | Energy = 37681013.622285
Iter 005 | Energy = 34425457.010148
Iter 006 | Energy = 31314304.652222
Iter 007 | Energy = 28258940.918308
Iter 008 | Energy = 25613438.511047
Iter 009 | Energy = 23416615.787896
Iter 010 | Energy = 21187442.170679
Iter 011 | Energy = 18912041.630807
Iter 012 | Energy = 16593571.609666
Iter 013 | Energy = 14328361.945581
Iter 014 | Energy = 12135673.713626
Iter 015 | Energy = 9538801.845862
Iter 016 | Energy = 7012843.175017
Iter 017 | Energy = 4819820.326770
Iter 018 | Energy = 3501050.320654
Iter 019 | Energy = 2616284.895363
Iter 020 | Energy = 1738106.951841
Iter 021 | Energy = 866153.570009
Iter 022 | Energy = 2337.924787
Iter 023 | Energy = 2293.491626
Iter 024 | Energy = 2254.633303
Iter 025 | Energy = 2249.315470
Iter 026 | Energy = 2267.489087
Iter 027 | Energy = 2312.282560
Iter 028 | Energy = 2356.970555
Iter 029 | Energy = 2390.154

In [None]:
N=512
step=25
prob=0.05
radius=50 
prior_prop=0.02
prior_sigma=1
odom_sigma=1
layers = []



layers = init_layers(N=N, step_size=step, loop_prob=prob, loop_radius=radius, prior_prop=prior_prop, seed=2001)
pair_idx = 0


# Create GBP graph
gbp_graph = build_noisy_pose_graph(layers[0]["nodes"], layers[0]["edges"],
                                    prior_sigma=prior_sigma,
                                    odom_sigma=odom_sigma,
                                    seed=2001)
layers[0]["graph"] = gbp_graph
gbp_graph.num_undamped_iters = 0
gbp_graph.min_linear_iters = 200
opts=[{"label":"base","value":"base"}]

basegraph = layers[0]["graph"]
for it in range(20):
    basegraph.synchronous_iteration()
    energy = energy_map(basegraph, include_priors=True, include_factors=True)
    print(f"Iter {it+1:03d} | Energy = {energy:.6f}")


Iter 001 | Energy = 99117762.977742
Iter 002 | Energy = 93472676.746413
Iter 003 | Energy = 81216946.922214
Iter 004 | Energy = 60764519.818854
Iter 005 | Energy = 44846122.972340
Iter 006 | Energy = 37681013.622285
Iter 007 | Energy = 34425457.010148
Iter 008 | Energy = 31314304.652222
Iter 009 | Energy = 28258940.918308
Iter 010 | Energy = 25613438.511047
Iter 011 | Energy = 23416615.787896
Iter 012 | Energy = 21187442.170679
Iter 013 | Energy = 18912041.630807
Iter 014 | Energy = 16593571.609666
Iter 015 | Energy = 14328361.945581
Iter 016 | Energy = 12135673.713626
Iter 017 | Energy = 9538801.845862
Iter 018 | Energy = 7012843.175017
Iter 019 | Energy = 4819820.326770
Iter 020 | Energy = 3501050.320654


In [None]:
N=512
step=25
prob=0.05
radius=50 
prior_prop=0.02
prior_sigma=1
odom_sigma=1
layers = []



layers = init_layers(N=N, step_size=step, loop_prob=prob, loop_radius=radius, prior_prop=prior_prop, seed=2001)
pair_idx = 0


# Create GBP graph
gbp_graph = build_noisy_pose_graph(layers[0]["nodes"], layers[0]["edges"],
                                    prior_sigma=prior_sigma,
                                    odom_sigma=odom_sigma,
                                    seed=2001)
layers[0]["graph"] = gbp_graph
gbp_graph.num_undamped_iters = 0
gbp_graph.min_linear_iters = 0
opts=[{"label":"base","value":"base"}]

basegraph = layers[0]["graph"]

basegraph.joint_distribution_cov()[0].shape

print(energy_map(basegraph, include_priors=True, include_factors=True))

101089344.5794572


In [10]:
for i in range(10):
    total = 0
    a = basegraph.joint_distribution_cov()[0].reshape(512,2)[:,:]
    for i,v in enumerate(basegraph.var_nodes[:basegraph.n_var_nodes]):
        gt = np.asarray(v.GT[0:2], dtype=float)
        r = np.asarray(a[i][0:2], dtype=float) - gt
        v.mu = a[i]
        total += 0.5 * float(r.T @ r)
    print(total)
    basegraph.relinearise_factors()



2324.863742651605
2324.863742651605
2324.863742651605
2324.863742651605
2324.863742651605
2324.863742651605
2324.863742651605
2324.863742651605
2324.863742651605
2324.863742651605


In [41]:
supergraph = layers[-1]["graph"]

"""
# 找到所有 super_prior factors
prior_factors = [f for f in supergraph.factors if getattr(f, "type", "") == "super_prior"]

# 对 super_prior factors 的邻居变量
prior_vars = []
for f in prior_factors:
    for v in f.adj_var_nodes:
        if v not in prior_vars:
            prior_vars.append(v)

# 找到所有 super_between factors
between_factors = [f for f in supergraph.factors if getattr(f, "type", "") != "super_prior"]

# 对 super_between factors 的邻居变量
between_vars = []
for f in between_factors:
    for v in f.adj_var_nodes:
        if v not in between_vars:
            between_vars.append(v)
"""

for it in range(1000):
    supergraph.synchronous_iteration()
    energy = supergraph.energy_map(include_priors=True, include_factors=True)
    print(f"Iter {it+1:03d} | Energy = {energy:.6f}")

Iter 001 | Energy = 8041.390395
Iter 002 | Energy = 8041.390395
Iter 003 | Energy = 8041.390395
Iter 004 | Energy = 8041.390395
Iter 005 | Energy = 8041.390395
Iter 006 | Energy = 8041.390395
Iter 007 | Energy = 8041.390395
Iter 008 | Energy = 8041.390395
Iter 009 | Energy = 8041.390395
Iter 010 | Energy = 8041.390395
Iter 011 | Energy = 8041.390395
Iter 012 | Energy = 8041.390395
Iter 013 | Energy = 8041.390395
Iter 014 | Energy = 8041.390395
Iter 015 | Energy = 8041.390395
Iter 016 | Energy = 8041.390395
Iter 017 | Energy = 8041.390395
Iter 018 | Energy = 8041.390395
Iter 019 | Energy = 8041.390395
Iter 020 | Energy = 8041.390395
Iter 021 | Energy = 8041.390395
Iter 022 | Energy = 8041.390395
Iter 023 | Energy = 8041.390395
Iter 024 | Energy = 8041.390395
Iter 025 | Energy = 8041.390395
Iter 026 | Energy = 8041.390395
Iter 027 | Energy = 8041.390395
Iter 028 | Energy = 8041.390395
Iter 029 | Energy = 8041.390395
Iter 030 | Energy = 8041.390395
Iter 031 | Energy = 8041.390395
Iter 032