In [4]:
import re
import dash
from dash import html, dcc, Input, Output, State, no_update
import dash_cytoscape as cyto
import numpy as np
from scipy.linalg import block_diag
from collections import defaultdict

# ==== GBP import ====
from gbp.gbp import *

app = dash.Dash(__name__)
app.title = "Factor Graph SVD Abs&Recovery"

# -----------------------
# SLAM-like base graph
# -----------------------
def make_slam_like_graph(
    N=100, step_size=25, loop_prob=0.05, loop_radius=50, prior_prop=0.0, seed=None
):
    if seed is None:
        rng = np.random.default_rng()
    else:
        rng = np.random.default_rng(seed)
    # --- helpers ---
    def wrap_angle(a):
        # (-pi, pi]
        return np.arctan2(np.sin(a), np.cos(a))

    def relpose_SE2(pose_i, pose_j):
        """ Return z_ij = [dx_local, dy_local, dtheta] where 'local' is frame of i """
        xi, yi, thi = pose_i
        xj, yj, thj = pose_j
        Ri = np.array([[np.cos(thi), -np.sin(thi)],
                       [np.sin(thi),  np.cos(thi)]])
        dp = np.array([xj - xi, yj - yi])
        trans_local = Ri.T @ dp
        dth = wrap_angle(thj - thi)
        return np.array([trans_local[0], trans_local[1], dth], dtype=float)

    # --- SE(2) trajectory (smooth heading) ---
    poses = []
    x, y, th = 0.0, 0.0, 0.0
    poses.append((x, y, th))

    TURN_STD = 1  # rad, per step (tune smaller/larger as needed)
    for _ in range(1, int(N)):
        dth = rng.normal(0.0, TURN_STD)
        th = wrap_angle(th + dth)
        x += float(step_size) * np.cos(th)
        y += float(step_size) * np.sin(th)
        poses.append((x, y, th))

    # --- nodes (dim:3); visualization uses x,y ---
    nodes = []
    for i, (px, py, pth) in enumerate(poses):
        nodes.append({
            "data": {"id": f"{i}", "layer": 0, "dim": 3, "theta": float(pth), "num_base": 1},
            "position": {"x": float(px), "y": float(py)}  # for plotting only
        })

    # --- sequential odometry edges; attach measurement z_ij (local frame) ---
    edges = []
    for i in range(int(N) - 1):
        z_ij = relpose_SE2(poses[i], poses[i+1])
        edges.append({
            "data": {"source": f"{i}", "target": f"{i+1}", "kind": "odom", "z": z_ij.tolist()}
        })

    # --- loop-closure edges (proximity-triggered); also attach SE(2) measurements ---
    for i in range(int(N)):
        xi, yi, _ = poses[i]
        for j in range(i + 5, int(N)):  # consider loop closures only when gap >= 5 steps
            if rng.random() < float(loop_prob):
                xj, yj, _ = poses[j]
                if np.hypot(xi - xj, yi - yj) < float(loop_radius):
                    z_ij = relpose_SE2(poses[i], poses[j])
                    edges.append({
                        "data": {"source": f"{i}", "target": f"{j}", "kind": "loop", "z": z_ij.tolist()}
                    })

    # --- strong priors (anchors, etc.); still connect to the virtual "prior" ---
    if prior_prop <= 0.0:
        strong_ids = {0}
    elif prior_prop >= 1.0:
        strong_ids = set(range(N))
    else:
        k = max(1, int(np.floor(prior_prop * N)))
        strong_ids = set(rng.choice(N, size=k, replace=False).tolist())

    for i in strong_ids:
        edges.append({"data": {"source": f"{i}", "target": "prior", "kind": "prior"}})

    return nodes, edges


# -----------------------
# Grid aggregation
# -----------------------
def fuse_to_super_grid(prev_nodes, prev_edges, gx, gy, layer_idx):
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)
    xmin, ymin = positions.min(axis=0); xmax, ymax = positions.max(axis=0)
    cell_w = (xmax - xmin) / gx if gx > 0 else 1.0
    cell_h = (ymax - ymin) / gy if gy > 0 else 1.0
    if cell_w == 0: cell_w = 1.0
    if cell_h == 0: cell_h = 1.0
    cell_map = {}
    for idx, n in enumerate(prev_nodes):
        x, y = n["position"]["x"], n["position"]["y"]
        cx = min(int((x - xmin) / cell_w), gx - 1)
        cy = min(int((y - ymin) / cell_h), gy - 1)
        cid = cx + cy * gx
        cell_map.setdefault(cid, []).append(idx)
    super_nodes, node_map = [], {}
    for cid, indices in cell_map.items():
        pts = positions[indices]
        mean_x, mean_y = pts.mean(axis=0)
        child_dims = [prev_nodes[i]["data"]["dim"] for i in indices]
        child_nums = [prev_nodes[i]["data"].get("num_base", 1) for i in indices]
        dim_val = int(max(1, sum(child_dims)))
        num_val = int(sum(child_nums))
        nid = str(len(super_nodes))
        super_nodes.append({
            "data": {
                "id": nid,
                "layer": layer_idx,
                "dim": dim_val,
                "num_base": num_val   # Inherit the sum
            },
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        for i in indices:
            node_map[prev_nodes[i]["data"]["id"]] = nid
    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            elif su == sv:
                eid = tuple(sorted((su, "prior")))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)

        elif v == "prior":
            su = node_map[u]
            eid = tuple(sorted((su, v)))
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map

# -----------------------
# K-Means aggregation
# -----------------------
def fuse_to_super_kmeans(prev_nodes, prev_edges, k, layer_idx, max_iters=20, tol=1e-6, seed=0):
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)
    n = positions.shape[0]
    if k <= 0: 
        k = 1
    k = min(k, n)
    rng = np.random.default_rng(seed)

    # -------- Improved initialization --------
    # Randomly sample k points without replacement to ensure each cluster starts with a distinct point
    init_idx = rng.choice(n, size=k, replace=False)
    centers = positions[init_idx]

    # Lloyd iterations
    for _ in range(max_iters):
        d2 = ((positions[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
        assign = np.argmin(d2, axis=1)

        # -------- Empty-cluster fix --------
        counts = np.bincount(assign, minlength=k)
        empty_clusters = np.where(counts == 0)[0]
        for ci in empty_clusters:
            # Find the largest cluster
            big_cluster = np.argmax(counts)
            big_idxs = np.where(assign == big_cluster)[0]
            # Steal one point over
            steal_idx = big_idxs[0]
            assign[steal_idx] = ci
            counts[big_cluster] -= 1
            counts[ci] += 1

        moved = 0.0
        for ci in range(k):
            idxs = np.where(assign == ci)[0]
            new_c = positions[idxs].mean(axis=0)
            moved = max(moved, float(np.linalg.norm(new_c - centers[ci])))
            centers[ci] = new_c
        if moved < tol:
            break

    # Final assign (redo once to be safe)
    d2 = ((positions[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
    assign = np.argmin(d2, axis=1)

    counts = np.bincount(assign, minlength=k)
    empty_clusters = np.where(counts == 0)[0]
    for ci in empty_clusters:
        big_cluster = np.argmax(counts)
        big_idxs = np.where(assign == big_cluster)[0]
        steal_idx = big_idxs[0]
        assign[steal_idx] = ci
        counts[big_cluster] -= 1
        counts[ci] += 1

    # ---------- Build the super graph ----------
    super_nodes, node_map = [], {}
    for ci in range(k):
        idxs = np.where(assign == ci)[0]
        pts = positions[idxs]
        mean_x, mean_y = pts.mean(axis=0)
        child_dims = [prev_nodes[i]["data"]["dim"] for i in idxs]
        child_nums = [prev_nodes[i]["data"].get("num_base", 1) for i in idxs]
        dim_val = int(max(1, sum(child_dims)))
        num_val = int(sum(child_nums)) 
        nid = f"{ci}"
        super_nodes.append({
            "data": {
                "id": nid,
                "layer": layer_idx,
                "dim": dim_val,
                "num_base": num_val   # Inherit the sum
            },
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        for i in idxs:
            node_map[prev_nodes[i]["data"]["id"]] = nid

    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]
        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            else:
                eid = (su, "prior")
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)
        else:
            su = node_map[u]
            eid = (su, "prior")
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map


def copy_to_abs(super_nodes, super_edges, layer_idx):
    abs_nodes = []
    for n in super_nodes:
        nid = n["data"]["id"].replace("s", "a", 1)
        abs_nodes.append({
            "data": {
                "id": nid,
                "layer": layer_idx,
                "dim": n["data"]["dim"],
                "num_base": n["data"].get("num_base", 1)  # Inherit
            },
            "position": {"x": n["position"]["x"], "y": n["position"]["y"]}
        })
    abs_edges = []
    for e in super_edges:
        abs_edges.append({"data": {
            "source": e["data"]["source"].replace("s", "a", 1),
            "target": e["data"]["target"].replace("s", "a", 1)
        }})
    return abs_nodes, abs_edges

# -----------------------
# Sequential merge (tail group absorbs remainder)
# -----------------------
def fuse_to_super_order(prev_nodes, prev_edges, k, layer_idx, tail_heavy=True):
    """
    Sequentially split prev_nodes in current order into k groups; the last group absorbs the remainder (tail_heavy=True).
    Reuse existing rules for aggregating dim/num_base, deduplicating edges, and propagating prior.
    """
    n = len(prev_nodes)
    if k <= 0: k = 1
    k = min(k, n)

    # Group sizes
    base = n // k
    rem  = n %  k
    if rem > 0:
        sizes = [k]*(base) + [rem]     # Tail absorbs remainder: ..., last += rem
    else:
        sizes = [k]*(base)

    # Build groups: record indices per group
    groups = []
    start = 0
    for s in sizes:
        groups.append(list(range(start, start+s)))
        start += s

    # ---- Build super_nodes & node_map ----
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)

    super_nodes, node_map = [], {}
    for gi, idxs in enumerate(groups):
        pts = positions[idxs]
        mean_x, mean_y = pts.mean(axis=0)

        child_dims = [prev_nodes[i]["data"]["dim"] for i in idxs]
        child_nums = [prev_nodes[i]["data"].get("num_base", 1) for i in idxs]
        dim_val = int(max(1, sum(child_dims)))
        num_val = int(sum(child_nums))

        nid = f"{gi}"  # Same as kmeans: use group index as id (string)
        super_nodes.append({
            "data": {
                "id": nid,
                "layer": layer_idx,
                "dim": dim_val,
                "num_base": num_val
            },
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        # Build base-id -> super-id mapping (note: ids are strings throughout)
        for i in idxs:
            node_map[prev_nodes[i]["data"]["id"]] = nid

    # ---- Super edges: keep and deduplicate inter-group edges; intra-group edges collapse to prior; prior edges roll up to their owning super ----
    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            else:
                # Intra-group pairwise edge → group prior (consistent with grid/kmeans handling)
                eid = tuple(sorted((su, "prior")))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)
        else:
            su = node_map[u]
            eid = tuple(sorted((su, "prior")))
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map


# -----------------------
# Tools
# -----------------------
def parse_layer_name(name):
    if name == "base": return ("base", 0)
    m = re.match(r"(super|abs)(\d+)$", name)
    return (m.group(1), int(m.group(2))) if m else ("base", 0)

def highest_pair_idx(names):
    hi = 0
    for nm in names:
        kind, k = parse_layer_name(nm)
        if kind in ("super","abs"): hi = max(hi, k)
    return hi


# -----------------------
# Initialization & Boundary
# -----------------------
def init_layers(N=100, step_size=25, loop_prob=0.05, loop_radius=50, prior_prop=0.0, seed=None):
    base_nodes, base_edges = make_slam_like_graph(N, step_size, loop_prob, loop_radius, prior_prop, seed)
    return [{"name": "base", "nodes": base_nodes, "edges": base_edges}]

VIEW_W, VIEW_H = 960, 600
ASPECT = VIEW_W / VIEW_H
AXIS_PAD=20.0

def adjust_bounds_to_aspect(xmin, xmax, ymin, ymax, aspect):
    cx=(xmin+xmax)/2; cy=(ymin+ymax)/2
    dx=xmax-xmin; dy=ymax-ymin
    if dx<=0: dx=1
    if dy<=0: dy=1
    if dx/dy > aspect:
        dy_new=dx/aspect
        return xmin,xmax,cy-dy_new/2,cy+dy_new/2
    else:
        dx_new=dy*aspect
        return cx-dx_new/2,cx+dx_new/2,ymin,ymax

def reset_global_bounds(base_nodes):
    global GLOBAL_XMIN, GLOBAL_XMAX, GLOBAL_YMIN, GLOBAL_YMAX
    global GLOBAL_XMIN_ADJ, GLOBAL_XMAX_ADJ, GLOBAL_YMIN_ADJ, GLOBAL_YMAX_ADJ
    xs=[n["position"]["x"] for n in base_nodes] or [0.0]
    ys=[n["position"]["y"] for n in base_nodes] or [0.0]
    GLOBAL_XMIN,GLOBAL_XMAX=min(xs),max(xs)
    GLOBAL_YMIN,GLOBAL_YMAX=min(ys),max(ys)
    GLOBAL_XMIN_ADJ,GLOBAL_XMAX_ADJ,GLOBAL_YMIN_ADJ,GLOBAL_YMAX_ADJ=adjust_bounds_to_aspect(
        GLOBAL_XMIN,GLOBAL_XMAX,GLOBAL_YMIN,GLOBAL_YMAX,ASPECT)

# ==== Blobal Status ====
layers = init_layers()
pair_idx = 0
reset_global_bounds(layers[0]["nodes"])
gbp_graph = None

# -----------------------
# GBP Graph Construction
# -----------------------
def build_noisy_pose_graph(
    nodes,
    edges,
    prior_sigma: float | tuple | np.ndarray = 1.0,   # A scalar expands to [s, s, s*THETA_RATIO]
    odom_sigma:  float | tuple | np.ndarray = 1.0,   # Same as above
    loop_sigma:  float | tuple | np.ndarray = 1.0,   # Same as above
    tiny_prior: float = 1e-10,
    seed=None,
):
    """
    Build an SE(2) 2D pose graph (x, y, theta). Binary edges use the SE(2) between nonlinear measurement model.
    Initialization policy: first propagate mu sequentially (reset when encountering a prior), then linearize all factors.
    """

    if seed is None:
        rng = np.random.default_rng()
    else:
        rng = np.random.default_rng(seed)

    THETA_RATIO = 0.01  # A scalar expands to [s, s, s*0.01]

    def _sigma_vec(s):
        v = np.array(s, dtype=float).ravel()
        if v.size == 1:
            s = float(v.item())
            return np.array([s, s, s * THETA_RATIO], dtype=float)
        assert v.size == 3, "If sigma is not scalar, it must be length-3 (x,y,theta)."
        return v

    def wrap_angle(a):
        return np.arctan2(np.sin(a), np.cos(a))

    def relpose_SE2(pose_i, pose_j):
        """z_ij = [dx_local, dy_local, dtheta], defined in i's coordinate frame."""
        xi, yi, thi = pose_i
        xj, yj, thj = pose_j
        c, s = np.cos(thi), np.sin(thi)
        RT = np.array([[ c, s],
                       [-s, c]])    # R(thi)^T
        dp = np.array([xj - xi, yj - yi])
        trans_local = RT @ dp
        dth = wrap_angle(thj - thi)
        return np.array([trans_local[0], trans_local[1], dth], dtype=float)

    def compose_SE2(pose_i, z_ij):
        """Given local measurement z_ij, propagate from pose_i to pose_j."""
        xi, yi, thi = pose_i
        dx, dy, dth = z_ij
        c, s = np.cos(thi), np.sin(thi)
        t_global = np.array([c*dx - s*dy, s*dx + c*dy])
        xj = xi + t_global[0]
        yj = yi + t_global[1]
        thj = wrap_angle(thi + dth)
        return np.array([xj, yj, thj], dtype=float)

    # ---------- Graph & variables ----------
    fg = FactorGraph(nonlinear_factors=True, eta_damping=0)

    var_nodes = []
    for i, n in enumerate(nodes):
        v = VariableNode(i, dofs=3)
        th = float(n["data"].get("theta", 0.0))  # Written by make_slam_like_graph (radians)
        v.GT = np.array([n["position"]["x"], n["position"]["y"], th], dtype=float)
        # Initialize mu to zero; will be set by sequential measurements later
        v.mu = np.zeros(3, dtype=float)
        var_nodes.append(v)

    fg.var_nodes = var_nodes
    fg.n_var_nodes = len(var_nodes)

    # ---------- Measurement model (SE(2) between) ----------
    def meas_fn_se2_between(xij, *args):
        # xij = [xi, yi, thi, xj, yj, thj]
        xi, yi, thi, xj, yj, thj = xij
        c, s = np.cos(thi), np.sin(thi)
        RT = np.array([[ c, s],
                       [-s, c]])
        dp = np.array([xj - xi, yj - yi])
        r  = RT @ dp
        dth = wrap_angle(thj - thi)
        return np.array([r[0], r[1], dth], dtype=float)
    
    def jac_fn_se2_between(xij, *args):
        # J: 3x6 w.r.t [xi,yi,thi,xj,yj,thj]
        xi, yi, thi, xj, yj, thj = xij
        c, s = np.cos(thi), np.sin(thi)
        RT = np.array([[ c, s],
                    [-s, c]])
        dp = np.array([xj - xi, yj - yi])
        r  = RT @ dp                    # [rx, ry]
        dr_dthi = np.array([ r[1], -r[0] ])  
        J = np.zeros((3, 6), dtype=float)
        # wrt i
        J[0:2, 0:2] = -RT
        J[0:2, 2]   = dr_dthi
        # wrt j
        J[0:2, 3:5] = RT
        # angle row
        J[2, 2] = -1.0
        J[2, 5] =  1.0
        return J

    def meas_fn_unary(x, *args):
        return x

    def jac_fn_unary(x, *args):
        return np.eye(3, dtype=float)

    # ---------- Information matrices ----------
    prior_sigma_vec = _sigma_vec(prior_sigma)
    odom_sigma_vec  = _sigma_vec(odom_sigma)
    loop_sigma_vec  = _sigma_vec(loop_sigma)

    Lambda_prior = np.diag(1.0 / (prior_sigma_vec**2))
    Lambda_odom  = np.diag(1.0 / (odom_sigma_vec**2))
    Lambda_loop  = np.diag(1.0 / (loop_sigma_vec**2))

    # ---------- Factors; first create noisy measurements (no linearization yet) ----------
    odom_meas = {}   # (i,j) -> z_noisy
    prior_meas = {}  # i -> z_noisy (unary)
    factors = []
    fid = 0

    # Strong anchor: fix global reference (optionally anchor only x,y by relaxing Lambda_anchor[2,2])
    v0 = var_nodes[0]
    z_anchor = v0.GT.copy()
    Lambda_anchor = np.diag(1.0 / (np.array([1e-3, 1e-3, 1e-5])**2))
    f0 = Factor(fid, [v0], z_anchor, Lambda_anchor, meas_fn_unary, jac_fn_unary)
    f0.type = "prior"
    f0.compute_factor(linpoint=z_anchor, update_self=True)
    factors.append(f0)
    v0.adj_factors.append(f0)
    fid += 1

    for e in edges:
        src = e["data"]["source"]
        dst = e["data"]["target"]

        if dst != "prior":
            i, j = int(src), int(dst)

            # Ground-truth relative pose
            if "z" in e["data"]:
                z = np.array(e["data"]["z"], dtype=float).ravel()
            else:
                z = relpose_SE2(var_nodes[i].GT, var_nodes[j].GT)

            kind = e["data"].get("kind", "between")

            # >>> CHANGED: choose noise model per kind
            if kind == "loop":
                noise_vec = rng.normal(0.0, loop_sigma_vec, size=3)
                this_Lambda = Lambda_loop
            else:
                # default treat as odom
                noise_vec = rng.normal(0.0, odom_sigma_vec, size=3)
                this_Lambda = Lambda_odom

            z_noisy = z.copy()
            z_noisy[:2] += noise_vec[:2]
            z_noisy[2]  = wrap_angle(z_noisy[2] + noise_vec[2])

            # only store sequential odom for init mu forward-prop
            if kind == "odom" and (j == i + 1):
                odom_meas[(i, j)] = z_noisy

            vi, vj = var_nodes[i], var_nodes[j]
            f = Factor(fid, [vi, vj], z_noisy, this_Lambda, meas_fn_se2_between, jac_fn_se2_between)
            f.type = kind
            factors.append(f)
            vi.adj_factors.append(f)
            vj.adj_factors.append(f)
            fid += 1

        else:
            # Prior edge: create a noisy measurement for this variable
            i = int(src)
            z = var_nodes[i].GT.copy()
            noise = rng.normal(0.0, prior_sigma_vec, size=3)
            z[:2] += noise[:2]
            z[2]   = wrap_angle(z[2] + noise[2])
            prior_meas[i] = z

            vi = var_nodes[i]
            f = Factor(fid, [vi], z, Lambda_prior, meas_fn_unary, jac_fn_unary)
            f.type = "prior"
            factors.append(f)
            vi.adj_factors.append(f)
            fid += 1

    # ---------- Sequentially initialize mu: forward propagation from node 0; reset when hitting a prior ----------
    N = len(var_nodes)
    # Start: use GT
    var_nodes[0].mu = var_nodes[0].GT

    for i in range(N - 1):
        # First propagate via odometry
        if (i, i+1) in odom_meas:
            var_nodes[i+1].mu = compose_SE2(var_nodes[i].mu, odom_meas[(i, i+1)])
        else:
            # If the edge is missing, fall back to GT
            var_nodes[i+1].mu = var_nodes[i+1].GT.copy()

        # If i+1 has a prior, override with the prior (replace the propagated value)
        if (i + 1) in prior_meas:
            var_nodes[i+1].mu = prior_meas[i + 1].copy()

    # ---------- Add a very weak prior to each variable to avoid singularities ----------
    Lam_weak = tiny_prior * np.diag([1.0, 1.0, 100.0])  # Make the angle dimension slightly “stiffer” to prevent drift
    for v in var_nodes:
        v.prior.lam = Lam_weak
        v.prior.eta = Lam_weak @ v.mu  # Use the initialized mu as the mean of the weak prior
        v.Sigma = 1/tiny_prior * np.diag([1.0, 1.0, 0.01])

    # ---------- Linearize all factors (mu is now in place) ----------
    for f in factors:
        lin = np.concatenate([vn.mu for vn in f.adj_var_nodes]) if f.adj_var_nodes else np.array([])
        f.compute_factor(linpoint=lin, update_self=True)

    fg.factors = factors
    fg.n_factor_nodes = len(factors)
    return fg


def build_super_graph(layers, eta_damping=0.2):
    """
    Construct the super graph based on the base graph in layers[-2] and the super-grouping in layers[-1].
    Requirement: layers[-2]["graph"] is an already-built base graph (with unary/binary factors).
    layers[-1]["node_map"]: { base_node_id (str, e.g., 'b12') -> super_node_id (str) }
    """
    # ---------- Extract base & super ----------
    base_graph = layers[-2]["graph"]
    super_nodes = layers[-1]["nodes"]
    super_edges = layers[-1]["edges"]
    node_map    = layers[-1]["node_map"]   # 'bN' -> 'sX_...'

    # base: id(int) -> VariableNode, handy to query dofs and mu
    id2var = {vn.variableID: vn for vn in base_graph.var_nodes}

    # ---------- super_id -> [base_id(int)] ----------
    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)


    # ---------- For each super group, build a (start, dofs) table ----------
    # local_idx[sid][bid] = (start, dofs), total_dofs[sid] = sum(dofs)
    local_idx   = {}
    total_dofs  = {}
    for sid, group in super_groups.items():
        off = 0
        local_idx[sid] = {}
        for bid in group:
            d = id2var[bid].dofs
            local_idx[sid][bid] = (off, d)
            off += d
        total_dofs[sid] = off


    # ---------- Create super VariableNodes ----------
    fg = FactorGraph(nonlinear_factors=False, eta_damping=eta_damping)

    super_var_nodes = {}
    for i, sn in enumerate(super_nodes):
        sid = sn["data"]["id"]
        dofs = total_dofs.get(sid, 0)

        v = VariableNode(i, dofs=dofs)
        gt_vec = np.zeros(dofs)
        mu_blocks = []
        Sigma_blocks = []
        for bid, (st, d) in local_idx[sid].items():
            # === Stack base GT ===
            gt_base = getattr(id2var[bid], "GT", None)
            if gt_base is None or len(gt_base) != d:
                gt_base = np.zeros(d)
            gt_vec[st:st+d] = gt_base

            # === Stack base belief ===
            vb = id2var[bid]
            mu_blocks.append(vb.mu)
            Sigma_blocks.append(vb.Sigma)

        super_var_nodes[sid] = v
        v.GT = gt_vec

        mu_super = np.concatenate(mu_blocks) if mu_blocks else np.zeros(dofs)
        Sigma_super = block_diag(*Sigma_blocks) if Sigma_blocks else np.eye(dofs)
        lam = np.linalg.inv(Sigma_super)
        eta = lam @ mu_super
        v.mu = mu_super
        v.Sigma = Sigma_super
        v.belief = NdimGaussian(dofs, eta, lam)
        v.prior.lam = 1e-10 * lam
        v.prior.eta = 1e-10 * eta

        fg.var_nodes.append(v)

    fg.n_var_nodes = len(fg.var_nodes)

    # ---------- Utility: assemble a group's linpoint (using base belief means) ----------
    def make_linpoint_for_group(sid):
        x = np.zeros(total_dofs[sid])
        for bid, (st, d) in local_idx[sid].items():
            mu = getattr(id2var[bid], "mu", None)
            if mu is None or len(mu) != d:
                mu = np.zeros(d)
            x[st:st+d] = mu
        return x

    # ---------- 3) super prior (in-group unary + in-group binary) ----------
    def make_super_prior_factor(sid, base_factors):
        group = super_groups[sid]
        idx_map = local_idx[sid]
        ncols = total_dofs[sid]

        # Select factors whose variables are all within the group (unary or binary)
        in_group = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if all(vid in group for vid in vids):
                in_group.append(f)

        def meas_fn_super_prior(x_super, *args):
            meas_fn = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # Assemble this factor's local x
                x_loc_list = []
                for vid in vids:
                    st, d = idx_map[vid]
                    x_loc_list.append(x_super[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)
                meas_fn.append(f.meas_fn(x_loc))
            return np.concatenate(meas_fn) if meas_fn else np.zeros(0)

        def jac_fn_super_prior(x_super, *args):
            Jrows = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # Build this factor's local x for (potentially) nonlinear Jacobian
                x_loc_list = []
                dims = []
                for vid in vids:
                    st, d = idx_map[vid]
                    dims.append(d)
                    x_loc_list.append(x_super[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)

                Jloc = f.jac_fn(x_loc)
                
                # Map Jloc column blocks back to the super variable columns
                row = np.zeros((Jloc.shape[0], ncols))
                c0 = 0
                for vid, d in zip(vids, dims):
                    st, _ = idx_map[vid]
                    row[:, st:st+d] = Jloc[:, c0:c0+d]
                    c0 += d

                Jrows.append(row)
            return np.vstack(Jrows) if Jrows else np.zeros((0, ncols))

        # z_super: concatenate each base factor's z
        z_list = [f.measurement for f in in_group]
        z_lambda_list = [f.measurement_lambda for f in in_group]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_prior, jac_fn_super_prior, z_super, z_super_lambda 

    # ---------- 4) super between (cross-group binary) ----------
    def make_super_between_factor(sidA, sidB, base_factors):
        groupA, groupB = super_groups[sidA], super_groups[sidB]
        idxA, idxB = local_idx[sidA], local_idx[sidB]
        nA, nB = total_dofs[sidA], total_dofs[sidB]

        cross = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if len(vids) != 2:
                continue
            i, j = vids
            # One side in A, the other side in B
            if (i in groupA and j in groupB) or (i in groupB and j in groupA):
                cross.append(f)


        def meas_fn_super_between(xAB, *args):
            xA, xB = xAB[:nA], xAB[nA:]
            meas_fn = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                x_loc = np.concatenate([xi, xj])
                meas_fn.append(f.meas_fn(x_loc))
            return np.concatenate(meas_fn) 

        def jac_fn_super_between(xAB, *args):
            xA, xB = xAB[:nA], xAB[nA:]
            Jrows = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                    left_start, right_start = si, nA + sj
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                    left_start, right_start = nA + si, sj
                x_loc = np.concatenate([xi, xj])
                Jloc = f.jac_fn(x_loc)

                row = np.zeros((Jloc.shape[0], nA + nB))
                row[:, left_start:left_start+di]   = Jloc[:, :di] 
                row[:, right_start:right_start+dj] = Jloc[:, di:di+dj] 

                Jrows.append(row)
            
            return np.vstack(Jrows) 

        z_list = [f.measurement for f in cross]
        z_lambda_list = [f.measurement_lambda for f in cross]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_between, jac_fn_super_between, z_super, z_super_lambda


    for e in super_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v == "prior":
            meas_fn, jac_fn, z, z_lambda = make_super_prior_factor(u, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u]], z, z_lambda, meas_fn, jac_fn)
            f.adj_beliefs = [vn.belief for vn in f.adj_var_nodes]
            f.type = "super_prior"
            lin0 = make_linpoint_for_group(u)
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            
        else:
            meas_fn, jac_fn, z, z_lambda = make_super_between_factor(u, v, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u], super_var_nodes[v]], z, z_lambda, meas_fn, jac_fn)
            f.adj_beliefs = [vn.belief for vn in f.adj_var_nodes]
            f.type = "super_between"
            lin0 = np.concatenate([make_linpoint_for_group(u), make_linpoint_for_group(v)])
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            super_var_nodes[v].adj_factors.append(f)


    fg.n_factor_nodes = len(fg.factors)
    return fg


def build_abs_graph(
    layers,
    r_reduced = 2,
    eta_damping=0.2):

    abs_var_nodes = {}
    Bs = {}
    ks = {}
    k2s = {}

    # === 1. Build Abstraction Variables ===
    abs_fg = FactorGraph(nonlinear_factors=False, eta_damping=eta_damping)
    sup_fg = layers[-2]["graph"]

    for sn in sup_fg.var_nodes:
        if sn.dofs <= r_reduced:
            r = sn.dofs  # No reduction if dofs already <= r
        else:
            r = r_reduced

        sid = sn.variableID
        varis_sup_mu = sn.mu
        varis_sup_sigma = sn.Sigma
        
        # Step 1: Eigen decomposition of the covariance matrix
        eigvals, eigvecs = np.linalg.eigh(varis_sup_sigma)

        # Step 2: Sort eigenvalues and eigenvectors in descending order of eigenvalues
        idx = np.argsort(eigvals)[::-1]      # Get indices of sorted eigenvalues (largest first)
        eigvals = eigvals[idx]               # Reorder eigenvalues
        eigvecs = eigvecs[:, idx]            # Reorder corresponding eigenvectors

        # Step 3: Select the top-k eigenvectors to form the projection matrix (principal subspace)
        B_reduced = eigvecs[:, :r]                 # B_reduced: shape (sup_dof, r), projects r to sup_dof
        Bs[sid] = B_reduced                        # Store the projection matrix for this variable

        # Step 4: Project eta and Lam onto the reduced 2D subspace
        varis_abs_mu = B_reduced.T @ varis_sup_mu          # Projected natural mean: shape (2,)
        varis_abs_sigma = B_reduced.T @ varis_sup_sigma @ B_reduced  # Projected covariance: shape (2, 2)
        ks[sid] = varis_sup_mu - B_reduced @ varis_abs_mu  # Store the mean offset for this variable
        #k2s[sid] = varis_sup_sigma - B_reduced @ varis_abs_sigma @ B_reduced.T  # Residual covariance

        varis_abs_lam = np.linalg.inv(varis_abs_sigma)  # Inverse covariance (precision matrix): shape (2, 2)
        varis_abs_eta = varis_abs_lam @ varis_abs_mu  # Natural parameters: shape (2,)

        v = VariableNode(sid, dofs=r)
        v.GT = sn.GT
        v.prior.lam = 1e-10 * varis_abs_lam
        v.prior.eta = 1e-10 * varis_abs_eta
        v.mu = varis_abs_mu
        v.Sigma = varis_abs_sigma
        v.belief = NdimGaussian(r, varis_abs_eta, varis_abs_lam)

        abs_var_nodes[sid] = v
        abs_fg.var_nodes.append(v)
    abs_fg.n_var_nodes = len(abs_fg.var_nodes)


    # === 2. Abstract Prior ===
    def make_abs_prior_factor(sup_factor):
        abs_id = sup_factor.adj_var_nodes[0].variableID
        B = Bs[abs_id]
        k = ks[abs_id]

        def meas_fn_abs_prior(x_abs, *args):
            return sup_factor.meas_fn(B @ x_abs + k)
        
        def jac_fn_abs_prior(x_abs, *args):
            return sup_factor.jac_fn(B @ x_abs + k) @ B

        return meas_fn_abs_prior, jac_fn_abs_prior, sup_factor.measurement, sup_factor.measurement_lambda
    


    # === 3. Abstract Between ===
    def make_abs_between_factor(sup_factor):
        vids = [v.variableID for v in sup_factor.adj_var_nodes]
        i, j = vids # two variable IDs
        ni = abs_var_nodes[i].dofs
        Bi, Bj = Bs[i], Bs[j]
        ki, kj = ks[i], ks[j]                       
    

        def meas_fn_super_between(xij, *args):
            xi, xj = xij[:ni], xij[ni:]
            return sup_factor.meas_fn(np.concatenate([Bi @ xi + ki, Bj @ xj + kj]))

        def jac_fn_super_between(xij, *args):
            xi, xj = xij[:ni], xij[ni:]
            J_sup = sup_factor.jac_fn(np.concatenate([Bi @ xi + ki, Bj @ xj + kj]))
            J_abs = np.zeros((J_sup.shape[0], ni + xj.shape[0]))
            J_abs[:, :ni] = J_sup[:, :Bi.shape[0]] @ Bi
            J_abs[:, ni:] = J_sup[:, Bi.shape[0]:] @ Bj
            return J_abs
        
        return meas_fn_super_between, jac_fn_super_between, sup_factor.measurement, sup_factor.measurement_lambda
    

    for f in sup_fg.factors:
        if len(f.adj_var_nodes) == 1:
            meas_fn, jac_fn, z, z_lambda = make_abs_prior_factor(f)
            v = abs_var_nodes[f.adj_var_nodes[0].variableID]
            abs_f = Factor(f.factorID, [v], z, z_lambda, meas_fn, jac_fn)
            abs_f.type = "abs_prior"
            abs_f.adj_beliefs = [v.belief]

            lin0 = v.mu
            abs_f.compute_factor(linpoint=lin0, update_self=True)

            abs_fg.factors.append(abs_f)
            v.adj_factors.append(abs_f)

        elif len(f.adj_var_nodes) == 2:
            meas_fn, jac_fn, z, z_lambda = make_abs_between_factor(f)
            i, j = [v.variableID for v in f.adj_var_nodes]
            vi, vj = abs_var_nodes[i], abs_var_nodes[j]
            abs_f = Factor(f.factorID, [vi, vj], z, z_lambda, meas_fn, jac_fn)
            abs_f.type = "abs_between"
            abs_f.adj_beliefs = [vi.belief, vj.belief]

            lin0 = np.concatenate([vi.mu, vj.mu])
            abs_f.compute_factor(linpoint=lin0, update_self=True)

            abs_fg.factors.append(abs_f)
            vi.adj_factors.append(abs_f)
            vj.adj_factors.append(abs_f)

    abs_fg.n_factor_nodes = len(abs_fg.factors)


    return abs_fg, Bs, ks, k2s



def update_super_graph_linearized(layers, eta_damping=0.2):
    """
    Construct the super graph based on the base graph in layers[-2] and the super-grouping in layers[-1].
    Requirement: layers[-2]["graph"] is an already-built base graph (with unary/binary factors).
    layers[-1]["node_map"]: { base_node_id (str, e.g., 'b12') -> super_node_id (str) }
    """
    # ---------- Extract base & super ----------
    base_graph = layers[-2]["graph"]
    super_nodes = layers[-1]["nodes"]
    super_edges = layers[-1]["edges"]
    node_map    = layers[-1]["node_map"]   # 'bN' -> 'sX_...'

    # base: id(int) -> VariableNode, handy to query dofs and mu
    id2var = {vn.variableID: vn for vn in base_graph.var_nodes}

    # ---------- super_id -> [base_id(int)] ----------
    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)


    # ---------- For each super group, build a (start, dofs) table ----------
    # local_idx[sid][bid] = (start, dofs), total_dofs[sid] = sum(dofs)
    local_idx   = {}
    total_dofs  = {}
    for sid, group in super_groups.items():
        off = 0
        local_idx[sid] = {}
        for bid in group:
            d = id2var[bid].dofs
            local_idx[sid][bid] = (off, d)
            off += d
        total_dofs[sid] = off


    # ---------- Create super VariableNodes ----------
    fg = FactorGraph(nonlinear_factors=False, eta_damping=eta_damping)


    super_var_nodes = {}
    for i, sn in enumerate(super_nodes):
        sid = sn["data"]["id"]
        dofs = total_dofs.get(sid, 0)

        v = VariableNode(i, dofs=dofs)
        gt_vec = np.zeros(dofs)
        mu_blocks = []
        Sigma_blocks = []
        for bid, (st, d) in local_idx[sid].items():
            # === Stack base GT ===
            gt_base = getattr(id2var[bid], "GT", None)
            if gt_base is None or len(gt_base) != d:
                gt_base = np.zeros(d)
            gt_vec[st:st+d] = gt_base

            # === Stack base belief ===
            vb = id2var[bid]
            mu_blocks.append(vb.mu)
            Sigma_blocks.append(vb.Sigma)

        super_var_nodes[sid] = v
        v.GT = gt_vec

        mu_super = np.concatenate(mu_blocks) if mu_blocks else np.zeros(dofs)
        Sigma_super = block_diag(*Sigma_blocks) if Sigma_blocks else np.eye(dofs)
        lam = np.linalg.inv(Sigma_super)
        eta = lam @ mu_super
        v.mu = mu_super
        v.Sigma = Sigma_super
        v.belief = NdimGaussian(dofs, eta, lam)
        v.prior.lam = 1e-10 * lam
        v.prior.eta = 1e-10 * eta

        fg.var_nodes.append(v)


    fg.n_var_nodes = len(fg.var_nodes)

    # ---------- Utility: assemble a group's linpoint (using base belief means) ----------
    def make_linpoint_for_group(sid):
        x = np.zeros(total_dofs[sid])
        for bid, (st, d) in local_idx[sid].items():
            mu = getattr(id2var[bid], "mu", None)
            if mu is None or len(mu) != d:
                mu = np.zeros(d)
            x[st:st+d] = mu
        return x

    # ---------- 3) super prior (in-group unary + in-group binary) ----------
    def make_super_prior_factor(sid, lin0, base_factors):
        group = super_groups[sid]
        idx_map = local_idx[sid]
        ncols = total_dofs[sid]

        # Select factors whose variables are all within the group (unary or binary)
        in_group = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if all(vid in group for vid in vids):
                in_group.append(f)

        def jac_fn_super_prior(x_super, *args):
            Jrows = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # Build this factor's local x for (potentially) nonlinear Jacobian
                x_loc_list = []
                dims = []
                for vid in vids:
                    st, d = idx_map[vid]
                    dims.append(d)
                    x_loc_list.append(lin0[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)

                Jloc = f.jac_fn(x_loc)
                # Map Jloc column blocks back to the super variable columns
                row = np.zeros((Jloc.shape[0], ncols))
                c0 = 0
                for vid, d in zip(vids, dims):
                    st, _ = idx_map[vid]
                    row[:, st:st+d] = Jloc[:, c0:c0+d]
                    c0 += d

                Jrows.append(row)
            return np.vstack(Jrows) if Jrows else np.zeros((0, ncols))

        J = jac_fn_super_prior(x_super=lin0)
        meas_fn = []
        for f in in_group:
            vids = [v.variableID for v in f.adj_var_nodes]
            # Build this factor's local x for (potentially) nonlinear Jacobian
            x_loc_list = []
            dims = []
            for vid in vids:
                st, d = idx_map[vid]
                dims.append(d)
                x_loc_list.append(lin0[st:st+d])
            x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)
            meas_fn.append(f.meas_fn(x_loc))
        meas_fn = np.concatenate(meas_fn) 
        def meas_fn_super_prior(x_super, *args):
            return meas_fn + J@(x_super-lin0)

        # z_super: concatenate each base factor's z
        z_list = [f.measurement for f in in_group]
        z_lambda_list = [f.measurement_lambda for f in in_group]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_prior, jac_fn_super_prior, z_super, z_super_lambda 

    # ---------- 4) super between (cross-group binary) ----------
    def make_super_between_factor(sidA, sidB, lin0, base_factors):
        groupA, groupB = super_groups[sidA], super_groups[sidB]
        idxA, idxB = local_idx[sidA], local_idx[sidB]
        nA, nB = total_dofs[sidA], total_dofs[sidB]

        cross = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if len(vids) != 2:
                continue
            i, j = vids
            # One side in A, the other side in B
            if (i in groupA and j in groupB) or (i in groupB and j in groupA):
                cross.append(f)

        
        def jac_fn_super_between(xAB, *args):
            xA, xB = lin0[:nA], lin0[nA:]
            Jrows = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                    left_start, right_start = si, nA + sj
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                    left_start, right_start = nA + si, sj

                x_loc = np.concatenate([xi, xj])
                Jloc = f.jac_fn(x_loc)
                row = np.zeros((Jloc.shape[0], nA + nB))
                row[:, left_start:left_start+di]   = Jloc[:, :di] 
                row[:, right_start:right_start+dj] = Jloc[:, di:di+dj] 
                Jrows.append(row)

            return np.vstack(Jrows) 
        
        J = jac_fn_super_between(xAB=lin0)
        xA, xB = lin0[:nA], lin0[nA:]
        meas_fn = []
        for f in cross:
            i, j = [v.variableID for v in f.adj_var_nodes]
            if i in groupA:
                si, di = idxA[i]
                sj, dj = idxB[j]
                xi = xA[si:si+di]
                xj = xB[sj:sj+dj]
            else:
                si, di = idxB[i]
                sj, dj = idxA[j]
                xi = xB[si:si+di]
                xj = xA[sj:sj+dj]
            x_loc = np.concatenate([xi, xj])
            meas_fn.append(f.meas_fn(x_loc))
        meas_fn = np.concatenate(meas_fn) 
        def meas_fn_super_between(xAB, *args):
            return meas_fn + J@(xAB-lin0)

        z_list = [f.measurement for f in cross]
        z_lambda_list = [f.measurement_lambda for f in cross]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_between, jac_fn_super_between, z_super, z_super_lambda


    for e in super_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v == "prior":
            lin0 = make_linpoint_for_group(u)
            meas_fn, jac_fn, z, z_lambda = make_super_prior_factor(u, lin0, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u]], z, z_lambda, meas_fn, jac_fn)
            f.adj_beliefs = [vn.belief for vn in f.adj_var_nodes]
            f.type = "super_prior"
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            
        else:
            lin0 = np.concatenate([make_linpoint_for_group(u), make_linpoint_for_group(v)])
            meas_fn, jac_fn, z, z_lambda = make_super_between_factor(u, v, lin0, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u], super_var_nodes[v]], z, z_lambda, meas_fn, jac_fn)
            f.adj_beliefs = [vn.belief for vn in f.adj_var_nodes]
            f.type = "super_between"
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            super_var_nodes[v].adj_factors.append(f)

    fg.n_factor_nodes = len(fg.factors)
    return fg


def bottom_up_modify_super_graph(layers):
    """
    modify the super graph's meas_fn and jac_fn according to the base graph's current jacobians.
    """

    base_graph = layers[-2]["graph"]
    super_graph = layers[-1]["graph"]
    node_map = layers[-1]["node_map"]

    id2var = {vn.variableID: vn for vn in base_graph.var_nodes}

    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)

    sid2idx = {sn["data"]["id"]: i for i, sn in enumerate(layers[-1]["nodes"])}



def top_down_modify_base_and_abs_graph(layers):
    """
    From the super graph downward, split μ to the base graph,
    and simultaneously update the base variables' beliefs and the adjacent factors'
    adj_beliefs / messages.

    Assume layers[-1] is the super layer and layers[-2] is the base layer.
    """
    super_graph = layers[-1]["graph"]
    base_graph = layers[-2]["graph"]
    node_map   = layers[-1]["node_map"]  # { base_id(str) -> super_id(str) }


    # super_id -> [base_id(int)]
    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)

    # child lookup
    id2var_base = {vn.variableID: vn for vn in base_graph.var_nodes}

    a = 0
    for s_var in super_graph.var_nodes:
        sid = str(s_var.variableID)
        if sid not in super_groups:
            continue
        base_ids = super_groups[sid]

        # === split super.mu to base ===
        mu_super = s_var.mu
        off = 0
        for bid in base_ids:
            v = id2var_base[bid]
            d = v.dofs
            mu_child = mu_super[off:off+d]
            off += d

            old_belief = v.belief

            # 1. update mu
            v.mu = mu_child

            # 2. new belief（keep Σ unchanged，use new mu）
            eta = v.belief.lam @ v.mu
            new_belief = NdimGaussian(v.dofs, eta, v.belief.lam)
            v.belief = new_belief
            v.prior = NdimGaussian(v.dofs, 1e-10*eta, 1e-10*v.belief.lam)

            # 3. Sync to adjacent factors (this step is important)
            if v.adj_factors:
                n_adj = len(v.adj_factors)
                d_eta = new_belief.eta - old_belief.eta
                d_lam = new_belief.lam - old_belief.lam

                for f in v.adj_factors:
                    if v in f.adj_var_nodes:
                        idx = f.adj_var_nodes.index(v)
                        # update adj_beliefs
                        f.adj_beliefs[idx] = new_belief
                        # correct coresponding message
                        msg = f.messages[idx]
                        msg.eta += d_eta / n_adj
                        msg.lam += d_lam / n_adj
                        f.messages[idx] = msg

    return base_graph


def top_down_modify_super_graph(layers):
    """
    From the abs graph downward, project mu / Sigma back to the super graph,
    and simultaneously update the super variables' beliefs and the adjacent
    factors' adj_beliefs / messages.

    Requirements:
      - layers[-1] is abs, layers[-2] is super
      - Factors at the abs level and the super level share the same factorID (one-to-one)
      - The columns of B are orthonormal (from covariance eigenvectors; eigenvectors from np.linalg.eigh are orthogonal)
    """

    abs_graph   = layers[-1]["graph"]
    super_graph = layers[-2]["graph"]
    Bs  = layers[-1]["Bs"]   # { super_id(int) -> B (d_super × r) }
    ks  = layers[-1]["ks"]   # { super_id(int) -> k (d_super,) }
    #k2s = layers[-1]["k2s"]  # { super_id(int) -> residual covariance (d_super × d_super) }

    # Prebuild abs factor index: factorID -> Factor
    #abs_f_by_id = {f.factorID: f for f in getattr(abs_graph, "factors", [])}

    # ---- First project variables' mu / Sigma and update beliefs ----
    for sn in super_graph.var_nodes:
        sid = sn.variableID
        if sid not in Bs or sid not in ks:
            continue
        B  = Bs[sid]    # (d_s × r)
        k  = ks[sid]    # (d_s,)
        #k2 = k2s[sid]   # (d_s × d_s)

        # x_s = B x_a + k; Σ_s = B Σ_a Bᵀ + k2
        mu_a    = abs_graph.var_nodes[sid].mu
        mu_s    = B @ mu_a + k
        sn.mu   = mu_s

        # Refresh super belief (natural parameters) with the new μ and Σ
        eta = sn.belief.lam @ sn.mu
        new_belief = NdimGaussian(sn.dofs, eta, sn.belief.lam)
        sn.belief  = new_belief
        sn.prior = NdimGaussian(sn.dofs, 1e-10*eta, 1e-10*sn.belief.lam)

    # ---- Then push abs messages back to super, preserving the original super messages on the orthogonal complement ----
    # Idea: for the side of the super factor f_sup connected to variable sid:
    #   η_s_new = B η_a + (I - B Bᵀ) η_s_old
    #   Λ_s_new = B Λ_a Bᵀ + (I - B Bᵀ) Λ_s_old (I - B Bᵀ)
    # This ensures the subspace is governed by the abs message, while the orthogonal complement keeps the old super message.
    for sn in super_graph.var_nodes:
        # Iterate over super factors adjacent to this super variable
        for f_sup in sn.adj_factors:
            idx_side = f_sup.adj_var_nodes.index(sn)
            # update the factor's recorded adjacent belief on that side (optional; usually refreshed in the next iteration)
            f_sup.adj_beliefs[idx_side] = sn.belief

    return


def refresh_gbp_results(layers):
    """
    Precompute an affine map to the base plane for each layer:
      base:   A_i = I2, b_i = 0
      super:  A_s = (1/m) [A_c1, A_c2, ..., A_cm], b_s = (1/m) Σ b_cj
      abs:    A_a = A_super(s) @ B_s,             b_a = A_super(s) @ k_s + b_super(s)
    Then refresh gbp_result via pos = A @ mu + b.
    Convention: use string keys everywhere (aligned with Cytoscape ids).
    """
    if not layers:
        return

    # ---------- 1) Bottom-up: compute A, b for each layer ----------
    for li, L in enumerate(layers):
        g = L.get("graph")
        if g is None:
            L.pop("A", None); L.pop("b", None); L.pop("gbp_result", None)
            continue

        name = L["name"]
        # ---- base ----
        if name.startswith("base"):
            L["A"], L["b"] = {}, {}
            for v in g.var_nodes:
                key = str(v.variableID)
                L["A"][key] = np.array([[1.0, 0.0, 0.0],
                                        [0.0, 1.0, 0.0]], dtype=float)  # 2x3: 只取 x,y
                L["b"][key] = np.zeros(2, dtype=float)

        # ---- super ----
        elif name.startswith("super"):
            parent = layers[li - 1]
            node_map = L["node_map"]  # { prev_id(str) -> super_id(str) }

            # Grouping (preserve insertion order to match the concatenation order in build_super_graph)
            groups = {}
            for prev_id, s_id in node_map.items():
                prev_id = str(prev_id); s_id = str(s_id)
                groups.setdefault(s_id, []).append(prev_id)

            L["A"], L["b"] = {}, {}
            for s_id, children in groups.items():
                m = len(children)
                # Horizontal concatenation [A_c1, A_c2, ...]
                A_blocks = [parent["A"][cid] for cid in children]  # each block has shape 2×d_c
                A_concat = np.hstack(A_blocks) if A_blocks else np.zeros((2, 0))
                b_sum = sum((parent["b"][cid] for cid in children), start=np.zeros(2, dtype=float))
                L["A"][s_id] = (1.0 / m) * A_concat
                L["b"][s_id] = (1.0 / m) * b_sum

        # ---- abs ----
        elif name.startswith("abs"):
            parent = layers[li - 1]  # the corresponding super layer
            Bs, ks = L["Bs"], L["ks"]  # Note: keys are the super variableIDs (int)

            # Build a mapping between super variableID (int) and the super string id (follow node list order)
            # The order of nodes in the parent (super) and this (abs) layer is consistent (copy_to_abs preserves order)
            int2sid = {i: str(parent["nodes"][i]["data"]["id"]) for i in range(len(parent["nodes"]))}

            L["A"], L["b"] = {}, {}
            for av in g.var_nodes:
                sid_int = av.variableID              # super variableID (int)
                s_id = int2sid.get(sid_int, str(sid_int))  # super string id (also the abs node id)
                B = Bs[sid_int]                       # (sum d_c) × r
                k = ks[sid_int]                       # (sum d_c,)

                A_sup = parent["A"][s_id]             # shape 2 × (sum d_c)
                b_sup = parent["b"][s_id]             # shape (2,)

                L["A"][s_id] = A_sup @ B              # 2 × r
                L["b"][s_id] = A_sup @ k + b_sup      # 2,

        else:
            # Unknown layer type
            L["A"], L["b"] = {}, {}

    # ---------- 2) Compute gbp_result ----------
    for li, L in enumerate(layers):
        g = L.get("graph")
        if g is None:
            L.pop("gbp_result", None)
            continue

        name = L["name"]
        res = {}

        if name.startswith("base"):
            for v in g.var_nodes:
                vid = str(v.variableID)
                res[vid] = v.mu[:2].tolist()

        elif name.startswith("super"):
            # Directly use A_super, b_super mapping
            # nodes order is consistent with var_nodes order
            for i, v in enumerate(g.var_nodes):
                s_id = str(L["nodes"][i]["data"]["id"])
                A, b = L["A"][s_id], L["b"][s_id]   # A: 2×(sum d_c)
                res[s_id] = (A @ v.mu + b).tolist()

        elif name.startswith("abs"):
            parent = layers[li - 1]
            # Also align via string ids
            for i, v in enumerate(g.var_nodes):
                a_id = str(L["nodes"][i]["data"]["id"])  # same text as the super s_id
                A, b = L["A"][a_id], L["b"][a_id]        # A: 2×r
                res[a_id] = (A @ v.mu + b).tolist()

        L["gbp_result"] = res


class VGraph:
    def __init__(self,
                 layers,
                 nonlinear_factors=True,
                 eta_damping=0.4,
                 beta=0.0,
                 iters_since_relinear=0,
                 num_undamped_iters=0,
                 min_linear_iters=50,
                 wild_thresh=0):

        self.layers = layers
        self.iters_since_relinear = iters_since_relinear
        self.min_linear_iters = min_linear_iters
        self.nonlinear_factors = nonlinear_factors
        self.eta_damping = eta_damping
        self.wild_thresh = wild_thresh

        #self.energy_history = []
        #self.error_history = []
        #self.nmsgs_history = []
        #self.mus = []


    def vloop(self):
        """
        Simplified V-cycle:
        1) bottom-up: rebuild and iterate once for base / super / abs in order
        2) top-down: propagate mu from super -> base
        3) refresh gbp_result on each layer for UI use
        """

        layers = self.layers

        # ---- bottom-up ----
        #if layers and "graph" in layers[0]:
        #    layers[0]["graph"].synchronous_iteration()
            
        for i in range(1, len(layers)):
            name = layers[i]["name"]

            if name.startswith("super1"):
                # Update super using the previous base graph's new linearization points
                if self.iters_since_relinear >= self.min_linear_iters:
                    print("relinarizing super1 layer")
                    layers[i]["graph"] = update_super_graph_linearized(layers[:i+1], eta_damping=self.eta_damping)
                    self.iters_since_relinear = 0
                else:
                    self.iters_since_relinear += 1

            elif name.startswith("super"):
                # Update super using the previous layer's graph
                layers[i]["graph"] = build_super_graph(layers[:i+1], eta_damping=self.eta_damping)

            elif name.startswith("abs"):
                # Rebuild abs using the previous super
                abs_graph, Bs, ks, k2s = build_abs_graph(layers[:i+1], eta_damping=self.eta_damping)
                layers[i]["graph"] = abs_graph
                layers[i]["Bs"], layers[i]["ks"], layers[i]["k2s"] = Bs, ks, k2s

            # After build, one iteration per layer
            if "graph" in layers[i]:
                layers[i]["graph"].synchronous_iteration()

        # ---- top-down (pass mu) ----
        for i in range(len(layers) - 1, 0, -1):
            # After one iterations per layer, reproject
            if "graph" in layers[i]:
                layers[i]["graph"].synchronous_iteration()

            # this is very important, but dont know why yet
            # so abs layer need more iterations
            #if name.startswith("abs"):
                #layers[i]["graph"].synchronous_iteration()  

            name = layers[i]["name"]
            if name.startswith("super"):
                # Split super.mu back to base/abs
                top_down_modify_base_and_abs_graph(layers[:i+1])

            elif name.startswith("abs"):
                # Project abs.mu back to super
                top_down_modify_super_graph(layers[:i+1])

        # ---- refresh gbp_result for UI ----
        #refresh_gbp_results(layers)
        return layers
vg = VGraph(layers)

In [5]:
N=512
step=25
prob=0.05
radius=50 
prior_prop=0.02
prior_sigma=1
odom_sigma=1
loop_sigma=1
layers = []



layers = init_layers(N=N, step_size=step, loop_prob=prob, loop_radius=radius, prior_prop=prior_prop, seed=2001)
pair_idx = 0


# 构建 GBP 图
gbp_graph = build_noisy_pose_graph(layers[0]["nodes"], layers[0]["edges"],
                                    prior_sigma=prior_sigma,
                                    odom_sigma=odom_sigma,
                                    loop_sigma=loop_sigma,
                                    seed=2001)
layers[0]["graph"] = gbp_graph
gbp_graph.num_undamped_iters = 0
gbp_graph.min_linear_iters = 2000
opts=[{"label":"base","value":"base"}]



In [6]:
def max_lam_change(gbp_graph, prev_lams):
    """
    Compute max change of belief.lam over all variables.

    Parameters
    ----------
    gbp_graph : FactorGraph
        Graph with var_nodes[i].belief.lam (numpy array)
    prev_lams : list[np.ndarray]
        Previous iteration lams, same order as var_nodes

    Returns
    -------
    float
        Maximum absolute change over all variables
    """
    max_change = 0.0

    if len(prev_lams) != len(gbp_graph.var_nodes):
        raise ValueError("prev_lams size mismatch")

    for i, v in enumerate(gbp_graph.var_nodes):
        if v is None:
            continue

        lam_now = v.belief.lam
        lam_prev = prev_lams[i]

        diff = np.max(np.abs(lam_now - lam_prev))
        max_change = max(max_change, diff)

    return max_change

def max_eta_change(gbp_graph, prev_etas):
    """
    Compute max change of belief.eta over all variables.

    Parameters
    ----------
    gbp_graph : FactorGraph
        Graph with var_nodes[i].belief.eta (numpy array)
    prev_etas : list[np.ndarray]
        Previous iteration etas, same order as var_nodes

    Returns
    -------
    float
        Maximum absolute change over all variables
    """
    max_change = 0.0

    if len(prev_etas) != len(gbp_graph.var_nodes):
        raise ValueError("prev_etas size mismatch")

    for i, v in enumerate(gbp_graph.var_nodes):
        if v is None:
            continue

        eta_now = v.belief.eta
        eta_prev = prev_etas[i]

        diff = np.max(np.abs(eta_now - eta_prev))
        max_change = max(max_change, diff)

    return max_change


def energy_map(graph, include_priors: bool = True, include_factors: bool = True) -> float:
    """
    It is actually the sum of squares of distances.
    """
    total = 0.0

    for v in graph.var_nodes[:graph.n_var_nodes]:
        gt = np.asarray(v.GT[0:2], dtype=float)
        r = np.asarray(v.mu[0:2], dtype=float) - gt
        total += 0.5 * float(r.T @ r)

    return total

energy_map(layers[0]["graph"], include_priors=True, include_factors=True)


41233.35115424026

In [7]:
N=512
step=25
prob=0.05
radius=50 
prior_prop=0.02
prior_sigma=1
odom_sigma=1
loop_sigma=10
layers = []



layers = init_layers(N=N, step_size=step, loop_prob=prob, loop_radius=radius, prior_prop=prior_prop, seed=2001)
pair_idx = 0


# 构建 GBP 图
gbp_graph = build_noisy_pose_graph(layers[0]["nodes"], layers[0]["edges"],
                                    prior_sigma=prior_sigma,
                                    odom_sigma=odom_sigma,
                                    loop_sigma=loop_sigma,
                                    seed=2001)
layers[0]["graph"] = gbp_graph
gbp_graph.num_undamped_iters = 0
gbp_graph.min_linear_iters = 20000
opts=[{"label":"base","value":"base"}]

basegraph = layers[0]["graph"]
basegraph.relinearise_factors()
print(energy_map(basegraph, include_priors=True, include_factors=True))

for k in range(15):
    prev_lams = [v.belief.lam.copy() for v in basegraph.var_nodes]
    prev_etas = [v.belief.eta.copy() for v in basegraph.var_nodes]
    max_lam_change_values = []
    max_eta_change_values = []
    for it in range(100):
        basegraph.synchronous_iteration()

        max_lam_change_value = max_lam_change(basegraph, prev_lams)
        max_eta_change_value = max_eta_change(basegraph, prev_etas)
        max_lam_change_values.append(max_lam_change_value)
        max_eta_change_values.append(max_eta_change_value)
        
        energy = energy_map(basegraph, include_priors=True, include_factors=True)
        print(f"OuterIter {k}, InnerIter {it:03d}, MaxLamChange {max_lam_change_value:.4e}, MaxEtaChange {max_eta_change_value:.4e} | Energy = {energy:.6f}")
        prev_lams = [v.belief.lam.copy() for v in gbp_graph.var_nodes]
        prev_etas = [v.belief.eta.copy() for v in gbp_graph.var_nodes]
    basegraph.relinearise_factors()

41233.35115424026
OuterIter 0, InnerIter 000, MaxLamChange 1.0000e+10, MaxEtaChange 2.5923e+04 | Energy = 173519.444239
OuterIter 0, InnerIter 001, MaxLamChange 1.0000e+04, MaxEtaChange 2.7522e+04 | Energy = 102961.225132
OuterIter 0, InnerIter 002, MaxLamChange 5.0706e+03, MaxEtaChange 1.6476e+04 | Energy = 85578.591995
OuterIter 0, InnerIter 003, MaxLamChange 3.4202e+03, MaxEtaChange 1.5710e+04 | Energy = 73855.478107
OuterIter 0, InnerIter 004, MaxLamChange 2.9833e+03, MaxEtaChange 1.4182e+04 | Energy = 64658.225863
OuterIter 0, InnerIter 005, MaxLamChange 2.6257e+03, MaxEtaChange 2.2398e+04 | Energy = 64196.256390
OuterIter 0, InnerIter 006, MaxLamChange 2.1726e+03, MaxEtaChange 1.2257e+04 | Energy = 76153.258196
OuterIter 0, InnerIter 007, MaxLamChange 2.1283e+03, MaxEtaChange 1.1524e+04 | Energy = 117366.902920
OuterIter 0, InnerIter 008, MaxLamChange 3.7869e+03, MaxEtaChange 1.1753e+04 | Energy = 193279.566390
OuterIter 0, InnerIter 009, MaxLamChange 2.0152e+03, MaxEtaChange 1.1

In [8]:
N = 512
step=25
prob=0.05
radius=50 
prior_prop=0.02
prior_sigma=1
odom_sigma=1
layers = []


layers = init_layers(N=N, step_size=step, loop_prob=prob, loop_radius=radius, prior_prop=prior_prop, seed=2001)
pair_idx = 0



# Create GBP graph
gbp_graph = build_noisy_pose_graph(layers[0]["nodes"], layers[0]["edges"],
                                    prior_sigma=prior_sigma,
                                    odom_sigma=odom_sigma,
                                    seed=2001)
layers[0]["graph"] = gbp_graph
gbp_graph.num_undamped_iters = 0
gbp_graph.min_linear_iters = 2000
opts=[{"label":"base","value":"base"}]


kk = 10
k_next = 1
super_layer_idx = k_next*2 - 1
last = layers[-1]
super_nodes, super_edges, node_map = fuse_to_super_order(last["nodes"], last["edges"], int(kk or 8), super_layer_idx, tail_heavy=True)
# Ensure base graph has run at least once
layers[-1]["graph"].synchronous_iteration() 
layers.append({"name":f"super{k_next}", "nodes":super_nodes, "edges":super_edges, "node_map":node_map})
if super_layer_idx > 1:
    layers[super_layer_idx]["graph"] = build_super_graph(layers)
else:
    layers[super_layer_idx]["graph"] = build_super_graph(layers)



abs_layer_idx = 2
k = 1
last = layers[-1]
abs_nodes, abs_edges = copy_to_abs(last["nodes"], last["edges"], abs_layer_idx)
# Ensure super graph has run at least once
layers[-1]["graph"].synchronous_iteration() 
layers.append({"name":f"abs{k}", "nodes":abs_nodes, "edges":abs_edges})
layers[abs_layer_idx]["graph"], layers[abs_layer_idx]["Bs"], layers[abs_layer_idx]["ks"], layers[abs_layer_idx]["k2s"] = build_abs_graph(
    layers, r_reduced=2)



k_next = 2
super_layer_idx = k_next*2 - 1
last = layers[-1]
super_nodes, super_edges, node_map = fuse_to_super_order(last["nodes"], last["edges"], int(kk or 8), super_layer_idx, tail_heavy=True)
# Ensure super graph has run at least once
layers[-1]["graph"].synchronous_iteration() 
layers.append({"name":f"super{k_next}", "nodes":super_nodes, "edges":super_edges, "node_map":node_map})
if super_layer_idx > 1:
    layers[super_layer_idx]["graph"] = build_super_graph(layers)
else:
    layers[super_layer_idx]["graph"] = build_super_graph(layers)



abs_layer_idx = 4
k = 2
last = layers[-1]
abs_nodes, abs_edges = copy_to_abs(last["nodes"], last["edges"], abs_layer_idx)
# Ensure super graph has run at least once
layers[-1]["graph"].synchronous_iteration() 
layers.append({"name":f"abs{k}", "nodes":abs_nodes, "edges":abs_edges})
layers[abs_layer_idx]["graph"], layers[abs_layer_idx]["Bs"], layers[abs_layer_idx]["ks"], layers[abs_layer_idx]["k2s"] = build_abs_graph(
    layers, r_reduced=2)



"""
k_next = 3
super_layer_idx = k_next*2 - 1
last = layers[-1]
super_nodes, super_edges, node_map = fuse_to_super_order(last["nodes"], last["edges"], int(kk or 8), super_layer_idx, tail_heavy=True)
# Ensure super graph has run at least once
layers[-1]["graph"].synchronous_iteration() 
layers.append({"name":f"super{k_next}", "nodes":super_nodes, "edges":super_edges, "node_map":node_map})
if super_layer_idx > 1:
    layers[super_layer_idx]["graph"] = build_super_graph(layers)
else:
    layers[super_layer_idx]["graph"] = build_super_graph(layers)


abs_layer_idx = 6
k = 3
last = layers[-1]
abs_nodes, abs_edges = copy_to_abs(last["nodes"], last["edges"], abs_layer_idx)
# Ensure super graph has run at least once
layers[-1]["graph"].synchronous_iteration() 
layers.append({"name":f"abs{k}", "nodes":abs_nodes, "edges":abs_edges})
layers[abs_layer_idx]["graph"], layers[abs_layer_idx]["Bs"], layers[abs_layer_idx]["ks"], layers[abs_layer_idx]["k2s"] = build_abs_graph(layers)
"""

"""
k_next = 4
super_layer_idx = k_next*2 - 1
last = layers[-1]
super_nodes, super_edges, node_map = fuse_to_super_order(last["nodes"], last["edges"], int(kk or 8), super_layer_idx, tail_heavy=True)
# Ensure super graph has run at least once
layers[-1]["graph"].synchronous_iteration() 
layers.append({"name":f"super{k_next}", "nodes":super_nodes, "edges":super_edges, "node_map":node_map})
if super_layer_idx > 1:
    layers[super_layer_idx]["graph"] = build_super_graph(layers)
else:
    layers[super_layer_idx]["graph"] = build_super_graph(layers)


abs_layer_idx = 8
k = 4
last = layers[-1]
abs_nodes, abs_edges = copy_to_abs(last["nodes"], last["edges"], abs_layer_idx)
# Ensure super graph has run at least once
layers[-1]["graph"].synchronous_iteration() 
layers.append({"name":f"abs{k}", "nodes":abs_nodes, "edges":abs_edges})
layers[abs_layer_idx]["graph"], layers[abs_layer_idx]["Bs"], layers[abs_layer_idx]["ks"], layers[abs_layer_idx]["k2s"] = build_abs_graph(layers)
"""


"""
for i in range(2000):
    layers[1]["graph"].synchronous_iteration()
    layers[0]["graph"].synchronous_iteration()
    layers[2]["graph"].synchronous_iteration()
    #top_down_modify_super_graph(layers[:])
    #top_down_modify_base_and_abs_graph(layers[0:2])
    energy = energy_map(layers[0]["graph"], include_priors=True, include_factors=True)
    print(f"Iter {i+1:03d} | Energy = {energy:.6f}")
"""

vg = VGraph(layers)
energy_prev = 0
counter = 0
for _ in range(200):
    vg.layers = layers
    vg.r_reduced=3
    vg.eta_damping = 0
    vg.layers = vg.vloop()
    energy = energy_map(layers[0]["graph"], include_priors=True, include_factors=True)
    if np.abs(energy_prev-energy) < 1e-5:
        counter += 1
        if counter >= 2:
            break
    print(f"Iter {_+1:03d} | Energy = {energy:.6f}")

    #energy_prev = energy
refresh_gbp_results(layers)


Iter 001 | Energy = 81392055.033818
Iter 002 | Energy = 48485174.027609
Iter 003 | Energy = 33800898.758707
Iter 004 | Energy = 21052178.838087
Iter 005 | Energy = 16281969.642530
Iter 006 | Energy = 13895113.827902
Iter 007 | Energy = 12707643.940458
Iter 008 | Energy = 11813446.775472
Iter 009 | Energy = 11168375.645369
Iter 010 | Energy = 10676745.913313
Iter 011 | Energy = 10276813.349552
Iter 012 | Energy = 9980598.817900
Iter 013 | Energy = 9729661.282836
Iter 014 | Energy = 9549041.733001
Iter 015 | Energy = 9395055.576061
Iter 016 | Energy = 9289684.099974
Iter 017 | Energy = 9199902.065973
Iter 018 | Energy = 9143947.192402
Iter 019 | Energy = 9095185.021747
Iter 020 | Energy = 9056405.003805
Iter 021 | Energy = 9046030.759119
Iter 022 | Energy = 9037689.701213
Iter 023 | Energy = 9028690.309043
Iter 024 | Energy = 9031299.425099
Iter 025 | Energy = 9029824.316969
Iter 026 | Energy = 9024758.497067
Iter 027 | Energy = 9040490.229401
Iter 028 | Energy = 9050436.161049
Iter 029 

KeyboardInterrupt: 

In [15]:
"""
gtsam_se2_benchmark_lm.py

Purpose
-------
Pure GTSAM benchmark for nonlinear SE(2) pose-graph:
- Use GTSAM's Pose2 / BetweenFactorPose2 / PriorFactorPose2
- Build one full graph (prior + odom + loop closures)
- Solve once with Levenberg-Marquardt (batch, one-shot)
- Report build time + solve time + basic diagnostics

Notes
-----
1) This is NOT using your GBP classes at all.
2) Measurements are SE(2) relative motions z_ij = (dx_local, dy_local, dtheta)
   which matches GTSAM Pose2 "between" convention.
3) Noise injection is done once, deterministically by seed, at graph construction.
4) Anchor: strong prior at node 0, plus optionally additional priors controlled by prior_prop.
   (Matches your generator's "prior edges to virtual prior".)

Run
---
python gtsam_se2_benchmark_lm.py
"""

import time
import numpy as np
import gtsam
from gtsam import symbol


# -----------------------
# Utilities
# -----------------------
def wrap_angle(a: float) -> float:
    return float(np.arctan2(np.sin(a), np.cos(a)))


def sigma_vec(s, theta_ratio=0.01) -> np.ndarray:
    v = np.array(s, dtype=float).ravel()
    if v.size == 1:
        s0 = float(v.item())
        return np.array([s0, s0, s0 * theta_ratio], dtype=float)
    if v.size != 3:
        raise ValueError("sigma must be scalar or length-3 [sx, sy, sth]")
    return v.astype(float)


def relpose_SE2(pose_i, pose_j) -> np.ndarray:
    """
    Return z_ij = [dx_local, dy_local, dtheta], local frame is pose_i.
    pose = [x, y, theta]
    """
    xi, yi, thi = pose_i
    xj, yj, thj = pose_j
    c, s = np.cos(thi), np.sin(thi)
    RT = np.array([[c, s], [-s, c]], dtype=float)  # R(thi)^T
    dp = np.array([xj - xi, yj - yi], dtype=float)
    trans_local = RT @ dp
    dth = wrap_angle(thj - thi)
    return np.array([trans_local[0], trans_local[1], dth], dtype=float)


def compose_SE2(pose_i, z_ij) -> np.ndarray:
    """pose_j = pose_i ⊕ z_ij, where z_ij is in i local frame."""
    xi, yi, thi = pose_i
    dx, dy, dth = z_ij
    c, s = np.cos(thi), np.sin(thi)
    t_global = np.array([c * dx - s * dy, s * dx + c * dy], dtype=float)
    xj = xi + t_global[0]
    yj = yi + t_global[1]
    thj = wrap_angle(thi + dth)
    return np.array([xj, yj, thj], dtype=float)


# -----------------------
# SLAM-like generator (clean GT + clean relative z on edges)
# -----------------------
def make_slam_like_graph(
    N=100,
    step_size=25,
    loop_prob=0.05,
    loop_radius=50,
    prior_prop=0.0,
    seed=None,
):
    rng = np.random.default_rng(seed)

    # --- SE(2) trajectory (smooth heading) ---
    poses = []
    x, y, th = 0.0, 0.0, 0.0
    poses.append((x, y, th))

    TURN_STD = 1.0  # rad per step
    for _ in range(1, int(N)):
        dth = rng.normal(0.0, TURN_STD)
        th = wrap_angle(th + dth)
        x += float(step_size) * np.cos(th)
        y += float(step_size) * np.sin(th)
        poses.append((x, y, th))

    poses = np.array(poses, dtype=float)  # (N,3)

    # --- edges: list of tuples (i, j, kind, z_clean) ---
    edges = []

    # sequential odom
    for i in range(int(N) - 1):
        z_ij = relpose_SE2(poses[i], poses[i + 1])
        edges.append((i, i + 1, "odom", z_ij))

    # loop closures (proximity-triggered)
    for i in range(int(N)):
        xi, yi, _ = poses[i]
        for j in range(i + 5, int(N)):
            if rng.random() < float(loop_prob):
                xj, yj, _ = poses[j]
                if np.hypot(xi - xj, yi - yj) < float(loop_radius):
                    z_ij = relpose_SE2(poses[i], poses[j])
                    edges.append((i, j, "loop", z_ij))

    # prior nodes set
    if prior_prop <= 0.0:
        prior_ids = {0}
    elif prior_prop >= 1.0:
        prior_ids = set(range(N))
    else:
        k = max(1, int(np.floor(prior_prop * N)))
        prior_ids = set(rng.choice(N, size=k, replace=False).tolist())

    return poses, edges, prior_ids


# -----------------------
# Build GTSAM Pose2 graph + initial
# -----------------------
import numpy as np
import gtsam
from gtsam import symbol

def build_gtsam_pose2_graph(
    GT_poses,  # (N,3) used ONLY to generate clean measurements / prior means
    edges,     # list of (i, j, kind, z_clean)  z_clean = [dx,dy,dth] in i-local
    prior_ids,
    prior_sigma=1.0,
    odom_sigma=1.0,
    loop_sigma=1.0,
    theta_ratio=0.01,
    seed=2001,
    add_measurement_noise=True,
    add_prior_noise=True,
    anchor_strong=True,
    anchor_sigmas=(1e-3, 1e-3, 1e-5),
):
    rng = np.random.default_rng(seed)
    N = int(GT_poses.shape[0])

    # noise models (sigmas)
    prior_sig = sigma_vec(prior_sigma, theta_ratio)  # (3,)
    odom_sig  = sigma_vec(odom_sigma, theta_ratio)   # (3,)
    loop_sig  = sigma_vec(loop_sigma, theta_ratio)   # (3,)

    prior_model = gtsam.noiseModel.Diagonal.Sigmas(gtsam.Point3(*prior_sig))
    odom_model  = gtsam.noiseModel.Diagonal.Sigmas(gtsam.Point3(*odom_sig))
    loop_model  = gtsam.noiseModel.Diagonal.Sigmas(gtsam.Point3(*loop_sig))

    graph = gtsam.NonlinearFactorGraph()
    initial = gtsam.Values()

    def key(i: int):
        return symbol("x", int(i))

    # ------------------------------------------------------------
    # Store the EXACT noisy measurements used by factors
    # so init can reuse them (no GT init, deterministic).
    # ------------------------------------------------------------
    strong_prior_meas = {}   # i -> noisy Pose (x,y,th) used in PriorFactorPose2
    odom_meas = {}           # (i,i+1) -> noisy z_ij used in BetweenFactorPose2 (odom edges only)

    # (A) strong anchor on x0 (factor mean uses GT, but init will use a noisy prior measurement)
    if anchor_strong:
        a = np.array(anchor_sigmas, dtype=float).ravel()
        if a.size != 3:
            raise ValueError("anchor_sigmas must be length-3")
        anchor_model = gtsam.noiseModel.Diagonal.Sigmas(
            gtsam.Point3(float(a[0]), float(a[1]), float(a[2]))
        )

        z0 = GT_poses[0].copy()
        graph.add(gtsam.PriorFactorPose2(
            key(0),
            gtsam.Pose2(float(z0[0]), float(z0[1]), float(z0[2])),
            anchor_model
        ))

        # For init reset, define an anchor "measurement mean" (noisy) if requested
        if add_prior_noise:
            n0 = rng.normal(0.0, prior_sig, size=3)
            z0n = z0.copy()
            z0n[:2] += n0[:2]
            z0n[2] = wrap_angle(z0n[2] + n0[2])
            strong_prior_meas[0] = z0n
        else:
            strong_prior_meas[0] = z0.copy()

    # (B) other strong priors: factor mean is noisy pose measurement
    for i in sorted(prior_ids):
        if i == 0 and anchor_strong:
            continue

        z = GT_poses[i].copy()
        if add_prior_noise:
            n = rng.normal(0.0, prior_sig, size=3)
            z[:2] += n[:2]
            z[2] = wrap_angle(z[2] + n[2])

        strong_prior_meas[i] = z.copy()

        graph.add(gtsam.PriorFactorPose2(
            key(i),
            gtsam.Pose2(float(z[0]), float(z[1]), float(z[2])),
            prior_model
        ))

    # (C) between factors: add factors; also store sequential odom measurements for init chaining
    for (i, j, kind, z_clean) in edges:
        z = np.array(z_clean, dtype=float).ravel()
        if z.size != 3:
            raise ValueError("z_clean must be length-3 [dx, dy, dtheta]")

        if add_measurement_noise:
            if kind == "loop":
                n = rng.normal(0.0, loop_sig, size=3)
                model = loop_model
            else:
                n = rng.normal(0.0, odom_sig, size=3)
                model = odom_model

            z_noisy = z.copy()
            z_noisy[:2] += n[:2]
            z_noisy[2] = wrap_angle(z_noisy[2] + n[2])
        else:
            model = loop_model if kind == "loop" else odom_model
            z_noisy = z.copy()

        # store odom for chaining ONLY if it's sequential odom edge (i,i+1)
        if kind == "odom" and (j == i + 1):
            odom_meas[(i, j)] = z_noisy.copy()

        graph.add(gtsam.BetweenFactorPose2(
            key(i), key(j),
            gtsam.Pose2(float(z_noisy[0]), float(z_noisy[1]), float(z_noisy[2])),
            model
        ))

    # ------------------------------------------------------------
    # (D) INITIALIZATION per your requirement:
    # start from (0,0,0), propagate with odom ⊕,
    # when hit a strong prior node, hard reset to that prior mean measurement.
    # ------------------------------------------------------------
    mu = np.zeros((N, 3), dtype=float)
    mu[0] = np.array([0.0, 0.0, 0.0], dtype=float)

    # If node 0 has a strong prior measurement, reset immediately (your rule allows this)
    if 0 in strong_prior_meas:
        mu[0] = strong_prior_meas[0].copy()

    for i in range(N - 1):
        if (i, i + 1) in odom_meas:
            mu[i + 1] = compose_SE2(mu[i], odom_meas[(i, i + 1)])
        else:
            # missing odom: hold last pose (still not GT)
            mu[i + 1] = mu[i].copy()

        # reset at strong prior node
        if (i + 1) in strong_prior_meas:
            mu[i + 1] = strong_prior_meas[i + 1].copy()

    for i in range(N):
        initial.insert(key(i), gtsam.Pose2(float(mu[i, 0]), float(mu[i, 1]), float(mu[i, 2])))

    return graph, initial



# -----------------------
# Solve once with LM (batch)
# -----------------------
def solve_lm_once(graph, initial, max_iters=100, rel_error_tol=1e-9, abs_error_tol=1e-9, lambda_initial=1e-3):
    params = gtsam.LevenbergMarquardtParams()
    #params = gtsam.GaussNewtonParams()
    params.setVerbosityLM("SILENT")  # for timing
    #params.setVerbosityLM("SUMMARY")
    params.setMaxIterations(int(max_iters))
    params.setRelativeErrorTol(float(rel_error_tol))
    params.setAbsoluteErrorTol(float(abs_error_tol))
    params.setlambdaInitial(float(lambda_initial))

    optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial, params)
    #optimizer = gtsam.GaussNewtonOptimizer(graph, initial, params)

    t0 = time.perf_counter()
    result = optimizer.optimize()
    t1 = time.perf_counter()
    return result, (t1 - t0)


def batch_benchmark(
    N=5000,
    step_size=25,
    loop_prob=0.05,
    loop_radius=50,
    prior_prop=0.0,
    prior_sigma=1.0,
    odom_sigma=1.0,
    loop_sigma=1.0,
    theta_ratio=0.01,
    seed=2001,
    runs=5,
    warmup=1,
):
    # generate GT + clean graph topology
    GT, edges, prior_ids = make_slam_like_graph(
        N=N,
        step_size=step_size,
        loop_prob=loop_prob,
        loop_radius=loop_radius,
        prior_prop=prior_prop,
        seed=seed,
    )

    # Build once (build time matters too, so measure separately if you want)
    t_build0 = time.perf_counter()
    graph, initial = build_gtsam_pose2_graph(
        GT, edges, prior_ids,
        prior_sigma=prior_sigma,
        odom_sigma=odom_sigma,
        loop_sigma=loop_sigma,
        theta_ratio=theta_ratio,
        seed=seed,
        add_measurement_noise=True,
        add_prior_noise=True,
        anchor_strong=True,
    )
    t_build1 = time.perf_counter()

    # Warmup solve(s) to avoid first-call overhead in timing (allocations, etc.)
    for _ in range(max(0, int(warmup))):
        _ = solve_lm_once(graph, initial, max_iters=100)[0]

    # Timed runs
    times = []
    last_result = None
    for _ in range(int(runs)):
        last_result, dt = solve_lm_once(graph, initial, max_iters=200)
        times.append(dt)

    times = np.array(times, dtype=float)

    # Basic diagnostics: final error
    final_error = graph.error(last_result)

    print("=== GTSAM Pose2 Batch LM Benchmark ===")
    print(f"N={N}, edges={len(edges)} (odom+loop), priors={len(prior_ids)}")
    print(f"sigmas: prior={prior_sigma}, odom={odom_sigma}, loop={loop_sigma}, theta_ratio={theta_ratio}")
    print(f"build_time: {(t_build1 - t_build0)*1000:.3f} ms")
    print(f"solve_time over {runs} runs (after warmup={warmup}):")
    print(f"  mean = {times.mean()*1000:.3f} ms")
    print(f"  std  = {times.std()*1000:.3f} ms")
    print(f"  min  = {times.min()*1000:.3f} ms")
    print(f"  max  = {times.max()*1000:.3f} ms")
    print(f"final_graph_error = {final_error:.6g}")

    lm_iter_timing_print(graph, initial, max_iters=200)
    return {
        "build_s": float(t_build1 - t_build0),
        "solve_s_mean": float(times.mean()),
        "solve_s_std": float(times.std()),
        "solve_s_min": float(times.min()),
        "solve_s_max": float(times.max()),
        "final_error": float(final_error),
        "edges": int(len(edges)),
        "priors": int(len(prior_ids)),
        "GT": GT,
        "result": last_result,
    }


import time
import numpy as np
import gtsam
from gtsam import symbol

def lm_iter_timing_print(
    graph,
    initial,
    max_iters=200,
    lambda_initial=1e-3,
    abs_tol=1e-9,          # absolute error improvement threshold
    rel_tol=1e-9,          # relative error improvement threshold
    stall_iters=3,         # require k consecutive "small improvements" before stopping
    min_iters=1,           # force at least this many iterations
    print_every=1,
):
    """
    Run LM manually with per-iteration timing and early stopping based on error improvement.

    Stopping rule (after min_iters):
      let e_prev, e_now
      improvement = e_prev - e_now

      stop if improvement <= abs_tol  OR  improvement/|e_prev| <= rel_tol
      for `stall_iters` consecutive iterations.

    Notes
    -----
    - This is the cleanest Python-side way to mimic LM termination when verbosity is not available.
    - If an iteration increases error (rejected step / damping behavior), improvement becomes negative;
      that counts as "stall" and accelerates termination unless you change the logic below.
    """

    params = gtsam.LevenbergMarquardtParams()
    # verbosity in python often prints nothing; keep it but don't rely on it
    params.setVerbosityLM("SILENT")
    params.setlambdaInitial(float(lambda_initial))
    params.setMaxIterations(int(max_iters))
    # We do our own stopping; set tolerances to 0 to avoid internal early stop (if it works)
    params.setRelativeErrorTol(0.0000000001)
    params.setAbsoluteErrorTol(0.0000000001)

    opt = gtsam.LevenbergMarquardtOptimizer(graph, initial, params)

    ts = []
    errors = []

    e_prev = float(opt.error())
    errors.append(e_prev)

    stall_count = 0

    for k in range(max_iters):
        t0 = time.perf_counter()
        opt.iterate()  # 1 iter = relinearize + linear solve + retract (+ lambda update)
        t1 = time.perf_counter()

        dt_ms = (t1 - t0) * 1000.0
        ts.append(dt_ms)

        e_now = float(opt.error())
        errors.append(e_now)

        improvement = e_prev - e_now
        rel_impr = improvement / (abs(e_prev) + 1e-30)

        if (k % print_every) == 0:
            print(
                f"[TIMING] iter {k:03d}: {dt_ms:8.3f} ms | "
                f"error={e_now:.6g} | dE={improvement:.3e} | rel_dE={rel_impr:.3e}"
            )

        # --- early stop logic ---
        if (k + 1) >= min_iters:
            # define "stalled" as very small improvement (or worse)
            stalled = (improvement <= abs_tol) or (rel_impr <= rel_tol)
            if stalled:
                stall_count += 1
            else:
                stall_count = 0

            if stall_count >= stall_iters:
                print(
                    f"[STOP] stalled for {stall_iters} iters "
                    f"(abs_tol={abs_tol:.1e}, rel_tol={rel_tol:.1e}) at iter {k:03d}."
                )
                break

        e_prev = e_now

    return opt.values(), np.array(ts, dtype=float), np.array(errors, dtype=float)


if __name__ == "__main__":
    # ---- your config (copied from your message) ----
    stats = batch_benchmark(
        N=512,
        step_size=25,
        loop_prob=0.05,
        loop_radius=50,
        prior_prop=0.0,
        prior_sigma=1,
        odom_sigma=1,
        loop_sigma=10,
        theta_ratio=0.01,   # same as your THETA_RATIO
        seed=2001,
        runs=10,
        warmup=1,
    )




=== GTSAM Pose2 Batch LM Benchmark ===
N=512, edges=578 (odom+loop), priors=1
sigmas: prior=1, odom=1, loop=10, theta_ratio=0.01
build_time: 13.919 ms
solve_time over 10 runs (after warmup=1):
  mean = 31.032 ms
  std  = 2.454 ms
  min  = 28.449 ms
  max  = 35.544 ms
final_graph_error = 97.0286
[TIMING] iter 000:    3.262 ms | error=103.287 | dE=2.308e+06 | rel_dE=1.000e+00
[TIMING] iter 001:    2.892 ms | error=98.4026 | dE=4.885e+00 | rel_dE=4.729e-02
[TIMING] iter 002:    3.096 ms | error=97.6151 | dE=7.875e-01 | rel_dE=8.003e-03
[TIMING] iter 003:    2.816 ms | error=97.302 | dE=3.131e-01 | rel_dE=3.208e-03
[TIMING] iter 004:    2.817 ms | error=97.1343 | dE=1.678e-01 | rel_dE=1.724e-03
[TIMING] iter 005:    2.817 ms | error=97.0288 | dE=1.055e-01 | rel_dE=1.086e-03
[TIMING] iter 006:    3.750 ms | error=97.0286 | dE=1.687e-04 | rel_dE=1.738e-06
[TIMING] iter 007:    3.203 ms | error=97.0286 | dE=3.486e-06 | rel_dE=3.593e-08
[TIMING] iter 008:    3.209 ms | error=97.0286 | dE=1.785

In [None]:
def values_to_numpy_pose2(values: gtsam.Values, N: int):
    """
    Convert GTSAM Values (Pose2 with keys x0..x{N-1})
    to numpy array of shape (N,3): [x, y, theta].
    """
    out = np.zeros((N, 3), dtype=float)
    for i in range(N):
        pose: gtsam.Pose2 = values.atPose2(symbol("x", i))
        out[i, 0] = pose.x()
        out[i, 1] = pose.y()
        out[i, 2] = pose.theta()
    return out

mu_opt = values_to_numpy_pose2(stats["result"], N)


array([67.03366727, 11.78831146,  1.19792441])

In [29]:
i= 11
print(basegraph.var_nodes[i].GT)
print(mu_opt[i])
print(np.linalg.solve(basegraph.var_nodes[i].belief.lam,basegraph.var_nodes[i].belief.eta))

[40.14694024  0.64364101 -2.83206372]
[44.79787344  5.60984232 -2.87827601]
[44.37075522 -1.82902734 -2.8371492 ]


In [34]:
total = 0
for i in range(len(basegraph.var_nodes)):
    gt = np.asarray(basegraph.var_nodes[i].GT[0:3], dtype=float)
    r = mu_opt[i][0:3] - gt
    total += 0.5 * float(r.T @ r)

print(total)

2736571.9371848297


In [35]:
total = 0
for i in range(len(basegraph.var_nodes)):
    gt = np.asarray(basegraph.var_nodes[i].GT[0:3], dtype=float)
    mu = np.linalg.solve(basegraph.var_nodes[i].belief.lam,basegraph.var_nodes[i].belief.eta)
    r = mu[0:3] - gt
    total += 0.5 * float(r.T @ r)
print(total)

15793.554749407427


array([63.95886945,  8.25889067,  1.23746235])