In [None]:
EDGE_WEIGHT = 30

DATASET_NAMES = [
    # "3elt",
    # "1138_bus",
    # "bull",
    # "chvatal",
    # "cubical",
    # "davis_southern_women",
    # "desargues",
    # "diamond",
    # "dodecahedral",
    # "dwt_1005",
    # "dwt_2680",
    # "florentine_families",
    # "frucht",
    # "heawood",
    # "hoffman_singleton",
    # "house_x",
    # "house",
    # "icosahedral",
    # "karate_club",
    # "krackhardt_kite",
    # "les_miserables",
    # "moebius_kantor",
    # "octahedral",
    # "pappus",
    # "petersen",
    # "poli",
    # "qh882",
    # "sedgewick_maze",
    # "tutte",
    "USpowerGrid",
]

In [None]:
from egraph import Drawing, all_sources_bfs
from ex_utils.config.paths import get_dataset_path

# from ex_utils.share import draw, draw_and_measure
from ex_utils.utils.graph import (
    egraph_graph,
    load_nx_graph,
    nx_graph_preprocessing,
)
import networkx as nx
import matplotlib.pyplot as plt
from time import perf_counter
from tqdm import tqdm
from ex_utils.share import calc_bounds

In [None]:
from egraph import crossing_edges
from ex_utils.share import draw, measure_quality_metrics


def draw_and_measure(
    pivots,
    iterations,
    eps,
    eg_graph,
    eg_indices,
    eg_drawing,
    eg_distance_matrix,
    edge_weight,
    seed,
):
    params = {
        "pivots": pivots,
        "iterations": iterations,
        "eps": eps,
    }
    start = perf_counter()
    pos = draw(
        params=params,
        eg_graph=eg_graph,
        eg_indices=eg_indices,
        eg_drawing=eg_drawing,
        edge_weight=edge_weight,
        seed=seed,
    )
    end = perf_counter()

    eg_crossings = crossing_edges(eg_graph, eg_drawing)
    quality_metrics = measure_quality_metrics(
        eg_graph=eg_graph,
        eg_drawing=eg_drawing,
        eg_crossings=eg_crossings,
        eg_distance_matrix=eg_distance_matrix,
    )
    quality_metrics["runtime"] = -(end - start)

    return params, quality_metrics, pos


In [None]:
import numpy as np

N_SPLIT = 7
pivots_v, iterations_v, eps_v = np.meshgrid(
    np.linspace(1, 100, N_SPLIT, dtype=int),
    np.linspace(1, 200, N_SPLIT, dtype=int),
    np.logspace(np.log10(0.01), np.log10(1), N_SPLIT),
    indexing="ij",
)

In [None]:
for dataset_name in DATASET_NAMES:
    dataset_path = get_dataset_path(dataset_name=dataset_name)
    nx_graph = nx_graph_preprocessing(
        load_nx_graph(dataset_path=dataset_path), EDGE_WEIGHT
    )

    eg_graph, eg_indices = egraph_graph(nx_graph=nx_graph)
    eg_distance_matrix = all_sources_bfs(eg_graph, EDGE_WEIGHT)

    for pi in tqdm(range(N_SPLIT)):
        for ii in tqdm(range(N_SPLIT), leave=False):
            for ei in tqdm(range(N_SPLIT), leave=False):
                pivots = pivots_v[pi, ii, ei]
                iterations = iterations_v[pi, ii, ei]
                eps = eps_v[pi, ii, ei]
                eg_drawing = Drawing.initial_placement(eg_graph)

                params, quality_metrics, pos = draw_and_measure(
                    pivots=pivots,
                    iterations=iterations,
                    eps=eps,
                    eg_graph=eg_graph,
                    eg_indices=eg_indices,
                    eg_drawing=eg_drawing,
                    eg_distance_matrix=eg_distance_matrix,
                    edge_weight=EDGE_WEIGHT,
                    seed=0,
                )

                x_bounds, y_bounds = calc_bounds(pos)
                x_center = sum(x_bounds) / 2
                y_center = sum(y_bounds) / 2
                padding = 10

                fig, ax = plt.subplots(dpi=300, facecolor="white")
                ax.set_aspect("equal")

                ax.set_title(
                    f"""{dataset_name} runtime={quality_metrics['runtime']}
pivots={pivots},iter={iterations},eps={round(eps, 4)}"""
                )

                nx.draw(
                    nx_graph,
                    pos=pos,
                    node_size=5,
                    node_color="#AB47BC",
                    edge_color="#CFD8DC",
                    ax=ax,
                )

                # limits = plt.axis("on")  # turns on axis
                # ax.tick_params(
                #     left=True, bottom=True, labelleft=True, labelbottom=True
                # )

                # print(quality_metrics)

                plt.show()
        #         break
        #     break
        # break
