This document contains 3 implementations of computing the 0-dimensional PH of the product of 2 vertex-based filtrations:

1. Computing PH of the product directly using an union-find method.
2. Theorem 4 of our paper.
3. Computing PH of the product directly using the gudhi library.

We then compute them on pairs of graphs loaded with the BREC dataset.

NOTE: For convenience, we modeled time t = +infinity as -1 in the code file here, and the filtration values are assumed to be positive in the algorithm. We also precomputed the values at time t = -infinity appropriately.

In [None]:
# Uncomment to install the gudhi package
# !pip install gudhi
import torch
import itertools
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# import random
import gudhi
import math
from itertools import product
import time

In [None]:
# The code implementation of the Union Find PH Based Method
# The code is adapted from codes produced in the paper:
# "Going beyond persistent homology using persistent homology"
# by Johanna Immonen, Amauri H. Souza, and Vikas Garg.

class UnionFind():
    def __init__(self,N):
        self._parents = list(range(0, N))

    def find(self, p):
        return(self._parents[p])

    def merge(self, p, q):
        root_p, root_q = self._parents[p], self._parents[q]
        for i in range(0, len(self._parents)):
            if(self._parents[i] == root_p):
                self._parents[i] = root_q

    def connected(self,p,q):
        return self._parents[p] == self._parents[q]

    def roots(self):
        roots = []
        for i in range(0, len(self._parents)):
            if self._parents[i] == i:
                roots.append(i)
        return roots


def persistence_routine(filtered_v, filtered_e, edge_indices):

    n, m = filtered_v.shape[0], edge_indices.shape[1]

    filtered_e, sorted_indices = torch.sort(filtered_e)
    uf = UnionFind(n)
    persistence = torch.zeros((n, 3), device=filtered_v.device)
    persistence1 = torch.zeros((m, 2), device=filtered_v.device)

    unpaired_value = filtered_e[-1]  # used as infinity

    for edge_index, edge_weight in zip(sorted_indices, filtered_e):

        nodes = edge_indices[:, edge_index]
        younger = uf.find(nodes[0])
        older = uf.find(nodes[1])
        if younger == older:
            persistence1[edge_index, 0] = edge_weight # filtered_e[edge_index]
            persistence1[edge_index, 1] = -1
            continue
        else:
            if filtered_v[younger] < filtered_v[older]:
                younger, older = older, younger
                nodes = torch.flip(nodes, [0])

        persistence[younger, 0] = younger
        persistence[younger, 1] = filtered_v[younger]
        persistence[younger, 2] = edge_weight

        uf.merge(nodes[0], nodes[1])

    for root in uf.roots():
        persistence[root, 0] = root
        persistence[root, 1] = filtered_v[root]
        persistence[root, 2] = -1

    return persistence, persistence1

In [None]:
# Utility function to print the output of the persistence diagram in readable way
def pretty_print(pd_with_labels):
  for elt in pd_with_labels:
    for i, j, birth, death in elt:
      if birth == 2.0:
        print("Label: " + str((i, j)) + " B/D: " + str((birth, death)))

# Implementation of Theorem 4 of our paper.
# Input: Takes in a graph data item of the form of a tuple
# (graph_data_item 0, graph_data_item 1)

# For i = 0 or 1, graph_data_item i is composed of 4 items for a graph G of n + 1 vertices with a vertex filtration function fv
# The vertex set of the graph is given by {0, 1, ..., n}.
# The 4 items in graph_data_item are:

# input_v - a list of numbers that specifies the value of fv on the index i (this is not used in the algorithm)
# edge_index - list of edges of the graph, being tuples of the form (i,j) for vertices i and j (made in pytorch)
# filtered_v - torch.Tensor(input_v)
# filtered_e - The value of fv on edges, this can be made by the line:
# filtered_e, _ = torch.max(torch.stack((filtered_v[edge_index[0]], filtered_v[edge_index[1]])), axis=0)

def thm4_ph(graph_data_item):
  _, edge_index_G, filtered_v_G, filtered_e_G = graph_data_item[0]
  _, edge_index_H, filtered_v_H, filtered_e_H = graph_data_item[1]

  # Precompute the Persistence Pairs using the Union Find Code Above
  impl1 = persistence_routine(filtered_v_G, filtered_e_G, edge_index_G)[0]
  impl2 = persistence_routine(filtered_v_H, filtered_e_H, edge_index_H)[0]

  # Both should have the same filtration steps, this uniformizes their time steps
  filtration_steps = torch.unique(torch.cat((filtered_v_G, filtered_v_H))).tolist()
  filtration_steps.sort()
  filtration_steps.append(-1)

  # Initialize non-trivial deaths list
  G_nd = torch.zeros(len(filtration_steps)+1).tolist()
  H_nd = torch.zeros(len(filtration_steps)+1).tolist()
  G_nd[0] = (0, [])
  H_nd[0] = (0, [])

  # Initialize the trivial deaths list
  G_td = torch.zeros(len(filtration_steps)+1).tolist()
  H_td = torch.zeros(len(filtration_steps)+1).tolist()
  G_td[0] = (0, [])
  H_td[0] = (0, [])

  # Initialize the non-trivial births list
  G_nb = torch.zeros(len(filtration_steps)+1).tolist()
  H_nb = torch.zeros(len(filtration_steps)+1).tolist()
  G_nb[0] = (0, [])
  H_nb[0] = (0, [])

  # Precompute the non-trivial deaths, trivial deaths, and the non-trivial births
  for i in range(0, len(filtration_steps)):
    a_i = filtration_steps[i]

    G_deaths = [int(x[0].item()) for x in impl1 if x[2] == a_i and x[1] != x[2]]
    H_deaths = [int(x[0].item()) for x in impl2 if x[2] == a_i and x[1] != x[2]]
    G_nd[i+1] = (len(G_deaths), G_deaths)
    H_nd[i+1] = (len(H_deaths), H_deaths)

    G_trivial_deaths = [int(x[0].item()) for x in impl1 if x[2] == a_i and x[1] == x[2]]
    H_trivial_deaths = [int(x[0].item()) for x in impl2 if x[2] == a_i and x[1] == x[2]]
    G_td[i+1] = (len(G_trivial_deaths), G_trivial_deaths)
    H_td[i+1] = (len(H_trivial_deaths), H_trivial_deaths)


    G_births = [int(x[0].item()) for x in impl1 if x[1] == a_i and x[1] != x[2]]
    H_births = [int(x[0].item()) for x in impl2 if x[1] == a_i and x[1] != x[2]]

    G_nb[i+1] = (len(G_births), G_births)
    H_nb[i+1] = (len(H_births), H_births)

  # Makes a count of the zeroth Betti number list
  G_b0 = torch.zeros(len(filtration_steps)+1).tolist()
  H_b0 = torch.zeros(len(filtration_steps)+1).tolist()
  G_b0[0] = (0, [])
  H_b0[0] = (0, [])

  for i in range(0, len(filtration_steps)):
    G_count = G_b0[i][0] + G_nb[i+1][0] - G_nd[i+1][0]
    H_count = H_b0[i][0] + H_nb[i+1][0] - H_nd[i+1][0]

    G_list = list(set(G_b0[i][1] + G_nb[i+1][1]).difference(set(G_nd[i+1][1])))
    H_list = list(set(H_b0[i][1] + H_nb[i+1][1]).difference(set(H_nd[i+1][1])))

    G_b0[i+1] = (G_count, G_list)
    H_b0[i+1] = (H_count, H_list)

  # Now we start computing 0-dim PH of the Product
  nG = filtered_v_G.shape[0]
  nH = filtered_v_H.shape[0]

  # Initialize Persistence
  persistence = []
  for i in range(0, nG):
    input = []
    for j in range(0, nH):
      input.append([i, j, torch.max(filtered_v_G[i], filtered_v_H[j]).item(), None])
    persistence.append(input)

  # Create the non-trivial non-permanent holes
  for i in range(0, len(filtration_steps)):
    a_i = filtration_steps[i]
    g_count, g_deaths = G_nd[i+1]
    h_count, h_deaths = H_nd[i+1]

    _, g_t_deaths = G_td[i+1]
    _, h_t_deaths = H_td[i+1]

    _, g_births = G_nb[i+1]
    _, h_births = H_nb[i+1]

    _, g_betti0 = G_b0[i+1]
    _, h_betti0 = H_b0[i+1]

    # Mark Non-trivial Deaths
    for v in h_deaths:
      for ell in range(0, len(filtered_v_G)):
        item = persistence[ell][v]
        if item[3] == None and (item[2] < a_i or a_i == -1): # ie. the tuple has not been marked dead yet
          item[3] = a_i
    for w in g_deaths:
      for k in range(0, len(filtered_v_H)):
        item = persistence[w][k]
        if item[3] == None and (item[2] < a_i or a_i == -1): # ie. the tuple has not been marked dead yet
          item[3] = a_i

    _, h_betti_prev = H_b0[i]
    _, g_betti_prev = G_b0[i]

    non_trivial_births_current = list(product(g_births, h_betti_prev)) + list(product(g_betti0, h_births))
    non_trivial_deaths_second = list(product(g_births, h_deaths))
    # Using the following list also gives the same output
    # non_trivial_deaths_second = list(product(g_betti0, h_deaths))
    non_trivial_births_current = list(set(non_trivial_births_current).difference(set(non_trivial_deaths_second)))

    # There is an alternative way to compute the non_trivial_births_current as follows:
    # non_trivial_births_current = list(product(g_births, h_betti_prev)) + list(product(g_betti0, h_births))
    # non_trivial_births_current_alt = list(product(g_betti_prev, h_births)) + list(product(g_births, h_betti0))
    # non_trivial_births_current = list(set(non_trivial_births_current).intersection(set(non_trivial_births_current_alt)))

    # Mark Trivial Deaths
    for a in range(0, len(filtered_v_G)):
      for b in range(0, len(filtered_v_H)):
        item = persistence[a][b]
        if item[3] == None and item[2] == a_i and (a, b) not in non_trivial_births_current:
          item[3] = a_i

  # Forget about the labels to the persistence pairs at the end
  projected_persistence = list(itertools.chain(*[[[float(x[2]), float(x[3])] for x in y] for y in persistence]))
  return projected_persistence


In [None]:
# Utility Function - prints the output of persistent diagrams for gudhi and naive on G \Box H. nH is the number of vertices in H.
def pretty_print_prod(prod_pd_with_labels, nH):
  for elt in prod_pd_with_labels:
    i, birth, death = elt
    y = i % nH
    x = (i-y)/nH
    print("Label: " + str((x,y)) + ", B/D: " + str((birth, death)))


# Given a pair (i, j), thought of as a node in the product G \Box H, and nH, the number of vertices in H
# Returns the index (i, j) has if the matrix of vertices is flattened into a single list.
def prod(pair, nH):
  x = pair[0]
  y = pair[1]
  return nH*x + y

# Naive implementation by computing PH of the whole product using the union-find structure
# The inputs are - graph_data_item (explained in the implementation for Theorem 4)
# P - the networkx graph of the product of the graph G and the graph H
def naive_pd(graph_data_item, P):
  input_v_G, edge_index_G, filtered_v_G, filtered_e_G = graph_data_item[0]
  input_v_H, edge_index_H, filtered_v_H, filtered_e_H = graph_data_item[1]

  prod_vert_index = []
  prod_vert_value = []
  prod_edge_index = [[], []]

  base_edge_index_G = np.transpose(edge_index_G.numpy()).tolist()
  base_edge_index_H = np.transpose(edge_index_H.numpy()).tolist()
  for item in base_edge_index_G:
    item.sort()
  for item in base_edge_index_H:
    item.sort()

  # print(base_edge_index_G)
  # print(base_edge_index_H)
  # print("-----------------")

  nG = filtered_v_G.shape[0]
  nH = filtered_v_H.shape[0]

  for i in range(0, nG):
    for j in range(0, nH):
      prod_vert_index.append((i,j))
      prod_vert_value.append(max(input_v_G[i], input_v_H[j]))
  # print(len(prod_vert_index))

  for v,w in P.edges:
    prod_edge_index[0].append(prod(v, nH))
    prod_edge_index[1].append(prod(w, nH))


  prod_edge_index_torch = torch.Tensor(prod_edge_index).long()
  prod_filtered_v = torch.Tensor(prod_vert_value)
  prod_filtered_e, _ = torch.max(torch.stack((prod_filtered_v[prod_edge_index_torch[0]], prod_filtered_v[prod_edge_index_torch[1]])), axis=0)

  # print(prod_filtered_e)

  # Computes the PH on the entire product using the union-find code
  prod_ph = persistence_routine(prod_filtered_v, prod_filtered_e, prod_edge_index_torch)[0]
  projected_prod_ph = [[x for x in y] for y in prod_ph[:, 1:].tolist()]

  return projected_prod_ph

In [None]:
# Naive implementation by computing PH of the whole product using the gudhi library
# The inputs are - graph_data_item (explained in the implementation for Theorem 4)
# P - the networkx graph of the product of the graph G and the graph H
def gudhi_prod(graph_data_item, P):
  input_v_G, edge_index_G, filtered_v_G, filtered_e_G = graph_data_item[0]
  input_v_H, edge_index_H, filtered_v_H, filtered_e_H = graph_data_item[1]

  st = gudhi.SimplexTree()

  nG = filtered_v_G.shape[0]
  nH = filtered_v_H.shape[0]

  prod_vert_index = []
  for i in range(0, nG):
    for j in range(0, nH):
      val_i = torch.max(filtered_v_G[i], filtered_v_H[j]).item()
      st.insert([prod((i,j), nH)], filtration=val_i)
      prod_vert_index.append((i,j))

  base_edge_index_G = np.transpose(edge_index_G.numpy()).tolist()
  base_edge_index_H = np.transpose(edge_index_H.numpy()).tolist()
  for item in base_edge_index_G:
    item.sort()
  for item in base_edge_index_H:
    item.sort()

  for v,w in P.edges:
      # print((v,w))
      val_v = torch.max(filtered_v_G[v[0]], filtered_v_H[v[1]]).item()
      val_w = torch.max(filtered_v_G[w[0]], filtered_v_H[w[1]]).item()
      st.insert([prod(v, nH), prod(w, nH)], filtration=max(val_v, val_w))

  st.make_filtration_non_decreasing()
  dgms = st.persistence(min_persistence=-1)

  dgms_no_dim = [list(x[1]) for x in dgms]
  for item in dgms_no_dim:
    _, d = item
    if d == np.inf:
      item[1] = -1.0
  return dgms_no_dim

The rest of the Code does a comparison testing between the 3 implementations to see that they agree with each other. We take the BREC dataset (see the paper for specifications) and compute the 0-th dim PH of the product of pair of graphs in the BREC dataset, using a degree-based filtration on the two components.

In [None]:
# Loading the dataset
from google.colab import drive
drive.mount('/content/drive')

def get_dataset(name):
    dataset = []
    if name in ['basic', 'str', 'dr', '4vtx']:
        # When you are running this, replace the following line with a link to the relevant file in the BREC dataset.
        data = np.load(f'/content/drive/My Drive/Colab Notebooks/datasets/BREC/{name}.npy')
        for i in range(0, data.size, 2):
            if name == 'dr':
                pyg_graph_1 =  nx.from_graph6_bytes(data[i])
                pyg_graph_2 =  nx.from_graph6_bytes(data[i+1])
            else:
                pyg_graph_1 =  nx.from_graph6_bytes(data[i].encode())
                pyg_graph_2 =  nx.from_graph6_bytes(data[i+1].encode())
            dataset.append((pyg_graph_1, pyg_graph_2))
        return dataset
    elif name in ['regular', 'extension', 'cfi']:
        # When you are running this, replace the following line with a link to the relevant file in the BREC dataset.
        data = np.load(f'/content/drive/My Drive/Colab Notebooks/datasets/BREC/{name}.npy')
        for i in range(0, data.size // 2):
            g6_tuple = data[i]
            if name == 'regular' or name == 'cfi':
                pyg_graph_1 = nx.from_graph6_bytes(g6_tuple[0])
                pyg_graph_2 = nx.from_graph6_bytes(g6_tuple[1])
            else:
                pyg_graph_1 = nx.from_graph6_bytes(g6_tuple[0].encode())
                pyg_graph_2 = nx.from_graph6_bytes(g6_tuple[1].encode())
            dataset.append((pyg_graph_1, pyg_graph_2))
        return dataset

dataset = get_dataset('basic')
print(len(dataset))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
60


In [None]:
# Converting the dataset inputs to entries for the 3 algorithms
# Here we filtrate the two graphs using a degree-based vertex filtration.
def dataset_entry_to_input(item):
  G, H = item
  G_edge_list = [[], []]
  H_edge_list = [[], []]
  for v0, v1 in G.edges:
    G_edge_list[0].append(v0)
    G_edge_list[1].append(v1)

  for v0, v1 in H.edges:
    H_edge_list[0].append(v0)
    H_edge_list[1].append(v1)

  G_nodes = list(G.nodes)
  H_nodes = list(H.nodes)

  # Degree
  input_v_G = [G.degree[x] for x in G_nodes]
  input_v_H = [H.degree[x] for x in H_nodes]

  edge_index_G = torch.Tensor(G_edge_list).long()
  filtered_v_G = torch.Tensor(input_v_G)
  filtered_e_G, _ = torch.max(torch.stack((filtered_v_G[edge_index_G[0]], filtered_v_G[edge_index_G[1]])), axis=0)

  edge_index_H = torch.Tensor(H_edge_list).long()
  filtered_v_H = torch.Tensor(input_v_H)
  filtered_e_H, _ = torch.max(torch.stack((filtered_v_H[edge_index_H[0]], filtered_v_H[edge_index_H[1]])), axis=0)

  implG = persistence_routine(filtered_v_G, filtered_e_G, edge_index_G)[0]
  implH = persistence_routine(filtered_v_H, filtered_e_H, edge_index_H)[0]

  return (implG, implH), ((input_v_G, edge_index_G, filtered_v_G, filtered_e_G), (input_v_H, edge_index_H, filtered_v_H, filtered_e_H))

impl_list = []
graph_data_list = []

for item in dataset:
  entry_impl, entry_gd = dataset_entry_to_input(item)
  impl_list.append(entry_impl)
  graph_data_list.append(entry_gd)


In [None]:
# Computing the results for the 3 implementations

# Theorem 4
thm_4_pd_list = []
for ind in range(0, len(graph_data_list)):
  # Theorem 4 Implementation
  output = thm4_ph(graph_data_list[ind])
  thm_4_pd_list.append(output)

# Compute the actual NetworkX graph of the product of the two graphs
P_list = []
for item in dataset:
  G, H = item
  P = nx.cartesian_product(G, H)
  P_list.append(P)

# Naive Union-Find
naive_pd_list = []
for ind in range(0, len(graph_data_list)):
  output = naive_pd(graph_data_list[ind], P_list[ind])
  naive_pd_list.append(output)

# gudhi
gudhi_pd_list = []
for ind in range(0, len(graph_data_list)):
  output = gudhi_prod(graph_data_list[ind], P_list[ind])
  gudhi_pd_list.append(output)

In [None]:
# Helper functions to check if two persistence diagrams are equal
def pd_to_multiset(pd_list):
  output = {}
  for item in pd_list:
    output[str(item)] = 0
  for item in pd_list:
    output[str(item)] += 1
  return output

def check_pd_equal(pd1, pd2):
  pd1_list = pd_to_multiset(pd1)
  pd2_list = pd_to_multiset(pd2)

  return pd1_list, pd2_list, pd1_list == pd2_list

In [None]:
# Checking that the outputs match each other
for i in range(0, len(naive_pd_list)):
  thm4_output = thm_4_pd_list[i]
  naive_output = naive_pd_list[i]
  gudhi_output = gudhi_pd_list[i]

  # _, _, value = check_pd_equal(thm4_output, gudhi_output)
  # print(value)
  left_list, middle_list, value = check_pd_equal(thm4_output, naive_output)
  middle_list, right_list, value1 = check_pd_equal(naive_output, gudhi_output)

  if value and value1:
    print(True)
  else:
    print("False: ", i)
    print("Thm4 = Naive", value)
    print("Naive = gudhi", value1)
    print(left_list)
    print(middle_list)
    print(right_list)
    print("--------------------------------------------")

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
