In [3]:
import jax.numpy as jnp
import jax
import numpy as np
import networkx as nx

def jax_geodesic_distance(adj_matrix):
    """
    Compute the geodesic distance in a weighted graph using Dijkstra's algorithm in JAX.
    Args:
        adj_matrix: 2D JAX array representing the weighted adjacency matrix of a graph.

    Returns:
        A 2D JAX array containing the shortest path distances between all pairs of nodes.
    """
    n_nodes = adj_matrix.shape[0]

    def single_source_dijkstra(src):
        # Initialize distances and visited status
        dist = jnp.full((n_nodes,), jnp.inf)
        dist = dist.at[src].set(0)
        visited = jnp.zeros((n_nodes,), dtype=bool)

        def relax_step(carry, _):
            dist, visited = carry
            # Find the closest unvisited node
            unvisited_dist = jnp.where(visited, jnp.inf, dist)
            u = jnp.argmin(unvisited_dist)
            visited = visited.at[u].set(True)

            # Relax distances for neighbors of the selected node
            new_dist = jnp.where(
                ~visited,
                jnp.minimum(dist, dist[u] + adj_matrix[u]),
                dist
            )
            return (new_dist, visited), None

        (dist, _), _ = jax.lax.scan(relax_step, (dist, visited), None, length=n_nodes)
        return dist

    distances = jax.vmap(single_source_dijkstra)(jnp.arange(n_nodes))
    return distances


# NetworkX Validation
def networkx_geodesic_distance(adj_matrix):
    """
    Compute the geodesic distance using NetworkX for validation.
    Args:
        adj_matrix: 2D numpy array representing the weighted adjacency matrix of a graph.

    Returns:
        A 2D numpy array containing the shortest path distances between all pairs of nodes.
    """
    G = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph)
    dist = dict(nx.all_pairs_dijkstra_path_length(G))
    n_nodes = len(dist)
    dist_matrix = np.full((n_nodes, n_nodes), np.inf)

    for i in range(n_nodes):
        for j in dist[i]:
            dist_matrix[i, j] = dist[i][j]
    
    return dist_matrix



# Example adjacency matrix (symmetric, weights represent distances)
adj_matrix = np.array([
    [0, -3, np.inf, np.inf, 5],
    [3, 0, 8, np.inf, 2],
    [np.inf, 8, 0, 1, np.inf],
    [np.inf, np.inf, 1, 0, 4],
    [5, 2, np.inf, 4, 0]
])

# Convert to JAX array
adj_matrix_jax = jnp.array(adj_matrix)

# Compute geodesic distances
jax_distances = jax_geodesic_distance(adj_matrix_jax)
nx_distances = networkx_geodesic_distance(adj_matrix)

# Compare results
print("JAX Geodesic Distances:\n", jax_distances)
print("NetworkX Geodesic Distances:\n", nx_distances)

# Validation
assert np.allclose(jax_distances, nx_distances, atol=1e-5), "Results do not match!"
print("Validation passed!")


JAX Geodesic Distances:
 [[ 0. -3.  4.  3. -1.]
 [ 3.  0.  7.  6.  2.]
 [10.  7.  0.  1.  5.]
 [ 9.  6.  1.  0.  4.]
 [ 5.  2.  5.  4.  0.]]
NetworkX Geodesic Distances:
 [[ 0. -3.  4.  3. -1.]
 [ 3.  0.  7.  6.  2.]
 [10.  7.  0.  1.  5.]
 [ 9.  6.  1.  0.  4.]
 [ 5.  2.  5.  4.  0.]]
Validation passed!


In [8]:
import os
import sys
newPath = os.path.dirname(os.path.dirname(os.path.abspath("")))
if newPath not in sys.path:
    sys.path.append(newPath)

from main import *
m=bi()
m.net.degree(adj_matrix_jax)

jax.local_device_count 1


NameError: name 'indegree_jit' is not defined