In [1]:
import pandas as pd
from tqdm.auto import tqdm

from gest.data.gest import GEST
from gest.service.evaluation.graph_matching.graph import GESTGraph
from gest.service.evaluation.graph_matching.similarity import (
    SimilarityService,
    SimilarityEngine,
)
from gest.service.evaluation.graph_matching.solver import SolverType
from gest.service.evaluation.graph_matching.embedding_type_enum import EmbeddingType

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
GEST_CSV_PATH = "/workspaces/GEST/data/gest.csv"
REQUIRED_COLUMNS = {"dataset", "id", "text", "gest"}

In [3]:
def ensure_required_columns(df: pd.DataFrame, name: str, required: set):
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"{name} is missing required columns: {sorted(missing)}")


def ensure_duplicated_pairs(df: pd.DataFrame, name: str) -> pd.DataFrame:
    """Keep only (dataset,id) keys with >=2 distinct texts."""
    dup_mask = df.duplicated(["dataset", "id"], keep=False)
    if not dup_mask.any():
        raise ValueError(f"{name} has no duplicated (dataset, id) pairs.")
    dups = df.loc[dup_mask, ["dataset", "id", "text", "gest"]].copy()
    per_key = dups.groupby(["dataset", "id"])["text"].nunique()
    dups = dups.merge(
        per_key[per_key > 1].rename("n_unique_texts").reset_index(),
        on=["dataset", "id"],
        how="inner",
    ).drop(columns="n_unique_texts")
    if dups.empty:
        raise ValueError(
            "All duplicated keys have identical text only; nothing to pair."
        )
    return dups


def safe_sim(service: SimilarityService, g1: GESTGraph, g2: GESTGraph) -> float:
    try:
        return float(service.graph_similarity_normalized(g1, g2))
    except Exception:
        return 0.0

In [4]:
synthetic = pd.read_csv(GEST_CSV_PATH)
ensure_required_columns(synthetic, "synthetic", REQUIRED_COLUMNS)

In [5]:
dups = ensure_duplicated_pairs(synthetic, "synthetic").reset_index(names="row_id")

pairs = dups.merge(dups, on=["dataset", "id"], how="inner", suffixes=("_val1", "_val2"))
pairs = pairs[pairs["row_id_val1"] < pairs["row_id_val2"]].copy()
pairs = pairs[
    (pairs["text_val1"] != pairs["text_val2"])
    & (pairs["gest_val1"] != pairs["gest_val2"])
].copy()

if pairs.empty:
    raise ValueError("No valid pairs after filtering (need differing text & GEST).")

In [6]:
tqdm.pandas(desc="Parsing GEST val1")
pairs["g1"] = pairs["gest_val1"].progress_apply(
    lambda s: GESTGraph(gest=GEST.model_validate_json(s))
)

Parsing GEST val1: 100%|██████████| 97/97 [00:00<00:00, 1776.28it/s]


In [7]:
tqdm.pandas(desc="Parsing GEST val2")
pairs["g2"] = pairs["gest_val2"].progress_apply(
    lambda s: GESTGraph(gest=GEST.model_validate_json(s))
)

Parsing GEST val2: 100%|██████████| 97/97 [00:00<00:00, 1147.98it/s]


In [8]:
configurations = {
    "Spectral_GloVe300": SimilarityService(
        engine=SimilarityEngine(
            solver_type=SolverType.SPECTRAL,
            embedding_type=EmbeddingType.GLOVE300,
            use_edges=True,
        )
    ),
    "NGM_GloVe300": SimilarityService(
        engine=SimilarityEngine(
            solver_type=SolverType.NGM,
            embedding_type=EmbeddingType.GLOVE300,
            use_edges=True,
        )
    ),
    "Spectral_GloVe300_NoEdges": SimilarityService(
        engine=SimilarityEngine(
            solver_type=SolverType.SPECTRAL,
            embedding_type=EmbeddingType.GLOVE300,
            use_edges=False,
        )
    ),
}

In [9]:
means = {}
for name, service in configurations.items():
    total, n = 0.0, 0
    for _, r in tqdm(pairs.iterrows(), total=len(pairs), desc=f"{name}"):
        total += safe_sim(service, r["g1"], r["g2"])
        n += 1
    means[name] = total / n if n else float("nan")

Spectral_GloVe300: 100%|██████████| 97/97 [00:27<00:00,  3.56it/s]
NGM_GloVe300: 100%|██████████| 97/97 [00:23<00:00,  4.15it/s]
Spectral_GloVe300_NoEdges: 100%|██████████| 97/97 [00:02<00:00, 35.93it/s]


In [10]:
print(f"'Spectral_GloVe300' mean: {means['Spectral_GloVe300']:.6f}")
print(f"'NGM_GloVe300' mean: {means['NGM_GloVe300']:.6f}")
print(f"'Spectral_GloVe300_NoEdges' mean: {means['Spectral_GloVe300_NoEdges']:.6f}")

'Spectral_GloVe300' mean: 0.372793
'NGM_GloVe300' mean: 0.098853
'Spectral_GloVe300_NoEdges' mean: 0.782064
