In [2]:
import numpy as np

def r_of_theta(theta, dp):
    c, s = np.cos(theta), np.sin(theta)
    RT = np.array([[c, s], [-s, c]])  # R(theta)^T
    return RT @ dp

theta = 0.7
dp = np.array([1.2, -0.8])

# 数值导（中心差分）
eps = 1e-7
rn = r_of_theta(theta + eps, dp)
rp = r_of_theta(theta - eps, dp)
dr_num = (rn - rp) / (2*eps)

# 正确解析导：-J r = [ r_y, -r_x ]
r = r_of_theta(theta, dp)
dr_correct = np.array([ r[1], -r[0] ])

# 旧写法（错误）：+J r = [-r_y, r_x]
dr_old = np.array([-r[1],  r[0]])

print("numeric     :", dr_num)
print("correct (-J r):", dr_correct, "  |err|=", np.linalg.norm(dr_correct - dr_num))
print("old (+J r)    :", dr_old,     "  |err|=", np.linalg.norm(dr_old - dr_num))


numeric     : [-1.38493497 -0.40243648]
correct (-J r): [-1.38493497 -0.40243647]   |err|= 3.058302616545954e-10
old (+J r)    : [1.38493497 0.40243647]   |err|= 2.884441020424412


In [36]:
def fuse_to_super_kmeans(prev_nodes, prev_edges, k, layer_idx, max_iters=20, tol=1e-6, seed=0):
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)
    n = positions.shape[0]
    if k <= 0: 
        k = 1
    k = min(k, n)
    rng = np.random.default_rng(seed)

    # -------- 改进版初始化 --------
    # 随机无放回抽 k 个点，保证一开始每簇有独立的点
    init_idx = rng.choice(n, size=k, replace=False)
    centers = positions[init_idx]

    # Lloyd 迭代
    for _ in range(max_iters):
        d2 = ((positions[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
        assign = np.argmin(d2, axis=1)

        # -------- 空簇修补 --------
        counts = np.bincount(assign, minlength=k)
        empty_clusters = np.where(counts == 0)[0]
        for ci in empty_clusters:
            # 找到最大簇
            big_cluster = np.argmax(counts)
            big_idxs = np.where(assign == big_cluster)[0]
            # 偷一个点过来
            steal_idx = big_idxs[0]
            assign[steal_idx] = ci
            counts[big_cluster] -= 1
            counts[ci] += 1

        moved = 0.0
        for ci in range(k):
            idxs = np.where(assign == ci)[0]
            new_c = positions[idxs].mean(axis=0)
            moved = max(moved, float(np.linalg.norm(new_c - centers[ci])))
            centers[ci] = new_c
        if moved < tol:
            break

    # final assign (再做一次保证)
    d2 = ((positions[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2)
    assign = np.argmin(d2, axis=1)

    counts = np.bincount(assign, minlength=k)
    empty_clusters = np.where(counts == 0)[0]
    for ci in empty_clusters:
        big_cluster = np.argmax(counts)
        big_idxs = np.where(assign == big_cluster)[0]
        steal_idx = big_idxs[0]
        assign[steal_idx] = ci
        counts[big_cluster] -= 1
        counts[ci] += 1

    # ---------- 构造 super graph ----------
    super_nodes, node_map = [], {}
    for ci in range(k):
        idxs = np.where(assign == ci)[0]
        pts = positions[idxs]
        mean_x, mean_y = pts.mean(axis=0)
        child_dims = [prev_nodes[i]["data"]["dim"] for i in idxs]
        dim_val = int(max(1, sum(child_dims)))
        nid = f"{ci}"
        super_nodes.append({
            "data": {"id": nid, "layer": layer_idx, "dim": dim_val},
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        for i in idxs:
            node_map[prev_nodes[i]["data"]["id"]] = nid

    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]
        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            else:
                eid = (su, "prior")
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)
        else:
            su = node_map[u]
            eid = (su, "prior")
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map


In [37]:
def fuse_to_super_grid(prev_nodes, prev_edges, gx, gy, layer_idx):
    positions = np.array([[n["position"]["x"], n["position"]["y"]] for n in prev_nodes], dtype=float)
    xmin, ymin = positions.min(axis=0); xmax, ymax = positions.max(axis=0)
    cell_w = (xmax - xmin) / gx if gx > 0 else 1.0
    cell_h = (ymax - ymin) / gy if gy > 0 else 1.0
    if cell_w == 0: cell_w = 1.0
    if cell_h == 0: cell_h = 1.0
    cell_map = {}
    for idx, n in enumerate(prev_nodes):
        x, y = n["position"]["x"], n["position"]["y"]
        cx = min(int((x - xmin) / cell_w), gx - 1)
        cy = min(int((y - ymin) / cell_h), gy - 1)
        cid = cx + cy * gx
        cell_map.setdefault(cid, []).append(idx)
    super_nodes, node_map = [], {}
    for cid, indices in cell_map.items():
        pts = positions[indices]
        mean_x, mean_y = pts.mean(axis=0)
        child_dims = [prev_nodes[i]["data"]["dim"] for i in indices]
        dim_val = int(max(1, sum(child_dims)))
        nid = str(len(super_nodes))
        super_nodes.append({
            "data": {"id": nid, "layer": layer_idx, "dim": dim_val},
            "position": {"x": float(mean_x), "y": float(mean_y)}
        })
        for i in indices:
            node_map[prev_nodes[i]["data"]["id"]] = nid
    super_edges, seen = [], set()
    for e in prev_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v != "prior":
            su, sv = node_map[u], node_map[v]
            if su != sv:
                eid = tuple(sorted((su, sv)))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": sv}})
                    seen.add(eid)
            elif su == sv:
                eid = tuple(sorted((su, "prior")))
                if eid not in seen:
                    super_edges.append({"data": {"source": su, "target": "prior"}})
                    seen.add(eid)

        elif v == "prior":
            su = node_map[u]
            eid = tuple(sorted((su, v)))
            if eid not in seen:
                super_edges.append({"data": {"source": su, "target": "prior"}})
                seen.add(eid)

    return super_nodes, super_edges, node_map

In [38]:
def build_super_graph(layers):
    """
    基于 layers[-2] 的 base graph, 和 layers[-1] 的 super 分组，构造 super graph。
    要求: layers[-2]["graph"] 已经是构建好的基图（含 unary/binary 因子）。
    layers[-1]["node_map"]: { base_node_id(str, 如 'b12') -> super_node_id(str) }
    """
    from scipy.linalg import block_diag
    # ---------- 取出 base & super ----------
    base_graph = layers[-2]["graph"]
    super_nodes = layers[-1]["nodes"]
    super_edges = layers[-1]["edges"]
    node_map    = layers[-1]["node_map"]   # 'bN' -> 'sX_...'

    # base: id(int)->VariableNode，方便查 dofs 和 mu
    id2var = {vn.variableID: vn for vn in base_graph.var_nodes}

    # ---------- super_id -> [base_id(int)] ----------
    super_groups = {}
    for b_str, s_id in node_map.items():
        b_int = int(b_str)
        super_groups.setdefault(s_id, []).append(b_int)


    # ---------- 为每个 super 组建立 (start, dofs) 表 ----------
    # local_idx[sid][bid] = (start, dofs), total_dofs[sid] = sum(dofs)
    local_idx   = {}
    total_dofs  = {}
    for sid, group in super_groups.items():
        off = 0
        local_idx[sid] = {}
        for bid in group:
            d = id2var[bid].dofs
            local_idx[sid][bid] = (off, d)
            off += d
        total_dofs[sid] = off


    # ---------- 创建 super VariableNodes ----------
    fg = FactorGraph(nonlinear_factors=False, eta_damping=0)

    super_var_nodes = {}
    for i, sn in enumerate(super_nodes):
        sid = sn["data"]["id"]
        dofs = total_dofs.get(sid, 0)

        v = VariableNode(i, dofs=dofs)

        # === 叠加 base GT ===
        gt_vec = np.zeros(dofs)
        for bid, (st, d) in local_idx[sid].items():
            gt_base = getattr(id2var[bid], "GT", None)
            if gt_base is None or len(gt_base) != d:
                gt_base = np.zeros(d)
            gt_vec[st:st+d] = gt_base
        v.GT = gt_vec
        v.prior.lam = 1e-10 * np.eye(dofs, dtype=float)
        v.prior.eta = np.zeros(dofs, dtype=float)

        super_var_nodes[sid] = v
        fg.var_nodes.append(v)


    fg.n_var_nodes = len(fg.var_nodes)

    # ---------- 工具：拼接某组的 linpoint（用 base belief 均值） ----------
    def make_linpoint_for_group(sid):
        x = np.zeros(total_dofs[sid])
        for bid, (st, d) in local_idx[sid].items():
            mu = getattr(id2var[bid], "mu", None)
            if mu is None or len(mu) != d:
                mu = np.zeros(d)
            x[st:st+d] = mu
        return x

    # ---------- 3) super prior（in_group unary + in_group binary） ----------
    def make_super_prior_factor(sid, base_factors):
        group = super_groups[sid]
        idx_map = local_idx[sid]
        ncols = total_dofs[sid]

        # 选出：所有变量都在组内的因子（unary 或 binary）
        in_group = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if all(vid in group for vid in vids):
                in_group.append(f)

        def meas_fn_super_prior(x_super, *args):
            meas_fn = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # 拼本因子的局部 x
                x_loc_list = []
                for vid in vids:
                    st, d = idx_map[vid]
                    x_loc_list.append(x_super[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)
                meas_fn.append(f.meas_fn(x_loc))
            return np.concatenate(meas_fn) if meas_fn else np.zeros(0)

        def jac_fn_super_prior(x_super, *args):
            Jrows = []
            for f in in_group:
                vids = [v.variableID for v in f.adj_var_nodes]
                # 构造本因子的局部 x，用于（潜在）非线性雅可比
                x_loc_list = []
                dims = []
                for vid in vids:
                    st, d = idx_map[vid]
                    dims.append(d)
                    x_loc_list.append(x_super[st:st+d])
                x_loc = np.concatenate(x_loc_list) if x_loc_list else np.zeros(0)

                Jloc = f.jac_fn(x_loc)
                # 将 Jloc 列块映射回 super 变量的列
                row = np.zeros((Jloc.shape[0], ncols))
                c0 = 0
                for vid, d in zip(vids, dims):
                    st, _ = idx_map[vid]
                    row[:, st:st+d] = Jloc[:, c0:c0+d]
                    c0 += d

                Jrows.append(row)
            return np.vstack(Jrows) if Jrows else np.zeros((0, ncols))

        # z_super：拼各 base 因子的 z
        z_list = [f.measurement for f in in_group]
        z_lambda_list = [f.measurement_lambda for f in in_group]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_prior, jac_fn_super_prior, z_super, z_super_lambda 

    # ---------- 4) super between（cross_group binary） ----------
    def make_super_between_factor(sidA, sidB, base_factors):
        groupA, groupB = super_groups[sidA], super_groups[sidB]
        idxA, idxB = local_idx[sidA], local_idx[sidB]
        nA, nB = total_dofs[sidA], total_dofs[sidB]

        cross = []
        for f in base_factors:
            vids = [v.variableID for v in f.adj_var_nodes]
            if len(vids) != 2:
                continue
            i, j = vids
            # on side in A，the other side in B
            if (i in groupA and j in groupB) or (i in groupB and j in groupA):
                cross.append(f)


        def meas_fn_super_between(xAB, *args):
            xA, xB = xAB[:nA], xAB[nA:]
            meas_fn = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                x_loc = np.concatenate([xi, xj])
                meas_fn.append(f.meas_fn(x_loc))
            return np.concatenate(meas_fn) 

        def jac_fn_super_between(xAB, *args):
            xA, xB = xAB[:nA], xAB[nA:]
            Jrows = []
            for f in cross:
                i, j = [v.variableID for v in f.adj_var_nodes]
                if i in groupA:
                    si, di = idxA[i]
                    sj, dj = idxB[j]
                    xi = xA[si:si+di]
                    xj = xB[sj:sj+dj]
                    left_start, right_start = si, nA + sj
                else:
                    si, di = idxB[i]
                    sj, dj = idxA[j]
                    xi = xB[si:si+di]
                    xj = xA[sj:sj+dj]
                    left_start, right_start = nA + si, sj
                x_loc = np.concatenate([xi, xj])
                Jloc = f.jac_fn(x_loc)

                row = np.zeros((Jloc.shape[0], nA + nB))
                row[:, left_start:left_start+di]   = Jloc[:, :di] 
                row[:, right_start:right_start+dj] = Jloc[:, di:di+dj] 

                Jrows.append(row)
            return np.vstack(Jrows) 

        z_list = [f.measurement for f in cross]
        z_lambda_list = [f.measurement_lambda for f in cross]
        z_super = np.concatenate(z_list) 
        z_super_lambda = block_diag(*z_lambda_list)

        return meas_fn_super_between, jac_fn_super_between, z_super, z_super_lambda


    for e in super_edges:
        u, v = e["data"]["source"], e["data"]["target"]

        if v == "prior":
            meas_fn, jac_fn, z, z_lambda = make_super_prior_factor(u, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u]], z, z_lambda, meas_fn, jac_fn)
            f.type = "super_prior"
            lin0 = make_linpoint_for_group(u)
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            
        else:
            meas_fn, jac_fn, z, z_lambda = make_super_between_factor(u, v, base_graph.factors)
            f = Factor(len(fg.factors), [super_var_nodes[u], super_var_nodes[v]], z, z_lambda, meas_fn, jac_fn)
            f.type = "super_between"
            lin0 = np.concatenate([make_linpoint_for_group(u), make_linpoint_for_group(v)])
            f.compute_factor(linpoint=lin0, update_self=True)
            fg.factors.append(f)
            super_var_nodes[u].adj_factors.append(f)
            super_var_nodes[v].adj_factors.append(f)


    fg.n_factor_nodes = len(fg.factors)
    return fg


In [39]:
N=50
step=25
prob=0.1
radius=50 
prior_prop=0
prior_sigma=0.01
odom_sigma=10
layers = []


# -----------------------
# 初始化 & 边界
# -----------------------
def init_layers(N=10, step_size=25, loop_prob=0.05, loop_radius=50, prior_prop=0.0):
    base_nodes, base_edges = make_slam_like_graph(N, step_size, loop_prob, loop_radius, prior_prop)
    return [{"name": "base", "nodes": base_nodes, "edges": base_edges}]

layers = init_layers(N=N, step_size=step, loop_prob=prob, loop_radius=radius, prior_prop=prior_prop)
pair_idx = 0


# 构建 GBP 图
gbp_graph = build_noisy_pose_graph(layers[0]["nodes"], layers[0]["edges"],
                                    prior_sigma=prior_sigma,
                                    odom_sigma=odom_sigma
                                    )
layers[0]["graph"] = gbp_graph
opts=[{"label":"base","value":"base"}]


last = layers[-1]
super_layer_idx = 1


super_nodes, super_edges, node_map = fuse_to_super_grid(last["nodes"], last["edges"], 1, 1, super_layer_idx)
#super_nodes, super_edges, node_map = fuse_to_super_kmeans(last["nodes"], last["edges"], 2, super_layer_idx)
layers.append({"name":f"super{1}", "nodes":super_nodes, "edges":super_edges, "node_map":node_map})
layers[-1]["graph"] = build_super_graph(layers)




In [40]:
basegraph = layers[-2]["graph"]
for it in range(2000):
    basegraph.synchronous_iteration()
    energy = basegraph.energy_map(include_priors=True, include_factors=True)
    print(f"Iter {it+1:03d} | Energy = {energy:.6f}")


Iter 001 | Energy = 108089.375641
Iter 002 | Energy = 97496.061196
Iter 003 | Energy = 84881.233525
Iter 004 | Energy = 73279.742772
Iter 005 | Energy = 60143.371892
Iter 006 | Energy = 47318.063302
Iter 007 | Energy = 36002.380055
Iter 008 | Energy = 27075.876182
Iter 009 | Energy = 20396.397966
Iter 010 | Energy = 14535.620767
Iter 011 | Energy = 10734.652882
Iter 012 | Energy = 8029.017172
Iter 013 | Energy = 6609.852917
Iter 014 | Energy = 6942.211186
Iter 015 | Energy = 6758.090473
Iter 016 | Energy = 6691.763564
Iter 017 | Energy = 6683.746428
Iter 018 | Energy = 6713.574141
Iter 019 | Energy = 6780.943638
Iter 020 | Energy = 6812.682685
Iter 021 | Energy = 6841.707528
Iter 022 | Energy = 6864.746327
Iter 023 | Energy = 6896.602211
Iter 024 | Energy = 6918.260058
Iter 025 | Energy = 6936.846105
Iter 026 | Energy = 6954.346905
Iter 027 | Energy = 6974.851936
Iter 028 | Energy = 6993.889378
Iter 029 | Energy = 7012.928016
Iter 030 | Energy = 7033.135675
Iter 031 | Energy = 7051.634

In [41]:
supergraph = layers[-1]["graph"]

"""
# 找到所有 super_prior factors
prior_factors = [f for f in supergraph.factors if getattr(f, "type", "") == "super_prior"]

# 对 super_prior factors 的邻居变量
prior_vars = []
for f in prior_factors:
    for v in f.adj_var_nodes:
        if v not in prior_vars:
            prior_vars.append(v)

# 找到所有 super_between factors
between_factors = [f for f in supergraph.factors if getattr(f, "type", "") != "super_prior"]

# 对 super_between factors 的邻居变量
between_vars = []
for f in between_factors:
    for v in f.adj_var_nodes:
        if v not in between_vars:
            between_vars.append(v)
"""

for it in range(1000):
    supergraph.synchronous_iteration()
    energy = supergraph.energy_map(include_priors=True, include_factors=True)
    print(f"Iter {it+1:03d} | Energy = {energy:.6f}")

Iter 001 | Energy = 8041.390395
Iter 002 | Energy = 8041.390395
Iter 003 | Energy = 8041.390395
Iter 004 | Energy = 8041.390395
Iter 005 | Energy = 8041.390395
Iter 006 | Energy = 8041.390395
Iter 007 | Energy = 8041.390395
Iter 008 | Energy = 8041.390395
Iter 009 | Energy = 8041.390395
Iter 010 | Energy = 8041.390395
Iter 011 | Energy = 8041.390395
Iter 012 | Energy = 8041.390395
Iter 013 | Energy = 8041.390395
Iter 014 | Energy = 8041.390395
Iter 015 | Energy = 8041.390395
Iter 016 | Energy = 8041.390395
Iter 017 | Energy = 8041.390395
Iter 018 | Energy = 8041.390395
Iter 019 | Energy = 8041.390395
Iter 020 | Energy = 8041.390395
Iter 021 | Energy = 8041.390395
Iter 022 | Energy = 8041.390395
Iter 023 | Energy = 8041.390395
Iter 024 | Energy = 8041.390395
Iter 025 | Energy = 8041.390395
Iter 026 | Energy = 8041.390395
Iter 027 | Energy = 8041.390395
Iter 028 | Energy = 8041.390395
Iter 029 | Energy = 8041.390395
Iter 030 | Energy = 8041.390395
Iter 031 | Energy = 8041.390395
Iter 032