In [1]:
import os
import sys
newPath = os.path.dirname(os.path.dirname(os.path.abspath("")))
if newPath not in sys.path:
    sys.path.append(newPath)

from main import *
import networkx as nx

m=bi()

  from .autonotebook import tqdm as notebook_tqdm


jax.local_device_count 16


## Data

In [2]:
network = np.array([
    [0, 3,0, 0, 5],
    [3, 0, 8, 0, 2],
    [0, 8, 0, 1, 0],
    [0, 0, 1, 0, 4],
    [5, 2, 0, 4, 0]
])
G = nx.from_numpy_array(network, create_using=nx.DiGraph)

In [None]:
bi.net.is_symmetric(network)

Array(True, dtype=bool)

## Nodal measures

## Degree

In [5]:
print(m.net.indegree(network))
print(G.in_degree())

[2 3 2 2 3]
[(0, 2), (1, 3), (2, 2), (3, 2), (4, 3)]


In [6]:
print(m.net.outdegree(network))
print(print(G.out_degree()))

[2 3 2 2 3]
[(0, 2), (1, 3), (2, 2), (3, 2), (4, 3)]
None


In [7]:
print(m.net.degree(network))
print(print(G.degree()))

[4 6 4 4 6]
[(0, 4), (1, 6), (2, 4), (3, 4), (4, 6)]
None


In [8]:
# n alters to do 

## Strength


In [9]:
print(m.net.strength(network))
print(G.degree(weight='weight'))

[16 26 18 10 22]
[(0, 16), (1, 26), (2, 18), (3, 10), (4, 22)]

[(0, 16), (1, 26), (2, 18), (3, 10), (4, 22)]


In [10]:
print(m.net.instrength(network))
print(G.in_degree(weight='weight'))

[ 8 13  9  5 11]
[(0, 8), (1, 13), (2, 9), (3, 5), (4, 11)]


In [11]:
print(m.net.outstrength(network))
print(G.out_degree(weight='weight'))

[ 8 13  9  5 11]
[(0, 8), (1, 13), (2, 9), (3, 5), (4, 11)]


## Eingenvector

In [12]:
import numpy as np
import numba
from numba import njit
import networkx as nx
import math

@njit
def normalize_vector(x):
    norm = np.linalg.norm(x)
    if norm == 0:
        norm = 1
    return x / norm

@njit
def check_convergence(x, xlast, nnodes, tol):
    return np.sum(np.abs(x - xlast)) < nnodes * tol

@njit
def power_iteration(G, nstart, max_iter, tol, weight):
    nnodes = G.shape[0]
    if nstart is None:
        nstart = np.ones(nnodes)
    if np.all(nstart == 0):
        raise nx.NetworkXError("initial vector cannot have all zero values")

    nstart_sum = np.sum(nstart)
    x = nstart / nstart_sum

    for _ in range(max_iter):
        xlast = x.copy()
        x = np.zeros(nnodes)

        for n in range(nnodes):
            for nbr in range(nnodes):
                if G[n, nbr] != 0:
                    w = G[n, nbr] if weight else 1
                    x[nbr] += xlast[n] * w

        x = normalize_vector(x)

        if check_convergence(x, xlast, nnodes, tol):
            return x

    raise nx.PowerIterationFailedConvergence(max_iter)


nstart = None
max_iter = 100
tol = 1e-6
weight = 'weight'

result = power_iteration(network, nstart, max_iter, tol, weight)
print(result)
print(nx.eigenvector_centrality(G, max_iter=300, weight=weight))


[0.3882815  0.61232662 0.51425236 0.21454302 0.40473373]
{0: 0.3882811376028844, 1: 0.6123280935190829, 2: 0.5142509671571023, 3: 0.21454330028691557, 4: 0.4047334668013984}


## Clustering coefficient

In [16]:
print(bi.net.cc(network))
nx.clustering(G).values()

[1.0714285  0.3846154  0.         0.         0.54545456]


dict_values([1.0, 0.3333333333333333, 0, 0, 0.3333333333333333])

## Betweeness

In [14]:
import jax.numpy as jnp
from jax import jit, lax

@jit
def dijkstra(adjacency_matrix, source):
    n = adjacency_matrix.shape[0]
    visited = jnp.zeros(n, dtype=bool)
    dist = jnp.inf * jnp.ones(n)
    dist = dist.at[source].set(0)

    def body_fn(carry):
        visited, dist = carry
        
        # Find the next node to process
        u = jnp.argmin(jnp.where(visited, jnp.inf, dist))
        visited = visited.at[u].set(True)

        # Update distances to all neighbors
        def update_dist(v, dist):
            return jax.lax.cond(
                jnp.logical_and(jnp.logical_not(visited[v]), adjacency_matrix[u, v] > 0),
                lambda _: jnp.minimum(dist[v], dist[u] + adjacency_matrix[u, v]),
                lambda _: dist[v],
                None
            )

        dist = lax.fori_loop(0, n, lambda v, dist: dist.at[v].set(update_dist(v, dist)), dist)

        return visited, dist

    def cond_fn(carry):
        visited, _ = carry
        return jnp.any(jnp.logical_not(visited))

    # Loop until all nodes are visited
    visited, dist_final = lax.while_loop(cond_fn, body_fn, (visited, dist))

    return dist_final

print(dijkstra(network,0))
nx.single_source_dijkstra_path_length(G, 0).values()

[ 0.  3. 10.  9.  5.]


dict_values([0, 3, 5, 9, 10])

In [17]:
import jax.numpy as jnp
from jax import jit
import networkx as nx

@jit
def betweenness_centrality(adjacency_matrix):
    n = adjacency_matrix.shape[0]
    centrality = jnp.zeros(n)

    for s in range(n):
        dist = dijkstra(adjacency_matrix, s)
        paths = jnp.zeros((n, n))  # paths[i, j] = number of shortest paths from i to j
        predecessors = jnp.zeros((n, n), dtype=bool)  # Track predecessors

        for t in range(n):
            if s != t:
                path_count = 0
                for u in range(n):
                    path_count += jax.lax.cond(
                        jnp.logical_and(u != s, dist[u] + adjacency_matrix[s, u] == dist[t]),
                        lambda _: 1.0,  # Return a float here
                        lambda _: 0.0,  # Return a float here
                        None
                    )
                    predecessors = predecessors.at[t, u].set(jnp.logical_and(u != s, dist[u] + adjacency_matrix[s, u] == dist[t]))

                paths = paths.at[s, t].set(path_count)

        # Update betweenness centrality based on contributions from each node
        for t in range(n):
            if s != t:
                for u in range(n):
                    contribution = jax.lax.cond(
                        predecessors[t, u],
                        lambda _: paths[s, t] / paths[s, t],
                        lambda _: 0.0,  # Return a float here
                        None
                    )
                    centrality = centrality.at[u].add(contribution)

    return centrality

# Compute betweenness centrality using JAX
jax_betweenness = betweenness_centrality(adjacency_matrix).block_until_ready()

# Compute betweenness centrality using NetworkX

nx_betweenness = nx.betweenness_centrality(G)

# Print results
print("JAX Betweenness Centrality:", jax_betweenness)
print("NetworkX Betweenness Centrality:", [nx_betweenness[i] for i in range(len(nx_betweenness))])



NameError: name 'adjacency_matrix' is not defined

# Dyadic measures
## Assortatitivity
## Transitivity

# Global measures

## Geodesic distance

In [4]:
# NetworkX Validation
def networkx_geodesic_distance(adj_matrix):
    """
    Compute the geodesic distance using NetworkX for validation.
    Args:
        adj_matrix: 2D numpy array representing the weighted adjacency matrix of a graph.

    Returns:
        A 2D numpy array containing the shortest path distances between all pairs of nodes.
    """
    G = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph)
    dist = dict(nx.all_pairs_dijkstra_path_length(G))
    n_nodes = len(dist)
    dist_matrix = np.full((n_nodes, n_nodes), np.inf)

    for i in range(n_nodes):
        for j in dist[i]:
            dist_matrix[i, j] = dist[i][j]
    
    return dist_matrix


# Example Usage
if __name__ == "__main__":
    # Convert to JAX array
    adj_matrix_jax = jnp.array(network)

    # Compute geodesic distances
    jax_distances = bi.net.geodesic_distance(jnp.array(network))
    nx_distances = networkx_geodesic_distance(network)

    # Compare results
    print("JAX Geodesic Distances:\n", jax_distances)
    print("NetworkX Geodesic Distances:\n", nx_distances)

    # Validation
    assert np.allclose(jax_distances, nx_distances, atol=1e-5), "Results do not match!"
    print("Validation passed!")



JAX Geodesic Distances:
 [[ 0.  3. 10.  9.  5.]
 [ 3.  0.  7.  6.  2.]
 [10.  7.  0.  1.  5.]
 [ 9.  6.  1.  0.  4.]
 [ 5.  2.  5.  4.  0.]]
NetworkX Geodesic Distances:
 [[ 0.  3. 10.  9.  5.]
 [ 3.  0.  7.  6.  2.]
 [10.  7.  0.  1.  5.]
 [ 9.  6.  1.  0.  4.]
 [ 5.  2.  5.  4.  0.]]
Validation passed!


## Diameter

In [4]:
bi.net.diameter(jnp.array(network))

Array(10., dtype=float64)

# Density

# Modularity

## Global clustering coefficient

## Global efficiency