In [1]:
import torch
import numpy as np

from brec.dataset import BRECDataset
import time
import torch_geometric.utils as tg
from rephine_mt import compute_rephine_batched_mt
from spectre import compute_spectre_batched_mt
from ph_cpu import compute_persistence_homology_batched_mt


import collections
import numpy.linalg as linalg
import numpy as np
import networkx as nx

from torch_geometric.utils.convert import from_networkx, to_networkx


In [2]:
# load the dataset

def process_dataset(name):
    dataset = []
    if name in ['basic', 'str', 'dr', '4vtx']:
        data = np.load(f'datasets/BREC/{name}.npy')
        for i in range(0, data.size, 2):
            if name == 'dr':
                pyg_graph_1 =  from_networkx(nx.from_graph6_bytes(data[i]))
                pyg_graph_2 =  from_networkx(nx.from_graph6_bytes(data[i+1]))
            else:
                pyg_graph_1 =  from_networkx(nx.from_graph6_bytes(data[i].encode()))
                pyg_graph_2 =  from_networkx(nx.from_graph6_bytes(data[i+1].encode()))
            dataset.append((pyg_graph_1, pyg_graph_2))
        return dataset
    elif name in ['regular', 'extension', 'cfi']:
        data = np.load(f'datasets/BREC/{name}.npy')
        for i in range(0, data.size // 2):
            g6_tuple = data[i]
            if name == 'regular' or name == 'cfi':
                pyg_graph_1 = from_networkx(nx.from_graph6_bytes(g6_tuple[0]))
                pyg_graph_2 = from_networkx(nx.from_graph6_bytes(g6_tuple[1]))
            else:
                pyg_graph_1 = from_networkx(nx.from_graph6_bytes(g6_tuple[0].encode()))
                pyg_graph_2 = from_networkx(nx.from_graph6_bytes(g6_tuple[1].encode()))    
            dataset.append((pyg_graph_1, pyg_graph_2))
        return dataset
    

In [3]:
def compute_persistence(diagram_type, x, e, edge_index, vertex_slices, edge_slices):
    filtered_v = x
    filtered_e = e
    if diagram_type == "rephine":
        compute = compute_rephine_batched_mt
    elif diagram_type == "spectre":
        compute = compute_spectre_batched_mt
    elif diagram_type == "standard":
        filtered_e, _ = torch.max(
            torch.stack((filtered_v[edge_index[0]], filtered_v[edge_index[1]])),
            axis=0,
        )
        compute = compute_persistence_homology_batched_mt
    elif diagram_type == "edge":
        compute = compute_rephine_batched_mt

    vertex_slices = vertex_slices.cpu().long()
    edge_slices = edge_slices.cpu().long()
    filtered_v = filtered_v.transpose(1, 0).cpu().contiguous()
    filtered_e = filtered_e.transpose(1, 0).cpu().contiguous()
    edge_index = edge_index.cpu().transpose(1, 0).contiguous()

    
    persistence0, persistence1 = compute(
        filtered_v, filtered_e, edge_index, vertex_slices, edge_slices
    )

    persistence0 = persistence0.to(x.device)
    persistence1 = persistence1.to(x.device)

    if diagram_type == "rephine" or diagram_type == 'spectre' or diagram_type == 'edge':
        full_size = persistence0.shape[2]
        indices = list(range(3, full_size, 1))
        persistence0 = persistence0[:, :, [0, 2, 1] + indices]

        persistence0 = torch.cat(
            (
                torch.zeros((persistence0.shape[0], persistence0.shape[1], 1)).to(
                    x.device
                ),
                persistence0,
            ),
            dim=-1,
        )
        persistence1 = torch.cat(
            (
                torch.zeros((persistence1.shape[0], persistence1.shape[1], 1)).to(
                    x.device
                ),
                persistence1,
            ),
            dim=-1,
        )

        persistence0[persistence0.isnan()] = 1000

    return persistence0.squeeze(), persistence1.squeeze()

def hks_signature(eigenvectors, eigenvals, time=1):
    return np.square(eigenvectors).dot(np.diag(np.exp(-time * eigenvals))).sum(axis=1)


In [None]:
diagram = 'spectre'
name = 'dr'
dataset = process_dataset(name)

count = 0
for (g1, g2) in dataset:
    # graph 1
    adj = tg.to_dense_adj(g1.edge_index)
    degree = adj.sum(dim=-1).squeeze()
    g1.x = degree.unsqueeze(dim=0).T
    if diagram == 'edge':
        g1.x = -1*torch.ones(degree.shape[0], 1)
    correct_idx = g1.edge_index[0] <= g1.edge_index[1]
    new_edge_index = g1.edge_index[:, correct_idx]
    edge_features = torch.zeros(new_edge_index.shape[1], 1)
    for i in range(new_edge_index.shape[1]):
        e = [g1.edge_index[0, i].item(), g1.edge_index[1, i].item()]
        edge_features[i] = 4 - degree[e[0]] - degree[e[1]]
        neighbors = g1.edge_index[1, g1.edge_index[0] == e[0]].tolist()
        neighbors2 = g1.edge_index[1, g1.edge_index[0] == e[1]].tolist()
        edge_features[i] += 3 * len(set(neighbors) & set(neighbors2))
    g1.edge_features = edge_features
    g1.edge_index = new_edge_index
    vs = torch.tensor([0, g1.x.shape[0]]).long()
    es = torch.tensor([0, g1.edge_index.shape[1]]).long()
    d10, d11 = compute_persistence(diagram_type=diagram, x=g1.x, e=g1.edge_features, 
                                   edge_index=g1.edge_index, 
                                   vertex_slices=vs, edge_slices=es)
    if diagram == 'edge':
        d10=d10[:, :2]
        d11=d11[:, :2]

    adj = tg.to_dense_adj(g2.edge_index)
    degree = adj.sum(dim=-1).squeeze()
    g2.x = degree.unsqueeze(dim=0).T
    if 'diagram' == 'edge':
        g2.x = -1*torch.ones(degree.shape[0], 1)

    correct_idx = g2.edge_index[0] <= g2.edge_index[1]
    new_edge_index = g2.edge_index[:, correct_idx]
    edge_features = torch.zeros(new_edge_index.shape[1], 1)
    for i in range(new_edge_index.shape[1]):
        e = [g2.edge_index[0, i].item(), g2.edge_index[1, i].item()]
        edge_features[i] = 4 - degree[e[0]] - degree[e[1]]
        neighbors = g2.edge_index[1, g2.edge_index[0] == e[0]].tolist()
        neighbors2 = g2.edge_index[1, g2.edge_index[0] == e[1]].tolist()
        edge_features[i] += 3 * len(set(neighbors) & set(neighbors2))
    g2.edge_features = edge_features
    g2.edge_index = new_edge_index
    vs = torch.tensor([0, g2.x.shape[0]]).long()
    es = torch.tensor([0, g2.edge_index.shape[1]]).long()
    d10_2, d11_2 = compute_persistence(diagram_type=diagram, x=g2.x, e=g2.edge_features, 
                                   edge_index=g2.edge_index, 
                                   vertex_slices=vs, edge_slices=es)
    if diagram == 'edge':
        d10_2=d10_2[:, :2]
        d11_2=d11_2[:, :2]

    l1 = [] 
    for b in range(d10.shape[0]):
        l1.append(tuple(d10[b].numpy().round(3)))

    l2 = []
    for b in range(d10_2.shape[0]):
        l2.append(tuple(d10_2[b].numpy().round(3)))

    l11 = []
    for b in range(d11.shape[0]):
        l11.append(tuple(d11[b].numpy().round(3)))

    l21 = [] 
    for b in range(d11_2.shape[0]):
        l21.append(tuple(d11_2[b].numpy().round(3)))

    if (collections.Counter(l1) == collections.Counter(l2)) and (collections.Counter(l11) == collections.Counter(l21)):
        count += 1
        
print(f'{name}, {diagram}, {len(dataset)}:, {(1.0-count/len(dataset)):.2f}')