In [1]:
import numpy as np 
import ot 
import sys 
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from constants import ROOT_DIR

sys.path.append(ROOT_DIR)

from methods import DataIO
from methods import GromovWassersteinFramework
import dev.generate_util as Gen

sns.set()
%matplotlib inline

2023-01-25 15:36:43.073165: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def process_graph(graph, weights=None): 
    """
    Compute the probability distribution and shortest path matrix given a graph. 

    Inputs: 
        - graph: a networkx graph
    
    Outputs: 
        - p: a numpy vector representing the node distributions of nx_graph
        - cost: a cost matrix representing the shortest distance between nodes
    """
    # initialize variables
    probs = np.zeros((len(graph.nodes), 1))
    cost = nx.floyd_warshall_numpy(graph)
    # populate the probaility distribution of the nodes 
    # by the number of neighbors they carry 
    for edge in graph.edges:
        src = edge[0]
        dst = edge[1]
        if weights is None:
            probs[src, 0] += 1
            probs[dst, 0] += 1
        else:
            probs[src, 0] += weights[src, dst]
            probs[dst, 0] += weights[src, dst]
    # normalize to form the probability
    probs /= np.sum(probs)
    return probs.flatten(), cost

In [3]:
def compute_gwd(p_s, p_t, cost_s, cost_t, k=0.01, get_trans=False): 
    """
    Wrapper function for computing GW_Dist
    """
    trans, info = ot.gromov.entropic_gromov_wasserstein(cost_s, cost_t, p_s, p_t, 
                    loss_fun="square_loss", epsilon=k, log=True)
    gw_dist = info["gw_dist"]
    if get_trans: 
        plt.figure(figsize=(20, 20))
        plt.imshow(trans)
        plt.colorbar(fraction=0.046, pad=0.04)
        return gw_dist, trans
    return gw_dist

def compute_gwd2(p_s, p_t, cost_s, cost_t, k=0.01, get_trans=False): 
    """
    Wrapper function for computing GW_Dist
    """
    gw_dist, info = ot.gromov.entropic_gromov_wasserstein2(cost_s, cost_t, p_s, p_t, 
                    loss_fun="square_loss", epsilon=k, log=True)
    trans = info["T"]
    if get_trans: 
        plt.figure(figsize=(20, 20))
        plt.imshow(trans)
        plt.colorbar(fraction=0.046, pad=0.04)
        return gw_dist, trans
    return gw_dist

In [4]:
def compute_gwd_graph(source, target, k=0.01, get_trans=False): 
    """
    Wrapper function for computing GW_Dist
    """
    p_s, cost_s = process_graph(source)
    p_t, cost_t = process_graph(target)
    trans, info = ot.gromov.entropic_gromov_wasserstein(cost_s, cost_t, p_s, p_t, 
                    loss_fun="square_loss", epsilon=k, log=True)
    gw_dist = info["gw_dist"]
    if get_trans: 
        plt.figure(figsize=(20, 20))
        plt.imshow(trans)
        plt.colorbar(fraction=0.046, pad=0.04)
        return gw_dist, trans
    return gw_dist

def compute_gwd2_graph(source, target, k=0.01, get_trans=False): 
    """
    Wrapper function for computing GW_Dist
    """
    p_s, cost_s = process_graph(source)
    p_t, cost_t = process_graph(target)
    gw_dist, info = ot.gromov.entropic_gromov_wasserstein2(cost_s, cost_t, p_s, p_t, 
                    loss_fun="square_loss", epsilon=k, log=True)
    trans = info["T"]
    if get_trans: 
        plt.figure(figsize=(20, 20))
        plt.imshow(trans)
        plt.colorbar(fraction=0.046, pad=0.04)
        return gw_dist, trans
    return gw_dist

In [5]:
def gwd_growth_epsilon(graph_s, graph_t, params, title=None, verbose=False): 
    """
    Visualization of GW distance growth over epsilon values to test numerical instability.
    Displays a plot, no return value. 

    Input: 
        - graph_s: a networkx graph, represents the source graph
        - graph_t: a networkx graph, represents the target graph
        - params: a vector of epsilon values to try
    """
    # ensure that the params are ordered
    params.sort()
    gw_dists = np.full(len(params), 0)
    p_s, cost_s = process_graph(graph_s)
    p_t, cost_t = process_graph(graph_t)
    for i, k in enumerate(params): 
        if verbose: 
            print(f"Processing epsilon = {k}.....")
        try: 
            gw_dists[i] = compute_gwd(p_s, p_t, cost_s, cost_t, k)
        except RuntimeError: 
            print(f"Encounter divide-by-zero error at epsilon = {k}")
    # visualize 
    plt.plot(params, gw_dists)
    plt.ylabel("Gromov-Wasserstein Distance")
    plt.xlabel("Epsilon")
    if title is not None: 
        plt.title(title)
    else: 
        plt.title("Growth of GW Dist vs. Epsilon")