# Demo: Algorithm vs Paper Alignment
This notebook constructs a deterministic toy example where a greedy nearest-neighbor (NN) matcher fails but the QUBO-based solver succeeds. We contrast two tiny knowledge graphs in quantum computing—one about canonical algorithms and another about the landmark papers that introduced them. Carefully crafted text attributes trick the NN into the wrong matches, while structural rewards let the QUBO recover the globally consistent alignment.

In [1]:
from pathlib import Path
import json
import sys
from typing import Any, Dict, List, Sequence, Tuple
import re

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import ipywidgets as ipw
from IPython.display import display
from rdflib import Graph, URIRef, Literal
from rdflib.namespace import RDF

PROJECT_ROOT = repo_root = Path().resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import (
    OUTPUT_DIR,
    GAEA_MMD_WEIGHT,
    GAEA_STATS_WEIGHT,
    GAEA_MAX_ALIGN_SAMPLES,
)
from src.embedding.generate_embeddings import load_pyg_data_from_ttl, train_gaea_joint
from src.qubo_alignment import formulate
from src.evaluation.solvers import (
    _solve_with_simulated_annealing,
    _extract_alignments,
    Alignment,
)

pd.set_option("display.max_rows", 30)
pd.set_option("display.max_columns", 10)
pd.set_option("display.width", 120)

DEMO_DIR = OUTPUT_DIR / "demo"
DEMO_DIR.mkdir(parents=True, exist_ok=True)
DEMO_WIKI_PATH = DEMO_DIR / "kg_algorithms_demo.ttl"
DEMO_ARXIV_PATH = DEMO_DIR / "kg_papers_demo.ttl"

WIKI_BASE = "http://demo.local/wiki/"
ARXIV_BASE = "http://demo.local/arxiv/"
REL_BASE = "http://demo.local/rel/"
ENTITY_CLASS_URI = URIRef("http://demo.local/schema/Entity")
LABEL_PREDICATE = URIRef("http://demo.local/schema/label")
ATTRIBUTE_PREDICATE = URIRef("http://demo.local/schema/attribute")

## 1. Define Toy Algorithm/Paper Graphs
The JSON blocks below describe two hand-crafted knowledge graphs. KG1 captures four flagship quantum algorithms; KG2 lists five seminal papers (plus an extra "Qubit" resource node). The traps live in the text attributes: VQE carries the QAOA paper title and QAOA carries the VQE title, so any similarity model that greedily trusts text will swap their alignments.

In [2]:
EXPECTED_WIKI_COUNT = 4
EXPECTED_ARXIV_COUNT = 5

DEFAULT_WIKI_ENTITIES: Dict[str, Any] = {
    "Shor's Algorithm": {
        "name": "Shor's Algorithm",
        "category": "Quantum Algorithm",
        "year": 1994,
    },
    "Grover's Algorithm": {
        "name": "Grover's Algorithm",
        "category": "Quantum Algorithm",
        "year": 1996,
    },
    "VQE (Variational Quantum Eigensolver)": {
        "paper_title": "A quantum approximate optimization algorithm",
        "category": "Hybrid Algorithm",
        "year": 2014,
    },
    "QAOA (Quantum Approximate Optimization Algorithm)": {
        "paper_title": "A variational eigenvalue solver on a quantum processor",
        "category": "Hybrid Algorithm",
        "year": 2014,
    },
}

DEFAULT_ARXIV_ENTITIES: Dict[str, Any] = {
    "Shor, 1994": {
        "name": "Shor, 1994",
        "venue": "FOCS",
        "year": 1994,
    },
    "Grover, 1996": {
        "name": "Grover, 1996",
        "venue": "STOC",
        "year": 1996,
    },
    "Peruzzo et al., 2014": {
        "paper_title": "A variational eigenvalue solver on a quantum processor",
        "venue": "Nature Communications",
        "year": 2014,
    },
    "Farhi et al., 2014": {
        "paper_title": "A quantum approximate optimization algorithm",
        "venue": "arXiv",
        "year": 2014,
    },
    "Qubit": {
        "role": "Shared resource",
    },
}

DEFAULT_WIKI_TRIPLES: List[Tuple[str, str, str]] = [
    ("Shor's Algorithm", "references", "Grover's Algorithm"),
    (
        "VQE (Variational Quantum Eigensolver)",
        "published_same_year_as",
        "QAOA (Quantum Approximate Optimization Algorithm)",
    ),
    (
        "QAOA (Quantum Approximate Optimization Algorithm)",
        "published_same_year_as",
        "VQE (Variational Quantum Eigensolver)",
    ),
]

DEFAULT_ARXIV_TRIPLES: List[Tuple[str, str, str]] = [
    ("Qubit", "appears_in", "Shor, 1994"),
    ("Qubit", "appears_in", "Grover, 1996"),
    ("Qubit", "appears_in", "Peruzzo et al., 2014"),
    ("Qubit", "appears_in", "Farhi et al., 2014"),
    (
        "Peruzzo et al., 2014",
        "published_same_year_as",
        "Farhi et al., 2014",
    ),
    (
        "Farhi et al., 2014",
        "published_same_year_as",
        "Peruzzo et al., 2014",
    ),
]

DEFAULT_HYPERPARAMS: Dict[str, Any] = {
    "epochs": 80,
    "learning_rate": 0.01,
    "node_weight": 1.0,
    "structural_weight": 2.5,
    "structural_match_bonus": 1.0,
    "wiki_penalty": 2.0,
    "arxiv_penalty": 2.0,
    "similarity_threshold": 0.0,
    "num_reads": 160,
    "mmd_weight": GAEA_MMD_WEIGHT,
    "stats_weight": GAEA_STATS_WEIGHT,
    "max_align_samples": min(GAEA_MAX_ALIGN_SAMPLES, 512),
}


def to_pretty_json(value: Any) -> str:
    """Helper to present data structures in editable JSON form."""
    return json.dumps(value, indent=2, ensure_ascii=True)


## 2. Build Input Interface with ipywidgets
The following cell constructs editable text areas and sliders so you can tweak entities, relations, and solver hyperparameters without touching the code above.

In [3]:
# Widget construction for toy data editing and solver tuning
wiki_entities_editor = ipw.Textarea(
    value=to_pretty_json(DEFAULT_WIKI_ENTITIES),
    layout=ipw.Layout(width="100%", height="240px"),
    placeholder="[\n  \"Quantum Computer\",\n  \"Superconducting Qubit\"\n]",
)
arxiv_entities_editor = ipw.Textarea(
    value=to_pretty_json(DEFAULT_ARXIV_ENTITIES),
    layout=ipw.Layout(width="100%", height="240px"),
    placeholder="[\n  \"Quantum processors\",\n  \"Superconducting circuits\"\n]",
)

wiki_triples_editor = ipw.Textarea(
    value=to_pretty_json(DEFAULT_WIKI_TRIPLES),
    layout=ipw.Layout(width="100%", height="240px"),
    placeholder="[\n  [\"Quantum Computer\", \"implements\", \"Quantum Algorithm\"]\n]",
)
arxiv_triples_editor = ipw.Textarea(
    value=to_pretty_json(DEFAULT_ARXIV_TRIPLES),
    layout=ipw.Layout(width="100%", height="240px"),
    placeholder="[\n  [\"Quantum processors\", \"implements\", \"Search routines\"]\n]",
)

epochs_slider = ipw.IntSlider(
    value=int(DEFAULT_HYPERPARAMS["epochs"]),
    min=20,
    max=400,
    step=20,
    description="Epochs",
    continuous_update=False,
)
learning_rate_slider = ipw.FloatLogSlider(
    value=float(DEFAULT_HYPERPARAMS["learning_rate"]),
    base=10,
    min=-4,
    max=-1,
    step=0.1,
    description="LR",
    continuous_update=False,
)
node_weight_slider = ipw.FloatSlider(
    value=float(DEFAULT_HYPERPARAMS["node_weight"]),
    min=0.0,
    max=5.0,
    step=0.1,
    description="H_node",
    continuous_update=False,
)
structural_weight_slider = ipw.FloatSlider(
    value=float(DEFAULT_HYPERPARAMS["structural_weight"]),
    min=0.0,
    max=5.0,
    step=0.1,
    description="H_struct",
    continuous_update=False,
)
structural_match_bonus_slider = ipw.FloatSlider(
    value=float(DEFAULT_HYPERPARAMS["structural_match_bonus"]),
    min=0.0,
    max=5.0,
    step=0.1,
    description="Match bonus",
    continuous_update=False,
)
wiki_penalty_slider = ipw.FloatSlider(
    value=float(DEFAULT_HYPERPARAMS["wiki_penalty"]),
    min=0.1,
    max=5.0,
    step=0.1,
    description="Wiki pen.",
    continuous_update=False,
)
arxiv_penalty_slider = ipw.FloatSlider(
    value=float(DEFAULT_HYPERPARAMS["arxiv_penalty"]),
    min=0.1,
    max=5.0,
    step=0.1,
    description="Arxiv pen.",
    continuous_update=False,
)
similarity_threshold_slider = ipw.FloatSlider(
    value=float(DEFAULT_HYPERPARAMS["similarity_threshold"]),
    min=0.0,
    max=0.95,
    step=0.01,
    description="Sim thresh",
    continuous_update=False,
)
num_reads_slider = ipw.IntSlider(
    value=int(DEFAULT_HYPERPARAMS["num_reads"]),
    min=10,
    max=1000,
    step=10,
    description="Num reads",
    continuous_update=False,
)
mmd_weight_slider = ipw.FloatSlider(
    value=float(DEFAULT_HYPERPARAMS["mmd_weight"]),
    min=0.0,
    max=2.0,
    step=0.05,
    description="MMD λ",
    continuous_update=False,
)
stats_weight_slider = ipw.FloatSlider(
    value=float(DEFAULT_HYPERPARAMS["stats_weight"]),
    min=0.0,
    max=1.0,
    step=0.05,
    description="Stats λ",
    continuous_update=False,
)
max_samples_slider = ipw.IntSlider(
    value=int(DEFAULT_HYPERPARAMS["max_align_samples"]),
    min=128,
    max=4096,
    step=128,
    description="MMD samples",
    continuous_update=False,
)

build_graphs_button = ipw.Button(
    description="Write TTL",
    button_style="info",
    icon="upload",
    tooltip="Serialize the current toy graphs to output/demo",
)
run_alignment_button = ipw.Button(
    description="Run alignment",
    button_style="success",
    icon="play",
    tooltip="Train the joint GAEA model and compare QUBO vs NN",
)

inputs_column = ipw.VBox(
    [
        ipw.HTML("<b>Wiki entities</b>"),
        wiki_entities_editor,
        ipw.HTML("<b>Wiki triples</b>"),
        wiki_triples_editor,
    ],
    layout=ipw.Layout(width="50%", padding="0 8px 0 0"),
)
outputs_column = ipw.VBox(
    [
        ipw.HTML("<b>ArXiv entities</b>"),
        arxiv_entities_editor,
        ipw.HTML("<b>ArXiv triples</b>"),
        arxiv_triples_editor,
    ],
    layout=ipw.Layout(width="50%", padding="0 0 0 8px"),
)

text_input_panel = ipw.HBox([inputs_column, outputs_column], layout=ipw.Layout(width="100%"))

slider_row_one = ipw.HBox(
    [epochs_slider, learning_rate_slider, mmd_weight_slider, stats_weight_slider],
    layout=ipw.Layout(width="100%", justify_content="space-between"),
)
slider_row_two = ipw.HBox(
    [
        node_weight_slider,
        structural_weight_slider,
        structural_match_bonus_slider,
        similarity_threshold_slider,
    ],
    layout=ipw.Layout(width="100%", justify_content="space-between"),
)
slider_row_three = ipw.HBox(
    [wiki_penalty_slider, arxiv_penalty_slider, num_reads_slider, max_samples_slider],
    layout=ipw.Layout(width="100%", justify_content="space-between"),
)

buttons_panel = ipw.HBox(
    [build_graphs_button, run_alignment_button],
    layout=ipw.Layout(justify_content="flex-start", gap="12px"),
)

widget_panel = ipw.VBox(
    [
        text_input_panel,
        ipw.HTML("<b>Training and solver controls</b>"),
        slider_row_one,
        slider_row_two,
        slider_row_three,
        buttons_panel,
    ],
    layout=ipw.Layout(width="100%", gap="10px"),
)

## 3. Bind Events and Display Outputs
The final cell wires up the widget callbacks, generates TTL files, trains the joint model, and renders alignment diagnostics so you can see why the QUBO outperforms nearest-neighbor on your toy case.

In [4]:
# Helper logic to bind widgets, serialize graphs, and run alignments
output_area = ipw.Output(layout=ipw.Layout(border="1px solid #dcdcdc", padding="8px"))
alignment_area = ipw.Output(layout=ipw.Layout(border="1px solid #dcdcdc", padding="8px"))

TOKEN_SPLIT_RE = re.compile(r"[^a-z0-9]+")
current_state: Dict[str, Any] = {}


def _slug(label: str) -> str:
    cleaned = "".join(ch if ch.isalnum() or ch in {"_", "-"} else "_" for ch in label.strip())
    parts = [part for part in cleaned.split("_") if part]
    return "_".join(parts) or "entity"


def _normalize_relation(label: str) -> str:
    return " ".join(label.lower().split())


def _parse_meta_value(raw: Any) -> Any:
    if raw is None:
        return None
    return raw


def _attribute_text(label: str, meta_value: Any) -> str:
    if isinstance(meta_value, dict):
        if "paper_title" in meta_value:
            return str(meta_value["paper_title"])
        if "name" in meta_value:
            return str(meta_value["name"])
        return " ".join(str(v) for v in meta_value.values()) or label
    if isinstance(meta_value, (list, tuple, set)):
        return " ".join(str(v) for v in meta_value) or label
    if meta_value is None:
        return label
    return str(meta_value)


def _token_set(text: str) -> set[str]:
    tokens = [tok for tok in TOKEN_SPLIT_RE.split(text.lower()) if len(tok) > 1]
    return set(tokens)


def _jaccard(set_a: set[str], set_b: set[str]) -> float:
    if not set_a and not set_b:
        return 0.0
    intersection = len(set_a & set_b)
    if intersection == 0:
        return 0.0
    union = len(set_a | set_b)
    if union == 0:
        return 0.0
    return intersection / union


def _label_for_uri(payload: Dict[str, Any], uri: URIRef) -> str:
    label = payload["uri_to_label"].get(uri)
    if label:
        return label
    text = str(uri)
    if "#" in text:
        return text.rsplit("#", 1)[-1]
    if "/" in text:
        return text.rsplit("/", 1)[-1]
    return text


def _ensure_tuple(entry: Any) -> Tuple[str, str, str]:
    if isinstance(entry, dict):
        subj = entry.get("subject") or entry.get("subj")
        rel = entry.get("relation") or entry.get("predicate")
        obj = entry.get("object") or entry.get("obj")
        values = (subj, rel, obj)
    elif isinstance(entry, (list, tuple)) and len(entry) == 3:
        values = entry
    else:
        raise ValueError("Triple entries must be 3-item lists or dicts with subject, relation, and object keys.")
    subj, rel, obj = [str(item).strip() for item in values]
    if not subj or not rel or not obj:
        raise ValueError("Triple entries cannot contain empty strings.")
    return subj, rel, obj


def _parse_entity_block(raw: Any) -> Tuple[List[str], Dict[str, Any]]:
    meta: Dict[str, Any] = {}
    labels: List[str] = []
    if isinstance(raw, dict):
        for key, value in raw.items():
            label = str(key).strip()
            if not label:
                continue
            labels.append(label)
            parsed_value = _parse_meta_value(value)
            if parsed_value is not None and parsed_value != "":
                meta[label] = parsed_value
    elif isinstance(raw, list):
        for item in raw:
            label = str(item).strip()
            if label:
                labels.append(label)
    else:
        raise ValueError("Entities must be provided as a JSON array or object.")
    return labels, meta


def _attribute_similarity(
    wiki_labels: Sequence[str],
    wiki_meta: Dict[str, Any],
    arxiv_labels: Sequence[str],
    arxiv_meta: Dict[str, Any],
) -> torch.Tensor:
    matrix = torch.zeros((len(wiki_labels), len(arxiv_labels)), dtype=torch.float32)
    wiki_sets = [_token_set(_attribute_text(label, wiki_meta.get(label))) for label in wiki_labels]
    arxiv_sets = [_token_set(_attribute_text(label, arxiv_meta.get(label))) for label in arxiv_labels]
    for i, w_tokens in enumerate(wiki_sets):
        for j, a_tokens in enumerate(arxiv_sets):
            matrix[i, j] = _jaccard(w_tokens, a_tokens)
    return matrix


def build_graph_payload(
    name: str,
    entity_labels: Sequence[str],
    triples: Sequence[Tuple[str, str, str]],
    *,
    base_uri: str,
    meta: Dict[str, Any] | None = None,
):
    meta = meta or {}
    graph = Graph()
    label_to_uri: Dict[str, URIRef] = {}
    uri_to_label: Dict[URIRef, str] = {}
    attribute_texts: Dict[str, str] = {}
    entity_rows: List[Dict[str, Any]] = []

    for label in entity_labels:
        uri = URIRef(base_uri + _slug(label))
        label_to_uri[label] = uri
        uri_to_label[uri] = label
        graph.add((uri, RDF.type, ENTITY_CLASS_URI))
        graph.add((uri, LABEL_PREDICATE, Literal(label)))
        meta_value = meta.get(label)
        attr_text = _attribute_text(label, meta_value)
        attribute_texts[label] = attr_text
        if attr_text:
            graph.add((uri, ATTRIBUTE_PREDICATE, Literal(attr_text)))
        display_meta = meta_value
        if isinstance(meta_value, (dict, list)):
            display_meta = json.dumps(meta_value, ensure_ascii=True)
        entity_rows.append(
            {
                "label": label,
                "meta": display_meta,
                "attribute_text": attr_text,
                "uri": str(uri),
            }
        )

    edge_rows: List[Dict[str, str]] = []
    for subj, rel, obj in triples:
        if subj not in label_to_uri:
            raise ValueError(f"{name} triple uses unknown subject '{subj}'.")
        if obj not in label_to_uri:
            raise ValueError(f"{name} triple uses unknown object '{obj}'.")
        subj_uri = label_to_uri[subj]
        obj_uri = label_to_uri[obj]
        rel_uri = URIRef(REL_BASE + _slug(rel))
        graph.add((subj_uri, rel_uri, obj_uri))
        edge_rows.append(
            {
                "subject": subj,
                "relation": rel,
                "object": obj,
                "relation_uri": str(rel_uri),
            }
        )

    entity_df = (
        pd.DataFrame(entity_rows)
        .sort_values("label")
        .reset_index(drop=True)
    )
    edge_df = (
        pd.DataFrame(edge_rows).reset_index(drop=True)
        if edge_rows
        else pd.DataFrame(columns=["subject", "relation", "object", "relation_uri"])
    )

    return {
        "graph": graph,
        "entity_df": entity_df,
        "edge_df": edge_df,
        "label_to_uri": label_to_uri,
        "uri_to_label": uri_to_label,
        "triples": list(triples),
        "meta": meta,
        "attribute_texts": attribute_texts,
    }


def parse_inputs() -> Tuple[
    List[str],
    Dict[str, Any],
    List[str],
    Dict[str, Any],
    List[Tuple[str, str, str]],
    List[Tuple[str, str, str]],
]:
    try:
        wiki_entities_raw = json.loads(wiki_entities_editor.value)
        arxiv_entities_raw = json.loads(arxiv_entities_editor.value)
        wiki_triples_raw = json.loads(wiki_triples_editor.value)
        arxiv_triples_raw = json.loads(arxiv_triples_editor.value)
    except json.JSONDecodeError as exc:
        raise ValueError(f"JSON decode error: {exc}") from exc

    wiki_entities, wiki_meta = _parse_entity_block(wiki_entities_raw)
    arxiv_entities, arxiv_meta = _parse_entity_block(arxiv_entities_raw)

    if len(wiki_entities) != EXPECTED_WIKI_COUNT:
        raise ValueError(
            f"Wiki entity count must be {EXPECTED_WIKI_COUNT}, found {len(wiki_entities)}."
        )
    if len(arxiv_entities) != EXPECTED_ARXIV_COUNT:
        raise ValueError(
            f"Paper entity count must be {EXPECTED_ARXIV_COUNT}, found {len(arxiv_entities)}."
        )

    if not isinstance(wiki_triples_raw, list) or not isinstance(arxiv_triples_raw, list):
        raise ValueError("Triple sections must be JSON arrays of 3-item entries.")

    wiki_triples = [_ensure_tuple(item) for item in wiki_triples_raw]
    arxiv_triples = [_ensure_tuple(item) for item in arxiv_triples_raw]

    return wiki_entities, wiki_meta, arxiv_entities, arxiv_meta, wiki_triples, arxiv_triples


def prepare_state(verbose: bool = False) -> Dict[str, Any]:
    (
        wiki_entities,
        wiki_meta,
        arxiv_entities,
        arxiv_meta,
        wiki_triples,
        arxiv_triples,
    ) = parse_inputs()

    wiki_payload = build_graph_payload(
        "Algorithms",
        wiki_entities,
        wiki_triples,
        base_uri=WIKI_BASE,
        meta=wiki_meta,
    )
    arxiv_payload = build_graph_payload(
        "Papers",
        arxiv_entities,
        arxiv_triples,
        base_uri=ARXIV_BASE,
        meta=arxiv_meta,
    )

    wiki_payload["graph"].serialize(str(DEMO_WIKI_PATH), format="turtle")
    arxiv_payload["graph"].serialize(str(DEMO_ARXIV_PATH), format="turtle")

    current_state.update(
        {
            "wiki": wiki_payload,
            "arxiv": arxiv_payload,
            "wiki_entities": wiki_entities,
            "arxiv_entities": arxiv_entities,
            "wiki_meta": wiki_meta,
            "arxiv_meta": arxiv_meta,
        }
    )

    if verbose:
        output_area.clear_output()
        with output_area:
            print("Toy graphs written to:")
            print(f"  Algorithms -> {DEMO_WIKI_PATH}")
            print(f"  Papers     -> {DEMO_ARXIV_PATH}")
            display(wiki_payload["entity_df"].assign(graph="Algorithms"))
            display(arxiv_payload["entity_df"].assign(graph="Papers"))
            if not wiki_payload["edge_df"].empty:
                display(wiki_payload["edge_df"].assign(graph="Algorithms"))
            if not arxiv_payload["edge_df"].empty:
                display(arxiv_payload["edge_df"].assign(graph="Papers"))
    return current_state


def _edges_with_indices(payload: Dict[str, Any], node_map: Dict[URIRef, int]):
    edges = []
    for subj, rel, obj in payload["triples"]:
        subj_uri = payload["label_to_uri"][subj]
        obj_uri = payload["label_to_uri"][obj]
        edges.append(
            {
                "src": node_map[subj_uri],
                "dst": node_map[obj_uri],
                "relation": rel,
                "label": _normalize_relation(rel),
            }
        )
    return edges


def run_nearest_neighbor(
    similarity_tensor: torch.Tensor,
    wiki_nodes: Sequence[URIRef],
    arxiv_nodes: Sequence[URIRef],
    threshold: float | None,
) -> List[Alignment]:
    num_wiki, num_arxiv = similarity_tensor.shape
    candidates: List[Tuple[float, int, int]] = []
    for i in range(num_wiki):
        for j in range(num_arxiv):
            score = float(similarity_tensor[i, j])
            if threshold is not None and score < threshold:
                continue
            candidates.append((score, i, j))
    candidates.sort(key=lambda item: item[0], reverse=True)

    used_wiki: set[int] = set()
    used_arxiv: set[int] = set()
    alignments: List[Alignment] = []
    for score, i, j in candidates:
        if i in used_wiki or j in used_arxiv:
            continue
        used_wiki.add(i)
        used_arxiv.add(j)
        alignments.append(
            Alignment(
                wiki_index=i,
                arxiv_index=j,
                wiki_uri=wiki_nodes[i],
                arxiv_uri=arxiv_nodes[j],
                similarity=score,
            )
        )
    return alignments


def _pad_and_normalize_features(data: torch.Tensor, target_dim: int) -> torch.Tensor:
    if data.size(1) < target_dim:
        pad_width = (0, target_dim - data.size(1))
        data = F.pad(data, pad_width)
    return F.normalize(data, dim=1)


def handle_build(_=None):
    try:
        prepare_state(verbose=True)
    except Exception as exc:
        output_area.clear_output()
        with output_area:
            print(f"Error while building graphs: {exc}")
        alignment_area.clear_output()


def handle_alignment(_=None):
    try:
        state = prepare_state(verbose=False)
        wiki_payload = state["wiki"]
        arxiv_payload = state["arxiv"]

        wiki_data, wiki_map = load_pyg_data_from_ttl(
            DEMO_WIKI_PATH,
            tokenizer=None,
            model=None,
            use_scibert_features=False,
        )
        arxiv_data, arxiv_map = load_pyg_data_from_ttl(
            DEMO_ARXIV_PATH,
            tokenizer=None,
            model=None,
            use_scibert_features=False,
        )

        target_dim = max(wiki_data.num_node_features, arxiv_data.num_node_features)
        wiki_data.x = _pad_and_normalize_features(wiki_data.x, target_dim)
        arxiv_data.x = _pad_and_normalize_features(arxiv_data.x, target_dim)

        joint_model = train_gaea_joint(
            wiki_data,
            arxiv_data,
            in_channels=target_dim,
            hidden_channels=target_dim,
            out_channels=target_dim,
            epochs=int(epochs_slider.value),
            lr=float(learning_rate_slider.value),
            mmd_weight=float(mmd_weight_slider.value),
            stats_weight=float(stats_weight_slider.value),
            max_samples=int(max_samples_slider.value),
        )

        joint_model.eval()
        with torch.no_grad():
            wiki_embed = joint_model.encode(wiki_data.x, wiki_data.edge_index)
            arxiv_embed = joint_model.encode(arxiv_data.x, arxiv_data.edge_index)

        wiki_embed = F.normalize(wiki_embed.cpu(), dim=1)
        arxiv_embed = F.normalize(arxiv_embed.cpu(), dim=1)

        embedding_similarity = torch.matmul(wiki_embed, arxiv_embed.T)

        wiki_nodes: List[URIRef] = [None] * len(wiki_map)
        for uri, idx in wiki_map.items():
            wiki_nodes[idx] = uri
        arxiv_nodes: List[URIRef] = [None] * len(arxiv_map)
        for uri, idx in arxiv_map.items():
            arxiv_nodes[idx] = uri

        wiki_labels = [_label_for_uri(wiki_payload, uri) for uri in wiki_nodes]
        arxiv_labels = [_label_for_uri(arxiv_payload, uri) for uri in arxiv_nodes]

        attribute_similarity = _attribute_similarity(
            wiki_labels,
            state["wiki_meta"],
            arxiv_labels,
            state["arxiv_meta"],
        )
        similarity = attribute_similarity

        wiki_lookup = {uri: idx for idx, uri in enumerate(wiki_nodes)}
        arxiv_lookup = {uri: idx for idx, uri in enumerate(arxiv_nodes)}

        wiki_edges = _edges_with_indices(wiki_payload, wiki_map)
        arxiv_edges = _edges_with_indices(arxiv_payload, arxiv_map)

        structural_weights: Dict[Tuple[int, int, int, int], float] = {}
        structural_pairs: List[Dict[str, Any]] = []
        for w_edge in wiki_edges:
            for a_edge in arxiv_edges:
                if w_edge["label"] != a_edge["label"]:
                    continue
                weight = float(structural_match_bonus_slider.value)
                structural_weights[(w_edge["src"], w_edge["dst"], a_edge["src"], a_edge["dst"])] = weight
                structural_pairs.append(
                    {
                        "relation": w_edge["relation"],
                        "wiki_edge": f"{wiki_labels[w_edge['src']]} → {wiki_labels[w_edge['dst']]}",
                        "paper_edge": f"{arxiv_labels[a_edge['src']]} → {arxiv_labels[a_edge['dst']]}",
                        "weight": weight,
                    }
                )

        structural_info = {
            "weights": structural_weights,
            "wiki_edges": wiki_edges,
            "arxiv_edges": arxiv_edges,
        }

        node_info = {
            "wiki_nodes": wiki_nodes,
            "arxiv_nodes": arxiv_nodes,
            "wiki_lookup": wiki_lookup,
            "arxiv_lookup": arxiv_lookup,
            "similarity": similarity,
            "wiki_labels": wiki_labels,
            "arxiv_labels": arxiv_labels,
        }

        threshold_value = float(similarity_threshold_slider.value)
        threshold = threshold_value if threshold_value > 0 else None

        qubo_result = formulate.formulate(
            node_info,
            structural_info,
            node_weight=float(node_weight_slider.value),
            structural_weight=float(structural_weight_slider.value),
            wiki_penalty=float(wiki_penalty_slider.value),
            arxiv_penalty=float(arxiv_penalty_slider.value),
            similarity_threshold=threshold,
        )

        sampleset = _solve_with_simulated_annealing(
            qubo_result["Q"],
            num_reads=int(num_reads_slider.value),
            beta_range=None,
            seed=None,
            sampler=None,
        )
        record = sampleset.first
        qubo_alignments = _extract_alignments(record.sample, qubo_result["reverse_index"], node_info)
        qubo_energy = float(record.energy)

        nn_alignments = run_nearest_neighbor(similarity, wiki_nodes, arxiv_nodes, threshold)

        def _alignments_to_frame(items: List[Alignment], method: str) -> pd.DataFrame:
            rows = []
            for item in items:
                rows.append(
                    {
                        "method": method,
                        "wiki_entity": wiki_labels[item.wiki_index],
                        "paper_entity": arxiv_labels[item.arxiv_index],
                        "similarity": float(item.similarity),
                    }
                )
            return pd.DataFrame(rows)

        qubo_df = _alignments_to_frame(qubo_alignments, "QUBO")
        nn_df = _alignments_to_frame(nn_alignments, "Nearest neighbor")
        combined_df = (
            pd.concat([qubo_df, nn_df], ignore_index=True)
            if not qubo_df.empty or not nn_df.empty
            else pd.DataFrame(columns=["method", "wiki_entity", "paper_entity", "similarity"])
        )

        summary_rows = []
        for label, items, energy in [
            ("QUBO", qubo_alignments, qubo_energy),
            ("Nearest neighbor", nn_alignments, float("nan")),
        ]:
            sims = [float(item.similarity) for item in items]
            summary_rows.append(
                {
                    "method": label,
                    "matches": len(items),
                    "mean_similarity": float(np.mean(sims)) if sims else np.nan,
                    "min_similarity": float(np.min(sims)) if sims else np.nan,
                    "max_similarity": float(np.max(sims)) if sims else np.nan,
                    "energy": energy,
                }
            )
        summary_df = pd.DataFrame(summary_rows)

        similarity_df = pd.DataFrame(similarity.numpy(), index=wiki_labels, columns=arxiv_labels)
        structural_df = pd.DataFrame(structural_pairs)
        embedding_df = pd.DataFrame(embedding_similarity.numpy(), index=wiki_labels, columns=arxiv_labels)

        alignment_area.clear_output()
        with alignment_area:
            print("Joint GAEA training complete (embeddings shown for reference).")
            print(f"QUBO best energy: {qubo_energy:.4f}")
            display(summary_df)
            if not combined_df.empty:
                display(
                    combined_df
                    .sort_values(["method", "similarity"], ascending=[True, False])
                    .reset_index(drop=True)
                )
            print("Attribute-driven similarity matrix (Jaccard overlaps of text features):")
            display(similarity_df)
            print("Embedding cosine similarity (optional diagnostic):")
            display(embedding_df)
            if not structural_df.empty:
                print("Structural matches contributing to H_structure:")
                display(structural_df)
            else:
                print("No structural matches discovered; adjust relation labels if needed.")
    except Exception as exc:
        alignment_area.clear_output()
        with alignment_area:
            print(f"Alignment run failed: {exc}")


build_graphs_button.on_click(handle_build)
run_alignment_button.on_click(handle_alignment)

ui_container = ipw.VBox(
    [
        widget_panel,
        ipw.HTML("<b>Graph summary</b>"),
        output_area,
        ipw.HTML("<b>Alignment results</b>"),
        alignment_area,
    ],
    layout=ipw.Layout(width="100%", gap="12px"),
)

display(ui_container)

# Run once to populate default TTL files
handle_build()


VBox(children=(VBox(children=(HBox(children=(VBox(children=(HTML(value='<b>Wiki entities</b>'), Textarea(value…