In [35]:
from typing import Callable, Dict, List
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from functools import partial
import numpy as np
import time


@register_pytree_node_class
class Gaussian:
    def __init__(self, eta, Lam):
        self.eta = eta
        self.Lam = Lam

    def tree_flatten(self):
        return (self.eta, self.Lam), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

    def mu(self):
        return jnp.where(
            jnp.allclose(self.Lam, 0), self.eta, jnp.linalg.solve(self.Lam, self.eta)
        )
    
    def sigma(self):
        return jnp.linalg.inv(self.Lam)

    def zero_like(self):
        return Gaussian(jnp.zeros_like(self.eta), jnp.zeros_like(self.Lam))

    def __repr__(self) -> str:
        return f"Gaussian(eta={self.eta}, lam={self.Lam})"

    def __mul__(self, other):
        return Gaussian(self.eta + other.eta, self.Lam + other.Lam)

    def __truediv__(self, other):
        return Gaussian(self.eta - other.eta, self.Lam - other.Lam)

    def copy(self):
        return Gaussian(self.eta.copy(), self.Lam.copy())


@register_pytree_node_class
class Variable:
    var_id: int
    belief: Gaussian
    msgs: Gaussian
    adj_factor_idx: jnp.array

    def __init__(self, var_id, belief, msgs, adj_factor_idx):
        self.var_id = var_id
        self.belief = belief
        self.msgs = msgs
        self.adj_factor_idx = adj_factor_idx

    def tree_flatten(self):
        return (self.var_id, self.belief, self.msgs, self.adj_factor_idx), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)


@register_pytree_node_class
class Factor:
    factor_id: jnp.array
    z: jnp.ndarray
    z_Lam: jnp.ndarray
    threshold: jnp.ndarray
    potential: Gaussian
    adj_var_id: jnp.array
    adj_var_idx: jnp.array

    def __init__(
        self, factor_id, z, z_Lam, threshold, potential, adj_var_id, adj_var_idx
    ):
        self.factor_id = factor_id
        self.z = z
        self.z_Lam = z_Lam
        self.threshold = threshold
        self.potential = potential
        self.adj_var_id = adj_var_id
        self.adj_var_idx = adj_var_idx

    def tree_flatten(self):
        return (
            self.factor_id,
            self.z,
            self.z_Lam,
            self.threshold,
            self.potential,
            self.adj_var_id,
            self.adj_var_idx,
        ), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)


@partial(jax.jit, static_argnames=["i", "j"])
def marginalize(gaussians: Gaussian, i, j): # Equ. (46), (47); Compute msg to i:j Variables from connected factors
    eta = gaussians.eta
    Lam = gaussians.Lam
    k = eta.size
    idx = jnp.arange(0, k)
    aa = idx[i:j] # index from i to j-1
    bb = jnp.concat([idx[:i], idx[j:]]) # rest
    aa_eta = eta[aa]
    bb_eta = eta[bb]
    aa_Lam = Lam[aa[:, None], aa]
    ab_Lam = Lam[aa[:, None], bb]
    bb_Lam = Lam[bb][:, bb]
    if bb_Lam.size == 0:
        return Gaussian(aa_eta, aa_Lam)
    # print("How large? ", bb_Lam.shape)

    bb_Cov = jnp.linalg.inv(bb_Lam)
    eta = aa_eta - ab_Lam @ bb_Cov @ bb_eta
    Lam = aa_Lam - ab_Lam @ bb_Cov @ ab_Lam.T
    return Gaussian(eta, Lam)


def tree_stack(tree, axis=0, use_np=True):
    if use_np:
        return jax.tree.map(lambda *v: jnp.array(np.stack(v, axis=axis)), *tree)
    return jax.tree.map(lambda *v: jnp.stack(v, axis=axis), *tree)


def h_fn(x):
    return x


def h2_fn(x):
    
    x1 = x[0]
    x2 = x[1]
    return x2 - x1


@jax.jit
def update_belief(var: Variable, ftov_msgs): # Calculate Eq. (7)
    belief = var.belief.zero_like()
    for i in range(ftov_msgs.eta.shape[0]):
        belief = belief * Gaussian(ftov_msgs.eta[i], ftov_msgs.Lam[i])
    return belief


@jax.jit
def compute_vtof_msgs(var: Variable, ftov_msgs): # Eq.(19); do for each variable (x_m)
    vtof_msgs = []
    for i, idx in enumerate(var.adj_factor_idx): # for each f_si connected to x_m...
        msg = var.belief / Gaussian(ftov_msgs.eta[i], ftov_msgs.Lam[i]) # Eq.(19) LHS subscript of SUM
        eta = jnp.where(idx < 0, msg.zero_like().eta, msg.eta) # Those not connected should not affect the calculation (idx < 0)
        Lam = jnp.where(idx < 0, msg.zero_like().Lam, msg.Lam) # The reason to not using "if" (while it's per-element) is to optimize better
        vtof_msgs.append(Gaussian(eta, Lam)) # append (x_m -> f_si)
    return tree_stack(vtof_msgs, use_np=False) # [(x_m -> f_s1), (x_m -> f_s2), ... ] # The length is Ni_v


@partial(jax.jit, static_argnames=["h_fn"])
def factor_energy(factor, xs, h_fn):
    h = h_fn(xs)
    z = factor.z
    z_Lam = factor.z_Lam
    r = z - h
    return 0.5 * r @ z_Lam @ r.T


@partial(jax.jit, static_argnames=["h_fn", "w"])
def factor_update(factor, xs, h_fn, w):
    h = h_fn(xs)
    J = jax.jacrev(h_fn)(xs).reshape(h.size, xs.size) # Jacobian auto-diff (J_s)
    z = factor.z # I think this is a vector
    z_Lam = factor.z_Lam
    r = z - h.reshape(-1) # TODO: reshape can be problematic
    s = w(r.T @ z_Lam @ r, factor.threshold) # Scale to consider Robust Loss
    Lam = s * J.T @ z_Lam @ J # Eq. (36)
    eta = s * J.T @ z_Lam @ (J @ xs.reshape(-1) + r) # TODO: reshape can be problematic; Eq. (36); xs should be a vector
    return Gaussian(eta, Lam) # Factor; represented w.r.t. neighboring variables xs


@jax.jit
def compute_ftov_msg(factor, vtof_msgs): # Ch 3.5 Message Passing at a Factor Node
    N_adj, dim = vtof_msgs.eta.shape # #Ni_v * 3 (#Vars is vmapped)
    pot = factor.potential.copy() # log(f_s), but for only a specific variable a factor is connected to.
    i = 0
    for n in range(N_adj): # Add all! (Produce all)
        j = i + dim
        pot.eta = pot.eta.at[i:j].add(vtof_msgs.eta[n])
        pot.Lam = pot.Lam.at[i:j, i:j].add(vtof_msgs.Lam[n])
        i = j


    ftov_msgs = []
    i = 0
    for n in range(N_adj):
        j = i + dim
        pot_m_1 = pot.copy()
        pot_m_1.eta = pot_m_1.eta.at[i:j].add(-vtof_msgs.eta[n]) # Subtract direction of going out! (42)
        pot_m_1.Lam = pot_m_1.Lam.at[i:j, i:j].add(-vtof_msgs.Lam[n]) # (43)
        msg = marginalize(pot_m_1, i, j) # (46), (47)
        ftov_msgs.append(msg)
        i = j
    return tree_stack(ftov_msgs, use_np=False)


@jax.jit
def update_variable(vars): # Update belief with receiving msgs and calculate msg to factors; vars.msgs are up-to-date and vars.belief are not
    vars.belief = jax.vmap(update_belief)(vars, vars.msgs) # Eq. (7); vars.msgs is receiving msgs (ftov)
    vtof_msgs = jax.vmap(compute_vtof_msgs)(vars, vars.msgs) # Variable -> Factor Msg; Eq. (19)
    linpoints = jax.vmap(lambda x: x.mu())(vars.belief) # Current avg of belief! Belief is posterior
    return vars, vtof_msgs, linpoints # vtof msgs: # Var * # Var-direction (factor, Ni_v) msgs


@partial(jax.jit, static_argnames=["f", "w"])
def update_factor(facs, vars, vtof_msgs, linpoints, f, w): # f is factor function, w is robustifier
    vtof_msgs_reordered = jax.tree_util.tree_map( # Variable to factor messages to specific (variable, factor; or variable-direction) pair
        lambda x: x[facs.adj_var_id, facs.adj_var_idx], vtof_msgs # id: Variable id (one end), idx: direction (another end)
    )
    linpoints_reordered = jax.tree_util.tree_map(
        lambda x: x[facs.adj_var_id], linpoints # Reorder linpoints by adj_var_id: variables' mean for factors' one ends
    )
    facs.potential = jax.vmap(factor_update, in_axes=(0, 0, None, None))( # Calculate each factor potential (f_s(x, x_1, ..., x_M) of Eq. (15))
        facs, linpoints_reordered, f, w # Each factor contribution of variable-direction pair (factor: variable-direction pair)
    ) # 1 or 2-dimensional!! (gradient / prior factor or smoothness factor)
    ftov_msgs = jax.vmap(compute_ftov_msg)(facs, vtof_msgs_reordered) # ftov calculation by Eq. (15), with potential f_s, and msg vtof
    vars.msgs.eta = vars.msgs.eta.at[facs.adj_var_id, facs.adj_var_idx].set( # Setting vars' receiving messages
        ftov_msgs.eta
    )
    vars.msgs.Lam = vars.msgs.Lam.at[facs.adj_var_id, facs.adj_var_idx].set(
        ftov_msgs.Lam
    )
    vars.msgs.eta = vars.msgs.eta.at[:, -1].set(0) # Receiving messages via last port (invalid port) is zero.
    vars.msgs.Lam = vars.msgs.Lam.at[:, -1].set(0)

    return facs, vars


@jax.jit
def huber(e, t):
    x = jnp.sqrt(e)
    return jnp.where(x <= t, 1.0, t / x)


@jax.jit
def l2(e, _):
    return 1.0


In [36]:
import jax
import jaxlib

print("jax version:", jax.__version__)
print("jaxlib version:", jaxlib.__version__)
print("backend:", jax.default_backend())
print("devices:", jax.devices())


jax version: 0.6.2
jaxlib version: 0.6.2
backend: cpu
devices: [CpuDevice(id=0)]


In [37]:
# -----------------------
# SLAM-like base graph
# -----------------------
def make_slam_like_graph(N=100, step_size=25, loop_prob=0.05, loop_radius=50, prior_prop=0.0, seed=None):
    if seed is None :
        rng = np.random.default_rng()  # ✅ Ensure we have an RNG
    else:
        rng = np.random.default_rng(seed)
    nodes, edges = [], []
    positions = []
    x, y = 0.0, 0.0
    positions.append((x, y))

    # ✅ Deterministic-by-RNG: trajectory generation
    for _ in range(1, int(N)):
        dx, dy = rng.standard_normal(2)  # replace np.random.randn
        norm = np.sqrt(dx**2 + dy**2) + 1e-6
        dx, dy = dx / norm * float(step_size), dy / norm * float(step_size)
        x, y = x + dx, y + dy
        positions.append((x, y))

    # Sequential edges along the path
    for i, (px, py) in enumerate(positions):
        nodes.append({
            "data": {"id": f"{i}", "layer": 0, "dim": 2, "num_base": 1},
            "position": {"x": float(px), "y": float(py)}
        })

    for i in range(int(N) - 1):
        edges.append({"data": {"source": f"{i}", "target": f"{i+1}"}})

    # ✅ Deterministic-by-RNG: loop-closure edges
    for i in range(int(N)):
        for j in range(i + 5, int(N)):
            if rng.random() < float(loop_prob):  # replace np.random.rand
                xi, yi = positions[i]
                xj, yj = positions[j]
                if np.hypot(xi - xj, yi - yj) < float(loop_radius):
                    edges.append({"data": {"source": f"{i}", "target": f"{j}"}})

    # ✅ Sample priors using the same RNG
    if prior_prop <= 0.0:
        strong_ids = {0}
    elif prior_prop >= 1.0:
        strong_ids = set(range(N))
    else:
        k = max(1, int(np.floor(prior_prop * N)))
        strong_ids = set(rng.choice(N, size=k, replace=False).tolist())

    # Add edges for nodes with strong priors
    for i in strong_ids:
        edges.append({"data": {"source": f"{i}", "target": "prior"}})

    edges.append({"data": {"source": f"{0}", "target": "anchor"}}) 
    return nodes, edges

def build_posegraph_jax_from_nodes_edges(
    nodes,
    edges,
    prior_sigma: float = 10.0,
    odom_sigma: float = 10.0,
    tiny_prior: float = 1e-12,
    anchor_sigma: float = 1e-4,
    seed=None,
    device="cpu",
    robust_threshold: float = 1.0,  # 你的 grid 里 threshold=ones(...,1)，这里保持可调
):
    """
    JAX version matching build_noisy_pose_graph():
      - variables: 2D
      - between measurement: (xj - xi) + N(0, odom_sigma^2 I)
      - strong prior: z = GT + N(0, prior_sigma^2 I)
      - anchor: node 0 with sigma=anchor_sigma, z=GT (no noise)
      - tiny prior for all vars: Lam=tiny_prior*I, eta=0 (like your old code)
    Returns:
      vars, prior_facs, between_facs
    """

    # -------------------------
    # 1) parse GT positions
    # -------------------------
    N = len(nodes)
    D = 2
    GT = np.zeros((N, D), dtype=np.float32)
    for i, n in enumerate(nodes):
        GT[i, 0] = float(n["position"]["x"])
        GT[i, 1] = float(n["position"]["y"])

    # -------------------------
    # 2) split edges
    # -------------------------
    # old code: binary if dst not in {"prior","anchor"}
    binary_pairs = []
    prior_vars = []   # strong priors ("prior")
    has_anchor_edge = False

    for e in edges:
        src = e["data"]["source"]
        dst = e["data"]["target"]
        if (dst != "prior") and (dst != "anchor"):
            binary_pairs.append((int(src), int(dst)))
        elif dst == "prior":
            prior_vars.append(int(src))
        elif dst == "anchor":
            has_anchor_edge = True

    # old code always adds an anchor factor for v0 regardless; edges include "anchor" but we enforce anyway
    # anchor applies to node 0
    if not has_anchor_edge:
        # still apply anchor factor; we don't need to modify edges list
        pass

    E = len(binary_pairs)

    # -------------------------
    # 3) RNG + noises (match old: pre-generate per-edge and per-prior)
    # -------------------------
    rng = np.random.default_rng(seed)

    odom_noise = np.zeros((E, D), dtype=np.float32)
    for k, (i, j) in enumerate(binary_pairs):
        odom_noise[k] = rng.normal(0.0, odom_sigma, size=D).astype(np.float32)

    # strong prior noises
    prior_vars = list(dict.fromkeys(prior_vars))  # unique, stable
    P_strong = len(prior_vars)
    prior_noise = np.zeros((P_strong, D), dtype=np.float32)
    for k, v in enumerate(prior_vars):
        prior_noise[k] = rng.normal(0.0, prior_sigma, size=D).astype(np.float32)

    # -------------------------
    # 4) build factor arrays (unary priors = strong priors + anchor)
    # -------------------------
    # unary factors count
    P_anchor = 1
    P = P_strong + P_anchor

    # between factors
    z_b = np.zeros((E, D), dtype=np.float32)
    zLam_b = np.zeros((E, D, D), dtype=np.float32)

    inv_odom_var = 1.0 / (odom_sigma * odom_sigma)
    for k, (i, j) in enumerate(binary_pairs):
        z_b[k] = (GT[j] - GT[i]) + odom_noise[k]
        zLam_b[k] = np.eye(D, dtype=np.float32) * inv_odom_var

    # unary priors
    z_p = np.zeros((P, D), dtype=np.float32)
    zLam_p = np.zeros((P, D, D), dtype=np.float32)

    inv_prior_var = 1.0 / (prior_sigma * prior_sigma)
    for k, v in enumerate(prior_vars):
        z_p[k] = GT[v] + prior_noise[k]
        zLam_p[k] = np.eye(D, dtype=np.float32) * inv_prior_var

    # anchor factor at last
    z_p[P_strong] = GT[0]
    zLam_p[P_strong] = np.eye(D, dtype=np.float32) * (1.0 / (anchor_sigma * anchor_sigma))

    # -------------------------
    # 5) factor ids (must be global unique)
    #    We'll do: priors: [0..P-1], betweens: [P..P+E-1]
    # -------------------------
    prior_factor_id = np.arange(0, P, dtype=np.int32)
    between_factor_id = np.arange(P, P + E, dtype=np.int32)

    # -------------------------
    # 6) adjacency + ports
    # -------------------------
    # collect per-variable factor list in deterministic order:
    #  - if variable has strong prior: include its unary factor id
    #  - if variable is 0: include anchor unary factor id
    #  - then include all binary factor ids in edges order
    var_adj = [[] for _ in range(N)]

    # map from unary factor index -> variable
    # strong priors: factor k corresponds to prior_vars[k]
    for k, v in enumerate(prior_vars):
        fid = int(prior_factor_id[k])
        var_adj[v].append(fid)

    # anchor unary factor is prior_factor_id[P_strong], attached to v=0
    anchor_fid = int(prior_factor_id[P_strong])
    var_adj[0].append(anchor_fid)

    # binary factors: fid = between_factor_id[k], attached to i and j
    for k, (i, j) in enumerate(binary_pairs):
        fid = int(between_factor_id[k])
        var_adj[i].append(fid)
        var_adj[j].append(fid)

    max_deg = max(len(a) for a in var_adj) if N > 0 else 0
    Ni_v = max_deg + 1  # +dummy like your grid (last port invalid)
    adj_factor_idx = -np.ones((N, Ni_v), dtype=np.int32)

    # map (v, fid) -> port index
    port_of = {}
    for v in range(N):
        for p, fid in enumerate(var_adj[v]):
            adj_factor_idx[v, p] = fid
            port_of[(v, fid)] = p
        # last port stays -1

    # -------------------------
    # 7) build Factor.adj_var_id / adj_var_idx
    # -------------------------
    # priors: unary => shape (P,1)
    adj_var_id_p = np.zeros((P, 1), dtype=np.int32)
    adj_var_idx_p = np.zeros((P, 1), dtype=np.int32)

    for k, v in enumerate(prior_vars):
        fid = int(prior_factor_id[k])
        adj_var_id_p[k, 0] = v
        adj_var_idx_p[k, 0] = port_of[(v, fid)]

    # anchor
    adj_var_id_p[P_strong, 0] = 0
    adj_var_idx_p[P_strong, 0] = port_of[(0, anchor_fid)]

    # betweens: binary => shape (E,2)
    adj_var_id_b = np.zeros((E, 2), dtype=np.int32)
    adj_var_idx_b = np.zeros((E, 2), dtype=np.int32)

    for k, (i, j) in enumerate(binary_pairs):
        fid = int(between_factor_id[k])
        adj_var_id_b[k, 0] = i
        adj_var_id_b[k, 1] = j
        adj_var_idx_b[k, 0] = port_of[(i, fid)]
        adj_var_idx_b[k, 1] = port_of[(j, fid)]

    # -------------------------
    # 8) variables: belief, msgs, tiny prior
    # -------------------------
    belief = Gaussian(
        eta=jnp.zeros((N, D), dtype=jnp.float32),
        Lam=jnp.tile(jnp.eye(D, dtype=jnp.float32)[None, :, :], (N, 1, 1)),
    )

    # override Lam to tiny_prior * I, eta=0 (matches your old C++/python)
    belief = Gaussian(
        eta=jnp.zeros((N, D), dtype=jnp.float32),
        Lam=(tiny_prior * jnp.eye(D, dtype=jnp.float32))[None, :, :].repeat(N, axis=0),
    )

    msgs = Gaussian(
        eta=jnp.zeros((N, Ni_v, D), dtype=jnp.float32),
        Lam=jnp.zeros((N, Ni_v, D, D), dtype=jnp.float32),
    )

    vars = Variable(
        var_id=jnp.arange(N, dtype=jnp.int32),
        belief=belief,
        msgs=msgs,
        adj_factor_idx=jnp.asarray(adj_factor_idx),
    )

    # -------------------------
    # 9) factors
    # -------------------------
    thr_p = np.ones((P, 1), dtype=np.float32) * float(robust_threshold)
    thr_b = np.ones((E, 1), dtype=np.float32) * float(robust_threshold)

    # potential: your update_factor overwrites facs.potential, but Factor class requires it.
    # set zeros with correct shape:
    pot_p = Gaussian(
        eta=jnp.zeros((P, D), dtype=jnp.float32),
        Lam=jnp.zeros((P, D, D), dtype=jnp.float32),
    )
    pot_b = Gaussian(
        eta=jnp.zeros((E, 2 * D), dtype=jnp.float32),
        Lam=jnp.zeros((E, 2 * D, 2 * D), dtype=jnp.float32),
    )

    prior_facs = Factor(
        factor_id=jnp.asarray(prior_factor_id),
        z=jnp.asarray(z_p),
        z_Lam=jnp.asarray(zLam_p),
        threshold=jnp.asarray(thr_p),
        potential=pot_p,
        adj_var_id=jnp.asarray(adj_var_id_p),
        adj_var_idx=jnp.asarray(adj_var_idx_p),
    )

    between_facs = Factor(
        factor_id=jnp.asarray(between_factor_id),
        z=jnp.asarray(z_b),
        z_Lam=jnp.asarray(zLam_b),
        threshold=jnp.asarray(thr_b),
        potential=pot_b,
        adj_var_id=jnp.asarray(adj_var_id_b),
        adj_var_idx=jnp.asarray(adj_var_idx_b),
    )

    # -------------------------
    # 10) device put
    # -------------------------
    dev = jax.devices(device)[0]
    vars = jax.device_put(vars, dev)
    prior_facs = jax.device_put(prior_facs, dev)
    between_facs = jax.device_put(between_facs, dev)

    return vars, prior_facs, between_facs


from functools import partial
def one_step(state, _):
    vars, prior_facs, between_facs = state
    vars, vtof_msgs, linpoints = update_variable(vars)
    prior_facs, vars = update_factor(prior_facs, vars, vtof_msgs, linpoints, h_fn, l2)
    between_facs, vars = update_factor(between_facs, vars, vtof_msgs, linpoints, h2_fn, l2)
    return (vars, prior_facs, between_facs), None


@partial(jax.jit, static_argnames=("num_iters",))
def run_iters(vars, prior_facs, between_facs, num_iters: int):
    (vars, prior_facs, between_facs), _ = jax.lax.scan(
        one_step,
        (vars, prior_facs, between_facs),
        xs=None,
        length=num_iters
    )
    return vars, prior_facs, between_facs



@jax.jit
def energy_map_jax(vars: Variable, GT: jnp.ndarray):
    """
    Equivalent to old energy_map():
    sum_i 0.5 * ||mu_i - GT_i||^2
    """
    mu = jax.vmap(lambda b: b.mu())(vars.belief)   # (N, D)
    r = mu - GT                                    # (N, D)
    return 0.5 * jnp.sum(r * r)

In [46]:
N = 5000
step = 25
prob = 0.0
radius = 50
prior_prop = 0.02
prior_sigma = 1.0
odom_sigma = 1.0
seed = 2001

# -----------------------
# 1) nodes, edges
# -----------------------
nodes, edges = make_slam_like_graph(
    N=N,
    step_size=step,
    loop_prob=prob,
    loop_radius=radius,
    prior_prop=prior_prop,
    seed=seed
)

vars, prior_facs, between_facs = build_posegraph_jax_from_nodes_edges(
    nodes, edges,
    prior_sigma=prior_sigma,
    odom_sigma=odom_sigma,
    tiny_prior=1e-12,      # 跟你旧版一致
    anchor_sigma=1e-4,     # 跟你旧版一致
    seed=seed,
    device="cpu"
)



In [50]:
GT = jnp.asarray(
    np.array([[n["position"]["x"], n["position"]["y"]] for n in nodes], dtype=np.float32)
)
vars, prior_facs, between_facs = run_iters(vars, prior_facs, between_facs, num_iters=1000)
energy = energy_map_jax(vars, GT)
print(float(energy))



85003.796875


In [47]:
adj = np.array(vars.adj_factor_idx)      # (N, Ni_v)
deg = (adj >= 0).sum(axis=1)             # 每个变量的真实度数
print("Ni_v =", adj.shape[1])
print("max_deg =", deg.max(), "mean_deg =", deg.mean(), "p95_deg =", np.percentile(deg, 95))


Ni_v = 4
max_deg = 3 mean_deg = 2.0198 p95_deg = 2.0


In [48]:
import numpy as np

adj = np.array(vars.adj_factor_idx)   # (N, Ni_v)
deg = (adj >= 0).sum(axis=1)          # (N,)

max_deg = deg.max()
nodes_max = np.where(deg == max_deg)[0]

print("max_deg =", max_deg)
print("nodes with max_deg:", nodes_max)
v_max = int(nodes_max[0])
print("one max-deg node:", v_max)


max_deg = 3
nodes with max_deg: [  22   74   88  114  126  210  354  417  432  435  481  486  612  616
  628  694  699  712  795  817  824  840  871  873  899  943  991 1005
 1022 1064 1069 1125 1168 1207 1236 1249 1261 1266 1277 1311 1388 1477
 1480 1497 1529 1581 1740 1741 1848 1878 1884 1933 2030 2054 2167 2321
 2406 2732 2785 2793 2838 2900 2919 2924 2972 2991 2994 3031 3128 3210
 3251 3292 3333 3392 3404 3436 3441 3466 3513 3542 3596 3628 3677 3887
 3900 3907 4041 4069 4184 4227 4255 4388 4449 4459 4695 4809 4895 4897
 4928 4997]
one max-deg node: 22


In [49]:
v = v_max
cnt = 0
for e in edges:
    s = e["data"]["source"]
    t = e["data"]["target"]
    if t not in ("prior", "anchor"):
        i, j = int(s), int(t)
        if i == v or j == v:
            cnt += 1
    elif t == "prior" and int(s) == v:
        cnt += 1
    elif t == "anchor" and v == 0:
        cnt += 1

print("edge-count from edges list:", cnt)


edge-count from edges list: 3
