In [1]:
import sys
sys.path.insert(0, './pyged/lib')
import pyged

import torch 
import numpy as np
import matplotlib.pyplot as plt

from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.transforms import RemoveDuplicatedEdges
from torch_geometric.utils import subgraph
from multiprocessing import Pool, cpu_count
import torch.multiprocessing as mp
mp.set_sharing_strategy('file_system')
from time import time

from data.distances.OCTADataset import OCTADataset
from data.distances.CitiesDataset import CitiesDataset
from data.visualize_sample import draw_graph
from data.distances.vis_utils import plot_ged_res, get_features_df, plot_hist
from data.distances.gmd import _compute_gmd

In [2]:
from importlib import reload
reload(pyged)

<module 'pyged' from '/home/anna_alex/relationformer/./pyged/lib/pyged.cpython-311-x86_64-linux-gnu.so'>

In [3]:
def edge_lengths(g):
    return torch.norm(g.pos, dim=1)

def to_pyged(g):
    nodes = [tuple(pos) for pos in g.pos.tolist()]
    edges = list(zip(*g.edge_index.tolist(), edge_lengths(g).tolist()))
    return (nodes, edges)

def ged_func(data):
    try:
        query, target = data['pair'][0], data['pair'][1]
        if query.edge_index.ndim == 1:
            query.edge_index = torch.empty(2, 0)
        if target.edge_index.ndim == 1:
            target.edge_index = torch.empty(2, 0)
        
#         C_V = 1
#         C_E = 10
#         M = 1
#         gmd = _compute_gmd([[query, target], C_V, C_E, M])

#         data['gmd_e10'] = gmd['gmd']
#         data['norm_gmd_e10'] = gmd['gmd'] / query.num_edges

#         C_V = 10
#         C_E = 1
#         M = 1
#         gmd = _compute_gmd([[query, target], C_V, C_E, M])

#         data['gmd_v10'] = gmd['gmd']
#         data['norm_gmd_v10'] = gmd['gmd'] / query.num_nodes
            
        q_pyged, t_pyged = to_pyged(query), to_pyged(target)
        ged = pyged.ged_dist(q_pyged, t_pyged, method_name, method_args, cost_name)[1]
        data['ged'] = ged
    except Exception as e:
        print(e)
        print(data)
        data['ged'] = 1000
#         data['gmd_e10'] = 1000
#         data['norm_gmd_e10'] = 1000
#         data['gmd_v10'] = 1000
#         data['norm_gmd_v10'] = 1000
    
    return data


def calculate_graph_metrics(data):      
    start_time = time()

    with Pool(processes=cpu_count()) as pool:
        data_with_ged = pool.map(ged_func, data, chunksize=1)
    
    end_time = time()
    execution_time = end_time - start_time
    print(f"Execution time: {execution_time:.1f} seconds")

    return data_with_ged

In [4]:
# f2, branch, ipfp
method_name = ['f2']
method_args = ['--threads 64 --time-limit 60']
# base, node, edge
cost_name = "node" 

In [5]:
data = torch.load('/media/data/anna_alex/distances/results/rf_results/full_500.pt')

In [6]:
file_path = '/media/data/anna_alex/distances/results/rf_results/full_500_ged.pt'

In [7]:
data = calculate_graph_metrics(data)

In [8]:
torch.save(data, file_path)

In [9]:
data = torch.load(file_path)

In [10]:
idx = 34
query, target = data[idx]['pair']

In [None]:
q_pyged, t_pyged = to_pyged(query), to_pyged(target)
pyged.ged_dist(q_pyged, t_pyged, method_name, method_args, cost_name)

In [None]:
fig, ax = plt.subplots(1, 1)
draw_graph(query.pos, query.edge_index.t(), ax)

In [None]:
draw_graph(target.pos, target.edge_index.t(), ax)

In [None]:
from torch_geometric.utils import to_networkx
from data.distances.gmd import vis_mapping

def e2e_gmd(results, idx, C_V=1.0, C_E=1.0, multiplier=1.0, shift=1):
    g1, g2 = results[idx]['pair']
    G1 = to_networkx(g1, node_attrs=["pos"]).to_undirected()
    G2 = to_networkx(g2, node_attrs=["pos"]).to_undirected()
    start_time = time()
    res = _compute_gmd(((g1, g2), C_V, C_E, multiplier))
    end_time = time()
    execution_time = end_time - start_time
    print(f"GMD execution time: {execution_time:.2f} seconds")
        
    print(f"GMD: {res['gmd']}")
    vis_mapping(G1, G2, res['gmd_flow'], shift)

In [None]:
e2e_gmd([{'pair': (query, target)}], 0, shift=0)

In [None]:
query.pos

In [None]:
target.pos