## Test Pagerank

In [10]:
# First, let's check and configure JAX properly
import jax
import jax.numpy as jnp
import numpy as np

print("=== JAX Configuration Debug ===")
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"JAX devices: {jax.devices()}")
print(f"Current x64 status: {jax.config.jax_enable_x64}")

# Enable x64 globally for this session
jax.config.update("jax_enable_x64", True)
print(f"Updated x64 status: {jax.config.jax_enable_x64}")

# Test basic array operations
print("\n=== Precision Test ===")
test_arr = jnp.array([1.0, 2.0, 3.0])
print(f"Default array dtype: {test_arr.dtype}")

test_arr_f64 = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float64)
print(f"Explicit float64 dtype: {test_arr_f64.dtype}")

# Test argsort
print("\n=== Argsort Test ===")
try:
    sorted_indices = jnp.argsort(test_arr_f64)
    print(f"Argsort successful: {sorted_indices}")
    print(f"Argsort result dtype: {sorted_indices.dtype}")
except Exception as e:
    print(f"Argsort failed: {e}")

print("=== Configuration Complete ===\n")

=== JAX Configuration Debug ===
JAX version: 0.6.2
JAX backend: cpu
JAX devices: [CpuDevice(id=0)]
Current x64 status: True
Updated x64 status: True

=== Precision Test ===
Default array dtype: float64
Explicit float64 dtype: float64

=== Argsort Test ===
Argsort successful: [0 1 2]
Argsort result dtype: int64
=== Configuration Complete ===



In [11]:
import jax
import jax.numpy as jnp
import numpy as np
import networkx as nx
from graphs import from_networkx
from algorithms.pagerank import pagerank

def test_pagerank():
    """
    Tests the JAX-based PageRank implementation against NetworkX's implementation.
    """
    print("--- Testing PageRank Algorithm ---")
    
    # 1. Create a graph. The Karate Club graph is a good small test case.
    nx_g = nx.karate_club_graph()
    
    # Print graph statistics
    print(f"Graph statistics:")
    print(f"  Nodes: {nx_g.number_of_nodes()}")
    print(f"  Edges: {nx_g.number_of_edges()}")
    print(f"  Is directed: {nx_g.is_directed()}")
    
    # Convert to our JAX graph format
    jax_g = from_networkx(nx_g)
    print(f"JAX Graph:")
    print(f"  n_nodes: {jax_g.n_nodes}")
    print(f"  n_edges: {jax_g.n_edges}")
    print(f"  senders shape: {jax_g.senders.shape}")
    print(f"  receivers shape: {jax_g.receivers.shape}")
    print(f"  edge_weights: {jax_g.edge_weights}")
    
    # 2. Run our JAX PageRank implementation
    print(f"\nCurrent x64 status before PageRank: {jax.config.jax_enable_x64}")
    print("Running JAX PageRank...")
    jax_pr = pagerank(jax_g, damping_factor=0.85, tolerance=1e-8)
    print(f"JAX PageRank result shape: {jax_pr.shape}")
    print(f"JAX PageRank sum: {jnp.sum(jax_pr)}")
    print(f"JAX PageRank dtype: {jax_pr.dtype}")
    
    # 3. Run the NetworkX PageRank implementation as a baseline
    print("\nRunning NetworkX PageRank...")
    nx_pr = nx.pagerank(nx_g, alpha=0.85, tol=1e-8)
    print(f"NetworkX PageRank result length: {len(nx_pr)}")
    print(f"NetworkX PageRank sum: {sum(nx_pr.values())}")
    
    # Convert NetworkX result to a sorted numpy array for comparison
    nx_pr_array = np.array([nx_pr[i] for i in sorted(nx_pr.keys())])
    print(f"NetworkX array shape: {nx_pr_array.shape}")
    print(f"NetworkX array dtype: {nx_pr_array.dtype}")
    
    # 4. Compare the results
    print("\n--- Detailed Results Comparison ---")
    print(f"Top 10 PageRank scores (JAX): {jax_pr[:10]}")
    print(f"Top 10 PageRank scores (NX):  {nx_pr_array[:10]}")
    
    # Calculate differences
    diff = jax_pr - nx_pr_array
    abs_diff = np.abs(diff)
    max_diff = np.max(abs_diff)
    mean_diff = np.mean(abs_diff)
    
    print(f"\n--- Error Analysis ---")
    print(f"Maximum absolute difference: {max_diff}")
    print(f"Mean absolute difference: {mean_diff}")
    print(f"Index of max difference: {np.argmax(abs_diff)}")
    print(f"Value at max diff - JAX: {jax_pr[np.argmax(abs_diff)]}")
    print(f"Value at max diff - NX:  {nx_pr_array[np.argmax(abs_diff)]}")
    
    # Show top-k nodes by PageRank - fix the argsort issue
    k = 5
    print(f"\n--- Top {k} Nodes (Debug) ---")
    print(f"JAX PageRank array type: {type(jax_pr)}")
    print(f"JAX PageRank shape: {jax_pr.shape}")
    print(f"JAX PageRank dtype: {jax_pr.dtype}")
    
    # Convert JAX array to numpy array first to avoid argsort issues
    jax_pr_numpy = np.array(jax_pr)
    print(f"Converted to numpy - shape: {jax_pr_numpy.shape}, dtype: {jax_pr_numpy.dtype}")
    
    jax_top_k = np.argsort(jax_pr_numpy)[-k:][::-1]
    nx_top_k = np.argsort(nx_pr_array)[-k:][::-1]
    
    print(f"JAX top {k} nodes (by index): {jax_top_k}")
    print(f"NX  top {k} nodes (by index): {nx_top_k}")
    print(f"Top {k} match: {np.array_equal(jax_top_k, nx_top_k)}")
    
    # Print the actual PageRank values for top nodes
    print(f"\n--- Top {k} Node Values ---")
    for i, (jax_idx, nx_idx) in enumerate(zip(jax_top_k, nx_top_k)):
        print(f"Rank {i+1}: JAX node {jax_idx} (PR={jax_pr_numpy[jax_idx]:.6f}), NX node {nx_idx} (PR={nx_pr_array[nx_idx]:.6f})")
    
    is_close = np.allclose(jax_pr_numpy, nx_pr_array, atol=1e-6)
    print(f"\nResults are consistent with NetworkX (atol=1e-6): {is_close}")
    
    if not is_close:
        # Try with more relaxed tolerance
        is_close_relaxed = np.allclose(jax_pr_numpy, nx_pr_array, atol=1e-5)
        print(f"Results are consistent with relaxed tolerance (atol=1e-5): {is_close_relaxed}")
        
        # Try even more relaxed for debugging
        tolerances = [1e-4, 1e-3, 1e-2]
        for tol in tolerances:
            is_close_debug = np.allclose(jax_pr_numpy, nx_pr_array, atol=tol)
            print(f"Results are consistent with tolerance {tol}: {is_close_debug}")
            if is_close_debug:
                break
    
    # Use more lenient assertion for now to understand the differences
    try:
        assert is_close or np.allclose(jax_pr_numpy, nx_pr_array, atol=1e-4), "PageRank implementation does not match NetworkX."
        print("PageRank test passed successfully.\n")
    except AssertionError as e:
        print(f"PageRank test failed: {e}")
        print("This might be due to algorithmic differences or precision issues.")
        print("Consider investigating the differences further.\n")

In [12]:
test_pagerank()

--- Testing PageRank Algorithm ---
Graph statistics:
  Nodes: 34
  Edges: 78
  Is directed: False
JAX Graph:
  n_nodes: 34
  n_edges: 156
  senders shape: (156,)
  receivers shape: (156,)
  edge_weights: [4. 4. 5. 5. 3. 3. 3. 3. 3. 3. 3. 3. 2. 2. 2. 2. 2. 2. 3. 3. 1. 1. 3. 3.
 2. 2. 2. 2. 2. 2. 2. 2. 6. 6. 3. 3. 4. 4. 5. 5. 1. 1. 2. 2. 2. 2. 2. 2.
 3. 3. 4. 4. 5. 5. 1. 1. 3. 3. 2. 2. 2. 2. 2. 2. 3. 3. 3. 3. 3. 3. 2. 2.
 3. 3. 5. 5. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 4. 4. 2. 2. 3. 3. 3. 3. 2. 2.
 3. 3. 4. 4. 1. 1. 2. 2. 1. 1. 3. 3. 1. 1. 2. 2. 3. 3. 5. 5. 4. 4. 3. 3.
 5. 5. 4. 4. 2. 2. 3. 3. 2. 2. 7. 7. 4. 4. 2. 2. 4. 4. 2. 2. 2. 2. 4. 4.
 2. 2. 3. 3. 3. 3. 4. 4. 4. 4. 5. 5.]

Current x64 status before PageRank: True
Running JAX PageRank...
JAX PageRank result shape: (34,)
JAX PageRank sum: 1.0
JAX PageRank dtype: float64

Running NetworkX PageRank...
NetworkX PageRank result length: 34
NetworkX PageRank sum: 0.9999999999999994
NetworkX array shape: (34,)
NetworkX array dtype: float64

--- 

## Real Data: Web-Google

In [13]:
import os
# Now let's examine what's different about the web-Google graph
print("\n=== Analyzing Web-Google Graph Properties ===")

# Load a small sample of the web-Google graph for analysis
file_path = 'web-Google.txt'
if os.path.exists(file_path):
    # Load the full graph
    nx_web = nx.read_edgelist(file_path, comments='#', create_using=nx.DiGraph(), nodetype=int)
    
    print(f"Web-Google graph properties:")
    print(f"  Is directed: {nx_web.is_directed()}")
    print(f"  Nodes: {nx_web.number_of_nodes()}")
    print(f"  Edges: {nx_web.number_of_edges()}")
    
    # Check for self-loops and parallel edges
    self_loops = list(nx.selfloop_edges(nx_web))
    print(f"  Self-loops: {len(self_loops)}")
    
    # Check node ID ranges
    node_ids = list(nx_web.nodes())
    print(f"  Min node ID: {min(node_ids)}")
    print(f"  Max node ID: {max(node_ids)}")
    print(f"  Node IDs are consecutive: {max(node_ids) - min(node_ids) + 1 == len(node_ids)}")
    
    # Check out-degree distribution
    out_degrees = [nx_web.out_degree(n) for n in nx_web.nodes()]
    zero_out_degree = sum(1 for d in out_degrees if d == 0)
    print(f"  Nodes with out-degree 0 (dangling): {zero_out_degree}")
    print(f"  Max out-degree: {max(out_degrees)}")
    print(f"  Average out-degree: {np.mean(out_degrees):.2f}")
    
    # Convert to JAX and check the conversion
    jax_web = from_networkx(nx_web)
    print(f"\nJAX conversion:")
    print(f"  JAX n_nodes: {jax_web.n_nodes}")
    print(f"  JAX n_edges: {jax_web.n_edges}")
    print(f"  Senders range: {jnp.min(jax_web.senders)} to {jnp.max(jax_web.senders)}")
    print(f"  Receivers range: {jnp.min(jax_web.receivers)} to {jnp.max(jax_web.receivers)}")
    
    # Check out-degrees in JAX format
    jax_out_degrees = jnp.zeros(jax_web.n_nodes).at[jax_web.senders].add(1)
    jax_zero_out_degree = jnp.sum(jax_out_degrees == 0)
    print(f"  JAX nodes with out-degree 0: {jax_zero_out_degree}")
    print(f"  JAX max out-degree: {jnp.max(jax_out_degrees)}")
    
else:
    print("web-Google.txt not found")


=== Analyzing Web-Google Graph Properties ===
Web-Google graph properties:
  Is directed: True
  Nodes: 875713
  Edges: 5105039
  Self-loops: 0
Web-Google graph properties:
  Is directed: True
  Nodes: 875713
  Edges: 5105039
  Self-loops: 0
  Min node ID: 0
  Max node ID: 916427
  Node IDs are consecutive: False
  Nodes with out-degree 0 (dangling): 136259
  Max out-degree: 456
  Average out-degree: 5.83
  Min node ID: 0
  Max node ID: 916427
  Node IDs are consecutive: False
  Nodes with out-degree 0 (dangling): 136259
  Max out-degree: 456
  Average out-degree: 5.83

JAX conversion:
  JAX n_nodes: 875713
  JAX n_edges: 5105039
  Senders range: 0 to 875712
  Receivers range: 0 to 875712
  JAX nodes with out-degree 0: 136259
  JAX max out-degree: 456.0

JAX conversion:
  JAX n_nodes: 875713
  JAX n_edges: 5105039
  Senders range: 0 to 875712
  Receivers range: 0 to 875712
  JAX nodes with out-degree 0: 136259
  JAX max out-degree: 456.0


In [25]:
import os
import timeit
import jax
import jax.numpy as jnp
import numpy as np
import networkx as nx
from graphs import from_networkx
from algorithms.pagerank import pagerank

def benchmark_pagerank_on_real_data():
    """
    在真实的 SNAP 数据集上加载图，运行并评测 PageRank 算法。
    """
    # --- 1. 加载真实世界数据集 ---
    # 请从 https://snap.stanford.edu/data/web-Google.html 下载
    # 并将解压后的 'web-Google.txt' 文件路径放在这里
    file_path = 'web-Google.txt'

    if not os.path.exists(file_path):
        print(f"数据集文件未找到: {file_path}")
        print("请从 SNAP 下载并放置在正确的位置后再运行。")
        return
    
    print(f"--- 正在从 {file_path} 加载图... ---")
    # SNAP 的文件通常有一些注释行，以 '#' 开头
    # web-Google 是一个有向图
    nx_g_real = nx.read_edgelist(file_path, comments='#', create_using=nx.DiGraph(), nodetype=int)
    print(f"图已加载. 节点数: {nx_g_real.number_of_nodes()}, 边数: {nx_g_real.number_of_edges()}")
    
    # 转换为我们的 JAX Graph 对象
    print("正在转换为 JAX Graph 格式...")
    jax_g_real = from_networkx(nx_g_real)
    print(f"JAX Graph 已创建:")
    print(f"  n_nodes: {jax_g_real.n_nodes}")
    print(f"  n_edges: {jax_g_real.n_edges}")
    print(f"  senders shape: {jax_g_real.senders.shape}")
    print(f"  receivers shape: {jax_g_real.receivers.shape}")
    print(f"  edge_weights 是否为 None: {jax_g_real.edge_weights is None}\n")

    # --- 2. 运行 PageRank 并进行性能评测 ---
    print("--- 开始 PageRank 基准测试 (JAX vs NetworkX) ---")
    # JAX 实现
    print("正在预热/编译 JAX PageRank...")
    # 预热编译
    jax_pr_real = pagerank(jax_g_real).block_until_ready()
    print(f"JAX PageRank 预热完成")
    print(f"  结果形状: {jax_pr_real.shape}")
    print(f"  结果类型: {jax_pr_real.dtype}")
    print(f"  结果总和: {jnp.sum(jax_pr_real)}")
    print(f"  最大值: {jnp.max(jax_pr_real)}")
    print(f"  最小值: {jnp.min(jax_pr_real)}")
    
    print("正在运行 JAX PageRank 性能测试...")
    jax_time = timeit.timeit(lambda: pagerank(jax_g_real).block_until_ready(), number=5)
    print(f"JAX PageRank 实现速度: {jax_time / 5:.6f}s (平均每次)\n")

    # NetworkX 实现
    print("正在运行 NetworkX PageRank 性能测试...")
    # NetworkX 不需要预热
    nx_pr_real_dict = nx.pagerank(nx_g_real)
    print(f"NetworkX PageRank 完成")
    print(f"  结果字典长度: {len(nx_pr_real_dict)}")
    print(f"  结果总和: {sum(nx_pr_real_dict.values())}")
    print(f"  最大值: {max(nx_pr_real_dict.values())}")
    print(f"  最小值: {min(nx_pr_real_dict.values())}")
    
    nx_time = timeit.timeit(lambda: nx.pagerank(nx_g_real), number=5)
    print(f"NetworkX PageRank 实现速度: {nx_time / 5:.6f}s (平均每次)")
    print(f"JAX 相比 NetworkX 的加速比: {nx_time / jax_time:.2f}x")

    # --- 3. 验证结果 ---
    print("\n--- 验证结果一致性 ---")
    # NetworkX 可能不会返回所有节点的排名（如果它们是孤立的），我们需要创建一个完整的数组
    all_nodes = sorted(nx_g_real.nodes())
    print(f"图中所有节点数量: {len(all_nodes)}")
    print(f"NetworkX 返回的节点数量: {len(nx_pr_real_dict)}")
    
    nx_pr_real_array = np.array([nx_pr_real_dict.get(n, 0) for n in all_nodes])
    print(f"NetworkX 数组形状: {nx_pr_real_array.shape}")
    print(f"NetworkX 数组总和: {nx_pr_real_array.sum()}")
    
    # 重新归一化以确保总和为1，就像我们的实现一样
    nx_pr_real_array /= nx_pr_real_array.sum()
    print(f"归一化后 NetworkX 数组总和: {nx_pr_real_array.sum()}")

    # 计算差异统计
    diff = jax_pr_real - nx_pr_real_array
    abs_diff = np.abs(diff)
    max_diff = np.max(abs_diff)
    mean_diff = np.mean(abs_diff)
    
    print(f"\n--- 误差分析 ---")
    print(f"最大绝对误差: {max_diff}")
    print(f"平均绝对误差: {mean_diff}")
    print(f"最大误差位置: {np.argmax(abs_diff)}")

    # 使用 float32 时，需要放宽 atol
    tolerances = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4]
    for tol in tolerances:
        is_close = np.allclose(jax_pr_real, nx_pr_real_array, atol=tol)
        print(f"容忍度 {tol}: {is_close}")
        if is_close:
            break
    
    # 比较前 10 个最高 PageRank 的节点
    k = 10
    jax_top_k = np.argsort(jax_pr_real)[-k:][::-1]
    nx_top_k = np.argsort(nx_pr_real_array)[-k:][::-1]
    
    print(f"\n--- 前 {k} 个最重要节点对比 ---")
    print(f"JAX  前 {k}: {jax_top_k}")
    print(f"NX   前 {k}: {nx_top_k}")
    print(f"前 {k} 节点顺序是否一致: {np.array_equal(jax_top_k, nx_top_k)}")
    
    if not np.allclose(jax_pr_real, nx_pr_real_array, atol=1e-5):
        print("\nJAX PageRank 在真实数据集上的结果与 NetworkX 不匹配。")
    else:
        print("\n在 web-Google 数据集上的测试成功")

benchmark_pagerank_on_real_data()

--- 正在从 web-Google.txt 加载图... ---
图已加载. 节点数: 875713, 边数: 5105039
正在转换为 JAX Graph 格式...
JAX Graph 已创建:
  n_nodes: 875713
  n_edges: 5105039
  senders shape: (5105039,)
  receivers shape: (5105039,)
  edge_weights 是否为 None: True

--- 开始 PageRank 基准测试 (JAX vs NetworkX) ---
正在预热/编译 JAX PageRank...
JAX PageRank 预热完成
  结果形状: (875713,)
  结果类型: float64
  结果总和: 1.0
  最大值: 0.0009521123333770589
  最小值: 2.8281106711291503e-07
正在运行 JAX PageRank 性能测试...
JAX PageRank 实现速度: 0.105448s (平均每次)

正在运行 NetworkX PageRank 性能测试...
NetworkX PageRank 完成
  结果字典长度: 875713
  结果总和: 1.0000000000009635
  最大值: 0.0009521123333766876
  最小值: 2.8281106711282054e-07
NetworkX PageRank 实现速度: 1.970189s (平均每次)
JAX 相比 NetworkX 的加速比: 18.68x

--- 验证结果一致性 ---
图中所有节点数量: 875713
NetworkX 返回的节点数量: 875713
NetworkX 数组形状: (875713,)
NetworkX 数组总和: 0.9999999999997095
归一化后 NetworkX 数组总和: 0.9999999999999999

--- 误差分析 ---
最大绝对误差: 1.599198204416119e-16
平均绝对误差: 1.184823007737448e-19
最大误差位置: 874876
容忍度 1e-08: True

--- 前 10 个最重要节点对比 ---
JAX  前 10