In [73]:
import torch
import numpy as np

from brec.dataset import BRECDataset
import time
import torch_geometric.utils as tg
from rephine_mt import compute_rephine_batched_mt
from spectre import compute_spectre_batched_mt
from ph_cpu import compute_persistence_homology_batched_mt

import collections
import numpy.linalg as linalg
import numpy as np
import networkx as nx

from torch_geometric.utils.convert import from_networkx, to_networkx


In [74]:
# load the dataset

def get_dataset(name):
    dataset = []
    if name in ['basic', 'str', 'dr', '4vtx']:
        data = np.load(f'datasets/BREC/{name}.npy')
        for i in range(0, data.size, 2):
            if name == 'dr':
                pyg_graph_1 =  from_networkx(nx.from_graph6_bytes(data[i]))
                pyg_graph_2 =  from_networkx(nx.from_graph6_bytes(data[i+1]))
            else:
                pyg_graph_1 =  from_networkx(nx.from_graph6_bytes(data[i].encode()))
                pyg_graph_2 =  from_networkx(nx.from_graph6_bytes(data[i+1].encode()))
            dataset.append((pyg_graph_1, pyg_graph_2))
        return dataset
    elif name in ['regular', 'extension', 'cfi']:
        data = np.load(f'datasets/BREC/{name}.npy')
        for i in range(0, data.size // 2):
            g6_tuple = data[i]
            if name == 'regular' or name == 'cfi':
                pyg_graph_1 = from_networkx(nx.from_graph6_bytes(g6_tuple[0]))
                pyg_graph_2 = from_networkx(nx.from_graph6_bytes(g6_tuple[1]))
            else:
                pyg_graph_1 = from_networkx(nx.from_graph6_bytes(g6_tuple[0].encode()))
                pyg_graph_2 = from_networkx(nx.from_graph6_bytes(g6_tuple[1].encode()))    
            dataset.append((pyg_graph_1, pyg_graph_2))
        return dataset
    

In [75]:
def compute_persistence(diagram_type, x, e, edge_index, vertex_slices, edge_slices):
    filtered_v = x
    filtered_e = e
    if diagram_type == "rephine":
        compute = compute_rephine_batched_mt
    elif diagram_type == "spectre" or diagram_type == 'partial_spectre' or diagram_type == "ls":
        compute = compute_spectre_batched_mt
    elif diagram_type == "standard":
        filtered_e, _ = torch.max(
            torch.stack((filtered_v[edge_index[0]], filtered_v[edge_index[1]])),
            axis=0,
        )
        compute = compute_persistence_homology_batched_mt
    elif diagram_type == "edge":
        compute = compute_rephine_batched_mt

    vertex_slices = vertex_slices.cpu().long()
    edge_slices = edge_slices.cpu().long()
    filtered_v = filtered_v.transpose(1, 0).cpu().contiguous()
    filtered_e = filtered_e.transpose(1, 0).cpu().contiguous()
    edge_index = edge_index.cpu().transpose(1, 0).contiguous()

    
    persistence0, persistence1 = compute(
        filtered_v, filtered_e, edge_index, vertex_slices, edge_slices
    )

    persistence0 = persistence0.to(x.device)
    persistence1 = persistence1.to(x.device)

    if diagram_type != "standard":
        full_size = persistence0.shape[2]
        indices = list(range(3, full_size, 1))
        persistence0 = persistence0[:, :, [0, 2, 1] + indices]

        persistence0 = torch.cat(
            (
                torch.zeros((persistence0.shape[0], persistence0.shape[1], 1)).to(
                    x.device
                ),
                persistence0,
            ),
            dim=-1,
        )
        persistence1 = torch.cat(
            (
                torch.zeros((persistence1.shape[0], persistence1.shape[1], 1)).to(
                    x.device
                ),
                persistence1,
            ),
            dim=-1,
        )

        persistence0[persistence0.isnan()] = 1000

    return persistence0.squeeze(), persistence1.squeeze()


In [76]:
def apply_filtering_function(g):
    
    adj = tg.to_dense_adj(g.edge_index)
    degree = adj.sum(dim=-1).squeeze()
    g.x = degree.unsqueeze(dim=0).T
    
    correct_idx = g.edge_index[0] <= g.edge_index[1]
    new_edge_index = g.edge_index[:, correct_idx] # remove "duplicated" edges 
    edge_features = torch.zeros(new_edge_index.shape[1], 1)
    for i in range(new_edge_index.shape[1]):
        e = [new_edge_index[0, i].item(), new_edge_index[1, i].item()]
        edge_features[i] = 4 - degree[e[0]] - degree[e[1]]
        neighbors = g.edge_index[1, g.edge_index[0] == e[0]].tolist()
        neighbors2 = g.edge_index[1, g.edge_index[0] == e[1]].tolist()
        edge_features[i] += 3 * len(set(neighbors) & set(neighbors2))
    g.edge_index = new_edge_index 
    g.edge_features = edge_features     
    return g


def compute_diagram(g, diagram):
    g = apply_filtering_function(g)
    vs = torch.tensor([0, g.x.shape[0]]).long()
    es = torch.tensor([0, g.edge_index.shape[1]]).long()
   
    d0, d1 = compute_persistence(diagram_type=diagram, x=g.x, e=g.edge_features, 
                                   edge_index=g.edge_index, 
                                   vertex_slices=vs, edge_slices=es)

    if diagram=='ls':
        d0 = d0[:, [0, 1] + list(range(4, d0.shape[1]))]
        d1 = d1[:, [0, 1] + list(range(4, d1.shape[1]))]

    if diagram=='partial_spectre':

        size = d0.shape[1]
        for i in range(d0.shape[0]):
            reversed_spectrum = d0[i, 4:].flip(dims=[0])
            first_index_in_reversed = torch.argmax(reversed_spectrum)
            n = size - 1 - first_index_in_reversed
            d0[i, 4+(n//3)+1:] = 0
        
        size = d1.shape[1]
        for i in range(d1.shape[0]):
            reversed_spectrum = d1[i, 4:].flip(dims=[0])
            first_index_in_reversed = torch.argmax(reversed_spectrum)
            n = size - 1 - first_index_in_reversed
            d1[i, 4+(n//3)+1:] = 0
            
    if diagram=='edge':
        d0 = d0[:, :2]
        d1 = d1[:, :2]
    return d0, d1    

In [77]:
def compare_diagrams(dG_0, dG_1, dH_0, dH_1):
    lG = [] 
    for b in range(dG_0.shape[0]):
        lG.append(tuple(dG_0[b].numpy().round(3)))

    lH = []
    for b in range(dH_0.shape[0]):
        lH.append(tuple(dH_0[b].numpy().round(3)))

    lG1 = []
    for b in range(dG_1.shape[0]):
        lG1.append(tuple(dG_1[b].numpy().round(3)))

    lH1 = [] 
    for b in range(dH_1.shape[0]):
        lH1.append(tuple(dH_1[b].numpy().round(3)))
    
    return (collections.Counter(lG) == collections.Counter(lH)) and (collections.Counter(lG1) == collections.Counter(lH1))

def test_isomorphism(g, h, diagram='standard'):
    
    D_G_0, D_G_1 = compute_diagram(g, diagram)
    D_H_0, D_H_1 = compute_diagram(h, diagram)
    return compare_diagrams(D_G_0, D_G_1, D_H_0, D_H_1) 


In [78]:
diagrams = ['standard', 'edge', 'rephine', 'spectre']
names = ['basic','regular', 'extension', 'cfi', 'dr', 'str', '4vtx']
for diagram in diagrams:
    total = 0
    for name in names:
        count = 0
        dataset = get_dataset(name)
        for (g1, g2) in dataset:
            if test_isomorphism(g1, g2, diagram):
                count = count + 1
        total = total + count
        print(f'{name}, {diagram}:, {1.0-(count/len(dataset)):.2f}')
    print(f'Total: {diagram}, all: {(400-total)/400:.2f}')

basic, standard:, 0.03
regular, standard:, 0.00
extension, standard:, 0.07
cfi, standard:, 0.03
dr, standard:, 0.00
str, standard:, 0.00
4vtx, standard:, 0.00
Total: standard, all: 0.03
basic, edge:, 0.98
regular, edge:, 0.94
extension, edge:, 0.55
cfi, edge:, 0.03
dr, edge:, 0.00
str, edge:, 0.00
4vtx, edge:, 0.00
Total: edge, all: 0.41
basic, rephine:, 0.98
regular, rephine:, 0.94
extension, rephine:, 0.55
cfi, rephine:, 0.03
dr, rephine:, 0.00
str, rephine:, 0.00
4vtx, rephine:, 0.00
Total: rephine, all: 0.41
basic, spectre:, 1.00
regular, spectre:, 1.00
extension, spectre:, 1.00
cfi, spectre:, 0.04
dr, spectre:, 0.05
str, spectre:, 0.00
4vtx, spectre:, 0.00
Total: spectre, all: 0.54


In [None]:
## Results in Appendix F (Additional Experiments)

diagrams = ['ls', 'partial_spectre']
names = ['basic','regular', 'extension', 'cfi', 'dr']
for diagram in diagrams:
    for name in names:
        count = 0
        dataset = get_dataset(name)
        print(len(dataset))
        for (g1, g2) in dataset:
            if test_isomorphism(g1, g2, diagram):
                count = count + 1
        print(f'{name}, {diagram}:, {1.0-(count/len(dataset)):.2f}')


60
basic, ls:, 1.00
50
regular, ls:, 1.00
100
extension, ls:, 1.00
100
cfi, ls:, 0.04
20
dr, ls:, 0.05
60
basic, partial_spectre:, 1.00
50
regular, partial_spectre:, 1.00
100
extension, partial_spectre:, 1.00
100
