This document contains 3 implementations of computing the 0-dimensional PH of the product of 2 edge-based filtrations:

1. Computing PH of the product directly using an union-find method.
2. Theorem 5 of our paper.
3. Computing PH of the product directly using the gudhi library.

We then assess their runtime performance on pairs of graphs loaded with (all of) the BREC dataset.

NOTE: For convenience, we modeled time t = +infinity as -1 in the code file here, and the filtration values are assumed to be positive in the algorithm. We also precomputed the values at time t = -infinity appropriately.

In [1]:
# Uncomment to install the gudhi package
!pip install gudhi

import torch
import itertools
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# import random
import gudhi
import math
import time



In [2]:
# The code implementation of the Union Find PH Based Method
# The code is adapted from codes produced in the paper:
# "Going beyond persistent homology using persistent homology"
# by Johanna Immonen, Amauri H. Souza, and Vikas Garg.

class UnionFind():
    def __init__(self,N):
        self._parents = list(range(0, N))

    def find(self, p):
        return(self._parents[p])

    def merge(self, p, q):
        root_p, root_q = self._parents[p], self._parents[q]
        for i in range(0, len(self._parents)):
            if(self._parents[i] == root_p):
                self._parents[i] = root_q

    def connected(self,p,q):
        return self._parents[p] == self._parents[q]

    def roots(self):
        roots = []
        for i in range(0, len(self._parents)):
            if self._parents[i] == i:
                roots.append(i)
        return roots


def persistence_routine(filtered_v, filtered_e, edge_indices):

    n, m = filtered_v.shape[0], edge_indices.shape[1]

    filtered_e, sorted_indices = torch.sort(filtered_e)
    uf = UnionFind(n)
    persistence = torch.zeros((n, 2), device=filtered_v.device)
    persistence1 = torch.zeros((m, 2), device=filtered_v.device)

    unpaired_value = filtered_e[-1]  # used as infinity

    for edge_index, edge_weight in zip(sorted_indices, filtered_e):

        nodes = edge_indices[:, edge_index]
        younger = uf.find(nodes[0])
        older = uf.find(nodes[1])
        if younger == older:
            persistence1[edge_index, 0] = edge_weight # filtered_e[edge_index]
            persistence1[edge_index, 1] = -1
            continue
        else:
            if filtered_v[younger] < filtered_v[older]:
                younger, older = older, younger
                nodes = torch.flip(nodes, [0])

        persistence[younger, 0] = filtered_v[younger]
        persistence[younger, 1] = edge_weight

        uf.merge(nodes[0], nodes[1])

    for root in uf.roots():
        persistence[root, 0] = filtered_v[root]
        persistence[root, 1] = -1

    return persistence, persistence1

In [3]:
# Theorem 5 Implementation

# Implementation of Theorem 5 of our paper.
# Input: Takes in a graph data item of the form of a tuple
# (graph_data_item 0, graph_data_item 1)

# For i = 0 or 1, graph_data_item i is composed of 4 items for a graph G of n + 1 vertices with a vertex filtration function fv
# The vertex set of the graph is given by {0, 1, ..., n}.
# The 5 items in graph_data_item are:

# input_v - a list of numbers that specifies the value of fv on the index i (this is not used in Theorem 5, they are always zero for edge filtrations)
# input_e - a list of numbers that specifies the values of fe on the edge indices, ordered according to edge_index
# edge_index - list of edges of the graph, being tuples of the form (i,j) for vertices i and j (made in pytorch)
# filtered_v - torch.Tensor(input_v)
# filtered_e = torch.Tensor(input_e)

def thm_5_pd(graph_data_item):
  input_v_G, input_e_G, edge_index_G, filtered_v_G, filtered_e_G = graph_data_item[0]
  input_v_H, input_e_H, edge_index_H, filtered_v_H, filtered_e_H = graph_data_item[1]

  # Precompute the Persistence Pairs using the Union Find Code Above
  implG = persistence_routine(filtered_v_G, filtered_e_G, edge_index_G)[0]
  implH = persistence_routine(filtered_v_H, filtered_e_H, edge_index_H)[0]

  #  Both should have the same filtration steps, this uniformizes their time steps
  filtration_steps = torch.unique(torch.cat((filtered_e_G, filtered_e_H))).tolist()
  filtration_steps.sort()
  filtration_steps.append(-1.0)
  # print(filtration_steps)

  # Initialize lists for number of vertices that dies non-trivially at each time
  G_nd = [0]
  H_nd = [0]
  nG = filtered_v_G.shape[0]
  nH = filtered_v_H.shape[0]

  # Initialize lists for number of vertices still alive at each time
  G_sa = [nG]
  H_sa = [nH]

  # Precompute the two respective lists
  for i in range(0, len(filtration_steps)):
    a_i = filtration_steps[i]

    # Note for edge filtrations there are no trivial deaths of vertices
    G_deaths = len([int(x[0].item()) for x in implG if x[1] == a_i])
    H_deaths = len([int(x[0].item()) for x in implH if x[1] == a_i])

    G_nd.append(G_deaths)
    H_nd.append(H_deaths)

    G_sa.append(G_sa[i] - G_deaths)
    H_sa.append(H_sa[i] - H_deaths)

  # Now we start computing 0-dim PH of the Product
  persistence = []
  for i in range(0, len(filtration_steps)):
    a_i = filtration_steps[i]
    hb_prev = H_sa[i]
    gb_current = G_sa[i+1]

    hd_current = H_nd[i+1]
    gd_current = G_nd[i+1]

    num = hb_prev*gd_current + gb_current*hd_current
    persistence += [[0.0, a_i] for i in range(0, num)]

  return persistence

In [4]:
# Naive PD List

# Given a pair (i, j), thought of as a node in the product G \Box H, and nH, the number of vertices in H
# Returns the index (i, j) has if the matrix of vertices is flattened into a single list.
def prod(pair, nH):
  x = pair[0]
  y = pair[1]
  return nH*x + y

# Naive implementation by computing PH of the whole product using the union-find structure
# The inputs are - graph_data_item (explained in the implementation for Theorem 5)
# P - the networkx graph of the product of the graph G and the graph H
def naive_pd_edge(graph_data_item, P):
  input_v_G, input_e_G, edge_index_G, filtered_v_G, filtered_e_G = graph_data_item[0]
  input_v_H, input_e_H, edge_index_H, filtered_v_H, filtered_e_H = graph_data_item[1]

  prod_vert_index = []
  prod_vert_value = []
  prod_edge_index = [[], []]
  prod_edge_value = []

  base_edge_index_G = np.transpose(edge_index_G.numpy()).tolist()
  base_edge_index_H = np.transpose(edge_index_H.numpy()).tolist()
  for item in base_edge_index_G:
    item.sort()
  for item in base_edge_index_H:
    item.sort()

  # print(base_edge_index_G)
  # print(base_edge_index_H)
  # print("-----------------")

  nG = filtered_v_G.shape[0]
  nH = filtered_v_H.shape[0]

  for i in range(0, nG):
    for j in range(0, nH):
      prod_vert_index.append((i,j))
      prod_vert_value.append(0)
  # print(len(prod_vert_index))
  for v,w in P.edges:
    prod_edge_index[0].append(prod(v, nH))
    prod_edge_index[1].append(prod(w, nH))

    if v[0] == w[0]:
      input_index = base_edge_index_H.index([v[1],w[1]])
      edge_value = input_e_H[input_index]
    else:
      input_index = base_edge_index_G.index([v[0],w[0]])
      edge_value = input_e_G[input_index]
    prod_edge_value.append(edge_value)

  prod_edge_index_torch = torch.Tensor(prod_edge_index).long()
  prod_filtered_v = torch.Tensor(prod_vert_value)
  prod_filtered_e = torch.Tensor(prod_edge_value)

  # Computes the PH on the entire product using the union-find code
  prod_ph = persistence_routine(prod_filtered_v, prod_filtered_e, prod_edge_index_torch)[0].tolist()

  return prod_ph

In [5]:
# Naive implementation by computing PH of the whole product using the gudhi library
# The inputs are - graph_data_item (explained in the implementation for Theorem 5)
# P - the networkx graph of the product of the graph G and the graph H
def gudhi_pd_edge(graph_data_item, P):
  input_v_G, input_e_G, edge_index_G, filtered_v_G, filtered_e_G = graph_data_item[0]
  input_v_H, input_e_H, edge_index_H, filtered_v_H, filtered_e_H = graph_data_item[1]

  st = gudhi.SimplexTree()

  prod_vert_index = []

  base_edge_index_G = np.transpose(edge_index_G.numpy()).tolist()
  base_edge_index_H = np.transpose(edge_index_H.numpy()).tolist()
  for item in base_edge_index_G:
    item.sort()
  for item in base_edge_index_H:
    item.sort()

  # print(base_edge_index_G)
  # print(base_edge_index_H)
  # print("-----------------")

  nG = filtered_v_G.shape[0]
  nH = filtered_v_H.shape[0]

  for i in range(0, nG):
    for j in range(0, nH):
      prod_vert_index.append((i,j))
      st.insert([prod((i,j), nH)], filtration=0)

  for v,w in P.edges:
    if v[0] == w[0]:
      input_index = base_edge_index_H.index([v[1],w[1]])
      edge_value = input_e_H[input_index]
    else:
      input_index = base_edge_index_G.index([v[0],w[0]])
      edge_value = input_e_G[input_index]
    st.insert([prod(v, nH), prod(w, nH)], filtration=edge_value)

  st.make_filtration_non_decreasing()
  dgms = st.persistence(min_persistence=-1)

  dgms_no_dim = [list(x[1]) for x in dgms]
  for item in dgms_no_dim:
    _, d = item
    if d == np.inf:
      item[1] = -1.0
  return dgms_no_dim



The rest of the Code computes the runtime of all 3 implementations on each dataset in the BREC dataset.

We take the BREC dataset (see the paper for specifications) and compute the 0-th dim PH of the product of pair of graphs in the BREC dataset, using an augmented Forman-Ricci-curvature filtration on the two components.

In [6]:
# Loading the dataset
from google.colab import drive
drive.mount('/content/drive')

def get_dataset(name):
    dataset = []
    if name in ['basic', 'str', 'dr', '4vtx']:
        # When you are running this, replace the following line with a link to the relevant file in the BREC dataset.
        data = np.load(f'/content/drive/My Drive/Colab Notebooks/datasets/BREC/{name}.npy')
        for i in range(0, data.size, 2):
            if name == 'dr':
                pyg_graph_1 =  nx.from_graph6_bytes(data[i])
                pyg_graph_2 =  nx.from_graph6_bytes(data[i+1])
            else:
                pyg_graph_1 =  nx.from_graph6_bytes(data[i].encode())
                pyg_graph_2 =  nx.from_graph6_bytes(data[i+1].encode())
            dataset.append((pyg_graph_1, pyg_graph_2))
        return dataset
    elif name in ['regular', 'extension', 'cfi']:
        # When you are running this, replace the following line with a link to the relevant file in the BREC dataset.
        data = np.load(f'/content/drive/My Drive/Colab Notebooks/datasets/BREC/{name}.npy')
        for i in range(0, data.size // 2):
            g6_tuple = data[i]
            if name == 'regular' or name == 'cfi':
                pyg_graph_1 = nx.from_graph6_bytes(g6_tuple[0])
                pyg_graph_2 = nx.from_graph6_bytes(g6_tuple[1])
            else:
                pyg_graph_1 = nx.from_graph6_bytes(g6_tuple[0].encode())
                pyg_graph_2 = nx.from_graph6_bytes(g6_tuple[1].encode())
            dataset.append((pyg_graph_1, pyg_graph_2))
        return dataset

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
# Converting the dataset inputs to entries for the 3 algorithms
# Here we filtrate the two graphs using an augmented Forman-Ricci curvature (as edge filtrations)

def dataset_entry_to_input(item):
  G, H = item

  G_edge_list = [[], []]
  H_edge_list = [[], []]
  for v0, v1 in G.edges:
    G_edge_list[0].append(v0)
    G_edge_list[1].append(v1)

  for v0, v1 in H.edges:
    H_edge_list[0].append(v0)
    H_edge_list[1].append(v1)

  G_nodes = list(G.nodes)
  H_nodes = list(H.nodes)

  G_edges = list(G.edges)
  H_edges = list(H.edges)

  input_v_G = [0 for _ in range(0, len(G_nodes))]
  input_v_H = [0 for _ in range(0, len(H_nodes))]

  input_e_G = []
  input_e_H = []

  # augmented Formanâ€“Ricci curvatures
  for u, w in G_edges:
    u_neighbors = list(G.neighbors(u))
    w_neighbors = list(G.neighbors(w))
    intersection = list(set(u_neighbors).intersection(set(w_neighbors)))
    input_e_G.append(4 - len(u_neighbors) - len(w_neighbors) + 3*len(intersection) + 2*len(G_nodes))
  for u, w in H_edges:
    u_neighbors = list(H.neighbors(u))
    w_neighbors = list(H.neighbors(w))
    intersection = list(set(u_neighbors).intersection(set(w_neighbors)))
    input_e_H.append(4 - len(u_neighbors) - len(w_neighbors) + 3*len(intersection) + 2*len(H_nodes))

  edge_index_G = torch.Tensor(G_edge_list).long()
  filtered_v_G = torch.Tensor(input_v_G)
  filtered_e_G = torch.Tensor(input_e_G)

  edge_index_H = torch.Tensor(H_edge_list).long()
  filtered_v_H = torch.Tensor(input_v_H)
  filtered_e_H = torch.Tensor(input_e_H)

  return ((input_v_G, input_e_G, edge_index_G, filtered_v_G, filtered_e_G), (input_v_H, input_e_H, edge_index_H, filtered_v_H, filtered_e_H))

In [8]:
# Preparing Dataset Entries for All of Them
# dataset_names = ['basic', 'str']
dataset_names = ['basic', 'str', 'dr', '4vtx', 'regular', 'extension', 'cfi']

raw_data_list = []
graph_data_list = []
for name in dataset_names:
  dataset = get_dataset(name)
  print(name, ": ", len(dataset))

  data_list = []
  for item in dataset:
    entry_gd = dataset_entry_to_input(item)
    data_list.append(entry_gd)

  raw_data_list.append(dataset)
  graph_data_list.append(data_list)

basic :  60
str :  50
dr :  20
4vtx :  20
regular :  50
extension :  100
cfi :  100


In [9]:
# Computing the results for the 3 implementations
thm5_time_list = []
naive_time_list = []
gudhi_time_list = []

thm_5_pd_list = []
naive_pd_e_list = []
gudhi_pd_e_list = []

for i in range(0, len(dataset_names)):
  data_list = graph_data_list[i]
  dataset = raw_data_list[i]

  # Theorem 5
  thm5_t = time.time()
  for gditem in data_list:
    thm_5_pd_list.append(thm_5_pd(gditem))
  thm5_time = time.time() - thm5_t
  thm5_time_list.append(thm5_time)

  # Compute the actual NetworkX graph of the product of the two graphs
  P_list = []
  for item in dataset:
    G, H = item
    P = nx.cartesian_product(G, H)
    P_list.append(P)

  # Naive Union-Find
  naive_t = time.time()
  for ind in range(0, len(data_list)):
    gditem = data_list[ind]
    naive_pd_e_list.append(naive_pd_edge(gditem, P_list[ind]))
  naive_time = time.time() - naive_t
  naive_time_list.append(naive_time)

  # gudhi
  gd_t = time.time()
  for ind in range(0, len(data_list)):
    gditem = data_list[ind]
    gudhi_pd_e_list.append(gudhi_pd_edge(gditem, P_list[ind]))
  gudhi_time = time.time() - gd_t
  gudhi_time_list.append(gudhi_time)


  print(dataset_names[i])
  print("Theorem 5: ", thm5_time)
  print("Union Find: ", naive_time)
  print("gudhi: ", gudhi_time)
  print("-------------------------")

basic
Theorem 5:  0.8054904937744141
Union Find:  2.7030720710754395
gudhi:  0.11327505111694336
-------------------------
str
Theorem 5:  2.2883822917938232
Union Find:  32.799060583114624
gudhi:  3.503389596939087
-------------------------
dr
Theorem 5:  0.3792917728424072
Union Find:  22.13980793952942
gudhi:  1.7820289134979248
-------------------------
4vtx
Theorem 5:  1.3597157001495361
Union Find:  149.3400056362152
gudhi:  42.74673390388489
-------------------------
regular
Theorem 5:  0.19175100326538086
Union Find:  0.9566614627838135
gudhi:  0.05453777313232422
-------------------------
extension
Theorem 5:  0.4298710823059082
Union Find:  2.8201935291290283
gudhi:  0.1920773983001709
-------------------------
cfi
Theorem 5:  5.464076519012451
Union Find:  780.0260004997253
gudhi:  33.380260705947876
-------------------------


In [10]:
# Helper functions to check if two persistence diagrams are equal
def pd_to_multiset(pd_list):
  output = {}
  for item in pd_list:
    output[str(item)] = 0
  for item in pd_list:
    output[str(item)] += 1
  return output

def check_pd_equal(pd1, pd2):
  pd1_list = pd_to_multiset(pd1)
  pd2_list = pd_to_multiset(pd2)

  return pd1_list, pd2_list, pd1_list == pd2_list

In [11]:
# Checking that the outputs match each other
for i in range(0, len(thm_5_pd_list)):
  thm4_output = thm_5_pd_list[i]
  naive_output = naive_pd_e_list[i]
  gudhi_output = gudhi_pd_e_list[i]

  left_list, middle_list, value = check_pd_equal(thm4_output, naive_output)
  middle_list, right_list, value1 = check_pd_equal(naive_output, gudhi_output)

  if not (value and value1):
    print("False: ", i)
    print("Thm4 = Naive", value)
    print("Naive = gudhi", value1)
    print(left_list)
    print(middle_list)
    print(right_list)
    print("--------------------------------------------")