In [8]:
import sys
sys.path.insert(0, './pyged/lib')
import pyged

import torch 
import numpy as np
import matplotlib.pyplot as plt

from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.transforms import RemoveDuplicatedEdges
from torch_geometric.utils import subgraph
from multiprocessing import Pool, cpu_count
import torch.multiprocessing as mp
mp.set_sharing_strategy('file_system')
from time import time

from data.distances.OCTADataset import OCTADataset
from data.distances.CitiesDataset import CitiesDataset
from data.visualize_sample import draw_graph
from data.distances.vis_utils import plot_ged_res, get_features_df, plot_hist

In [9]:
from importlib import reload
reload(pyged)

<module 'pyged' from '/home/anna_alex/relationformer/./pyged/lib/pyged.cpython-311-x86_64-linux-gnu.so'>

In [20]:
def edge_lengths(g):
    return torch.norm(g.pos, dim=1)

def to_pyged(g):
    nodes = [tuple(pos) for pos in g.pos.tolist()]
    edges = list(zip(*g.edge_index.tolist(), edge_lengths(g).tolist()))
    return (nodes, edges)

def ged_func(args):
    query, target = args[0], args[1]
    q_pyged, t_pyged = to_pyged(query), to_pyged(target)
    return query, target, pyged.ged_dist(q_pyged, t_pyged, method_name, method_args, cost_name), args[2]


def calculate_ged(data):
    start_time = time()

    with Pool(processes=cpu_count()) as pool:
        data_with_ged = pool.map(ged_func, data, chunksize=1)
    
    end_time = time()
    execution_time = end_time - start_time
    print(f"GED execution time: {execution_time:.1f} seconds")

    return data_with_ged

def generate(dataset, file_path, num_pairs=None):
    if num_pairs is None:
        num_pairs = len(dataset) // 2
    data_loader = DataLoader(dataset, batch_size=2, shuffle=True)
#     data_loader2 = DataLoader(dataset2, batch_size=1, shuffle=True)
    graph_pairs = [(next(iter(data_loader))) for _ in range(num_pairs)]
    # graph_pairs = [(dataset[i], dataset[i+1]) for i in range(0, num_pairs, 2)]
    
    graph_pairs_with_ged = calculate_ged(graph_pairs)

    torch.save(graph_pairs_with_ged, file_path)
    
def generate_one(dataset, dataset2, query, file_path, num_pairs=None):
    data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
    data_loader2 = DataLoader(dataset2, batch_size=1, shuffle=True)
    graph_pairs = [(query, next(iter(data_loader)), 'octa') for _ in range(num_pairs)]
    graph_pairs += [(query, next(iter(data_loader2)), 'cities') for _ in range(num_pairs)]
    graph_pairs_with_ged = calculate_ged(graph_pairs)

    torch.save(graph_pairs_with_ged, file_path)

In [21]:
# f2, branch, ipfp
method_name = ['f2']
method_args = ['--threads 64 --time-limit 60']
# base, node, edge
cost_name = "base" 

In [22]:
dataset = CitiesDataset(root='/media/data/anna_alex/20cities/test_data')


In [23]:
file_path = '/media/data/anna_alex/distances/results/random_pairs/pyged/cities_base_test.pt'

In [24]:
generate(dataset, file_path, num_pairs=100)

Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only - expires 2025-01-15
Academic license - for non-commercial use only -

IndexError: index 3 is out of bounds for dimension 0 with size 3

In [None]:
data = torch.load(file_path)

In [None]:
data_dict = [{'pair': (d[0], d[1]), 'ged': round(d[2][1], 2), 'dataset': 'cities'} for d in data]
df_data = get_features_df(data_dict)

In [None]:
def sort(results, key, reverse=False):
    return sorted(results, key=lambda x: x[key], reverse=reverse)

In [None]:
plot_ged_res(sort(df_data.to_dict(orient='records'), 'ged')[40:60])

In [None]:
plot_ged_res(sort(df_data.to_dict(orient='records'), 'ged', reverse=True)[:20])

In [None]:
import plotly.express as px

def plot_hist(df, x="ged", nbins=None):
    fig = px.histogram(df, x=x, nbins=nbins, color="dataset")
    fig.show()

In [None]:
plot_hist(df_data)