In [None]:
import os
import sys
newPath = os.path.dirname(os.path.abspath(""))
if newPath not in sys.path:
    sys.path.append(newPath)
from BI import bi

import numpy as np
import pandas as pd
import jax.numpy as jnp
import jax
import networkx as nx
m = bi(platform='cpu')
data_path = os.path.dirname(os.path.abspath("")) + "/BI/resources/data/"

jax.local_device_count 16


In [None]:
G = nx.karate_club_graph()
N = G.number_of_nodes()

# Get the adjacency matrix for JAX
adj_matrix_np = nx.to_numpy_array(G)
adj_matrix_jax = jnp.array(adj_matrix_np)

print(m.net.degree(adj_matrix_jax))
nx.degree(G)

[32 18 20 12  6  8  8  8 10  4  6  2  4 10  4  4  4  4  4  6  4  4  4 10
  6  6  4  8  6  8  8 12 24 34]


DegreeView({0: 16, 1: 9, 2: 10, 3: 6, 4: 3, 5: 4, 6: 4, 7: 4, 8: 5, 9: 2, 10: 3, 11: 1, 12: 2, 13: 5, 14: 2, 15: 2, 16: 2, 17: 2, 18: 2, 19: 3, 20: 2, 21: 2, 22: 2, 23: 5, 24: 3, 25: 3, 26: 2, 27: 4, 28: 3, 29: 4, 30: 4, 31: 6, 32: 12, 33: 17})

In [10]:
import jax
import jax.numpy as jnp
from jax import jit
import numpy as np
import networkx as nx

@jit
def _sssp_brandes_source(source_node, adj_matrix):
    """
    Computes dependency accumulation for a single source node using Brandes' algorithm.
    This version is fully JIT-compatible for unweighted graphs.
    """
    n_nodes = adj_matrix.shape[0]
    
    # --- Phase 1: BFS using a JAX-native loop ---
    # S stores the nodes at each distance level. Pre-allocated to max size.
    S = jnp.full((n_nodes, n_nodes), -1, dtype=jnp.int32)
    dist = jnp.full(n_nodes, -1, dtype=jnp.int32)
    sigma = jnp.zeros(n_nodes, dtype=jnp.float32)
    
    dist = dist.at[source_node].set(0)
    sigma = sigma.at[source_node].set(1.0)
    S = S.at[0, 0].set(source_node)
    
    # Loop state: (dist, sigma, S, number_of_nodes_at_current_level)
    initial_bfs_state = (dist, sigma, S, 1)

    def bfs_body_fn(d, state):
        dist, sigma, S, _ = state
        nodes_at_d = S[d, :]
        
        # Get all neighbors of all nodes at the current level in a vectorized way
        # This creates a boolean matrix: [num_nodes_at_d, n_nodes]
        neighbor_matrix = adj_matrix[nodes_at_d, :]
        
        # Summing columns gives the number of shortest paths from the previous level
        new_sigma = jnp.sum(neighbor_matrix * sigma[nodes_at_d, None], axis=0)

        # A node w has been reached for the first time if dist is -1 and it's a neighbor
        newly_reached = (dist == -1) & (jnp.sum(neighbor_matrix, axis=0) > 0)
        
        # Update distance and sigma for newly reached nodes
        dist = jnp.where(newly_reached, d + 1, dist)
        sigma = jnp.where(newly_reached, new_sigma, sigma)

        # Update sigma for nodes that were already on a shortest path
        already_reached = (dist == d + 1)
        sigma = jnp.where(already_reached, sigma + new_sigma, sigma)
        
        # Populate the next level of S
        next_nodes = jnp.where(dist == d + 1, jnp.arange(n_nodes), -1)
        # Use roll and slice to handle storing a variable number of nodes
        num_next = jnp.sum(dist == d + 1)
        S = S.at[d + 1].set(jnp.roll(jnp.sort(next_nodes), n_nodes - num_next))
        
        return dist, sigma, S, num_next

    # Run the BFS loop. It will stop producing new nodes after the graph diameter is reached.
    dist, sigma, S, _ = jax.lax.fori_loop(0, n_nodes - 1, bfs_body_fn, initial_bfs_state)
    num_levels = jnp.max(dist) + 1

    # --- Phase 2: Accumulation of dependencies ---
    dependency = jnp.zeros(n_nodes, dtype=jnp.float32)

    def accum_body_fn(i, dependency):
        d = num_levels - 1 - i
        # Get successors of nodes at level d-1
        nodes_at_prev_level = jnp.where(dist == d - 1, jnp.arange(n_nodes), -1)
        successors_matrix = adj_matrix[nodes_at_prev_level, :]
        
        # Calculate the credit to flow back from nodes at level d
        # credit = (sigma_predecessor / sigma_successor) * (1 + dependency_successor)
        with jnp.errstate(divide='ignore', invalid='ignore'): # Handle division by zero
            ratio = (1 + dependency) / sigma
        
        # Propagate credit backwards to predecessors
        credit = jnp.sum(successors_matrix * ratio[None, :], axis=1)
        
        # Add this credit to the dependency of the predecessors
        dependency = dependency.at[nodes_at_prev_level].add(sigma[nodes_at_prev_level] * credit)
        return dependency

    dependency = jax.lax.fori_loop(0, num_levels, accum_body_fn, dependency)
                        
    return dependency


def betweenness_centrality_jax(adj_matrix, normalized=True):
    """
    Compute betweenness centrality for an unweighted graph using a fully JIT-able approach.
    """
    n_nodes = adj_matrix.shape[0]
    if n_nodes <= 2:
        return jnp.zeros(n_nodes)
    
    # vmap the kernel over all possible source nodes for massive parallelism
    all_dependencies = jax.vmap(_sssp_brandes_source, in_axes=(0, None))(jnp.arange(n_nodes), adj_matrix)
    
    # Sum the results from each source node
    betweenness = jnp.sum(all_dependencies, axis=0)
        
    if normalized:
        # Brandes algorithm on undirected graphs counts each path twice, so we divide by 2.
        # This aligns the result with networkx's definition.
        scale = 1.0 / ((n_nodes - 1) * (n_nodes - 2))
        betweenness = betweenness * scale
    
    return betweenness

# --- Example Usage and Verification ---
# Create a sample graph
G = nx.karate_club_graph()
N = G.number_of_nodes()

# Get the binary adjacency matrix for JAX
adj_matrix_np = (nx.to_numpy_array(G) > 0).astype(float)
adj_matrix_jax = jnp.array(adj_matrix_np)

# Run the JAX implementation. The first run will be slow due to compilation.
print("Running JAX implementation...")
bc_jax = betweenness_centrality_jax(adj_matrix_jax, normalized=True)

# Run the NetworkX implementation for comparison
print("Running NetworkX implementation...")
bc_nx = nx.betweenness_centrality(G, normalized=True)
bc_nx_array = jnp.array([bc_nx[i] for i in range(N)])

# Compare the results
print("\nJAX results (first 5):", np.round(np.array(bc_jax[:5]), 4))
print("NX results (first 5): ", np.round(np.array(bc_nx_array[:5]), 4))
print("\nAre they close?", np.allclose(bc_jax, bc_nx_array, atol=1e-6))

Running JAX implementation...




TypeError: scan body function carry input and carry output must have equal types, but they differ:

The input carry component loop_carry[1][1] has type float32[34] but the corresponding output carry component has type float64[34], so the dtypes do not match.

Revise the function so that all output types match the corresponding input types.

In [19]:
import jax
import jax.numpy as jnp
from jax import jit
import numpy as np
import networkx as nx

# Use a consistent float type for all calculations
DTYPE = jnp.float64

@jit
def _sssp_brandes_source(source_node, adj_matrix):
    """
    Computes dependency accumulation for a single source node using Brandes' algorithm.
    This version is fully JIT-compatible for unweighted graphs.
    """
    n_nodes = adj_matrix.shape[0]
    
    # --- Phase 1: BFS using a JAX-native loop ---
    S = jnp.full((n_nodes, n_nodes), -1, dtype=jnp.int32)
    dist = jnp.full(n_nodes, -1, dtype=jnp.int32)
    # --- FIX: Initialize sigma with the consistent dtype ---
    sigma = jnp.zeros(n_nodes, dtype=DTYPE)
    
    dist = dist.at[source_node].set(0)
    sigma = sigma.at[source_node].set(1.0)
    S = S.at[0, 0].set(source_node)
    
    initial_bfs_state = (dist, sigma, S, 1)

    def bfs_body_fn(d, state):
        dist, sigma, S, _ = state
        nodes_at_d = S[d, :]
        
        neighbor_matrix = adj_matrix[nodes_at_d, :]
        new_sigma = jnp.sum(neighbor_matrix * sigma[nodes_at_d, None], axis=0)
        newly_reached = (dist == -1) & (jnp.sum(neighbor_matrix, axis=0) > 0)
        
        dist = jnp.where(newly_reached, d + 1, dist)
        sigma = jnp.where(newly_reached, new_sigma, sigma)

        already_reached = (dist == d + 1)
        sigma = jnp.where(already_reached, sigma + new_sigma, sigma)
        
        next_nodes = jnp.where(dist == d + 1, jnp.arange(n_nodes), -1)
        num_next = jnp.sum(dist == d + 1)
        S = S.at[d + 1].set(jnp.roll(jnp.sort(next_nodes), n_nodes - num_next))
        
        return dist, sigma, S, num_next

    dist, sigma, S, _ = jax.lax.fori_loop(0, n_nodes - 1, bfs_body_fn, initial_bfs_state)
    num_levels = jnp.max(dist) + 1

    # --- Phase 2: Accumulation of dependencies ---
    # --- FIX: Initialize dependency with the consistent dtype ---
    dependency = jnp.zeros(n_nodes, dtype=DTYPE)

    def accum_body_fn(i, dependency):
        d = num_levels - 1 - i
        nodes_at_prev_level = jnp.where(dist == d - 1, jnp.arange(n_nodes), -1)
        successors_matrix = adj_matrix[nodes_at_prev_level, :]
        
        with jnp.errstate(divide='ignore', invalid='ignore'):
            ratio = (1 + dependency) / sigma
        
        credit = jnp.sum(successors_matrix * ratio[None, :], axis=1)
        dependency = dependency.at[nodes_at_prev_level].add(sigma[nodes_at_prev_level] * credit)
        return dependency

    dependency = jax.lax.fori_loop(0, num_levels - 1, accum_body_fn, dependency)
                        
    return dependency

def betweenness_centrality_jax(adj_matrix, normalized=True):
    """
    Compute betweenness centrality for an unweighted graph using a fully JIT-able approach.
    """
    n_nodes = adj_matrix.shape[0]
    if n_nodes <= 2:
        return jnp.zeros(n_nodes)
    
    # Ensure matrix has the correct dtype
    adj_matrix = adj_matrix.astype(DTYPE)

    all_dependencies = jax.vmap(_sssp_brandes_source, in_axes=(0, None))(jnp.arange(n_nodes), adj_matrix)
    betweenness = jnp.sum(all_dependencies, axis=0)
        
    if normalized:
        scale = 1.0 / ((n_nodes - 1) * (n_nodes - 2))
        betweenness = betweenness * scale
    
    return betweenness

# --- Example Usage and Verification ---
G = nx.karate_club_graph()
N = G.number_of_nodes()

# --- FIX: Ensure the initial matrix uses the consistent dtype ---
adj_matrix_np = (nx.to_numpy_array(G) > 0).astype(np.float64)
adj_matrix_jax = jnp.array(adj_matrix_np)

print("Running JAX implementation...")
bc_jax = betweenness_centrality_jax(adj_matrix_jax, normalized=True)

print("Running NetworkX implementation...")
bc_nx = nx.betweenness_centrality(G, normalized=True)
bc_nx_array = jnp.array([bc_nx[i] for i in range(N)])

print("\nJAX results (first 5):", np.round(np.array(bc_jax[:5]), 4))
print("NX results (first 5): ", np.round(np.array(bc_nx_array[:5]), 4))
print("\nAre they close?", np.allclose(bc_jax, bc_nx_array, atol=1e-6))

Running JAX implementation...




AttributeError: module 'jax.numpy' has no attribute 'errstate'

In [None]:
import jax
import jax.numpy as jnp
from jax import random
from typing import Optional, Tuple, Dict
import numpy as np
from collections import deque

def betweenness_centrality_jax(
    adjacency_matrix: jnp.ndarray,
    k: Optional[int] = None,
    normalized: bool = True,
    weight_matrix: Optional[jnp.ndarray] = None,
    endpoints: bool = False,
    seed: int = 42
) -> jnp.ndarray:
    """
    Compute the shortest-path betweenness centrality for nodes using JAX.
    
    Betweenness centrality of a node v is the sum of the fraction of all-pairs 
    shortest paths that pass through v.
    
    Parameters
    ----------
    adjacency_matrix : jnp.ndarray
        Square adjacency matrix of shape (n, n) where n is the number of nodes.
        adjacency_matrix[i, j] = 1 if there's an edge from i to j, 0 otherwise.
    
    k : int, optional (default=None)
        If k is not None, use k node samples to estimate betweenness.
        Higher values give better approximation.
    
    normalized : bool, optional (default=True)
        If True, normalize betweenness values by the appropriate factor.
    
    weight_matrix : jnp.ndarray, optional (default=None)
        Weight matrix of same shape as adjacency_matrix.
        If None, all edges have weight 1.
    
    endpoints : bool, optional (default=False)
        If True, include endpoints in shortest path counts.
    
    seed : int, optional (default=42)
        Random seed for sampling nodes when k is specified.
    
    Returns
    -------
    betweenness : jnp.ndarray
        Array of betweenness centrality values for each node.
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Initialize betweenness centrality values
    betweenness = jnp.zeros(n_nodes)
    
    # Determine which nodes to use as sources
    if k is None or k >= n_nodes:
        source_nodes = jnp.arange(n_nodes)
    else:
        key = random.PRNGKey(seed)
        source_nodes = random.choice(key, n_nodes, (k,), replace=False)
    
    # Process each source node
    for s in source_nodes:
        if weight_matrix is None:
            # Unweighted shortest paths (BFS-style)
            S, P, sigma = _single_source_shortest_path_unweighted_jax(
                adjacency_matrix, s
            )
        else:
            # Weighted shortest paths (Dijkstra-style)
            S, P, sigma = _single_source_shortest_path_weighted_jax(
                adjacency_matrix, weight_matrix, s
            )
        
        # Accumulate betweenness centrality
        if endpoints:
            betweenness = _accumulate_endpoints_jax(betweenness, S, P, sigma, s)
        else:
            betweenness = _accumulate_basic_jax(betweenness, S, P, sigma, s)
    
    # Rescale the results
    betweenness = _rescale_jax(
        betweenness, n_nodes, normalized, k, endpoints, len(source_nodes)
    )
    
    return betweenness


def _single_source_shortest_path_unweighted_jax(
    adjacency_matrix: jnp.ndarray, 
    source: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Single-source shortest paths for unweighted graphs using BFS approach.
    
    Returns
    -------
    S : jnp.ndarray
        Nodes in order of non-increasing distance from source
    P : jnp.ndarray
        Predecessor matrix: P[v, u] = 1 if u is a predecessor of v
    sigma : jnp.ndarray
        Number of shortest paths from source to each node
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Initialize distances and predecessors
    dist = jnp.full(n_nodes, -1)
    dist = dist.at[source].set(0)
    sigma = jnp.zeros(n_nodes)
    sigma = sigma.at[source].set(1.0)
    
    # Store predecessors as lists (using a different approach)
    P = []
    for _ in range(n_nodes):
        P.append([])
    
    # BFS queue simulation
    current_layer = [source]
    current_dist = 0
    S = []  # Stack for nodes in order of discovery
    
    while current_layer:
        S.extend(current_layer)
        next_layer = []
        
        for v in current_layer:
            # Check all neighbors of v
            for w in range(n_nodes):
                if adjacency_matrix[v, w] > 0:  # There's an edge v -> w
                    # First time we see w
                    if dist[w] < 0:
                        dist = dist.at[w].set(current_dist + 1)
                        next_layer.append(w)
                        sigma = sigma.at[w].set(0.0)
                    
                    # If w is at the next level from v
                    if dist[w] == current_dist + 1:
                        sigma = sigma.at[w].add(sigma[v])
                        P[w].append(v)
        
        current_layer = next_layer
        current_dist += 1
    
    # Convert P to matrix format for compatibility
    P_matrix = jnp.zeros((n_nodes, n_nodes))
    for w in range(n_nodes):
        for v in P[w]:
            P_matrix = P_matrix.at[w, v].set(1.0)
    
    # S should be in reverse order for the algorithm
    S = jnp.array(S[::-1])
    
    return S, P_matrix, sigma


def _single_source_shortest_path_weighted_jax(
    adjacency_matrix: jnp.ndarray,
    weight_matrix: jnp.ndarray,
    source: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Single-source shortest paths for weighted graphs using Dijkstra's algorithm.
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Initialize distances and predecessors
    dist = jnp.full(n_nodes, jnp.inf)
    dist = dist.at[source].set(0.0)
    sigma = jnp.zeros(n_nodes)
    sigma = sigma.at[source].set(1.0)
    
    # Store predecessors as lists
    P = []
    for _ in range(n_nodes):
        P.append([])
    
    # Dijkstra's algorithm
    visited = jnp.zeros(n_nodes, dtype=bool)
    S = []  # Stack for nodes in order of completion
    
    for _ in range(n_nodes):
        # Find unvisited node with minimum distance
        unvisited_dist = jnp.where(visited, jnp.inf, dist)
        if jnp.all(jnp.isinf(unvisited_dist)):
            break
            
        u = int(jnp.argmin(unvisited_dist))
        
        if jnp.isinf(dist[u]):
            break
            
        visited = visited.at[u].set(True)
        S.append(u)
        
        # Update distances to neighbors
        for v in range(n_nodes):
            if adjacency_matrix[u, v] > 0 and not visited[v]:
                edge_weight = weight_matrix[u, v]
                new_dist = dist[u] + edge_weight
                
                # If we found a shorter path
                if dist[v] > new_dist:
                    dist = dist.at[v].set(new_dist)
                    sigma = sigma.at[v].set(sigma[u])
                    P[v] = [u]  # Reset predecessors
                
                # If we found an equally short path
                elif jnp.abs(dist[v] - new_dist) < 1e-10:
                    sigma = sigma.at[v].add(sigma[u])
                    P[v].append(u)
    
    # Convert P to matrix format
    P_matrix = jnp.zeros((n_nodes, n_nodes))
    for w in range(n_nodes):
        for v in P[w]:
            P_matrix = P_matrix.at[w, v].set(1.0)
    
    # S should be in reverse order
    S = jnp.array(S[::-1])
    
    return S, P_matrix, sigma


def _accumulate_basic_jax(
    betweenness: jnp.ndarray,
    S: jnp.ndarray,
    P: jnp.ndarray,
    sigma: jnp.ndarray,
    source: int
) -> jnp.ndarray:
    """
    Accumulate betweenness centrality (basic version without endpoints).
    """
    n_nodes = len(betweenness)
    delta = jnp.zeros(n_nodes)
    
    # Process nodes in reverse topological order (S is already reversed)
    for i in range(len(S)):
        w = int(S[i])
        if w == source:
            continue
            
        # Sum over predecessors
        predecessors = P[w] > 0
        for v in range(n_nodes):
            if predecessors[v] and sigma[w] > 0:
                coeff = (sigma[v] / sigma[w]) * (1.0 + delta[w])
                delta = delta.at[v].add(coeff)
        
        betweenness = betweenness.at[w].add(delta[w])
    
    return betweenness


def _accumulate_endpoints_jax(
    betweenness: jnp.ndarray,
    S: jnp.ndarray,
    P: jnp.ndarray,
    sigma: jnp.ndarray,
    source: int
) -> jnp.ndarray:
    """
    Accumulate betweenness centrality (with endpoints).
    """
    n_nodes = len(betweenness)
    delta = jnp.zeros(n_nodes)
    
    # Add contribution for source node
    betweenness = betweenness.at[source].add(len(S) - 1)
    
    # Process nodes in reverse topological order
    for i in reversed(range(len(S))):
        w = S[i]
        if w == source:
            continue
            
        # Add contribution for this endpoint
        betweenness = betweenness.at[w].add(delta[w] + 1)
        
        # Sum over predecessors
        predecessors = P[w] > 0
        for v in range(n_nodes):
            if predecessors[v]:
                coeff = (sigma[v] / sigma[w]) * (1.0 + delta[w])
                delta = delta.at[v].add(coeff)
    
    return betweenness


def _rescale_jax(
    betweenness: jnp.ndarray,
    n_nodes: int,
    normalized: bool,
    k: Optional[int],
    endpoints: bool,
    n_sampled: int
) -> jnp.ndarray:
    """
    Rescale betweenness centrality values.
    """
    if not normalized:
        return betweenness
    
    if n_nodes <= 2:
        return betweenness
    
    # Normalization factor
    if endpoints:
        # Include endpoints in normalization
        scale = 1.0 / ((n_nodes - 1) * (n_nodes - 2))
    else:
        # Standard normalization
        scale = 2.0 / ((n_nodes - 1) * (n_nodes - 2))
    
    # Adjust for sampling
    if k is not None and n_sampled < n_nodes:
        scale *= n_nodes / n_sampled
    
    return betweenness * scale


# Example usage and testing functions
def create_test_graph() -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Create a simple test graph for demonstration."""
    # Create a simple path graph: 0-1-2-3-4
    n = 5
    adj = jnp.zeros((n, n))
    for i in range(n-1):
        adj = adj.at[i, i+1].set(1.0)
        adj = adj.at[i+1, i].set(1.0)  # Make undirected
    
    # Create weight matrix (all weights = 1)
    weights = jnp.where(adj > 0, 1.0, 0.0)
    
    return adj, weights


def test_betweenness_centrality():
    """Test the betweenness centrality implementation."""
    adj, weights = create_test_graph()
    
    # Test unweighted
    bc_unweighted = betweenness_centrality_jax(adj, normalized=True)
    print("Unweighted betweenness centrality:")
    for i, val in enumerate(bc_unweighted):
        print(f"Node {i}: {val:.6f}")
    
    # Test weighted
    bc_weighted = betweenness_centrality_jax(adj, weight_matrix=weights, normalized=True)
    print("\nWeighted betweenness centrality:")
    for i, val in enumerate(bc_weighted):
        print(f"Node {i}: {val:.6f}")
    
    # Test with sampling
    bc_sampled = betweenness_centrality_jax(adj, k=3, normalized=True, seed=42)
    print("\nSampled betweenness centrality (k=3):")
    for i, val in enumerate(bc_sampled):
        print(f"Node {i}: {val:.6f}")
    
    # Test with endpoints
    bc_endpoints = betweenness_centrality_jax(adj, endpoints=True, normalized=True)
    print("\nBetweenness centrality with endpoints:")
    for i, val in enumerate(bc_endpoints):
        print(f"Node {i}: {val:.6f}")
    
    # Compare with expected NetworkX results for path graph 0-1-2-3-4
    expected = [0.0, 0.5, 0.6666666666666666, 0.5, 0.0]
    print("\nExpected NetworkX results:")
    for i, val in enumerate(expected):
        print(f"Node {i}: {val:.6f}")
    
    print("\nDifference from expected for unweigthed:")
    for i, (got, exp) in enumerate(zip(bc_unweighted/2, expected)):
        print(f"Node {i}: {abs(got - exp):.6f}")

    print("\nDifference from expected for weigthed:")
    for i, (got, exp) in enumerate(zip(bc_weighted/2, expected)):
        print(f"Node {i}: {abs(got - exp):.6f}")

# Additional debugging function
def debug_single_source(adj, source=0):
    """Debug single source shortest paths."""
    print(f"\nDebugging single source from node {source}:")
    S, P, sigma = _single_source_shortest_path_unweighted_jax(adj, source)
    
    print(f"S (traversal order): {S}")
    print(f"Sigma (path counts): {sigma}")
    print(f"Predecessors matrix P:")
    for i in range(len(P)):
        preds = [j for j in range(len(P)) if P[i, j] > 0]
        print(f"  Node {i}: predecessors = {preds}")
    
    return S, P, sigma

In [29]:
test_betweenness_centrality()

Unweighted betweenness centrality:
Node 0: 0.000000
Node 1: 1.000000
Node 2: 1.333333
Node 3: 1.000000
Node 4: 0.000000

Weighted betweenness centrality:
Node 0: 0.000000
Node 1: 1.000000
Node 2: 1.333333
Node 3: 1.000000
Node 4: 0.000000

Sampled betweenness centrality (k=3):
Node 0: 0.000000
Node 1: 0.833333
Node 2: 1.111111
Node 3: 1.111111
Node 4: 0.000000

Betweenness centrality with endpoints:
Node 0: 0.666667
Node 1: 0.666667
Node 2: 0.666667
Node 3: 0.666667
Node 4: 0.666667

Expected NetworkX results:
Node 0: 0.000000
Node 1: 0.500000
Node 2: 0.666667
Node 3: 0.500000
Node 4: 0.000000

Difference from expected:
Node 0: 0.000000
Node 1: 0.500000
Node 2: 0.666667
Node 3: 0.500000
Node 4: 0.000000


In [44]:
adj, weights = create_test_graph()
print(betweenness_centrality_jax(adj, normalized=True)/2)
# convert adj to networkx graph and compute betweenness centrality
G = nx.from_numpy_array(np.array(adj))

print(betweenness_centrality_jax(adj, weight_matrix=weights, normalized=True)/2)

nx.betweenness_centrality(G, normalized=True, weight=None)

[0.         0.5        0.66666667 0.5        0.        ]
[0.         0.5        0.66666667 0.5        0.        ]


{0: 0.0, 1: 0.5, 2: 0.6666666666666666, 3: 0.5, 4: 0.0}

In [45]:
import jax
import jax.numpy as jnp
from jax import random
from typing import Optional, Tuple, Dict
import numpy as np
from collections import deque

def betweenness_centrality_jax(
    adjacency_matrix: jnp.ndarray,
    k: Optional[int] = None,
    normalized: bool = True,
    weight_matrix: Optional[jnp.ndarray] = None,
    endpoints: bool = False,
    seed: int = 42,
    directed: bool = False
) -> jnp.ndarray:
    """
    Compute the shortest-path betweenness centrality for nodes using JAX.
    
    Betweenness centrality of a node v is the sum of the fraction of all-pairs 
    shortest paths that pass through v.
    
    Parameters
    ----------
    adjacency_matrix : jnp.ndarray
        Square adjacency matrix of shape (n, n) where n is the number of nodes.
        adjacency_matrix[i, j] = 1 if there's an edge from i to j, 0 otherwise.
    
    k : int, optional (default=None)
        If k is not None, use k node samples to estimate betweenness.
        Higher values give better approximation.
    
    normalized : bool, optional (default=True)
        If True, normalize betweenness values by the appropriate factor.
    
    weight_matrix : jnp.ndarray, optional (default=None)
        Weight matrix of same shape as adjacency_matrix.
        If None, all edges have weight 1.
    
    endpoints : bool, optional (default=False)
        If True, include endpoints in shortest path counts.
    
    seed : int, optional (default=42)
        Random seed for sampling nodes when k is specified.
    
    directed : bool, optional (default=False)
        If True, treat the graph as directed. If False, treat as undirected.
    
    Returns
    -------
    betweenness : jnp.ndarray
        Array of betweenness centrality values for each node.
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Initialize betweenness centrality values
    betweenness = jnp.zeros(n_nodes)
    
    # Determine which nodes to use as sources
    if k is None or k >= n_nodes:
        source_nodes = jnp.arange(n_nodes)
    else:
        key = random.PRNGKey(seed)
        source_nodes = random.choice(key, n_nodes, (k,), replace=False)
    
    # Process each source node
    for s in source_nodes:
        if weight_matrix is None:
            # Unweighted shortest paths (BFS-style)
            S, P, sigma = _single_source_shortest_path_unweighted_jax(
                adjacency_matrix, s
            )
        else:
            # Weighted shortest paths (Dijkstra-style)
            S, P, sigma = _single_source_shortest_path_weighted_jax(
                adjacency_matrix, weight_matrix, s
            )
        
        # Accumulate betweenness centrality
        if endpoints:
            betweenness = _accumulate_endpoints_jax(betweenness, S, P, sigma, s)
        else:
            betweenness = _accumulate_basic_jax(betweenness, S, P, sigma, s)
    
    # Rescale the results
    betweenness = _rescale_jax(
        betweenness, n_nodes, normalized, k, endpoints, len(source_nodes), directed
    )
    
    return betweenness


def _single_source_shortest_path_unweighted_jax(
    adjacency_matrix: jnp.ndarray, 
    source: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Single-source shortest paths for unweighted graphs using BFS approach.
    
    Returns
    -------
    S : jnp.ndarray
        Nodes in order of non-increasing distance from source
    P : jnp.ndarray
        Predecessor matrix: P[v, u] = 1 if u is a predecessor of v
    sigma : jnp.ndarray
        Number of shortest paths from source to each node
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Initialize distances and predecessors
    dist = jnp.full(n_nodes, -1)
    dist = dist.at[source].set(0)
    sigma = jnp.zeros(n_nodes)
    sigma = sigma.at[source].set(1.0)
    
    # Store predecessors as lists (using a different approach)
    P = []
    for _ in range(n_nodes):
        P.append([])
    
    # BFS queue simulation
    current_layer = [source]
    current_dist = 0
    S = []  # Stack for nodes in order of discovery
    
    while current_layer:
        S.extend(current_layer)
        next_layer = []
        
        for v in current_layer:
            # Check all neighbors of v
            for w in range(n_nodes):
                if adjacency_matrix[v, w] > 0:  # There's an edge v -> w
                    # First time we see w
                    if dist[w] < 0:
                        dist = dist.at[w].set(current_dist + 1)
                        next_layer.append(w)
                        sigma = sigma.at[w].set(0.0)
                    
                    # If w is at the next level from v
                    if dist[w] == current_dist + 1:
                        sigma = sigma.at[w].add(sigma[v])
                        P[w].append(v)
        
        current_layer = next_layer
        current_dist += 1
    
    # Convert P to matrix format for compatibility
    P_matrix = jnp.zeros((n_nodes, n_nodes))
    for w in range(n_nodes):
        for v in P[w]:
            P_matrix = P_matrix.at[w, v].set(1.0)
    
    # S should be in reverse order for the algorithm
    S = jnp.array(S[::-1])
    
    return S, P_matrix, sigma


def _single_source_shortest_path_weighted_jax(
    adjacency_matrix: jnp.ndarray,
    weight_matrix: jnp.ndarray,
    source: int
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Single-source shortest paths for weighted graphs using Dijkstra's algorithm.
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Initialize distances and predecessors
    dist = jnp.full(n_nodes, jnp.inf)
    dist = dist.at[source].set(0.0)
    sigma = jnp.zeros(n_nodes)
    sigma = sigma.at[source].set(1.0)
    
    # Store predecessors as lists
    P = []
    for _ in range(n_nodes):
        P.append([])
    
    # Dijkstra's algorithm
    visited = jnp.zeros(n_nodes, dtype=bool)
    S = []  # Stack for nodes in order of completion
    
    for _ in range(n_nodes):
        # Find unvisited node with minimum distance
        unvisited_dist = jnp.where(visited, jnp.inf, dist)
        if jnp.all(jnp.isinf(unvisited_dist)):
            break
            
        u = int(jnp.argmin(unvisited_dist))
        
        if jnp.isinf(dist[u]):
            break
            
        visited = visited.at[u].set(True)
        S.append(u)
        
        # Update distances to neighbors
        for v in range(n_nodes):
            if adjacency_matrix[u, v] > 0 and not visited[v]:
                edge_weight = weight_matrix[u, v]
                new_dist = dist[u] + edge_weight
                
                # If we found a shorter path
                if dist[v] > new_dist:
                    dist = dist.at[v].set(new_dist)
                    sigma = sigma.at[v].set(sigma[u])
                    P[v] = [u]  # Reset predecessors
                
                # If we found an equally short path
                elif jnp.abs(dist[v] - new_dist) < 1e-10:
                    sigma = sigma.at[v].add(sigma[u])
                    P[v].append(u)
    
    # Convert P to matrix format
    P_matrix = jnp.zeros((n_nodes, n_nodes))
    for w in range(n_nodes):
        for v in P[w]:
            P_matrix = P_matrix.at[w, v].set(1.0)
    
    # S should be in reverse order
    S = jnp.array(S[::-1])
    
    return S, P_matrix, sigma


def _accumulate_basic_jax(
    betweenness: jnp.ndarray,
    S: jnp.ndarray,
    P: jnp.ndarray,
    sigma: jnp.ndarray,
    source: int
) -> jnp.ndarray:
    """
    Accumulate betweenness centrality (basic version without endpoints).
    """
    n_nodes = len(betweenness)
    delta = jnp.zeros(n_nodes)
    
    # Process nodes in reverse topological order (S is already reversed)
    for i in range(len(S)):
        w = int(S[i])
        if w == source:
            continue
            
        # Sum over predecessors
        predecessors = P[w] > 0
        for v in range(n_nodes):
            if predecessors[v] and sigma[w] > 0:
                coeff = (sigma[v] / sigma[w]) * (1.0 + delta[w])
                delta = delta.at[v].add(coeff)
        
        betweenness = betweenness.at[w].add(delta[w])
    
    return betweenness


def _accumulate_endpoints_jax(
    betweenness: jnp.ndarray,
    S: jnp.ndarray,
    P: jnp.ndarray,
    sigma: jnp.ndarray,
    source: int
) -> jnp.ndarray:
    """
    Accumulate betweenness centrality (with endpoints).
    """
    n_nodes = len(betweenness)
    delta = jnp.zeros(n_nodes)
    
    # Add contribution for source node
    betweenness = betweenness.at[source].add(len(S) - 1)
    
    # Process nodes in reverse topological order
    for i in reversed(range(len(S))):
        w = S[i]
        if w == source:
            continue
            
        # Add contribution for this endpoint
        betweenness = betweenness.at[w].add(delta[w] + 1)
        
        # Sum over predecessors
        predecessors = P[w] > 0
        for v in range(n_nodes):
            if predecessors[v]:
                coeff = (sigma[v] / sigma[w]) * (1.0 + delta[w])
                delta = delta.at[v].add(coeff)
    
    return betweenness


def _rescale_jax(
    betweenness: jnp.ndarray,
    n_nodes: int,
    normalized: bool,
    k: Optional[int],
    endpoints: bool,
    n_sampled: int,
    directed: bool = False
) -> jnp.ndarray:
    """
    Rescale betweenness centrality values.
    """
    if not normalized:
        return betweenness
    
    if n_nodes <= 2:
        return betweenness
    
    # Normalization factor based on NetworkX implementation
    if endpoints:
        # Include endpoints in normalization
        scale = 1.0 / ((n_nodes - 1) * (n_nodes - 2))
    else:
        # Standard normalization
        if directed:
            # For directed graphs: 1/((n-1)(n-2))
            scale = 1.0 / ((n_nodes - 1) * (n_nodes - 2))
        else:
            # For undirected graphs: 2/((n-1)(n-2))
            scale = 2.0 / ((n_nodes - 1) * (n_nodes - 2))
    
    # Adjust for sampling
    if k is not None and n_sampled < n_nodes:
        scale *= n_nodes / n_sampled
    
    return betweenness * scale


# Example usage and testing functions
def create_test_graph() -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Create a simple test graph for demonstration."""
    # Create a simple path graph: 0-1-2-3-4
    n = 5
    adj = jnp.zeros((n, n))
    for i in range(n-1):
        adj = adj.at[i, i+1].set(1.0)
        adj = adj.at[i+1, i].set(1.0)  # Make undirected
    
    # Create weight matrix (all weights = 1)
    weights = jnp.where(adj > 0, 1.0, 0.0)
    
    return adj, weights


def test_betweenness_centrality():
    """Test the betweenness centrality implementation."""
    adj, weights = create_test_graph()
    
    # Test unweighted (undirected graph)
    bc_unweighted = betweenness_centrality_jax(adj, normalized=True, directed=False)
    print("Unweighted betweenness centrality:")
    for i, val in enumerate(bc_unweighted):
        print(f"Node {i}: {val:.6f}")
    
    # Test weighted (undirected graph)
    bc_weighted = betweenness_centrality_jax(adj, weight_matrix=weights, normalized=True, directed=False)
    print("\nWeighted betweenness centrality:")
    for i, val in enumerate(bc_weighted):
        print(f"Node {i}: {val:.6f}")
    
    # Test with sampling
    bc_sampled = betweenness_centrality_jax(adj, k=3, normalized=True, seed=42, directed=False)
    print("\nSampled betweenness centrality (k=3):")
    for i, val in enumerate(bc_sampled):
        print(f"Node {i}: {val:.6f}")
    
    # Test with endpoints
    bc_endpoints = betweenness_centrality_jax(adj, endpoints=True, normalized=True, directed=False)
    print("\nBetweenness centrality with endpoints:")
    for i, val in enumerate(bc_endpoints):
        print(f"Node {i}: {val:.6f}")
    
    # Compare with expected NetworkX results for path graph 0-1-2-3-4
    expected = [0.0, 0.5, 0.6666666666666666, 0.5, 0.0]
    print("\nExpected NetworkX results:")
    for i, val in enumerate(expected):
        print(f"Node {i}: {val:.6f}")
    
    print("\nDifference from expected:")
    for i, (got, exp) in enumerate(zip(bc_unweighted, expected)):
        print(f"Node {i}: {abs(got - exp):.6f}")


# Additional debugging function
def debug_single_source(adj, source=0):
    """Debug single source shortest paths."""
    print(f"\nDebugging single source from node {source}:")
    S, P, sigma = _single_source_shortest_path_unweighted_jax(adj, source)
    
    print(f"S (traversal order): {S}")
    print(f"Sigma (path counts): {sigma}")
    print(f"Predecessors matrix P:")
    for i in range(len(P)):
        preds = [j for j in range(len(P)) if P[i, j] > 0]
        print(f"  Node {i}: predecessors = {preds}")
    
    return S, P, sigma

In [53]:
# Create a sample graph
G = nx.karate_club_graph()
N = G.number_of_nodes()

# Get the adjacency matrix for JAX
adj_matrix_np = nx.to_numpy_array(G)
adj_matrix_jax = jnp.array(adj_matrix_np)

bc_unweighted = betweenness_centrality_jax(adj_matrix_jax, normalized=True, directed=False)
bc_weighted = betweenness_centrality_jax(adj_matrix_jax, weight_matrix=adj_matrix_jax, normalized=True, directed=False)

In [58]:
bc_unweighted/2

Array([0.43763528, 0.05393669, 0.14365681, 0.01190927, 0.00063131,
       0.02998737, 0.02998737, 0.        , 0.05592683, 0.00084776,
       0.00063131, 0.        , 0.        , 0.0458634 , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.03247505,
       0.        , 0.        , 0.        , 0.01761364, 0.0022096 ,
       0.00384049, 0.        , 0.02233345, 0.00179473, 0.00292208,
       0.01441198, 0.13827561, 0.14524711, 0.30407498], dtype=float64)

In [54]:
nx.betweenness_centrality(G, normalized=True, weight=None)


{0: 0.43763528138528146,
 1: 0.053936688311688304,
 2: 0.14365680615680618,
 3: 0.011909271284271283,
 4: 0.0006313131313131313,
 5: 0.02998737373737374,
 6: 0.029987373737373736,
 7: 0.0,
 8: 0.05592682780182781,
 9: 0.0008477633477633478,
 10: 0.0006313131313131313,
 11: 0.0,
 12: 0.0,
 13: 0.04586339586339586,
 14: 0.0,
 15: 0.0,
 16: 0.0,
 17: 0.0,
 18: 0.0,
 19: 0.03247504810004811,
 20: 0.0,
 21: 0.0,
 22: 0.0,
 23: 0.017613636363636363,
 24: 0.0022095959595959595,
 25: 0.0038404882154882154,
 26: 0.0,
 27: 0.02233345358345358,
 28: 0.0017947330447330447,
 29: 0.0029220779220779218,
 30: 0.014411976911976909,
 31: 0.13827561327561325,
 32: 0.145247113997114,
 33: 0.30407497594997596}

In [60]:
bc_weighted/2

Array([0.47376894, 0.06401515, 0.06941288, 0.00252525, 0.00094697,
       0.02935606, 0.02935606, 0.        , 0.02481061, 0.01379419,
       0.00094697, 0.        , 0.        , 0.00227273, 0.        ,
       0.        , 0.        , 0.03049242, 0.00568182, 0.24065657,
       0.        , 0.        , 0.        , 0.00189394, 0.06407828,
       0.00094697, 0.        , 0.01231061, 0.01912879, 0.        ,
       0.00568182, 0.12563131, 0.07222222, 0.3967803 ], dtype=float64)

In [55]:
nx.betweenness_centrality(G, normalized=True, weight='weight')

{0: 0.4737689393939393,
 1: 0.06401515151515152,
 2: 0.0694128787878788,
 3: 0.0025252525252525255,
 4: 0.000946969696969697,
 5: 0.029356060606060608,
 6: 0.029356060606060608,
 7: 0.0,
 8: 0.02481060606060606,
 9: 0.01379419191919192,
 10: 0.000946969696969697,
 11: 0.0,
 12: 0.0,
 13: 0.0022727272727272726,
 14: 0.0,
 15: 0.0,
 16: 0.0,
 17: 0.030492424242424244,
 18: 0.005681818181818182,
 19: 0.24065656565656565,
 20: 0.0,
 21: 0.0,
 22: 0.0,
 23: 0.001893939393939394,
 24: 0.06407828282828282,
 25: 0.000946969696969697,
 26: 0.0,
 27: 0.01231060606060606,
 28: 0.019128787878787877,
 29: 0.0,
 30: 0.005681818181818182,
 31: 0.12563131313131315,
 32: 0.07222222222222223,
 33: 0.3967803030303029}

In [64]:
import jax
import jax.numpy as jnp
from jax import random, vmap, lax, jit
from typing import Optional, Tuple
import functools

def betweenness_centrality_jax_optimized(
    adjacency_matrix: jnp.ndarray,
    k: Optional[int] = None,
    normalized: bool = True,
    weight_matrix: Optional[jnp.ndarray] = None,
    endpoints: bool = False,
    seed: int = 42,
    directed: bool = False
) -> jnp.ndarray:
    """
    Optimized JAX implementation of betweenness centrality using vmap and lax operations.
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Determine source nodes
    if k is None or k >= n_nodes:
        source_nodes = jnp.arange(n_nodes)
    else:
        key = random.PRNGKey(seed)
        source_nodes = random.choice(key, n_nodes, (k,), replace=False)
    
    # Choose algorithm based on whether weights are provided
    if weight_matrix is None:
        compute_single_source = functools.partial(
            _single_source_bfs_optimized, 
            adjacency_matrix
        )
    else:
        compute_single_source = functools.partial(
            _single_source_dijkstra_optimized,
            adjacency_matrix,
            weight_matrix
        )
    
    # Vectorize over all source nodes
    vectorized_compute = vmap(compute_single_source)
    all_contributions = vectorized_compute(source_nodes)
    
    # Sum contributions from all sources
    summed_betweenness = jnp.sum(all_contributions, axis=0)
    
    # Brandes' algorithm counts each path twice in an undirected graph.
    # The convention is to divide by 2.
    if not directed:
        summed_betweenness /= 2.0
    
    # Apply normalization
    betweenness = _rescale_optimized(
        summed_betweenness, n_nodes, normalized, k, endpoints, 
        len(source_nodes), directed
    )
    
    return betweenness


def _single_source_bfs_optimized(adjacency_matrix: jnp.ndarray, source: int) -> jnp.ndarray:
    """
    Optimized single-source BFS using JAX control flow primitives.
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Initialize state
    initial_distances = jnp.full(n_nodes, -1, dtype=jnp.int32).at[source].set(0)
    initial_sigma = jnp.zeros(n_nodes, dtype=jnp.float32).at[source].set(1.0)
    initial_P = jnp.zeros((n_nodes, n_nodes), dtype=jnp.float32)
    initial_layer = jnp.zeros(n_nodes, dtype=bool).at[source].set(True)
    initial_dist = 0

    initial_state = (initial_layer, initial_distances, initial_sigma, initial_P, initial_dist)
    
    def bfs_condition(state):
        current_layer, _, _, _, _ = state
        return jnp.any(current_layer)

    def bfs_body(state):
        current_layer, distances, sigma, P, current_dist = state
        
        # Find neighbors of the current layer
        neighbors = jnp.dot(current_layer.astype(jnp.float32), adjacency_matrix) > 0
        
        # Newly discovered nodes are unvisited neighbors
        newly_discovered = neighbors & (distances == -1)
        new_distances = jnp.where(newly_discovered, current_dist + 1, distances)
        
        # For all nodes at the next level, update their sigma by summing from predecessors
        path_to_next_level = (new_distances == current_dist + 1)
        sigma_from_preds = jnp.dot(adjacency_matrix.T, jnp.where(current_layer, sigma, 0.0))
        new_sigma = jnp.where(path_to_next_level, sigma + sigma_from_preds, sigma)
        
        # Update predecessor matrix for paths to the next level
        pred_updates = jnp.outer(path_to_next_level, current_layer) * adjacency_matrix
        new_P = P + pred_updates.T

        return (newly_discovered, new_distances, new_sigma, new_P, current_dist + 1)

    # Run BFS
    _, final_distances, final_sigma, final_P, _ = lax.while_loop(
        bfs_condition, bfs_body, initial_state
    )
    
    # The stack S is the set of nodes sorted by INCREASING distance.
    # The accumulation loop will iterate backwards over this stack.
    S = jnp.argsort(final_distances)
    
    # Compute betweenness contribution
    betweenness_contrib = _accumulate_dependencies_optimized(
        S, final_P, final_sigma, source
    )
    
    return betweenness_contrib


def _single_source_dijkstra_optimized(
    adjacency_matrix: jnp.ndarray, 
    weight_matrix: jnp.ndarray, 
    source: int
) -> jnp.ndarray:
    """
    Dijkstra's algorithm using lax.fori_loop.
    """
    n_nodes = adjacency_matrix.shape[0]
    
    distances = jnp.full(n_nodes, jnp.inf).at[source].set(0.0)
    sigma = jnp.zeros(n_nodes).at[source].set(1.0)
    P = jnp.zeros((n_nodes, n_nodes))
    S = jnp.full(n_nodes, -1, dtype=jnp.int32)
    visited = jnp.zeros(n_nodes, dtype=bool)

    def dijkstra_body(i, state):
        distances, sigma, visited, P, S = state
        
        unvisited_dist = jnp.where(visited, jnp.inf, distances)
        u = jnp.argmin(unvisited_dist)
        
        S = S.at[i].set(u)
        visited = visited.at[u].set(True)
        
        edge_weights = jnp.where(adjacency_matrix[u] > 0, weight_matrix[u], jnp.inf)
        new_dist_via_u = distances[u] + edge_weights
        
        is_shorter = new_dist_via_u < distances
        is_equal = jnp.isclose(new_dist_via_u, distances)

        new_distances = jnp.where(is_shorter, new_dist_via_u, distances)
        
        reset_sigma = jnp.where(is_shorter, sigma[u], sigma)
        new_sigma = jnp.where(is_equal, reset_sigma + sigma[u], reset_sigma)
        
        reset_P_col = jnp.zeros(n_nodes)
        updated_P_col = jnp.where(is_shorter, reset_P_col.at[u].set(1), P[u])
        updated_P_col = jnp.where(is_equal, updated_P_col.at[u].set(1), updated_P_col)
        new_P = P.at[u].set(updated_P_col)

        return new_distances, new_sigma, visited, new_P, S

    initial_state = (distances, sigma, visited, P, S)
    final_distances, final_sigma, _, final_P, final_S = lax.fori_loop(0, n_nodes, dijkstra_body, initial_state)
    
    # Reorder S by distance for accumulation
    sorted_S = final_S[jnp.argsort(final_distances[final_S])]

    betweenness_contrib = _accumulate_dependencies_optimized(
        sorted_S, final_P.T, final_sigma, source
    )
    
    return betweenness_contrib

def _accumulate_dependencies_optimized(
    S: jnp.ndarray,
    P: jnp.ndarray,
    sigma: jnp.ndarray,
    source: int
) -> jnp.ndarray:
    """
    Optimized dependency accumulation.
    S: stack of nodes in order of INCREASING distance from source.
    P: predecessor matrix where P[w, v] = 1 if v is a predecessor of w.
    """
    n_nodes = S.shape[0]
    delta = jnp.zeros(n_nodes)
    
    def accumulate_step(i, delta):
        # Iterate backwards through S
        w = S[n_nodes - 1 - i]
        
        predecessors_mask = P[w, :] > 0
        sigma_w = sigma[w]
        
        coeff = jnp.where(sigma_w > 0, (sigma / sigma_w) * (1.0 + delta[w]), 0.0)
        
        delta_update = jnp.sum(jnp.where(predecessors_mask, coeff, 0.0), axis=-1)
        new_delta = delta + delta_update * predecessors_mask

        return new_delta
    
    # Corrected logic for accumulation
    def body_fn(i, state):
        delta, betweenness = state
        w = S[n_nodes - 1 - i]

        coeff = (1 + delta[w]) / jnp.maximum(sigma[w], 1e-9)
        
        # Propagate dependency to predecessors
        delta_update = jnp.dot(P[w, :], sigma) * coeff
        delta = delta.at[jnp.arange(n_nodes)].add(jnp.where(P[w, :] > 0, delta_update, 0))

        betweenness = betweenness.at[w].add(delta[w])
        
        return delta, betweenness

    # The dependency accumulation needs to be done carefully
    # Using the Brandes algorithm formulation directly
    delta = jnp.zeros(n_nodes)
    def accumulation_loop(i, state):
        delta, betweenness = state
        w = S[n_nodes - 1 - i]

        # Get predecessors of w
        predecessors = P[w, :]
        
        # Calculate coefficients for predecessors
        coeffs = (sigma / jnp.maximum(sigma[w], 1e-9)) * (1 + delta[w])
        
        # Update delta for all predecessors of w
        delta_update = jnp.dot(predecessors, coeffs)
        
        # The update needs to be applied to the correct indices, which is tricky.
        # Let's use a simpler, more direct loop that is still JIT-able.
        
        w_delta = delta[w]
        
        def inner_loop(j, current_delta):
            v = S[j] # v is a potential predecessor
            is_pred = P[w, v] > 0
            
            update = (sigma[v] / jnp.maximum(sigma[w], 1e-9)) * (1 + w_delta)
            
            return jnp.where(is_pred, current_delta.at[v].add(update), current_delta)

        # This inner loop is inefficient. The vectorized version is better.
        # Let's correct the vectorized update.
        v_indices = jnp.arange(n_nodes)
        is_pred = P[w,:] > 0
        update_values = (sigma[v_indices] / jnp.maximum(sigma[w], 1e-9)) * (1 + delta[w])
        delta_updates = jnp.where(is_pred, update_values, 0)
        
        new_delta = delta + delta_updates
        new_betweenness = jnp.where(w != source, betweenness.at[w].add(delta[w]), betweenness)
        
        return new_delta, new_betweenness

    initial_state = (jnp.zeros(n_nodes), jnp.zeros(n_nodes))
    _, final_betweenness = lax.fori_loop(0, n_nodes, accumulation_loop, initial_state)

    return final_betweenness


def _rescale_optimized(
    betweenness: jnp.ndarray,
    n_nodes: int,
    normalized: bool,
    k: Optional[int],
    endpoints: bool,
    n_sampled: int,
    directed: bool = False
) -> jnp.ndarray:
    """Rescales the betweenness values."""
    if not normalized or n_nodes <= 2:
        return betweenness
    
    # Normalization factor for undirected graphs is (n-1)(n-2)/2
    # For directed, it's (n-1)(n-2)
    # Since we already divided by 2 for undirected, the scale is the same here.
    scale = 1.0 / ((n_nodes - 1) * (n_nodes - 2))
    
    if k is not None and n_sampled < n_nodes:
        scale *= n_nodes / n_sampled
    
    return betweenness * scale

def create_path_graph(n: int) -> jnp.ndarray:
    adj = jnp.zeros((n, n))
    for i in range(n - 1):
        adj = adj.at[i, i + 1].set(1.0)
        adj = adj.at[i + 1, i].set(1.0)
    return adj

def test_optimized_implementation():
    print("Testing optimized JAX betweenness centrality implementation...")
    
    print("\n1. Path graph (5 nodes):")
    path_adj = create_path_graph(5)
    
    bc_func = jit(functools.partial(betweenness_centrality_jax_optimized, normalized=True))
    bc_path = bc_func(adjacency_matrix=path_adj)
    
    print(f"Betweenness centrality: {bc_path}")
    
    expected_path = jnp.array([0.0, 0.5, 2.0/3.0, 0.5, 0.0])
    print(f"Expected values:        {expected_path}")
    print(f"Max difference:         {jnp.max(jnp.abs(bc_path - expected_path))}")
    
    print("\n2. Path graph with sampling (k=3):")
    bc_sampled_func = jit(functools.partial(betweenness_centrality_jax_optimized, k=3, normalized=True, seed=42))
    bc_sampled = bc_sampled_func(adjacency_matrix=path_adj)
    print(f"Sampled betweenness:    {bc_sampled}")
    
    print("\nOptimized implementation test completed!")

if __name__ == "__main__":
    test_optimized_implementation()


Testing optimized JAX betweenness centrality implementation...

1. Path graph (5 nodes):


TypeError: while_loop body function carry input and carry output must have equal types, but they differ:

  * the input carry component state[2] has type float32[5] but the corresponding output carry component has type float64[5], so the dtypes do not match;
  * the input carry component state[3] has type float32[5,5] but the corresponding output carry component has type float64[5,5], so the dtypes do not match.

Revise the function so that all output types match the corresponding input types.