# Project: Node Prediction for OGB-Arxiv using Curvature Graph Neural Networks

**CS224W: Machine Learning with Graphs**


_Stanford University. Winter, 2021._

---

**Team Members:** Gongqi Li, Khushal Sethi, Prathyusha Burugupalli

---
This colab implements generate Ollivier Curvature information for Ogb-Arxiv dataset.

## Environment Setup

In [1]:
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-geometric
!pip install ogb

[K     |████████████████████████████████| 2.6MB 2.2MB/s 
[K     |████████████████████████████████| 1.5MB 2.3MB/s 
[K     |████████████████████████████████| 194kB 5.5MB/s 
[K     |████████████████████████████████| 235kB 7.9MB/s 
[K     |████████████████████████████████| 2.2MB 8.9MB/s 
[K     |████████████████████████████████| 51kB 5.2MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone
Collecting ogb
[?25l  Downloading https://files.pythonhosted.org/packages/34/47/16573587124ee85c8255cebd30c55981fa78c815eaff966ff111fb11c32c/ogb-1.3.0-py3-none-any.whl (67kB)
[K     |████████████████████████████████| 71kB 2.9MB/s 
Collecting outdated>=0.2.0
  Downloading https://files.pythonhosted.org/packages/86/70/2f166266438a30e94140f00c99c0eac1c45807981052a1d4c123660e1323/outdated-0.2.0.tar.gz
Collecting littleutils
  Downloading https://files.pythonhosted.org/packages/4e/b1/bb4e06f010947d67349f863b6a2ad71577f85590180a935f60543f622652/littleutils-0.2.2.tar.gz
Buildi

In [2]:
!pip install POT
!pip install path.py
!pip3 install cmake cython
!pip3 install networkit

Collecting POT
[?25l  Downloading https://files.pythonhosted.org/packages/1f/47/1ead874bd6ac538f246dcabfe04d3cc67ec94cf8fc8f952ff4e56c07cfe2/POT-0.7.0-cp37-cp37m-manylinux2010_x86_64.whl (430kB)
[K     |████████████████████████████████| 440kB 5.1MB/s 
Installing collected packages: POT
Successfully installed POT-0.7.0
Collecting path.py
  Downloading https://files.pythonhosted.org/packages/8f/04/130b7a538c25693c85c4dee7e25d126ebf5511b1eb7320e64906687b159e/path.py-12.5.0-py3-none-any.whl
Collecting path
  Downloading https://files.pythonhosted.org/packages/d3/2a/b0f97e1b736725f6ec48a8bd564ee1d1f3f945bb5d39cb44ef8bbe66bd14/path-15.1.2-py3-none-any.whl
Installing collected packages: path, path.py
Successfully installed path-15.1.2 path.py-12.5.0
Collecting networkit
[?25l  Downloading https://files.pythonhosted.org/packages/58/7a/4ef04f2b34fc81c5f2a5060b5cb989509f847d8aa6e359899922acffacf8/networkit-8.1.tar.gz (3.1MB)
[K     |████████████████████████████████| 3.1MB 5.8MB/s 
Building w

In [31]:
import torch
from torch.nn import Sequential as seq, Parameter, LeakyReLU, init, Linear
import torch.nn.functional as F
import numpy as np

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv, MessagePassing
from torch_geometric.utils import add_self_loops, remove_self_loops, degree, softmax
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_undirected


from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

In [23]:
dataset = PygNodePropPredDataset(name='ogbn-arxiv')

In [32]:
def to_networkx(data, node_attrs=None, edge_attrs=None, to_undirected=False,
                remove_self_loops=False):
    r"""Converts a :class:`torch_geometric.data.Data` instance to a
    :obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or
    a directed :obj:`networkx.DiGraph` otherwise.

    Args:
        data (torch_geometric.data.Data): The data object.
        node_attrs (iterable of str, optional): The node attributes to be
            copied. (default: :obj:`None`)
        edge_attrs (iterable of str, optional): The edge attributes to be
            copied. (default: :obj:`None`)
        to_undirected (bool, optional): If set to :obj:`True`, will return a
            a :obj:`networkx.Graph` instead of a :obj:`networkx.DiGraph`. The
            undirected graph will correspond to the upper triangle of the
            corresponding adjacency matrix. (default: :obj:`False`)
        remove_self_loops (bool, optional): If set to :obj:`True`, will not
            include self loops in the resulting graph. (default: :obj:`False`)
    """

    if to_undirected:
        G = nx.Graph()
    else:
        G = nx.DiGraph()

    G.add_nodes_from(range(data.num_nodes))

    values = {}
    for key, item in data:
        if torch.is_tensor(item):
            values[key] = item.squeeze().tolist()
        else:
            values[key] = item
        if isinstance(values[key], (list, tuple)) and len(values[key]) == 1:
            values[key] = item[0]

    for i, (u, v) in enumerate(data.edge_index.t().tolist()):

        if to_undirected and v > u:
            continue

        if remove_self_loops and u == v:
            continue

        G.add_edge(u, v)
        for key in edge_attrs if edge_attrs is not None else []:
            G[u][v][key] = values[key][i]

    for key in node_attrs if node_attrs is not None else []:
        for i, feat_dict in G.nodes(data=True):
            feat_dict.update({key: values[key][i]})

    return G


In [33]:
import random
import heapq
import importlib
import math
import multiprocessing as mp
import time
from functools import lru_cache

import cvxpy as cvx
import networkit as nk
import networkx as nx
import numpy as np
import ot

import logging
import community as community_louvain
import networkx as nx
import numpy as np
from functools import partial, partialmethod

logging.TRACE = logging.DEBUG + 5
logging.addLevelName(logging.TRACE, 'TRACE')
logging.Logger.trace = partialmethod(logging.Logger.log, logging.TRACE)
logging.trace = partial(logging.log, logging.TRACE)

logger = logging.getLogger("GraphRicciCurvature")

In [34]:
def set_verbose(verbose="ERROR"):
    """Set up the verbose level of the GraphRicciCurvature.
    Parameters
    ----------
    verbose : {"INFO", "TRACE","DEBUG","ERROR"}
        Verbose level. (Default value = "ERROR")
            - "INFO": show only iteration process log.
            - "TRACE": show detailed iteration process log.
            - "DEBUG": show all output logs.
            - "ERROR": only show log if error happened.
    """
    if verbose == "INFO":
        logger.setLevel(logging.INFO)
    elif verbose == "TRACE":
        logger.setLevel(logging.TRACE)
    elif verbose == "DEBUG":
        logger.setLevel(logging.DEBUG)
    elif verbose == "ERROR":
        logger.setLevel(logging.ERROR)
    else:
        print('Incorrect verbose level, option:["INFO","DEBUG","ERROR"], use "ERROR instead."')
        logger.setLevel(logging.ERROR)


def cut_graph_by_cutoff(G_origin, cutoff, weight="weight"):
    """Remove graph's edges with "weight" greater than "cutoff".
    Parameters
    ----------
    G_origin : NetworkX graph
        A graph with ``weight`` as Ricci flow metric to cut.
    cutoff : float
        A threshold to remove all edges with "weight" greater than it.
    weight : str
        The edge weight used as Ricci flow metric. (Default value = "weight")
    Returns
    -------
    G: NetworkX graph
        A graph with edges cut by given cutoff value.
    """
    assert nx.get_edge_attributes(G_origin, weight), "No edge weight detected, abort."

    G = G_origin.copy()
    edge_trim_list = []
    for n1, n2 in G.edges():
        if G[n1][n2][weight] > cutoff:
            edge_trim_list.append((n1, n2))
    G.remove_edges_from(edge_trim_list)
    return G


def get_rf_metric_cutoff(G_origin, weight="weight", cutoff_step=0.025, drop_threshold=0.01):
    """Get good clustering cutoff points for Ricci flow metric by detect the change of modularity while removing edges.
    Parameters
    ----------
    G_origin : NetworkX graph
        A graph with "weight" as Ricci flow metric to cut.
    weight : str
        The edge weight used as Ricci flow metric. (Default value = "weight")
    cutoff_step : float
        The step size to find the good cutoff points.
    drop_threshold : float
        At least drop this much to considered as a drop for good_cut.
    Returns
    -------
    good_cuts : list of float
        A list of possible cutoff point, usually we use the first one as the best cut.
    """

    G = G_origin.copy()
    modularity, ari = [], []
    maxw = max(nx.get_edge_attributes(G, weight).values())
    cutoff_range = np.arange(maxw, 1, -cutoff_step)

    for cutoff in cutoff_range:
        G = cut_graph_by_cutoff(G, cutoff, weight=weight)
        # Get connected component after cut as clustering
        clustering = {c: idx for idx, comp in enumerate(nx.connected_components(G)) for c in comp}
        # Compute modularity
        modularity.append(community_louvain.modularity(clustering, G, weight))

    good_cuts = []
    mod_last = modularity[-1]

    # check drop from 1 -> maxw
    for i in range(len(modularity) - 1, 0, -1):
        mod_now = modularity[i]
        if mod_last > mod_now > 1e-4 and abs(mod_last - mod_now) / mod_last > drop_threshold:
            logger.trace("Cut detected: cut:%f, diff:%f, mod_now:%f, mod_last:%f" % (
                cutoff_range[i+1], mod_last - mod_now, mod_now, mod_last))
            good_cuts.append(cutoff_range[i+1])
        mod_last = mod_now

    return good_cuts

In [35]:
print
EPSILON = 1e-7  # to prevent divided by zero

# ---Shared global variables for multiprocessing used.---
_Gk = nk.graph.Graph()
_alpha = 0.5
_weight = "weight"
_method = "Sinkhorn"
_base = math.e
_exp_power = 2
_proc = mp.cpu_count()
_cache_maxsize = 1000000
_shortest_path = "all_pairs"
_nbr_topk = 1000
_apsp = {}
# -------------------------------------------------------

#@lru_cache(_cache_maxsize)

In [36]:

def _get_single_node_neighbors_distributions(node, direction="successors"):
    """Get the neighbor density distribution of given node `node`.
    Parameters
    ----------
    node : int
        Node index in Networkit graph `_Gk`.
    direction : {"predecessors", "successors"}
        Direction of neighbors in directed graph. (Default value: "successors")
    Returns
    -------
    distributions : lists of float
        Density distributions of neighbors up to top `_nbr_topk` nodes.
    nbrs : lists of int
        Neighbor index up to top `_nbr_topk` nodes.
    """
    if _Gk.isDirected():
        if direction == "predecessors":
            neighbors = list(_Gk.iterInNeighbors(node))
        else:  # successors
            neighbors = list(_Gk.iterNeighbors(node))
    else:
        neighbors = list(_Gk.iterNeighbors(node))

    # Get sum of distributions from x's all neighbors
    heap_weight_node_pair = []
    for nbr in neighbors:
        if direction == "predecessors":
            w = _base ** (-_Gk.weight(nbr, node) ** _exp_power)
        else:  # successors
            w = _base ** (-_Gk.weight(node, nbr) ** _exp_power)

        if len(heap_weight_node_pair) < _nbr_topk:
            heapq.heappush(heap_weight_node_pair, (w, nbr))
        else:
            heapq.heappushpop(heap_weight_node_pair, (w, nbr))

    nbr_edge_weight_sum = sum([x[0] for x in heap_weight_node_pair])

    if not neighbors:
        # No neighbor, all mass stay at node
        return [1], [node]

    if nbr_edge_weight_sum > EPSILON:
        # Sum need to be not too small to prevent divided by zero
        distributions = [(1.0 - _alpha) * w / nbr_edge_weight_sum for w, _ in heap_weight_node_pair]
    else:
        # Sum too small, just evenly distribute to every neighbors
        logger.warning("Neighbor weight sum too small, list:", heap_weight_node_pair)
        distributions = [(1.0 - _alpha) / len(heap_weight_node_pair)] * len(heap_weight_node_pair)

    nbr = [x[1] for x in heap_weight_node_pair]
    return distributions + [_alpha], nbr + [node]


def _distribute_densities(source, target):
    """Get the density distributions of source and target node, and the cost (all pair shortest paths) between
    all source's and target's neighbors. Notice that only neighbors with top `_nbr_topk` edge weights.
    Parameters
    ----------
    source : int
        Source node index in Networkit graph `_Gk`.
    target : int
        Target node index in Networkit graph `_Gk`.
    Returns
    -------
    x : (m,) numpy.ndarray
        Source's density distributions, includes source and source's neighbors.
    y : (n,) numpy.ndarray
        Target's density distributions, includes source and source's neighbors.
    d : (m, n) numpy.ndarray
        Shortest path matrix.
    """

    # Distribute densities for source and source's neighbors as x
    t0 = time.time()

    if _Gk.isDirected():
        x, source_topknbr = _get_single_node_neighbors_distributions(source, "predecessors")
    else:
        x, source_topknbr = _get_single_node_neighbors_distributions(source, "successors")

    # Distribute densities for target and target's neighbors as y
    y, target_topknbr = _get_single_node_neighbors_distributions(target, "successors")

    logger.debug("%8f secs density distribution for edge." % (time.time() - t0))

    # construct the cost dictionary from x to y
    t0 = time.time()

    if _shortest_path == "pairwise":
        d = []
        for src in source_topknbr:
            tmp = []
            for tgt in target_topknbr:
                tmp.append(_source_target_shortest_path(src, tgt))
            d.append(tmp)
        d = np.array(d)
    else:  # all_pairs
        d = _apsp[np.ix_(source_topknbr, target_topknbr)]  # transportation matrix

    x = np.array([x]).T  # the mass that source neighborhood initially owned
    y = np.array([y]).T  # the mass that target neighborhood needs to received

    logger.debug("%8f secs density matrix construction for edge." % (time.time() - t0))

    return x, y, d


@lru_cache(_cache_maxsize)
def _source_target_shortest_path(source, target):
    """Compute pairwise shortest path from `source` to `target` by BidirectionalDijkstra via Networkit.
    Parameters
    ----------
    source : int
        Source node index in Networkit graph `_Gk`.
    target : int
        Target node index in Networkit graph `_Gk`.
    Returns
    -------
    length : float
        Pairwise shortest path length.
    """

    length = nk.distance.BidirectionalDijkstra(_Gk, source, target).run().getDistance()
    assert length < 1e300, "Shortest path between %d, %d is not found" % (source, target)
    return length


def _get_all_pairs_shortest_path():
    """Pre-compute all pairs shortest paths of the assigned graph `_Gk`."""
    logger.trace("Start to compute all pair shortest path.")

    global _Gk

    t0 = time.time()
    apsp = nk.distance.APSP(_Gk).run().getDistances()
    logger.trace("%8f secs for all pair by NetworKit." % (time.time() - t0))

    return np.array(apsp)


def _optimal_transportation_distance(x, y, d):
    """Compute the optimal transportation distance (OTD) of the given density distributions by CVXPY.
    Parameters
    ----------
    x : (m,) numpy.ndarray
        Source's density distributions, includes source and source's neighbors.
    y : (n,) numpy.ndarray
        Target's density distributions, includes source and source's neighbors.
    d : (m, n) numpy.ndarray
        Shortest path matrix.
    Returns
    -------
    m : float
        Optimal transportation distance.
    """

    t0 = time.time()
    rho = cvx.Variable((len(y), len(x)))  # the transportation plan rho

    # objective function d(x,y) * rho * x, need to do element-wise multiply here
    obj = cvx.Minimize(cvx.sum(cvx.multiply(np.multiply(d.T, x.T), rho)))

    # \sigma_i rho_{ij}=[1,1,...,1]
    source_sum = cvx.sum(rho, axis=0, keepdims=True)
    constrains = [rho @ x == y, source_sum == np.ones((1, (len(x)))), 0 <= rho, rho <= 1]
    prob = cvx.Problem(obj, constrains)

    m = prob.solve()  # change solver here if you want
    # solve for optimal transportation cost

    logger.debug("%8f secs for cvxpy. \t#source_nbr: %d, #target_nbr: %d" % (time.time() - t0, len(x), len(y)))

    return m


def _sinkhorn_distance(x, y, d):
    """Compute the approximate optimal transportation distance (Sinkhorn distance) of the given density distributions.
    Parameters
    ----------
    x : (m,) numpy.ndarray
        Source's density distributions, includes source and source's neighbors.
    y : (n,) numpy.ndarray
        Target's density distributions, includes source and source's neighbors.
    d : (m, n) numpy.ndarray
        Shortest path matrix.
    Returns
    -------
    m : float
        Sinkhorn distance, an approximate optimal transportation distance.
    """
    t0 = time.time()
    m = ot.sinkhorn2(x, y, d, 1e-1, method='sinkhorn')[0]
    logger.debug(
        "%8f secs for Sinkhorn. dist. \t#source_nbr: %d, #target_nbr: %d" % (time.time() - t0, len(x), len(y)))

    return m


def _average_transportation_distance(source, target):
    """Compute the average transportation distance (ATD) of the given density distributions.
    Parameters
    ----------
    source : int
        Source node index in Networkit graph `_Gk`.
    target : int
        Target node index in Networkit graph `_Gk`.
    Returns
    -------
    m : float
        Average transportation distance.
    """

    t0 = time.time()
    if _Gk.isDirected():
        source_nbr = list(_Gk.iterInNeighbors(source))
    else:
        source_nbr = list(_Gk.iterNeighbors(source))
    target_nbr = list(_Gk.iterNeighbors(target))

    share = (1.0 - _alpha) / (len(source_nbr) * len(target_nbr))
    cost_nbr = 0
    cost_self = _alpha * _apsp[source][target]

    for src in source_nbr:
        for tgt in target_nbr:
            cost_nbr += _apsp[src][tgt] * share

    m = cost_nbr + cost_self  # Average transportation cost

    logger.debug("%8f secs for avg trans. dist. \t#source_nbr: %d, #target_nbr: %d" % (time.time() - t0,
                                                                                       len(source_nbr),
                                                                                       len(target_nbr)))
    return m


def _compute_ricci_curvature_single_edge(source, target):
    """Ricci curvature computation for a given single edge.
    Parameters
    ----------
    source : int
        Source node index in Networkit graph `_Gk`.
    target : int
        Target node index in Networkit graph `_Gk`.
    Returns
    -------
    result : dict[(int,int), float]
        The Ricci curvature of given edge in dict format. E.g.: {(node1, node2): ricciCurvature}
    """
    # logger.debug("EDGE:%s,%s"%(source,target))
    assert source != target, "Self loop is not allowed."  # to prevent self loop

    # If the weight of edge is too small, return 0 instead.
    if _Gk.weight(source, target) < EPSILON:
        logger.warning("Zero weight edge detected for edge (%s,%s), return Ricci Curvature as 0 instead." %
                       (source, target))
        return {(source, target): 0}

    # compute transportation distance
    m = 1  # assign an initial cost
    assert _method in ["OTD", "ATD", "Sinkhorn"], \
        'Method %s not found, support method:["OTD", "ATD", "Sinkhorn"]' % _method
    if _method == "OTD":
        x, y, d = _distribute_densities(source, target)
        m = _optimal_transportation_distance(x, y, d)
    elif _method == "ATD":
        m = _average_transportation_distance(source, target)
    elif _method == "Sinkhorn":
        x, y, d = _distribute_densities(source, target)
        m = _sinkhorn_distance(x, y, d)

    # compute Ricci curvature: k=1-(m_{x,y})/d(x,y)
    result = 1 - (m / _Gk.weight(source, target))  # Divided by the length of d(i, j)
    logger.debug("Ricci curvature (%s,%s) = %f" % (source, target, result))

    return {(source, target): result}


def _wrap_compute_single_edge(stuff):
    """Wrapper for args in multiprocessing."""
    return _compute_ricci_curvature_single_edge(*stuff)


def _compute_ricci_curvature_edges(G: nx.Graph, weight="weight", edge_list=[],
                                   alpha=0.5, method="OTD",
                                   base=math.e, exp_power=2, proc=mp.cpu_count(), chunksize=None, cache_maxsize=1000000,
                                   shortest_path="all_pairs", nbr_topk=1000):
    """Compute Ricci curvature for edges in  given edge lists.
    Parameters
    ----------
    G : NetworkX graph
        A given directional or undirectional NetworkX graph.
    weight : str
        The edge weight used to compute Ricci curvature. (Default value = "weight")
    edge_list : list of edges
        The list of edges to compute Ricci curvature, set to [] to run for all edges in G. (Default value = [])
    alpha : float
        The parameter for the discrete Ricci curvature, range from 0 ~ 1.
        It means the share of mass to leave on the original node.
        E.g. x -> y, alpha = 0.4 means 0.4 for x, 0.6 to evenly spread to x's nbr.
        (Default value = 0.5)
    method : {"OTD", "ATD", "Sinkhorn"}
        The optimal transportation distance computation method. (Default value = "OTD")
        Transportation method:
            - "OTD" for Optimal Transportation Distance,
            - "ATD" for Average Transportation Distance.
            - "Sinkhorn" for OTD approximated Sinkhorn distance.  (faster)
    base : float
        Base variable for weight distribution. (Default value = `math.e`)
    exp_power : float
        Exponential power for weight distribution. (Default value = 0)
    proc : int
        Number of processor used for multiprocessing. (Default value = `cpu_count()`)
    chunksize : int
        Chunk size for multiprocessing, set None for auto decide. (Default value = `None`)
    cache_maxsize : int
        Max size for LRU cache for pairwise shortest path computation.
        Set this to `None` for unlimited cache. (Default value = 1000000)
    shortest_path : {"all_pairs","pairwise"}
        Method to compute shortest path. (Default value = `all_pairs`)
    nbr_topk : int
        Only take the top k edge weight neighbors for density distribution.
        Smaller k run faster but the result is less accurate. (Default value = 1000)
    Returns
    -------
    output : dict[(int,int), float]
        A dictionary of edge Ricci curvature. E.g.: {(node1, node2): ricciCurvature}.
    """

    logger.trace("Number of nodes: %d" % G.number_of_nodes())
    logger.trace("Number of edges: %d" % G.number_of_edges())

    if not nx.get_edge_attributes(G, weight):
        logger.info('Edge weight not detected in graph, use "weight" as default edge weight.')
        for (v1, v2) in G.edges():
            G[v1][v2][weight] = 1.0

    # ---set to global variable for multiprocessing used.---
    global _Gk
    global _alpha
    global _weight
    global _method
    global _base
    global _exp_power
    global _proc
    global _cache_maxsize
    global _shortest_path
    global _nbr_topk
    global _apsp
    # -------------------------------------------------------

    _Gk = nk.nxadapter.nx2nk(G, weightAttr=weight)
    _alpha = alpha
    _weight = weight
    _method = method
    _base = base
    _exp_power = exp_power
    _proc = proc
    _cache_maxsize = cache_maxsize
    _shortest_path = shortest_path
    _nbr_topk = nbr_topk

    # Construct nx to nk dictionary
    nx2nk_ndict, nk2nx_ndict = {}, {}
    for idx, n in enumerate(G.nodes()):
        nx2nk_ndict[n] = idx
        nk2nx_ndict[idx] = n

    if _shortest_path == "all_pairs":
        # Construct the all pair shortest path dictionary
        # if not _apsp:
        _apsp = _get_all_pairs_shortest_path()

    #args = []
    if edge_list:
      # count = 0
      # for source, target in edge_list:
      #   count += 1
      #   if not (count%1000):
      #     print(count)
      #     args += [(nx2nk_ndict[source], nx2nk_ndict[target])]

      args = [(nx2nk_ndict[source], nx2nk_ndict[target]) for source, target in edge_list]
    else:
      # count = 0
      # for source, target in edge_list:
      #   count += 1
      #   if not (count%1000):
      #     print(count)
      #     args += [(nx2nk_ndict[source], nx2nk_ndict[target])]
      args = [(nx2nk_ndict[source], nx2nk_ndict[target]) for source, target in G.edges()]

    # Start compute edge Ricci curvature
    t0 = time.time()

    with mp.get_context('fork').Pool(processes=_proc) as pool:
        # WARNING: Now only fork works, spawn will hang.

        # Decide chunksize following method in map_async
        if chunksize is None:
            chunksize, extra = divmod(len(args), proc * 4)
            if extra:
                chunksize += 1

        # Compute Ricci curvature for edges
        result = pool.imap_unordered(_wrap_compute_single_edge, args, chunksize=chunksize)
        pool.close()
        pool.join()

    # Convert edge index from nk back to nx for final output
    output = {}
    #count = 0
    for rc in result:
        # count += 1
        # if not (count%100):
        #   print(count)
        for k in list(rc.keys()):
            output[(nk2nx_ndict[k[0]], nk2nx_ndict[k[1]])] = rc[k]

    logger.info("%8f secs for Ricci curvature computation." % (time.time() - t0))

    return output


def _compute_ricci_curvature(G: nx.Graph, weight="weight", **kwargs):
    """Compute Ricci curvature of edges and nodes.
    The node Ricci curvature is defined as the average of node's adjacency edges.
    Parameters
    ----------
    G : NetworkX graph
        A given directional or undirectional NetworkX graph.
    weight : str
        The edge weight used to compute Ricci curvature. (Default value = "weight")
    **kwargs
        Additional keyword arguments passed to `_compute_ricci_curvature_edges`.
    Returns
    -------
    G: NetworkX graph
        A NetworkX graph with "ricciCurvature" on nodes and edges.
    """

    # compute Ricci curvature for all edges
    edge_ricci = _compute_ricci_curvature_edges(G, weight=weight, **kwargs)

    # Assign edge Ricci curvature from result to graph G
    nx.set_edge_attributes(G, edge_ricci, "ricciCurvature")

    # Compute node Ricci curvature
    for n in G.nodes():
        rc_sum = 0  # sum of the neighbor Ricci curvature
        if G.degree(n) != 0:
            for nbr in G.neighbors(n):
                if 'ricciCurvature' in G[n][nbr]:
                    rc_sum += G[n][nbr]['ricciCurvature']

            # Assign the node Ricci curvature to be the average of node's adjacency edges
            G.nodes[n]['ricciCurvature'] = rc_sum / G.degree(n)
            logger.debug("node %s, Ricci Curvature = %f" % (n, G.nodes[n]['ricciCurvature']))

    return G




In [37]:
weight="weight"
alpha=0.5
method="Sinkhorn"
base=math.e
exp_power=2
proc=mp.cpu_count()
chunksize=None
shortest_path="all_pairs"
cache_maxsize=1000000
nbr_topk=1000
verbose="ERROR"

In [43]:
data = dataset[0]
data.edge_index = to_undirected(data.edge_index, data.num_nodes)
gdir = to_networkx(data)
len(gdir.edges)

2315598

In [None]:
#gdir = to_networkx(dataset[0])
remaining_edges = list(gdir.edges)
print(len(remaining_edges))
count = 0
curv_dict = {}
while len(remaining_edges):
  count += 1
  print(count)
  if len(remaining_edges) > 12000:
    sampled_edges = remaining_edges[:12000]
    remaining_edges= remaining_edges[12000:]
    sampled_graph = gdir.edge_subgraph(sampled_edges)
    ricci_c = _compute_ricci_curvature_edges(G= sampled_graph, weight= weight,
                                          alpha= alpha, method= method,
                                          base= base, exp_power= exp_power,
                                          proc= 10, chunksize=chunksize, cache_maxsize= cache_maxsize,
                                          shortest_path= shortest_path, nbr_topk= nbr_topk)
    curv_dict.update(ricci_c)
    print(len(sampled_edges))
    # print(len(remaining_edges))
  else:
    sampled_edges = remaining_edges
    remaining_edges= []
    print(len(sampled_edges))
    # print(len(remaining_edges))
    sampled_graph = gdir.edge_subgraph(sampled_edges)
    ricci_c = _compute_ricci_curvature_edges(G= sampled_graph, weight= weight,
                                          alpha= alpha, method= method,
                                          base= base, exp_power= exp_power,
                                          proc= 10, chunksize=chunksize, cache_maxsize= cache_maxsize,
                                          shortest_path= shortest_path, nbr_topk= nbr_topk)
    curv_dict.update(ricci_c)

2315598
1
12000
2
12000
3
12000
4
12000
5
12000
6
12000
7
12000
8
12000
9
12000
10
12000
11
12000
12
12000
13
12000
14
12000
15
12000
16


In [40]:
!pip install ujson
import ujson

Collecting ujson
[?25l  Downloading https://files.pythonhosted.org/packages/17/4e/50e8e4cf5f00b537095711c2c86ac4d7191aed2b4fffd5a19f06898f6929/ujson-4.0.2-cp37-cp37m-manylinux1_x86_64.whl (179kB)
[K     |█▉                              | 10kB 14.9MB/s eta 0:00:01[K     |███▋                            | 20kB 16.2MB/s eta 0:00:01[K     |█████▌                          | 30kB 9.1MB/s eta 0:00:01[K     |███████▎                        | 40kB 9.0MB/s eta 0:00:01[K     |█████████▏                      | 51kB 4.6MB/s eta 0:00:01[K     |███████████                     | 61kB 5.3MB/s eta 0:00:01[K     |████████████▉                   | 71kB 5.6MB/s eta 0:00:01[K     |██████████████▋                 | 81kB 5.7MB/s eta 0:00:01[K     |████████████████▌               | 92kB 6.0MB/s eta 0:00:01[K     |██████████████████▎             | 102kB 4.4MB/s eta 0:00:01[K     |████████████████████▏           | 112kB 4.4MB/s eta 0:00:01[K     |██████████████████████          | 122kB 4

In [44]:
with open('curvature_full.txt', 'w') as file:
    file.write(ujson.dumps(curv_dict))