In [2]:
import sys
import os

# 获取当前 notebook 所在目录的父目录
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from gbp.gbp import *
from gbp.factor import *
from gbp.grid import *

H = 16
W = 16
prior_noise_std=prior_std=1
odom_noise_std=odom_std=0.01
seed=0
num_iters=100

In [14]:
positions, prior_meas, between_meas = generate_grid_slam_data(H=H, W=W, 
                                    prior_noise_std=prior_noise_std, odom_noise_std=odom_noise_std, seed=seed)
varis, prior_facs, between_facs = build_pose_slam_graph(N=256, prior_meas=prior_meas, between_meas=between_meas, 
                                                        prior_std=prior_std, odom_std=odom_std,
                                                        Ni_v=10, D=2)


# build the coarse-level pose SLAM grid
varis_sup, prior_facs_sup, horizontal_facs_sup, vertical_facs_sup  = build_coarse_slam_graph(
    prior_facs_fine=prior_facs,
    between_facs_fine=between_facs,
    H=H, W=W,
    stride = 2,
)


# Step 1: Variable update
varis_sup, vtof_msgs, linpoints = update_variable(varis_sup)

# Factor update
prior_facs, varis_sup = update_factor(prior_facs_sup, varis_sup, vtof_msgs, linpoints, h3_fn, l2)

# Variable update
varis_sup, vtof_msgs, linpoints = update_variable(varis_sup)


# Factor update
horizontal_facs_sup, varis_sup = update_factor(horizontal_facs_sup, varis_sup, vtof_msgs, 
                                                linpoints, h4_fn, l2)

vertical_facs_sup, varis_sup = update_factor(vertical_facs_sup, varis_sup, vtof_msgs, 
                                                linpoints, h5_fn, l2)


# Variable update
varis_sup, vtof_msgs, linpoints = update_variable(varis_sup)


In [None]:
@jax.jit
def h3_fn_tilde(x, B):
    """
    Predicts measurement h(x) for abs prior.

    Input:
        x: (2,) → 2D abstraction of 4 stacked fine-level variables
        B: (8,2) → projection matrix from 2D to 8D
    Output:
        z_hat: (16,) = [x0, x1, x2, x3, x1-x0, x2-x0, x3-x1, x3-x2]
    """
    x = x.reshape(-1)
    x = B @ x  # Project 2D abstraction to 8D
    
    x0 = x[0:2]
    x1 = x[2:4]
    x2 = x[4:6]
    x3 = x[6:8]

    # 4 priors (just xi)
    z_hat_0 = x0
    z_hat_1 = x1
    z_hat_2 = x2
    z_hat_3 = x3

    # 4 internal between
    z_hat_4 = x1 - x0
    z_hat_5 = x2 - x0
    z_hat_6 = x3 - x1
    z_hat_7 = x3 - x2

    return jnp.concatenate([
        z_hat_0, z_hat_1, z_hat_2, z_hat_3,
        z_hat_4, z_hat_5, z_hat_6, z_hat_7
    ])


@jax.jit
def h4_fn_tilde(xs, Bs):
    """
    Predicts coarse between measurement h(xs) where:
      - xs[0] is coarse variable i (8D)
      - xs[1] is coarse variable j (8D)
    Fixed: uses v01→v02 and v11→v12 edges
    
    Returns:
        z_hat: shape (4,) = two 2D relative positions
    """
    xi = xs[0]
    xj = xs[1]
    Bi = Bs[0]
    Bj = Bs[1]

    # Project coarse variables to 8D
    xi = Bi @ xi  # shape (8,)
    xj = Bj @ xj  # shape (8,)

    # First residual: xj[0:2] - xi[2:4]  (v02 - v01)
    r1 = xj[0:2] - xi[2:4]

    # Second residual: xj[4:6] - xi[6:8] (v12 - v11)
    r2 = xj[4:6] - xi[6:8]

    return jnp.concatenate([r1, r2])  # shape (4,)


@jax.jit
def h5_fn_tilde(xs, Bs):
    """
    Predicts coarse between measurement h(xs) where:
      - xs[0] is coarse variable i (8D)
      - xs[1] is coarse variable j (8D)
    Fixed: uses v10→v20 and v11→v21 edges

    Returns:
        z_hat: shape (4,) = two 2D relative positions
    """
    xi = xs[0]
    xj = xs[1]
    Bi = Bs[0]
    Bj = Bs[1]

    # Project coarse variables to 8D
    xi = Bi @ xi  # shape (8,)
    xj = Bj @ xj  # shape (8,)  

    # First residual: xj[0:2] - xi[4:6]  (v20 - v10)
    r1 = xj[0:2] - xi[4:6]

    # Second residual: xj[2:4] - xi[6:8] (v21 - v11)
    r2 = xj[2:4] - xi[6:8]

    return jnp.concatenate([r1, r2])  # shape (4,)

(32,)

In [None]:
def build_abs_slam_graph(
    varis_sup: Variable,
    prior_facs_sup: Factor,
    horizontal_facs_sup: Factor,
    vertical_facs_sup: Factor,
) -> Tuple[Variable, Factor, Factor]:

    abs_beliefs = []
    abs_msgs_eta = []
    abs_msgs_Lam = []
    abs_adj_factor_idx = []

    # === 1. Build Abstraction Variables ===
    varis_sup_mu = varis_sup.belief.mu()
    varis_sup_sigma = varis_sup.belief.sigma()

    for i in range(len(varis_sup.var_id)):
        varis_sup_mu_i = varis_sup_mu[i]
        varis_sup_sigma_i = varis_sup_sigma[i]

        eigvals, eigvecs = np.linalg.eigh(np.linalg.inv(varis_sup_sigma_i))

        # Step 2: Sort eigenvalues and eigenvectors in descending order of eigenvalues
        idx = np.argsort(eigvals)[::-1]      # Get indices of sorted eigenvalues (largest first)
        eigvals = eigvals[idx]               # Reorder eigenvalues
        eigvecs = eigvecs[:, idx]            # Reorder corresponding eigenvectors

        # Step 3: Select the top-k eigenvectors to form the projection matrix (principal subspace)
        k = 2
        B_k = eigvecs[:, :2]                 # B_k: shape (8, 2), projects 8D to 2D

        # Step 4: Project eta and Lam onto the reduced 2D subspace
        # This gives the natural parameters of the reduced 2D Gaussian
        varis_abs_mu_i = B_k.T @ varis_sup_mu_i          # Projected natural mean: shape (2,)
        varis_abs_sigma_i = B_k.T @ varis_sup_sigma_i @ B_k  # Projected covariance: shape (2, 2)

        varis_abs_lam_i = jnp.linalg.inv(varis_abs_sigma_i)  # Inverse covariance (precision matrix): shape (2, 2)
        varis_abs_eta_i = varis_abs_lam_i @ varis_abs_mu_i  # Natural parameters: shape (2,)
        abs_beliefs.append(Gaussian(varis_abs_eta_i, varis_abs_lam_i))

        abs_msgs_eta.append(jnp.zeros((8)))  # 8D for coarse variable
        abs_msgs_Lam.append(jnp.zeros((8, 8)))


    varis_abs = Variable(
        var_id=varis_sup.var_id,
        belief=tree_stack(abs_beliefs, axis=0),
        msgs=Gaussian(jnp.stack(abs_msgs_eta), jnp.stack(abs_msgs_Lam)),
        adj_factor_idx=jnp.stack(varis_sup.adj_factor_idx),
    )

    # === 2. Build Abs Priors ===
    prior_facs_abs = Factor(
        factor_id=prior_facs_sup.factor_id,
        z=prior_facs_sup.z,
        z_Lam=prior_facs_sup.z_Lam,
        threshold=prior_facs_sup.threshold,
        potential=None,
        adj_var_id=prior_facs_sup.adj_var_id,
        adj_var_idx=prior_facs_sup.adj_var_idx,
    )

    # === Build Horizontal & Vertical Between Factors Separately ===
    horizontal_facs_abs = Factor(
        factor_id=horizontal_facs_sup.factor_id,
        z= horizontal_facs_sup.z,
        z_Lam= horizontal_facs_sup.z_Lam,
        threshold= horizontal_facs_sup.threshold,
        potential=None,
        adj_var_id= horizontal_facs_sup.adj_var_id,
        adj_var_idx= horizontal_facs_sup.adj_var_idx,
    )   

    vertical_facs_abs = Factor(
        factor_id=vertical_facs_sup.factor_id,
        z= vertical_facs_sup.z,
        z_Lam= vertical_facs_sup.z_Lam,
        threshold= vertical_facs_sup.threshold,
        potential=None,                                         
        adj_var_id= vertical_facs_sup.adj_var_id,
        adj_var_idx= vertical_facs_sup.adj_var_idx,
    )

    return varis_abs, prior_facs_abs, horizontal_facs_abs, vertical_facs_abs

In [20]:
varis_abs, prior_facs_abs, horizontal_facs_abs, vertical_facs_abs = build_abs_slam_graph(
    varis_sup=varis_sup,
    prior_facs_sup=prior_facs_sup,
    horizontal_facs_sup=horizontal_facs_sup,
    vertical_facs_sup=vertical_facs_sup,
)

In [None]:
def gbp_solve_coarse(varis, prior_facs, horizontal_between_facs, vertical_between_facs, num_iters=50, visualize=False, prior_h=h3_fn, between_h=[h4_fn,h5_fn]):
    energy_log = []
    positions_log = []

    
    # Initialize variable with only priors factors 
    varis, vtof_msgs, linpoints = update_variable(varis)
    prior_facs, varis = update_factor(prior_facs, varis, vtof_msgs, linpoints, prior_h, l2)
    if visualize:
        # Linearization points and energy computation
        prior_energy = jnp.sum(jax.vmap(factor_energy, in_axes=(0, 0, None))(
            prior_facs, linpoints[prior_facs.adj_var_id[:, 0]], prior_h
        ))

        horizontal_between_energy = jnp.sum(jax.vmap(factor_energy, in_axes=(0, 0, None))(
            horizontal_between_facs, linpoints[horizontal_between_facs.adj_var_id], between_h[0]
        ))

        vertical_between_energy = jnp.sum(jax.vmap(factor_energy, in_axes=(0, 0, None))(
            vertical_between_facs, linpoints[vertical_between_facs.adj_var_id], between_h[1]
        ))


        energy = prior_energy + horizontal_between_energy + vertical_between_energy
        energy_log.append(energy)

        positions_log.append(linpoints)


    for i in range(num_iters-1):
        # Step 1: Variable update
        varis, vtof_msgs, linpoints = update_variable(varis)

        # Step 2: Factor update
        prior_facs, varis = update_factor(prior_facs, varis, vtof_msgs, linpoints, prior_h, l2)

        horizontal_between_facs, varis = update_factor(horizontal_between_facs, varis, vtof_msgs, 
                                                       linpoints, between_h[0], l2)

        vertical_between_facs, varis = update_factor(vertical_between_facs, varis, vtof_msgs, 
                                                       linpoints, between_h[1], l2)
        
        if visualize:
            # Step 3: Linearization points and energy computation
            prior_energy = jnp.sum(jax.vmap(factor_energy, in_axes=(0, 0, None))(
                prior_facs, linpoints[prior_facs.adj_var_id[:, 0]], prior_h
            ))
    
            horizontal_between_energy = jnp.sum(jax.vmap(factor_energy, in_axes=(0, 0, None))(
                horizontal_between_facs, linpoints[horizontal_between_facs.adj_var_id], between_h[0]
            ))
    
            vertical_between_energy = jnp.sum(jax.vmap(factor_energy, in_axes=(0, 0, None))(
                vertical_between_facs, linpoints[vertical_between_facs.adj_var_id], between_h[1]
            ))

            energy = prior_energy + horizontal_between_energy + vertical_between_energy

            energy_log.append(energy)
            positions_log.append(linpoints)

        
    return varis, prior_facs, horizontal_between_facs, vertical_between_facs, \
            np.array(energy_log), np.array(positions_log)