<a href="https://colab.research.google.com/github/IsaacFigNewton/SMIED/blob/adding-semantic-decomposition/BeamSemantic_Decomposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Config

## Install dependencies

In [1]:
!pip install gensim

Collecting gensim
  Downloading gensim-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting numpy<2.0,>=1.18.5 (from gensim)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy<1.14.0,>=1.7.0 (from gensim)
  Downloading scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
Downloading gensim-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.6/26.6 MB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
[2K   [90m━━━━━━━━━━━

## Import, config stuff

In [1]:
"""
Semantic decomposition of ("cat", "eats", "mouse") using WordNet + spaCy + depth-limited GBFS.
- Uses spaCy to parse verb synset glosses and detect subject/object dependencies.
- If both subject and object tokens are present, branches directly toward original triple synsets.
- Otherwise falls back to WordNet relations.
"""
import nltk
import spacy
from nltk.corpus import wordnet as wn
from heapq import heappush, heappop
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import gensim.downloader as api
from collections import deque
import heapq
import itertools
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple
import networkx as nx

In [2]:
nltk.download('wordnet')

# Load spaCy English model for dependency parsing
nlp = spacy.load("en_core_web_sm")

# Download required NLTK data (run once)
nltk.download('wordnet')
nltk.download('omw-1.4')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


True

In [3]:
embedding_model = api.load('word2vec-google-news-300')



In [4]:
# Type aliases
SynsetName = str  # e.g., "dog.n.01"
Path = List[SynsetName]
BeamElement = Tuple[Tuple[SynsetName, str], Tuple[SynsetName, str], float]
GetNewBeamsFn = Callable[[nx.DiGraph, SynsetName, SynsetName], List[BeamElement]]
TopKBranchFn = Callable[[List[List], object, int], List[BeamElement]]

# Helpers

In [5]:
def wn_to_nx():
    # Initialize directed graph
    G = nx.DiGraph()

    synset_rels = {
        # holonyms
        "part_holonyms": lambda x: x.part_holonyms(),
        "substance_holonyms": lambda x: x.substance_holonyms(),
        "member_holonyms": lambda x: x.member_holonyms(),

        # meronyms
        "part_meronyms": lambda x: x.part_meronyms(),
        "substance_meronyms": lambda x: x.substance_meronyms(),
        "member_meronyms": lambda x: x.member_meronyms(),

        # other
        "hypernyms": lambda x: x.hypernyms(),
        "hyponyms": lambda x: x.hyponyms(),
        "entailments": lambda x: x.entailments(),
        "causes": lambda x: x.causes(),
        "also_sees": lambda x: x.also_sees(),
        "verb_groups": lambda x: x.verb_groups(),
    }

    # add nodes (synsets) and all their edges (lexical relations) to the nx graph
    for synset in wn.all_synsets():
        for rel_name, rel_func in synset_rels.items():
            for target in rel_func(synset):
                G.add_edge(
                    synset.name(),
                    target.name(),
                    relation = rel_name[:-1]
                )
    return G

In [6]:
def get_all_neighbors(synset: wn.synset) -> List[wn.synset]:
    """Get all neighbors of a synset based on its POS."""
    neighbors = []

    # Add hypernyms and hyponyms
    neighbors.extend(synset.hypernyms())
    neighbors.extend(synset.hyponyms())

    # Add POS-specific neighbors
    if synset.pos() == 'n':
        neighbors.extend(get_noun_neighbors(synset))
    else:
        neighbors.extend(get_verb_neighbors(synset))

    return neighbors


def get_noun_neighbors(syn: wn.synset):
    """Get neighbors for a noun synset."""
    nbrs = set()
    nbrs.update(syn.part_meronyms())
    nbrs.update(syn.substance_meronyms())
    nbrs.update(syn.member_meronyms())
    nbrs.update(syn.part_holonyms())
    nbrs.update(syn.substance_holonyms())
    nbrs.update(syn.member_holonyms())
    return list(nbrs)


def get_verb_neighbors(syn: wn.synset):
    """Get neighbors for a verb synset."""
    nbrs = set()
    nbrs.update(syn.entailments())
    nbrs.update(syn.causes())
    nbrs.update(syn.also_sees())
    nbrs.update(syn.verb_groups())
    return list(nbrs)

# Embedding Helpers

In [7]:
# ============================================================================
# Embedding-based Helper Functions
# ============================================================================

def get_synset_embedding_centroid(synset, model) -> np.ndarray:
    """
    Given a wn.synset, compute centroid (mean) of embeddings for lemmas.
    Returns empty np.array if nothing found.
    """
    try:
        lemmas = [lemma.name().lower().replace("_", " ") for lemma in synset.lemmas()]
        embeddings = []
        for lemma in lemmas:
            if lemma in model:
                embeddings.append(np.asarray(model[lemma], dtype=float))
            elif lemma.replace(" ", "_") in model:
                embeddings.append(np.asarray(model[lemma.replace(" ", "_")], dtype=float))
            elif " " in lemma:
                # try individual words
                words = lemma.split()
                word_embs = [np.asarray(model[w], dtype=float) for w in words if w in model]
                if word_embs:
                    embeddings.append(np.mean(word_embs, axis=0))
        if not embeddings:
            return np.array([])  # empty
        return np.mean(embeddings, axis=0)
    except Exception as e:
        # defensive: return empty arr on any error
        return np.array([])


def embed_lexical_relations(synset, model) -> Dict[str, List[Tuple[SynsetName, np.ndarray]]]:
    """
    Return map: lexical_rel_name -> list of (synset_name, centroid ndarray)
    Filters out relations whose centroid is empty.
    """
    def _rel_centroids(get_attr):
        try:
            items = []
            for s in get_attr(synset):
                cent = get_synset_embedding_centroid(s, model)
                if cent.size > 0:
                    items.append((s.name(), cent))
            return items
        except Exception:
            return []

    return {
        "part_holonyms": _rel_centroids(lambda x: x.part_holonyms()),
        "substance_holonyms": _rel_centroids(lambda x: x.substance_holonyms()),
        "member_holonyms": _rel_centroids(lambda x: x.member_holonyms()),
        "part_meronyms": _rel_centroids(lambda x: x.part_meronyms()),
        "substance_meronyms": _rel_centroids(lambda x: x.substance_meronyms()),
        "member_meronyms": _rel_centroids(lambda x: x.member_meronyms()),
        "hypernyms": _rel_centroids(lambda x: x.hypernyms()),
        "hyponyms": _rel_centroids(lambda x: x.hyponyms()),
        "entailments": _rel_centroids(lambda x: x.entailments()),
        "causes": _rel_centroids(lambda x: x.causes()),
        "also_sees": _rel_centroids(lambda x: x.also_sees()),
        "verb_groups": _rel_centroids(lambda x: x.verb_groups()),
    }


def get_embedding_similarities(rel_embs_1: List[Tuple[str, np.ndarray]],
                              rel_embs_2: List[Tuple[str, np.ndarray]]) -> np.ndarray:
    """
    Return cosine similarity matrix (m x n) for lists of (name, centroid).
    If either is empty, returns empty (0xN or Mx0) array.
    """
    if not rel_embs_1 or not rel_embs_2:
        return np.zeros((0, 0))

    e1 = np.array([x[1] for x in rel_embs_1], dtype=float)  # (m,d)
    e2 = np.array([x[1] for x in rel_embs_2], dtype=float)  # (n,d)

    # avoid divide-by-zero: replace zero norms with eps
    e1_norms = np.linalg.norm(e1, axis=1, keepdims=True)
    e2_norms = np.linalg.norm(e2, axis=1, keepdims=True)
    e1_norms[e1_norms == 0] = 1e-8
    e2_norms[e2_norms == 0] = 1e-8

    e1u = e1 / e1_norms
    e2u = e2 / e2_norms

    sims = np.dot(e1u, e2u.T)
    return sims


def get_top_k_aligned_lex_rel_pairs(
    src_tgt_rel_map: Dict[str, str],
    src_emb_dict: Dict[str, List[Tuple[SynsetName, np.ndarray]]],
    tgt_emb_dict: Dict[str, List[Tuple[SynsetName, np.ndarray]]],
    beam_width: int = 3,
) -> List[BeamElement]:
    """
    src_tgt_rel_map: mapping from relation name in src to relation name in tgt,
      e.g., {'hypernyms': 'hyponyms', ...}

    Returns list of ((src_syn_name, src_rel), (tgt_syn_name, tgt_rel), similarity)
    """
    rel_sims = []
    for e1_rel, e2_rel in src_tgt_rel_map.items():
        e1_list = src_emb_dict.get(e1_rel, [])
        e2_list = tgt_emb_dict.get(e2_rel, [])
        if not e1_list or not e2_list:
            continue
        sims = get_embedding_similarities(e1_list, e2_list)  # shape (m,n)
        if sims.size == 0:
            continue
        for i in range(sims.shape[0]):
            for j in range(sims.shape[1]):
                try:
                    rel_sims.append(((e1_list[i][0], e1_rel), (e2_list[j][0], e2_rel), float(sims[i, j])))
                except IndexError:
                    continue

    # sort and return top-k
    rel_sims.sort(key=lambda x: x[2], reverse=True)
    if beam_width <= 0:
        return rel_sims
    return rel_sims[:beam_width]


# ============================================================================
# Refactored: Embedding-based synset branch ranking
# ============================================================================

def get_top_k_synset_branch_pairs(
    candidates: List[List],  # List of lists of synsets
    target_synsets,  # Single synset or list of synsets
    beam_width: int = 3,
    model=None,  # embedding model
    wn_module=None  # WordNet module
) -> List[BeamElement]:
    """
    Refactored to use embedding-based alignment instead of get_synset_relatedness.

    Given a list of candidate synset lists and target synset(s),
    return the k synset-relation pairs most similar to the target.

    Returns: List of ((synset_name, lexical_rel), (target_syn_name, lexical_rel), similarity)
    """
    if model is None or wn_module is None:
        return []

    # Handle target synsets (single or list)
    if not isinstance(target_synsets, (list, tuple)):
        target_synsets = [target_synsets]

    if not target_synsets:
        return []

    # Relation maps for alignment
    asymm_map = {
        "hypernyms": "hyponyms",
        "hyponyms": "hypernyms",
        "part_meronyms": "part_holonyms",
        "member_meronyms": "member_holonyms",
        "substance_meronyms": "substance_holonyms",
        "entailments": "causes",
        "causes": "entailments",
    }
    symm_map = {
        "hypernyms": "hypernyms",
        "hyponyms": "hyponyms",
        "part_meronyms": "part_meronyms",
        "member_meronyms": "member_meronyms",
        "also_sees": "also_sees",
        "verb_groups": "verb_groups",
    }

    all_results = []

    # Process each target synset
    for target_syn in target_synsets:
        if target_syn is None:
            continue

        # Precompute target embeddings
        tgt_emb_dict = embed_lexical_relations(target_syn, model)

        # Process each candidate list
        for synset_list in candidates:
            if not synset_list:
                continue

            for synset in synset_list:
                if synset is None:
                    continue

                try:
                    # Compute candidate embeddings
                    src_emb_dict = embed_lexical_relations(synset, model)

                    # Get aligned pairs from asymmetric relations
                    asymm_pairs = get_top_k_aligned_lex_rel_pairs(
                        asymm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width
                    )
                    all_results.extend(asymm_pairs)

                    # Get aligned pairs from symmetric relations
                    symm_pairs = get_top_k_aligned_lex_rel_pairs(
                        symm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width
                    )
                    all_results.extend(symm_pairs)

                except Exception:
                    continue

    # Sort by similarity and return top-k
    all_results.sort(key=lambda x: x[2], reverse=True)
    return all_results[:beam_width]

In [8]:
import numpy as np
from typing import Dict, List, Tuple, Optional, Iterable, Set, Callable
import networkx as nx

# Type aliases
SynsetName = str
BeamElement = Tuple[Tuple[SynsetName, str], Tuple[SynsetName, str], float]
GetNewBeamsFn = Callable[[nx.DiGraph, SynsetName, SynsetName], List[BeamElement]]
TopKBranchFn = Callable[[List[List], object, int], List[BeamElement]]

# -------------------------
# 1) Embedding centroids
# -------------------------
def get_synset_embedding_centroid(synset, model) -> np.ndarray:
    """
    Given a wn.synset, compute centroid (mean) of embeddings for lemmas.
    Returns empty np.array if nothing found.
    """
    try:
        lemmas = [lemma.name().lower().replace("_", " ") for lemma in synset.lemmas()]
        embeddings = []
        for lemma in lemmas:
            if lemma in model:
                embeddings.append(np.asarray(model[lemma], dtype=float))
            elif lemma.replace(" ", "_") in model:
                embeddings.append(np.asarray(model[lemma.replace(" ", "_")], dtype=float))
            elif " " in lemma:
                # try individual words
                words = lemma.split()
                word_embs = [np.asarray(model[w], dtype=float) for w in words if w in model]
                if word_embs:
                    embeddings.append(np.mean(word_embs, axis=0))
        if not embeddings:
            return np.array([])  # empty
        return np.mean(embeddings, axis=0)
    except Exception as e:
        # defensive: return empty arr on any error
        print(f"[get_synset_embedding_centroid] Error for {getattr(synset, 'name', lambda: synset)()}: {e}")
        return np.array([])

# -------------------------
# 2) Embed lexical relations
# -------------------------
def embed_lexical_relations(synset, model) -> Dict[str, List[Tuple[SynsetName, np.ndarray]]]:
    """
    Return map: lexical_rel_name -> list of (synset_name, centroid ndarray)
    Filters out relations whose centroid is empty.
    """
    def _rel_centroids(get_attr):
        try:
            items = []
            for s in get_attr(synset):
                cent = get_synset_embedding_centroid(s, model)
                if cent.size > 0:
                    items.append((s.name(), cent))
            return items
        except Exception as e:
            print(f"[embed_lexical_relations] error for {synset.name()}: {e}")
            return []

    return {
        "part_holonyms": _rel_centroids(lambda x: x.part_holonyms()),
        "substance_holonyms": _rel_centroids(lambda x: x.substance_holonyms()),
        "member_holonyms": _rel_centroids(lambda x: x.member_holonyms()),
        "part_meronyms": _rel_centroids(lambda x: x.part_meronyms()),
        "substance_meronyms": _rel_centroids(lambda x: x.substance_meronyms()),
        "member_meronyms": _rel_centroids(lambda x: x.member_meronyms()),
        "hypernyms": _rel_centroids(lambda x: x.hypernyms()),
        "hyponyms": _rel_centroids(lambda x: x.hyponyms()),
        "entailments": _rel_centroids(lambda x: x.entailments()),
        "causes": _rel_centroids(lambda x: x.causes()),
        "also_sees": _rel_centroids(lambda x: x.also_sees()),
        "verb_groups": _rel_centroids(lambda x: x.verb_groups()),
    }

# -------------------------
# 3) Embedding similarities
# -------------------------
def get_embedding_similarities(rel_embs_1: List[Tuple[str, np.ndarray]], rel_embs_2: List[Tuple[str, np.ndarray]]) -> np.ndarray:
    """
    Return cosine similarity matrix (m x n) for lists of (name, centroid).
    If either is empty, returns empty (0xN or Mx0) array.
    """
    if not rel_embs_1 or not rel_embs_2:
        return np.zeros((0, 0))

    e1 = np.array([x[1] for x in rel_embs_1], dtype=float)  # (m,d)
    e2 = np.array([x[1] for x in rel_embs_2], dtype=float)  # (n,d)

    # avoid divide-by-zero: replace zero norms with eps
    e1_norms = np.linalg.norm(e1, axis=1, keepdims=True)
    e2_norms = np.linalg.norm(e2, axis=1, keepdims=True)
    e1_norms[e1_norms == 0] = 1e-8
    e2_norms[e2_norms == 0] = 1e-8

    e1u = e1 / e1_norms
    e2u = e2 / e2_norms

    sims = np.dot(e1u, e2u.T)
    return sims

# -------------------------
# 4) Top-K aligned lex relation pairs
# -------------------------
def get_top_k_aligned_lex_rel_pairs(
    src_tgt_rel_map: Dict[str, str],
    src_emb_dict: Dict[str, List[Tuple[SynsetName, np.ndarray]]],
    tgt_emb_dict: Dict[str, List[Tuple[SynsetName, np.ndarray]]],
    beam_width: int = 3,
) -> List[BeamElement]:
    """
    src_tgt_rel_map: mapping from relation name in src to relation name in tgt,
      e.g., {'hypernyms': 'hyponyms', ...}

    Returns list of ((src_syn_name, src_rel), (tgt_syn_name, tgt_rel), similarity)
    """
    rel_sims = []
    for e1_rel, e2_rel in src_tgt_rel_map.items():
        e1_list = src_emb_dict.get(e1_rel, [])
        e2_list = tgt_emb_dict.get(e2_rel, [])
        if not e1_list or not e2_list:
            continue
        sims = get_embedding_similarities(e1_list, e2_list)  # shape (m,n)
        if sims.size == 0:
            continue
        for i in range(sims.shape[0]):
            for j in range(sims.shape[1]):
                try:
                    rel_sims.append(((e1_list[i][0], e1_rel), (e2_list[j][0], e2_rel), float(sims[i, j])))
                except IndexError as ex:
                    raise IndexError(f"Index error in get_top_k_aligned_lex_rel_pairs: i={i}, j={j}, shapes e1={len(e1_list)}, e2={len(e2_list)}") from ex

    # sort and return top-k
    rel_sims.sort(key=lambda x: x[2], reverse=True)
    if beam_width <= 0:
        return rel_sims
    return rel_sims[:beam_width]

# -------------------------
# 5) Adapter: get_new_beams_fn for PairwiseBidirectionalAStar
# -------------------------
def get_new_beams_from_embeddings(
    g: nx.DiGraph,
    src_name: SynsetName,
    tgt_name: SynsetName,
    wn_module,
    model,
    beam_width: int = 3,
    asymm_map: Optional[Dict[str, str]] = None,
    symm_map: Optional[Dict[str, str]] = None,
) -> List[BeamElement]:
    """
    Adapter to produce the beam format expected by PairwiseBidirectionalAStar.
    - src_name / tgt_name are synset name strings (e.g., 'dog.n.01')
    - wn_module is your WordNet interface (e.g., nltk.corpus.wordnet as wn)
    - model is embedding model (contains token -> vector)
    """
    # default relation maps (tweak as needed)
    if asymm_map is None:
        asymm_map = {
            "hypernyms": "hyponyms",
            "hyponyms": "hypernyms",
            "part_meronyms": "part_holonyms",
            "member_meronyms": "member_holonyms",
            "substance_meronyms": "substance_holonyms",
            "entailments": "causes",
            "causes": "entailments",
        }
    if symm_map is None:
        symm_map = {
            "hypernyms": "hypernyms",
            "hyponyms": "hyponyms",
            "part_meronyms": "part_meronyms",
            "member_meronyms": "member_meronyms",
            "also_sees": "also_sees",
            "verb_groups": "verb_groups",
        }

    try:
        src_syn = wn_module.synset(src_name)
        tgt_syn = wn_module.synset(tgt_name)
    except Exception:
        # If synset names invalid, return empty beams
        return []

    # Embed lexical relations for both synsets
    src_emb_dict = embed_lexical_relations(src_syn, model)
    tgt_emb_dict = embed_lexical_relations(tgt_syn, model)

    # get top k pairs from asymmetric and symmetric maps
    asymm_pairs = get_top_k_aligned_lex_rel_pairs(asymm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width)
    symm_pairs = get_top_k_aligned_lex_rel_pairs(symm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width)

    combined = asymm_pairs + symm_pairs
    # sort by similarity and trim to beam_width
    combined.sort(key=lambda x: x[2], reverse=True)
    return combined[:beam_width]

# -------------------------
# 6) Gloss seeding helper (optional use of top_k_branch_fn)
# -------------------------
def build_gloss_seed_nodes_from_predicate(
    pred_syn,
    wn_module,
    nlp_func,
    mode: str = "subjects",  # "subjects" or "objects" or "verbs"
    extract_subjects_fn: Optional[Callable] = None,
    extract_objects_fn: Optional[Callable] = None,
    extract_verbs_fn: Optional[Callable] = None,
    top_k_branch_fn: Optional[TopKBranchFn] = None,
    target_synsets: Optional[List] = None,
    max_sample_size: int = 5,
    beam_width: int = 3,
) -> Set[SynsetName]:
    """
    Extract tokens from pred_syn gloss and return a set of synset-name seeds.
    If top_k_branch_fn provided, use it to select top-k matching synsets.
    - pred_syn: wn.synset
    - nlp_func: spaCy call (text -> doc)
    - mode: 'subjects'|'objects'|'verbs' decides which extractor to use
    """
    doc = nlp_func(pred_syn.definition())
    tokens = []
    if mode == "subjects" and extract_subjects_fn is not None:
        tokens, _ = extract_subjects_fn(doc)
    elif mode == "objects" and extract_objects_fn is not None:
        tokens = extract_objects_fn(doc)
    elif mode == "verbs" and extract_verbs_fn is not None:
        tokens = extract_verbs_fn(doc)
    else:
        # fallback: use any nouns in doc
        tokens = [tok for tok in doc if tok.pos_ == "NOUN"]

    # candidate synset lists for each token
    candidate_synsets = []
    for tok in tokens[:max_sample_size]:
        try:
            cand = wn_module.synsets(tok.text, pos=wn_module.NOUN if mode != "verbs" else wn_module.VERB)
            candidate_synsets.append(cand)
        except Exception:
            candidate_synsets.append([])

    seeds = set()
    if top_k_branch_fn and target_synsets is not None:
        # top_k_branch_fn is expected to accept (candidates, target_synset_or_list, beam_width)
        top_k = top_k_branch_fn(candidate_synsets[:max_sample_size], target_synsets, beam_width)
        for (s_pair, _, _) in top_k:
            # s_pair is (synset_obj_or_name, lexical_rel); convert to name if synset object
            s = s_pair[0]
            if hasattr(s, "name"):
                seeds.add(s.name())
            elif isinstance(s, str):
                seeds.add(s)
    else:
        # conservative: add the first few candidate synsets' names
        for cand_list in candidate_synsets:
            for s in cand_list[:min(3, len(cand_list))]:
                seeds.add(s.name())
    return seeds


# Beam Construction

In [9]:
asymmetric_pairs_map = {
    # holonyms
    "part_holonyms": "part_meronyms",
    "substance_holonyms": "substance_meronyms",
    "member_holonyms": "member_meronyms",

    # meronyms
    "part_meronyms": "part_holonyms",
    "substance_meronyms": "substance_holonyms",
    "member_meronyms": "member_holonyms",

    # other
    "hypernyms": "hyponyms",
    "hyponyms": "hyponyms"
}


symmetric_pairs_map = {
    # holonyms
    "part_holonyms": "part_holonyms",
    "substance_holonyms": "substance_holonyms",
    "member_holonyms": "member_holonyms",

    # meronyms
    "part_meronyms": "part_meronyms",
    "substance_meronyms": "substance_meronyms",
    "member_meronyms": "member_meronyms",

    # other
    "hypernyms": "hypernyms",
    "hyponyms": "hyponyms",
    "entailments": "entailments",
    "causes": "causes",
    "also_sees": "also_sees",
    "verb_groups": "verb_groups"
}

In [10]:
def get_new_beams(
      g: nx.DiGraph,
      src: str,
      tgt: str,
      model=embedding_model,
      beam_width=3
    ) -> List[Tuple[
        Tuple[str, str],
        Tuple[str, str],
        float
    ]]:
    """
    Get the k closest pairs of lexical relations between 2 synsets.

    Args:
        src: WordNet Synset object (e.g., 'dog.n.01')
        tgt: WordNet Synset object (e.g., 'cat.n.01')
        model: token model (if None, will load default)
        beam_width: max number of pairs to return

    Returns:
        List of tuples of the form:
          (
            (synset1, lexical_rel),
            (synset2, lexical_rel),
            relatedness
          )
    """

    # Build a map of each synset's associated lexical relations
    #   and the centroids of their associated synsets
    src_lex_rel_embs = embed_lexical_relations(wn.synset(src), model)
    tgt_lex_rel_embs = embed_lexical_relations(wn.synset(tgt), model)

    # ensure the edges in the nx graph align with those in the embedding maps
    src_neighbors = {n for n in g.neighbors(src)}
    for rel, synset_list in src_lex_rel_embs.items():
      if not all(s[0] in src_neighbors for s in synset_list):
        raise ValueError(f"Not all lexical properties of {src} ({[s[0] for s in synset_list]}) in graph for relation {rel}")
    tgt_neighbors = {n for n in g.neighbors(tgt)}
    for rel, synset_list in tgt_lex_rel_embs.items():
      if not all(s[0] in tgt_neighbors for s in synset_list):
        raise ValueError(f"Not all lexical properties of {tgt} ({[s[0] for s in synset_list]}) in graph for relation {rel}")
    # in the future, get neighbor relation in node metadata with g.adj[n]

    # Get the asymmetric lexical relation pairings,
    #   sorted in descending order of embedding similarity
    #   e.x. similarity of synset1's hypernyms to synset2's hypernyms
    asymm_lex_rel_sims = get_top_k_aligned_lex_rel_pairs(
        asymmetric_pairs_map,
        src_lex_rel_embs,
        tgt_lex_rel_embs,
        model,
        beam_width
    )
    # Get the symmetric lexical relation pairings,
    #   sorted in descending order of embedding similarity
    #   e.x. similarity of synset1's hypernyms to synset2's hypernyms
    symm_lex_rel_sims = get_top_k_aligned_lex_rel_pairs(
        symmetric_pairs_map,
        src_lex_rel_embs,
        tgt_lex_rel_embs,
        model,
        beam_width
    )
    combined = asymm_lex_rel_sims + symm_lex_rel_sims
    beam = sorted(combined, key=lambda x: x[2], reverse=True)[:beam_width]
    return beam

# Pathing

## PairwiseBidirectionalAStar

In [39]:
import heapq
import itertools
from collections import deque
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple
import networkx as nx
import numpy as np

# Type aliases
SynsetName = str  # e.g., "dog.n.01"
Path = List[SynsetName]
BeamElement = Tuple[Tuple[SynsetName, str], Tuple[SynsetName, str], float]
GetNewBeamsFn = Callable[[nx.DiGraph, SynsetName, SynsetName], List[BeamElement]]
TopKBranchFn = Callable[[List[List], object, int], List[BeamElement]]
GlossSeedFn = Callable[[object], List]  # e.g., (parsed_gloss_doc) -> list of tokens



class PairwiseBidirectionalAStar:
    """
    Beam+depth-constrained, gloss-seeded bidirectional A* for pairwise synset search.

    Dependencies / injectable functions:
      - get_new_beams_fn(g, src, tgt) -> List[((src_node, rel),(tgt_node, rel), sim)]
      - top_k_branch_fn(candidates_lists, target_synset, beam_width) -> List[((synset, rel),(name,rel),sim)]
      - gloss_seed_fn(gloss_doc) -> list of tokens (subject/object/verb tokens)
      - wn and nlp (spaCy) are used by the outer pipeline; this class accepts seeds instead.

    Heuristics:
      - Embedding similarity -> h = 1 - sim (lower is better)
      - Gloss seeds get a small bonus (h -= GLOSS_BONUS)
    """
    GLOSS_BONUS = 0.15  # subtract from h for gloss-seeded nodes (tune)

    def __init__(
        self,
        g: nx.DiGraph,
        src: SynsetName,
        tgt: SynsetName,
        get_new_beams_fn: Optional[GetNewBeamsFn] = None,
        gloss_seed_nodes: Optional[Iterable[SynsetName]] = None,
        beam_width: int = 3,
        max_depth: int = 6,
        relax_beam: bool = False,
    ):
        """
        Args:
          g: nx.DiGraph — graph of synsets (nodes as synset names).
          src, tgt: synset node ids (strings).
          get_new_beams_fn: function to produce embedding-based beam pairs (optional).
          gloss_seed_nodes: explicit list of synset names seeded from glosses (optional).
          beam_width: beam width for initial seeding (passed to get_new_beams if used).
          max_depth: maximum hops allowed (total across both sides; enforced per side).
          relax_beam: if True, allow exploring nodes outside the allowed beams.
        """
        self.g = g
        self.src = src
        self.tgt = tgt
        self.get_new_beams_fn = get_new_beams_fn
        self.gloss_seed_nodes = set(gloss_seed_nodes) if gloss_seed_nodes else set()
        self.beam_width = beam_width
        self.max_depth = max_depth
        self.relax_beam = relax_beam

        # will be set by _build_allowed_and_heuristics
        self.src_allowed: Set[SynsetName] = set()
        self.tgt_allowed: Set[SynsetName] = set()
        self.h_forward: Dict[SynsetName, float] = {}
        self.h_backward: Dict[SynsetName, float] = {}

        # search state
        self._counter = itertools.count()
        self.open_f: List[Tuple[float, int, SynsetName]] = []
        self.open_b: List[Tuple[float, int, SynsetName]] = []
        self.g_f: Dict[SynsetName, float] = {}
        self.g_b: Dict[SynsetName, float] = {}
        self.depth_f: Dict[SynsetName, int] = {}
        self.depth_b: Dict[SynsetName, int] = {}
        self.parent_f: Dict[SynsetName, Optional[SynsetName]] = {}
        self.parent_b: Dict[SynsetName, Optional[SynsetName]] = {}
        self.closed_f: Set[SynsetName] = set()
        self.closed_b: Set[SynsetName] = set()

    # -------------------------
    # Setup: allowed sets & heuristics
    # -------------------------
    def _build_allowed_and_heuristics(self):
        """
        Build allowed node sets and heuristic maps using get_new_beams_fn and gloss seeds.
        - src_allowed/tgt_allowed: union of beam nodes and explicit gloss seeds + src/tgt.
        - h_forward/h_backward: h = 1 - sim (embedding), gloss seeds get bonus.
        """
        beams = []
        if self.get_new_beams_fn is not None:
            try:
                beams = self.get_new_beams_fn(self.g, self.src, self.tgt) or []
            except Exception:
                beams = []

        # base allowed sets from embedding beams
        src_beam_pairs = [b[0] for b in beams]
        tgt_beam_pairs = [b[1] for b in beams]
        self.src_allowed = {p[0] for p in src_beam_pairs}
        self.tgt_allowed = {p[0] for p in tgt_beam_pairs}

        # always include src and tgt
        self.src_allowed.add(self.src)
        self.tgt_allowed.add(self.tgt)

        # Include explicit gloss seeds in allowed sets (boost their priority)
        for node in self.gloss_seed_nodes:
            # you may want to add heuristics only to one side depending on how seeds were created.
            self.src_allowed.add(node)
            self.tgt_allowed.add(node)

        # Base heuristics from embedding beams
        self.h_forward = {}
        self.h_backward = {}
        for (s_pair, t_pair, sim) in beams:
            s_node = s_pair[0]
            t_node = t_pair[0]
            h_val = max(0.0, 1.0 - float(sim))
            # keep smallest h (best sim)
            if s_node not in self.h_forward or h_val < self.h_forward[s_node]:
                self.h_forward[s_node] = h_val
            if t_node not in self.h_backward or h_val < self.h_backward[t_node]:
                self.h_backward[t_node] = h_val

        # Ensure src/tgt have at least default heuristic
        self.h_forward.setdefault(self.src, 0.0)
        self.h_backward.setdefault(self.tgt, 0.0)

        # Apply gloss bonus: reduce heuristic for gloss seeds (they look more promising)
        for node in self.gloss_seed_nodes:
            # subtract GLOSS_BONUS but keep >= 0
            if node in self.h_forward:
                self.h_forward[node] = max(0.0, self.h_forward[node] - self.GLOSS_BONUS)
            else:
                self.h_forward[node] = max(0.0, 0.5 - self.GLOSS_BONUS)  # default h=0.5 if not in beams

            if node in self.h_backward:
                self.h_backward[node] = max(0.0, self.h_backward[node] - self.GLOSS_BONUS)
            else:
                self.h_backward[node] = max(0.0, 0.5 - self.GLOSS_BONUS)

    # -------------------------
    # Initialization of queues
    # -------------------------
    def _init_search_state(self):
        self._counter = itertools.count()
        self.open_f = []
        self.open_b = []
        self.g_f = {self.src: 0.0}
        self.g_b = {self.tgt: 0.0}
        self.depth_f = {self.src: 0}
        self.depth_b = {self.tgt: 0}
        self.parent_f = {self.src: None}
        self.parent_b = {self.tgt: None}
        self.closed_f = set()
        self.closed_b = set()

        heapq.heappush(self.open_f, (self.h_forward.get(self.src, 0.0), next(self._counter), self.src))
        heapq.heappush(self.open_b, (self.h_backward.get(self.tgt, 0.0), next(self._counter), self.tgt))

    # -------------------------
    # Utilities
    # -------------------------
    def _edge_weight(self, u: SynsetName, v: SynsetName) -> float:
        try:
            return float(self.g[u][v].get("weight", 1.0))
        except Exception:
            return 1.0

    def _allowed_forward(self, node: SynsetName) -> bool:
        return self.relax_beam or node in self.src_allowed or node in self.tgt_allowed or node == self.tgt or node == self.src

    def _allowed_backward(self, node: SynsetName) -> bool:
        return self.relax_beam or node in self.tgt_allowed or node in self.src_allowed or node == self.src or node == self.tgt

    # -------------------------
    # Expand one node from forward/back
    # -------------------------
    def _expand_forward_once(self) -> Optional[SynsetName]:
        """Pop one element from forward open and expand. Return meeting node if found."""
        while self.open_f:
            _, _, current = heapq.heappop(self.open_f)
            if current in self.closed_f:
                continue
            self.closed_f.add(current)

            # if already settled by backward:
            if current in self.closed_b:
                return current

            curr_depth = self.depth_f.get(current, 0)
            if curr_depth >= self.max_depth:
                continue

            for nbr in self.g.neighbors(current):
                if not self._allowed_forward(nbr):
                    continue
                tentative_g = self.g_f[current] + self._edge_weight(current, nbr)
                tentative_depth = curr_depth + 1
                if tentative_depth > self.max_depth:
                    continue
                if tentative_g < self.g_f.get(nbr, float("inf")):
                    self.g_f[nbr] = tentative_g
                    self.depth_f[nbr] = tentative_depth
                    self.parent_f[nbr] = current
                    f_score = tentative_g + self.h_forward.get(nbr, 0.0)
                    heapq.heappush(self.open_f, (f_score, next(self._counter), nbr))
                    if nbr in self.closed_b:
                        return nbr
            return None
        return None

    def _expand_backward_once(self) -> Optional[SynsetName]:
        """Pop one element from backward open and expand predecessors. Return meeting node if found."""
        while self.open_b:
            _, _, current = heapq.heappop(self.open_b)
            if current in self.closed_b:
                continue
            self.closed_b.add(current)

            if current in self.closed_f:
                return current

            curr_depth = self.depth_b.get(current, 0)
            if curr_depth >= self.max_depth:
                continue

            for nbr in self.g.predecessors(current):
                if not self._allowed_backward(nbr):
                    continue
                tentative_g = self.g_b[current] + self._edge_weight(nbr, current)
                tentative_depth = curr_depth + 1
                if tentative_depth > self.max_depth:
                    continue
                if tentative_g < self.g_b.get(nbr, float("inf")):
                    self.g_b[nbr] = tentative_g
                    self.depth_b[nbr] = tentative_depth
                    self.parent_b[nbr] = current
                    f_score = tentative_g + self.h_backward.get(nbr, 0.0)
                    heapq.heappush(self.open_b, (f_score, next(self._counter), nbr))
                    if nbr in self.closed_f:
                        return nbr
            return None
        return None

    # -------------------------
    # Path reconstruction
    # -------------------------
    def _reconstruct_path(self, meet: SynsetName) -> Path:
        # forward part
        path_f: Path = []
        n = meet
        while n is not None:
            path_f.append(n)
            n = self.parent_f.get(n)
        path_f.reverse()

        # backward part (exclude meet to avoid dup)
        path_b: Path = []
        n = self.parent_b.get(meet)
        while n is not None:
            path_b.append(n)
            n = self.parent_b.get(n)

        return path_f + path_b

    # -------------------------
    # Core: find multiple paths
    # -------------------------
    def find_paths(self, max_results: int = 3, len_tolerance: int = 0) -> List[Tuple[Path, float]]:
        """
        Run bidirectional beam+depth constrained search and return up to max_results unique paths.
        Paths returned have total cost (sum of g_f + g_b at meet) and are kept while <= best_cost + len_tolerance.

        len_tolerance: integer extra hops allowed beyond the best (shortest) path length.
        """
        # Setup
        self._build_allowed_and_heuristics()
        self._init_search_state()

        results: List[Tuple[Path, float]] = []
        seen_paths: Set[Tuple[SynsetName, ...]] = set()
        best_cost: Optional[float] = None

        # helper to compute current lower bound on any next path cost
        def current_lower_bound() -> float:
            min_f_f = self.open_f[0][0] if self.open_f else float("inf")
            min_f_b = self.open_b[0][0] if self.open_b else float("inf")
            return min_f_f + min_f_b

        # main loop
        while (self.open_f or self.open_b) and len(results) < max_results:
            # stopping condition: if we have a best_cost and the conservative lower bound
            # exceeds best_cost + len_tolerance, we can stop.
            if best_cost is not None and current_lower_bound() > best_cost + float(len_tolerance):
                break

            # expand side with smaller top f
            top_f = self.open_f[0][0] if self.open_f else float("inf")
            top_b = self.open_b[0][0] if self.open_b else float("inf")

            meet = None
            if top_f <= top_b:
                meet = self._expand_forward_once()
            else:
                meet = self._expand_backward_once()

            if meet is None:
                continue

            # When meet occurs, reconstruct path and compute cost.
            path = self._reconstruct_path(meet)
            path_key = tuple(path)
            # compute cost: if both g maps contain meet, sum them; otherwise try to compute from edges
            cost_f = self.g_f.get(meet, float("inf"))
            cost_b = self.g_b.get(meet, float("inf"))
            total_cost = cost_f + cost_b if (cost_f < float("inf") and cost_b < float("inf")) else float("inf")

            # Hop-based cost fallback if edge weights are all 1 or inexact: length-1 equals hops
            if total_cost == float("inf"):
                # fallback to hop count
                total_cost = len(path) - 1

            if path_key not in seen_paths:
                seen_paths.add(path_key)
                # Accept path if within tolerance of current best
                if best_cost is None or total_cost <= best_cost + float(len_tolerance):
                    results.append((path, total_cost))
                    if best_cost is None or total_cost < best_cost:
                        best_cost = total_cost

            # continue searching for more meets until stopping condition triggers

        return results

## Helpers

In [29]:
# -------------------------
# Refactored top_k_branch_wrapper using new embedding functions
# -------------------------
def top_k_branch_wrapper(
    candidates_lists: List[List],
    target_synset_or_list,
    beam_width: int = 3,
    model=None,  # embedding model
    wn_module=None  # WordNet module
) -> List[BeamElement]:
    """
    Adaptation so that find_connected_paths can call a branch-ranking function
    that uses the new embedding-based alignment functions.

    Args:
      candidates_lists: list of lists of wn.synset objects (from gloss tokens)
      target_synset_or_list: either a single wn.synset or a list of target synsets
      beam_width: number of top pairs to return
      model: embedding model (e.g., Word2Vec, GloVe)
      wn_module: WordNet module
    Returns:
      List of beam elements ((synset_name, rel), (target_syn_name, rel), sim)
    """
    if model is None or wn_module is None:
        return []

    results = []

    # Choose a single representative target synset to align against.
    # If a list is provided, pick the first (you can change this strategy).
    if isinstance(target_synset_or_list, (list, tuple)):
        if not target_synset_or_list:
            return []
        target_syn = target_synset_or_list[0]
    else:
        target_syn = target_synset_or_list

    if target_syn is None:
        return []

    # Import the embedding functions (assuming they're available)
    from embedding_helpers import (
        embed_lexical_relations,
        get_top_k_aligned_lex_rel_pairs
    )

    # Precompute target synset relation embeddings
    tgt_emb_dict = embed_lexical_relations(target_syn, model)

    # Relation maps (same defaults used by the embedding-beam adapter)
    asymm_map = {
        "hypernyms": "hyponyms",
        "hyponyms": "hypernyms",
        "part_meronyms": "part_holonyms",
        "member_meronyms": "member_holonyms",
        "substance_meronyms": "substance_holonyms",
        "entailments": "causes",
        "causes": "entailments",
    }
    symm_map = {
        "hypernyms": "hypernyms",
        "hyponyms": "hyponyms",
        "part_meronyms": "part_meronyms",
        "member_meronyms": "member_meronyms",
        "also_sees": "also_sees",
        "verb_groups": "verb_groups",
    }

    # For each candidate synset in every candidate list, compute its lexical-relation embeddings
    # and align to the target's relation embeddings using get_top_k_aligned_lex_rel_pairs
    for cand_list in candidates_lists:
        for cand_syn in cand_list:
            try:
                src_emb_dict = embed_lexical_relations(cand_syn, model)
            except Exception:
                continue

            # get top pairs from both maps
            try:
                asymm_results = get_top_k_aligned_lex_rel_pairs(
                    asymm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width
                )
                results.extend(asymm_results)
            except Exception:
                pass

            try:
                symm_results = get_top_k_aligned_lex_rel_pairs(
                    symm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width
                )
                results.extend(symm_results)
            except Exception:
                pass

    # Sort aggregated results by similarity and return top beam_width
    results.sort(key=lambda x: x[2], reverse=True)
    return results[:beam_width]


# -------------------------
# High-level triple search (subject -> predicate -> object)
# -------------------------
def find_connected_paths(
    g: nx.DiGraph,
    subject_word: str,
    predicate_word: str,
    object_word: str,
    wn_module,                # your WordNet interface (e.g., nltk.corpus.wordnet or custom)
    nlp_func,                 # spaCy NLP call (callable that takes a string -> doc)
    get_new_beams_fn=None,
    top_k_branch_fn: Optional[TopKBranchFn] = None,
    extract_subjects_from_gloss=None,
    extract_objects_from_gloss=None,
    beam_width: int = 3,
    max_depth: int = 8,
    max_self_intersection: int = 5,
    max_results_per_pair: int = 3,
    len_tolerance: int = 1,
    relax_beam: bool = False,
):
    """
    Find connected subject->predicate->object paths.

    Strategy:
      - Get synsets for subject (nouns), predicate (verbs), object (nouns).
      - For each candidate predicate synset:
          - Build gloss seeds for subject->predicate side (use predicate gloss to extract candidate subject tokens)
          - Build gloss seeds for predicate->object side (use predicate gloss to extract object tokens)
          - Run PairwiseBidirectionalAStar twice (subj->pred and pred->obj), requesting multiple paths each
          - Combine pairs of returned paths that share the predicate synset and meet intersection constraints
      - Return sorted list of connected path triples
    """

    # get word synsets
    subject_synsets = wn_module.synsets(subject_word, pos=wn_module.NOUN)
    predicate_synsets = wn_module.synsets(predicate_word, pos=wn_module.VERB)
    object_synsets = wn_module.synsets(object_word, pos=wn_module.NOUN)

    results = []

    # helper: safe top_k_branch function fallback (if not provided)
    def _default_top_k_branch_fn(candidates_lists, target_synset, beam_width_inner=3):
        # fallback: return empty (user should inject a real one)
        return []

    top_k_branch_fn_used = top_k_branch_fn or _default_top_k_branch_fn

    for pred_syn in predicate_synsets:
        pred_name = pred_syn.name()

        # ------------------------
        # Build gloss seeds for subj->pred
        # ------------------------
        pred_gloss_doc = nlp_func(pred_syn.definition())

        # Extract subject tokens from predicate gloss
        active_subject_tokens = []
        if extract_subjects_from_gloss:
            try:
                active_subject_tokens, _ = extract_subjects_from_gloss(pred_gloss_doc)
            except Exception:
                active_subject_tokens = []

        # Convert tokens to candidate synset lists and get top_k branches
        subject_candidate_synsets = []
        for t in active_subject_tokens:
            try:
                synsets = wn_module.synsets(t.text, pos=wn_module.NOUN)
                if synsets:
                    subject_candidate_synsets.append(synsets)
            except Exception:
                pass

        # Get top-k synsets that best align with subject synsets
        subj_gloss_seed_nodes = set()
        if subject_candidate_synsets and top_k_branch_fn_used:
            try:
                subj_top_k = top_k_branch_fn_used(subject_candidate_synsets, subject_synsets, beam_width)
                # Extract synset names from the results
                for beam_elem in subj_top_k:
                    s_pair = beam_elem[0]  # (synset_name, rel)
                    if isinstance(s_pair[0], str):
                        subj_gloss_seed_nodes.add(s_pair[0])
                    elif hasattr(s_pair[0], 'name'):
                        subj_gloss_seed_nodes.add(s_pair[0].name())
            except Exception:
                pass

        # ------------------------
        # Build gloss seeds for pred->obj
        # ------------------------
        objects_tokens = []
        if extract_objects_from_gloss:
            try:
                objects_tokens = extract_objects_from_gloss(pred_gloss_doc)
            except Exception:
                objects_tokens = []

        # Convert tokens to candidate synset lists
        object_candidate_synsets = []
        for t in objects_tokens:
            try:
                synsets = wn_module.synsets(t.text, pos=wn_module.NOUN)
                if synsets:
                    object_candidate_synsets.append(synsets)
            except Exception:
                pass

        # Get top-k synsets that best align with object synsets
        obj_gloss_seed_nodes = set()
        if object_candidate_synsets and top_k_branch_fn_used:
            try:
                obj_top_k = top_k_branch_fn_used(object_candidate_synsets, object_synsets, beam_width)
                # Extract synset names from the results
                for beam_elem in obj_top_k:
                    s_pair = beam_elem[0]  # (synset_name, rel)
                    if isinstance(s_pair[0], str):
                        obj_gloss_seed_nodes.add(s_pair[0])
                    elif hasattr(s_pair[0], 'name'):
                        obj_gloss_seed_nodes.add(s_pair[0].name())
            except Exception:
                pass

        # Now run pairwise searches:
        # For subject->predicate: src = each subject synset, tgt = pred_syn (pred_name)
        subject_paths_map = {}  # subj_syn_name -> list of (path, cost)
        for subj_syn in subject_synsets:
            # Check if cross-POS search
            if subj_syn.pos() != pred_syn.pos():
                # Cross-POS: rely more on gloss seeds, relax beam constraints
                pair_search = PairwiseBidirectionalAStar(
                    g=g,
                    src=subj_syn.name(),
                    tgt=pred_name,
                    get_new_beams_fn=None,  # Don't use beam function for cross-POS
                    gloss_seed_nodes=subj_gloss_seed_nodes,  # Use gloss seeds
                    beam_width=beam_width,
                    max_depth=max_depth,
                    relax_beam=True  # Allow exploration outside beams
                )
            else:
                # Same-POS: use normal beam search with gloss enrichment
                pair_search = PairwiseBidirectionalAStar(
                    g=g,
                    src=subj_syn.name(),
                    tgt=pred_name,
                    get_new_beams_fn=get_new_beams_fn,
                    gloss_seed_nodes=subj_gloss_seed_nodes,
                    beam_width=beam_width,
                    max_depth=max_depth,
                    relax_beam=relax_beam
                )

            subj_paths = pair_search.find_paths(
                max_results=max_results_per_pair,
                len_tolerance=len_tolerance
            )
            if subj_paths:
                subject_paths_map[subj_syn.name()] = subj_paths

        # For predicate->object: src = pred_name, tgt = each object synset
        object_paths_map = {}  # obj_syn_name -> list of (path, cost)
        for obj_syn in object_synsets:
            # Check if cross-POS search
            if pred_syn.pos() != obj_syn.pos():
                # Cross-POS: rely more on gloss seeds, relax beam constraints
                pair_search = PairwiseBidirectionalAStar(
                    g=g,
                    src=pred_name,
                    tgt=obj_syn.name(),
                    get_new_beams_fn=None,  # Don't use beam function for cross-POS
                    gloss_seed_nodes=obj_gloss_seed_nodes,  # Use gloss seeds
                    beam_width=beam_width,
                    max_depth=max_depth,
                    relax_beam=True  # Allow exploration outside beams
                )
            else:
                # Same-POS: use normal beam search with gloss enrichment
                pair_search = PairwiseBidirectionalAStar(
                    g=g,
                    src=pred_name,
                    tgt=obj_syn.name(),
                    get_new_beams_fn=get_new_beams_fn,
                    gloss_seed_nodes=obj_gloss_seed_nodes,
                    beam_width=beam_width,
                    max_depth=max_depth,
                    relax_beam=relax_beam
                )

            obj_paths = pair_search.find_paths(
                max_results=max_results_per_pair,
                len_tolerance=len_tolerance
            )
            if obj_paths:
                object_paths_map[obj_syn.name()] = obj_paths

        # Combine: for any subj_path and obj_path that both go through `pred_name`,
        # produce combined results if intersection is low enough.
        for subj_syn_name, subj_paths in subject_paths_map.items():
            for subj_path, subj_cost in subj_paths:
                for obj_syn_name, obj_paths in object_paths_map.items():
                    for obj_path, obj_cost in obj_paths:
                        # Make sure paths connect through the predicate
                        if subj_path[-1] == pred_name and obj_path[0] == pred_name:
                            # Check intersection to avoid tautological results
                            intersection_size = len(set(subj_path).intersection(set(obj_path)))
                            if intersection_size <= max_self_intersection:
                                # Combined path: subject_path + object_path[1:] (drop duplicate predicate)
                                combined_path = subj_path + obj_path[1:]
                                combined_cost = subj_cost + obj_cost
                                combined_len = len(combined_path)
                                results.append({
                                    "predicate_synset": pred_name,
                                    "subject_path": subj_path,
                                    "object_path": obj_path,
                                    "combined_path": combined_path,
                                    "combined_cost": combined_cost,
                                    "combined_len": combined_len,
                                })

    # sort results by combined_len then cost
    results = sorted(results, key=lambda r: (r["combined_len"], r["combined_cost"]))
    return results

## Gloss Parsing

In [30]:
# ============================================================================
# Gloss Analysis Helper Functions (Keep as-is, they work well)
# ============================================================================

def extract_subjects_from_gloss(gloss_doc):
    """Extract subject tokens from a parsed gloss."""
    subjects = []

    # Direct subjects
    subjects.extend([tok for tok in gloss_doc if tok.dep_ == "nsubj"])

    # Passive subjects (which are actually objects semantically)
    # Skip these for actor identification
    passive_subjects = [tok for tok in gloss_doc if tok.dep_ == "nsubjpass"]

    # Filter out passive subjects from the main list
    subjects = [s for s in subjects if s not in passive_subjects]

    return subjects, passive_subjects


def extract_objects_from_gloss(gloss_doc):
    """Extract various types of object tokens from a parsed gloss."""
    objs = []

    # Indirect objects
    iobjs = [tok for tok in gloss_doc if tok.dep_ == "iobj"]
    objs.extend(iobjs)

    # Direct objects
    # Only include if there were no indirect objects,
    #   crude, but good for MVP
    if not iobjs:
        objs.extend([tok for tok in gloss_doc if tok.dep_ == "dobj"])

    # Prepositional objects
    objs.extend([tok for tok in gloss_doc if tok.dep_ == "pobj"])

    # General objects
    objs.extend([tok for tok in gloss_doc if tok.dep_ == "obj"])

    # Check for noun chunks related to root verb
    root_verbs = [tok for tok in gloss_doc if tok.dep_ == "ROOT" and tok.pos_ == "VERB"]
    if root_verbs and not objs:
        for noun_chunk in gloss_doc.noun_chunks:
            if any(token.head == root_verbs[0] for token in noun_chunk):
                objs.append(noun_chunk.root)

    return objs


def extract_verbs_from_gloss(gloss_doc, include_passive=False):
    """Extract verb tokens from a parsed gloss."""
    verbs = [tok for tok in gloss_doc if tok.pos_ == "VERB"]

    if include_passive:
        # Past participles used as adjectives or in relative clauses
        passive_verbs = [tok for tok in gloss_doc if
                        tok.tag_ in ["VBN", "VBD"] and
                        tok.dep_ in ["acl", "relcl", "amod"]]
        verbs.extend(passive_verbs)

    return verbs


def find_instrumental_verbs(gloss_doc):
    """Find verbs associated with instrumental use (e.g., 'used for')."""
    instrumental_verbs = []

    if "used" in gloss_doc.text.lower():
        for i, token in enumerate(gloss_doc):
            if token.text.lower() == "used":
                # Check tokens after "used"
                for j in range(i+1, min(i+4, len(gloss_doc))):
                    if gloss_doc[j].pos_ == "VERB":
                        instrumental_verbs.append(gloss_doc[j])

    return instrumental_verbs


# ============================================================================
# Simple same-POS path finding (for backward compatibility)
# ============================================================================

# def get_all_neighbors(synset, wn_module=None):
#     """Get all lexically related neighbors of a synset."""
#     neighbors = set()

#     # Add all types of relations
#     relation_methods = [
#         'hypernyms', 'hyponyms', 'holonyms', 'meronyms',
#         'similar_tos', 'also_sees', 'verb_groups',
#         'entailments', 'causes', 'attributes'
#     ]

#     for method_name in relation_methods:
#         if hasattr(synset, method_name):
#             try:
#                 related = getattr(synset, method_name)()
#                 neighbors.update(related)
#             except:
#                 pass

#     return list(neighbors)


def path_syn_to_syn(start_synset, end_synset, max_depth=6, wn_module=None):
    """
    Find shortest path between synsets of the same POS using bidirectional BFS.
    Returns a list of synset names (strings) forming the path, or None if no path found.
    """
    # Convert to names for consistency
    start_name = start_synset.name() if hasattr(start_synset, 'name') else str(start_synset)
    end_name = end_synset.name() if hasattr(end_synset, 'name') else str(end_synset)

    # Check if same POS (if we have synset objects)
    if hasattr(start_synset, 'pos') and hasattr(end_synset, 'pos'):
        if start_synset.pos() != end_synset.pos():
            return None

    # Handle the trivial case where start and end are the same
    if start_name == end_name:
        return [start_name]

    # Initialize two search frontiers
    forward_queue = deque([(start_synset, 0)])
    forward_visited = {start_name: [start_name]}

    backward_queue = deque([(end_synset, 0)])
    backward_visited = {end_name: [end_name]}

    def expand_frontier(queue, visited_from_this_side, visited_from_other_side, is_forward):
        """Expand one step of the search frontier."""
        if not queue:
            return None

        curr_synset, depth = queue.popleft()

        if depth >= (max_depth + 1) // 2:
            return None

        curr_name = curr_synset.name() if hasattr(curr_synset, 'name') else str(curr_synset)
        path_to_current = visited_from_this_side[curr_name]

        for neighbor in get_all_neighbors(curr_synset, wn_module):
            neighbor_name = neighbor.name() if hasattr(neighbor, 'name') else str(neighbor)

            if neighbor_name in visited_from_this_side:
                continue

            if is_forward:
                new_path = path_to_current + [neighbor_name]
            else:
                new_path = [neighbor_name] + path_to_current

            if neighbor_name in visited_from_other_side:
                other_path = visited_from_other_side[neighbor_name]

                if is_forward:
                    full_path = path_to_current + other_path
                else:
                    full_path = other_path + path_to_current

                return full_path

            visited_from_this_side[neighbor_name] = new_path
            queue.append((neighbor, depth + 1))

        return None

    # Alternate between forward and backward search
    while forward_queue or backward_queue:
        if forward_queue:
            result = expand_frontier(forward_queue, forward_visited, backward_visited, True)
            if result:
                return result

        if backward_queue:
            result = expand_frontier(backward_queue, backward_visited, forward_visited, False)
            if result:
                return result

    return None

In [31]:
# ============================================================================
# Main wrapper function to replace find_connected_shortest_paths
# ============================================================================

def find_connected_shortest_paths(
    subject_word: str,
    predicate_word: str,
    object_word: str,
    wn_module,
    nlp_func,
    model=None,  # embedding model
    g: nx.DiGraph = None,  # synset graph
    max_depth: int = 10,
    max_self_intersection: int = 5,
    beam_width: int = 3,
    max_results_per_pair: int = 3,
    len_tolerance: int = 1,
    relax_beam: bool = False
):
    """
    Wrapper that uses the new find_connected_paths architecture.
    Returns the best connected path in the old format for backward compatibility.
    """

    # Import the necessary components
    from pathfinding_core import (
        PairwiseBidirectionalAStar,
        find_connected_paths,
        get_new_beams_from_embeddings
    )

    # Create the beam function if we have a model
    get_new_beams_fn = None
    if model is not None and g is not None:
        get_new_beams_fn = lambda g, src, tgt: get_new_beams_from_embeddings(
            g, src, tgt, wn_module, model, beam_width=beam_width
        )

    # Create the top_k_branch function
    top_k_branch_fn = None
    if model is not None:
        top_k_branch_fn = lambda candidates, target, bw: get_top_k_synset_branch_pairs(
            candidates, target, bw, model, wn_module
        )

    # Build the graph if not provided
    if g is None:
        g = build_synset_graph(wn_module)  # You'll need to implement this

    # Call the new find_connected_paths function
    results = find_connected_paths(
        g=g,
        subject_word=subject_word,
        predicate_word=predicate_word,
        object_word=object_word,
        wn_module=wn_module,
        nlp_func=nlp_func,
        get_new_beams_fn=get_new_beams_fn,
        top_k_branch_fn=top_k_branch_fn,
        extract_subjects_from_gloss=extract_subjects_from_gloss,
        extract_objects_from_gloss=extract_objects_from_gloss,
        beam_width=beam_width,
        max_depth=max_depth,
        max_self_intersection=max_self_intersection,
        max_results_per_pair=max_results_per_pair,
        len_tolerance=len_tolerance,
        relax_beam=relax_beam
    )

    # Convert results to old format (best_subject_path, best_object_path, best_predicate)
    if results:
        best_result = results[0]  # Take the best result
        # Convert synset names back to synset objects if needed
        subject_path = [wn_module.synset(name) if isinstance(name, str) else name
                       for name in best_result["subject_path"]]
        object_path = [wn_module.synset(name) if isinstance(name, str) else name
                      for name in best_result["object_path"]]
        pred_synset = wn_module.synset(best_result["predicate_synset"]) \
                     if isinstance(best_result["predicate_synset"], str) \
                     else best_result["predicate_synset"]

        return subject_path, object_path, pred_synset

    return None, None, None


# ============================================================================
# Helper function to build synset graph (if needed)
# ============================================================================

def build_synset_graph(wn_module) -> nx.DiGraph:
    """
    Build a directed graph of synsets with their lexical relations.
    """
    g = nx.DiGraph()

    # Get all synsets (you may want to limit this for performance)
    all_synsets = list(wn_module.all_synsets())

    # Add nodes
    for synset in all_synsets:
        g.add_node(synset.name())

    # Add edges based on relations
    for synset in all_synsets:
        synset_name = synset.name()

        # Add various relation types as edges
        relations = {
            'hypernyms': synset.hypernyms(),
            'hyponyms': synset.hyponyms(),
            'holonyms': synset.holonyms(),
            'meronyms': synset.meronyms(),
            'similar_tos': synset.similar_tos() if hasattr(synset, 'similar_tos') else [],
            'also_sees': synset.also_sees() if hasattr(synset, 'also_sees') else [],
            'verb_groups': synset.verb_groups() if hasattr(synset, 'verb_groups') else [],
            'entailments': synset.entailments() if hasattr(synset, 'entailments') else [],
            'causes': synset.causes() if hasattr(synset, 'causes') else [],
        }

        for rel_type, related_synsets in relations.items():
            for related in related_synsets:
                if related.name() in g:
                    g.add_edge(synset_name, related.name(),
                              relation=rel_type, weight=1.0)

    return g


# ============================================================================
# Display Functions (keep as-is for backward compatibility)
# ============================================================================

def show_path(label, path):
    """Pretty print a path of synsets."""
    if path:
        print(f"{label}:")
        # Handle both synset objects and name strings
        path_str = []
        for s in path:
            if hasattr(s, 'name'):
                path_str.append(f"{s.name()} ({s.definition()})")
            else:
                path_str.append(str(s))
        print(" -> ".join(path_str))
        print(f"Path length: {len(path)}")
        print()
    else:
        print(f"{label}: No path found")
        print()


def show_connected_paths(subject_path, object_path, predicate):
    """Display the connected paths with their shared predicate."""
    if subject_path and object_path and predicate:
        print("=" * 70)
        pred_name = predicate.name() if hasattr(predicate, 'name') else str(predicate)
        print(f"CONNECTED PATH through predicate: {pred_name}")
        print("=" * 70)

        show_path("Subject -> Predicate path", subject_path)
        show_path("Predicate -> Object path", object_path)

        # Show the complete connected path
        complete_path = subject_path + object_path[1:]  # Avoid duplicating the predicate
        print("Complete connected path:")
        path_names = []
        for s in complete_path:
            if hasattr(s, 'name'):
                path_names.append(s.name())
            else:
                path_names.append(str(s))
        print(" -> ".join(path_names))
        print(f"Total path length: {len(complete_path)}")
        print()
    else:
        print("No connected path found through any predicate synset.")

In [32]:
# # ============================================================================
# # Core Path Finding Functions
# # ============================================================================

# def path_syn_to_syn(start_synset, end_synset, max_depth=6):
#     """
#     Find shortest path between synsets of the same POS using bidirectional BFS.
#     Returns a list of synsets forming the path, or None if no path found.
#     """

#     if not (start_synset.pos() == end_synset.pos() and start_synset.pos() in {'n', 'v'}):
#       raise ValueError(f"{start_synset.name()} POS tag != {end_synset.name()}. Synsets must be of the same POS (noun or verb).")

#     # Handle the trivial case where start and end are the same
#     if start_synset.name() == end_synset.name():
#         return [start_synset]

#     # Initialize two search frontiers
#     forward_queue = deque([(start_synset, 0)])
#     forward_visited = {start_synset.name(): [start_synset]}

#     backward_queue = deque([(end_synset, 0)])
#     backward_visited = {end_synset.name(): [end_synset]}

#     def expand_frontier(queue, visited_from_this_side, visited_from_other_side, is_forward):
#         """Expand one step of the search frontier."""
#         if not queue:
#             return None

#         curr_synset, depth = queue.popleft()

#         if depth >= (max_depth + 1) // 2:
#             return None

#         path_to_current = visited_from_this_side[curr_synset.name()]

#         for neighbor in get_all_neighbors(curr_synset):
#             neighbor_name = neighbor.name()

#             if neighbor_name in visited_from_this_side:
#                 continue

#             if is_forward:
#                 new_path = path_to_current + [neighbor]
#             else:
#                 new_path = [neighbor] + path_to_current

#             if neighbor_name in visited_from_other_side:
#                 other_path = visited_from_other_side[neighbor_name]

#                 if is_forward:
#                     full_path = path_to_current + other_path
#                 else:
#                     full_path = other_path + path_to_current

#                 return full_path

#             visited_from_this_side[neighbor_name] = new_path
#             queue.append((neighbor, depth + 1))

#         return None

#     # Alternate between forward and backward search
#     while forward_queue or backward_queue:
#         if forward_queue:
#             result = expand_frontier(forward_queue, forward_visited, backward_visited, True)
#             if result:
#                 return result

#         if backward_queue:
#             result = expand_frontier(backward_queue, backward_visited, forward_visited, False)
#             if result:
#                 return result

#     return None


# # ============================================================================
# # Gloss Analysis Helper Functions
# # ============================================================================

# def extract_subjects_from_gloss(gloss_doc):
#     """Extract subject tokens from a parsed gloss."""
#     subjects = []

#     # Direct subjects
#     subjects.extend([tok for tok in gloss_doc if tok.dep_ == "nsubj"])

#     # Passive subjects (which are actually objects semantically)
#     # Skip these for actor identification
#     passive_subjects = [tok for tok in gloss_doc if tok.dep_ == "nsubjpass"]

#     # Filter out passive subjects from the main list
#     subjects = [s for s in subjects if s not in passive_subjects]

#     return subjects, passive_subjects


# def extract_objects_from_gloss(gloss_doc):
#     """Extract various types of object tokens from a parsed gloss."""
#     objs = []

#     # Indirect objects
#     iobjs = [tok for tok in gloss_doc if tok.dep_ == "iobj"]
#     objs.extend(iobjs)

#     # Direct objects
#     # Only include if there were no indirect objects,
#     #   crude, but good for MVP
#     if not iobjs:
#         objs.extend([tok for tok in gloss_doc if tok.dep_ == "dobj"])

#     # Prepositional objects
#     objs.extend([tok for tok in gloss_doc if tok.dep_ == "pobj"])

#     # General objects
#     objs.extend([tok for tok in gloss_doc if tok.dep_ == "obj"])

#     # Check for noun chunks related to root verb
#     root_verbs = [tok for tok in gloss_doc if tok.dep_ == "ROOT" and tok.pos_ == "VERB"]
#     if root_verbs and not objs:
#         for noun_chunk in gloss_doc.noun_chunks:
#             if any(token.head == root_verbs[0] for token in noun_chunk):
#                 objs.append(noun_chunk.root)

#     return objs


# def extract_verbs_from_gloss(gloss_doc, include_passive=False):
#     """Extract verb tokens from a parsed gloss."""
#     verbs = [tok for tok in gloss_doc if tok.pos_ == "VERB"]

#     if include_passive:
#         # Past participles used as adjectives or in relative clauses
#         passive_verbs = [tok for tok in gloss_doc if
#                         tok.tag_ in ["VBN", "VBD"] and
#                         tok.dep_ in ["acl", "relcl", "amod"]]
#         verbs.extend(passive_verbs)

#     return verbs


# def find_instrumental_verbs(gloss_doc):
#     """Find verbs associated with instrumental use (e.g., 'used for')."""
#     instrumental_verbs = []

#     if "used" in gloss_doc.text.lower():
#         for i, token in enumerate(gloss_doc):
#             if token.text.lower() == "used":
#                 # Check tokens after "used"
#                 for j in range(i+1, min(i+4, len(gloss_doc))):
#                     if gloss_doc[j].pos_ == "VERB":
#                         instrumental_verbs.append(gloss_doc[j])

#     return instrumental_verbs


# # ============================================================================
# # Cross-POS Path Finding Functions
# # ============================================================================
# def get_top_k_synset_branch_pairs(
#       candidates: List[List[wn.synset]],
#       target_synset: wn.synset,
#       beam_width=3
#     ) -> List[Tuple[
#         Tuple[wn.synset, str],
#         Tuple[wn.synset, str],
#         float
#     ]]:
#     """
#     Given a list of candidate tokens and a target synset,


#     Return the k synset subrelation pairs most similar to the target of the form:
#       ((synset, lexical_rel), (name, lexical_rel), relatedness)
#     """
#     top_k_asymm_branches = list()
#     top_k_symm_branches = list()
#     beam = list()
#     # for each list of possible synsets for a candidate token
#     for synsets in candidates:
#         # # filter to subjects based on whether they reside in the same sub-category
#         # #   where the subcategory != 'entity.n.01' or a similar top-level
#         # synsets = [
#         #     s for s in synsets
#         #     if s.root_hypernyms() != s.lowest_common_hypernyms(target_synset)
#         # ]
#         # # if there are synsets left for the candidate token after pruning
#         if synsets:
#             # # if the target is a verb,
#             # #   filter out any synsets with no lemma frames matching the target
#             # #   frame patterns: (Somebody [v] something), (Somebody [v]), ...
#             # if target_synset.pos() == 'v':
#             #     synsets = [
#             #         s for s in synsets
#             #         if any(
#             #             frame in s.frame_ids()
#             #             for frame in target_synset.frame_ids()
#             #         )
#             #     ]
#             for synset in synsets:
#               beam += get_synset_relatedness(synset, target_synset)
#               beam += get_synset_relatedness(synset, target_synset)
#     beam = sorted(
#         beam,
#         key=lambda x: x[2],
#         reverse=True
#     )[:beam_width]
#     return beam


# def find_subject_to_predicate_path(
#       subject_synset: wn.synset,
#       predicate_synset: wn.synset,
#       max_depth:int,
#       visited=set(),
#       max_sample_size=5,
#     ):
#     """Find path from subject (noun) to predicate (verb)."""
#     if subject_synset.name() in visited or predicate_synset.name() in visited:
#       return None

#     paths = []
#     print()
#     print(f"Finding path from {subject_synset.name()} to {predicate_synset.name()}")

#     # Strategy 1: Look for active subjects in verb's gloss
#     pred_gloss_doc = nlp(predicate_synset.definition())
#     # passive subjects are semantically equivalent to objects
#     active_subjects, _ = extract_subjects_from_gloss(pred_gloss_doc)
#     # convert spacy tokens to lists of synsets
#     subjects = [wn.synsets(s.text, pos=subject_synset.pos()) for s in active_subjects]
#     # of the remaining subjects, get the most similar
#     top_k = get_top_k_synset_branches(active_subjects[:max_sample_size], subject_synset)
#     if top_k:
#       print(f"Found best matches for {subject_synset.name()}: {top_k} using strategy 1")
#       for matched_synset, _ in top_k:
#         path = path_syn_to_syn(subject_synset, matched_synset, max_depth-1)
#         if path:
#             paths.append(path + [predicate_synset])

#     # Strategy 2: Look for verbs in the noun's gloss
#     subj_gloss_doc = nlp(subject_synset.definition())
#     verbs = extract_verbs_from_gloss(subj_gloss_doc, include_passive=False)
#     # convert spacy tokens to lists of synsets
#     verbs = [wn.synsets(v.text, pos=predicate_synset.pos()) for v in verbs]
#     # of the remaining subjects, get the most similar
#     top_k = get_top_k_synset_branches(verbs[:max_sample_size], predicate_synset)
#     if top_k:
#       print(f"Found best matches for {predicate_synset.name()}: {top_k} using strategy 2")
#       for matched_synset, _ in top_k:
#         path = path_syn_to_syn(matched_synset, predicate_synset, max_depth-1)
#         if path:
#             paths.append([subject_synset] + path)

#     # Strategy 3: Explore the 3 most promising pairs of neighbors
#     subject_neighbors = get_all_neighbors(subject_synset)
#     predicate_neighbors = get_all_neighbors(predicate_synset)
#     top_k = get_k_closest_synset_pairs(subject_neighbors, predicate_neighbors)
#     if top_k:
#       print(f"Most promising pairs for bidirectional exploration: {top_k} using strategy 3")
#       for s, p, _ in top_k:
#         visited.add(subject_synset.name())
#         visited.add(predicate_synset.name())
#         path = find_subject_to_predicate_path(s, p, max_depth-1, visited)
#         if path:
#             paths.append([subject_synset] + path + [predicate_synset])


#     # Return shortest path if any found
#     return min(paths, key=len) if paths else None


# def find_predicate_to_object_path(
#       predicate_synset: wn.synset,
#       object_synset: wn.synset,
#       max_depth:int,
#       visited=set(),
#       max_sample_size=5,
#     ):
#     """Find path from predicate (verb) to object (noun)."""

#     if predicate_synset.name() in visited or object_synset.name() in visited:
#       return None

#     paths = []
#     print()
#     print(f"Finding path from {predicate_synset.name()} to {object_synset.name()}")

#     # === Strategy 1: Objects in predicate gloss (incl. passive subjects) ===
#     pred_gloss_doc = nlp(predicate_synset.definition())
#     objects = extract_objects_from_gloss(pred_gloss_doc)
#     _, passive_subjects = extract_subjects_from_gloss(pred_gloss_doc)
#     objects.extend(passive_subjects)
#     # convert spacy tokens to lists of synsets
#     objects = [wn.synsets(o.text, pos=object_synset.pos()) for o in objects]
#     top_k = get_top_k_synset_branches(objects[:max_sample_size], object_synset)
#     if top_k:
#       print(f"Found best matches for {object_synset.name()}: {top_k} using strategy 1")
#       for matched_synset, _ in top_k:
#         path = path_syn_to_syn(matched_synset, object_synset, max_depth-1)
#         if path:
#             paths.append([predicate_synset] + path)

#     # === Strategy 2: Verbs in object's gloss ===
#     obj_gloss_doc = nlp(object_synset.definition())
#     verbs = extract_verbs_from_gloss(obj_gloss_doc, include_passive=True)
#     # Use instrumental verbs in object's gloss as backup
#     verbs.extend(find_instrumental_verbs(obj_gloss_doc))
#     # convert spacy tokens to lists of synsets
#     verbs = [wn.synsets(v.text, pos=predicate_synset.pos()) for v in verbs]
#     top_k = get_top_k_synset_branches(verbs[:max_sample_size], predicate_synset)
#     if top_k:
#       print(f"Found best matches for {predicate_synset.name()}: {top_k} using strategy 2")
#       for matched_synset, _ in top_k:
#         path = path_syn_to_syn(predicate_synset, matched_synset, max_depth-1)
#         if path:
#             paths.append(path + [object_synset])

#     # Strategy 3: Explore the 3 most promising neighbors
#     predicate_neighbors = get_all_neighbors(predicate_synset)
#     object_neighbors = get_all_neighbors(object_synset)
#     top_k = get_k_closest_synset_pairs(predicate_neighbors, object_neighbors)
#     if top_k:
#       print(f"Most promising pairs for bidirectional exploration: {top_k} using strategy 3")
#       for p, o, _ in top_k:
#         visited.add(predicate_synset.name())
#         visited.add(object_synset.name())
#         path = find_predicate_to_object_path(p, o, max_depth-1, visited)
#         if path:
#             paths.append([predicate_synset] + path + [object_synset])


#     # Return shortest path if any found
#     return min(paths, key=len) if paths else None


# # ============================================================================
# # Main Connected Path Finding Function
# # ============================================================================

# def find_connected_shortest_paths(
#       subject_word,
#       predicate_word,
#       object_word,
#       max_depth=10,
#       max_self_intersection=5
#     ):
#     """
#     Find shortest connected paths from subject through predicate to object.
#     Ensures that the same predicate synset connects both paths.
#     """

#     # Get synsets for each word
#     subject_synsets = wn.synsets(subject_word, pos=wn.NOUN)
#     predicate_synsets = wn.synsets(predicate_word, pos=wn.VERB)
#     object_synsets = wn.synsets(object_word, pos=wn.NOUN)

#     best_combined_path_length = float('inf')
#     best_subject_path = None
#     best_object_path = None
#     best_predicate = None

#     # Try each predicate synset as the connector
#     for pred in predicate_synsets:
#         # Find paths from all subjects to this specific predicate
#         subject_paths = []
#         for subj in subject_synsets:
#             path = find_subject_to_predicate_path(subj, pred, max_depth)
#             if path:
#                 subject_paths.append(path)

#         # Find paths from this specific predicate to all objects
#         object_paths = []
#         for obj in object_synsets:
#             path = find_predicate_to_object_path(pred, obj, max_depth)
#             if path:
#                 object_paths.append(path)

#         # If we have both paths through this predicate, check if it's the best
#         if subject_paths and object_paths:
#             # find pairs of paths that don't intersect with eachother
#             #   i.e. burglar > break_in > attack > strike > shoot > strike > attack > woman
#             #   would not be allowed, since tautological statements are uninformative
#             valid_pairs = list()
#             for subj_path in subject_paths:
#               for obj_path in object_paths:
#                 if len(set(subj_path).intersection(set(obj_path))) <= max_self_intersection:
#                   valid_pairs.append((
#                       subj_path,
#                       obj_path,
#                       # Calculate combined length (subtract 1 to avoid counting predicate twice)
#                       len(subj_path) + len(obj_path) - 1
#                   ))

#             if not valid_pairs:
#               print(f"No valid pairs of subj, obj paths found for {pred.name()}")
#               break

#             shortest_comb_path = min(valid_pairs, key=lambda x: x[2])

#             if shortest_comb_path[2] < best_combined_path_length:
#                 best_combined_path_length = shortest_comb_path[2]
#                 best_subject_path = shortest_comb_path[0]
#                 best_object_path = shortest_comb_path[1]
#                 best_predicate = pred

#     return best_subject_path, best_object_path, best_predicate


# # ============================================================================
# # Display Functions
# # ============================================================================

# def show_path(label, path):
#     """Pretty print a path of synsets."""
#     if path:
#         print(f"{label}:")
#         print(" -> ".join(f"{s.name()} ({s.definition()})" for s in path))
#         print(f"Path length: {len(path)}")
#         print()
#     else:
#         print(f"{label}: No path found")
#         print()


# def show_connected_paths(subject_path, object_path, predicate):
#     """Display the connected paths with their shared predicate."""
#     if subject_path and object_path and predicate:
#         print("=" * 70)
#         print(f"CONNECTED PATH through predicate: {predicate.name()}")
#         print("=" * 70)

#         show_path("Subject -> Predicate path", subject_path)
#         show_path("Predicate -> Object path", object_path)

#         # Show the complete connected path
#         complete_path = subject_path + object_path[1:]  # Avoid duplicating the predicate
#         print("Complete connected path:")
#         print(" -> ".join(f"{s.name()}" for s in complete_path))
#         print(f"Total path length: {len(complete_path)}")
#         print()
#     else:
#         print("No connected path found through any predicate synset.")


# Testing

In [33]:
g = wn_to_nx()

## Example

In [40]:
# ------------------------------
# Example call of find_connected_paths
# ------------------------------
results = find_connected_paths(
    g=g,
    subject_word="dog",
    predicate_word="chase",
    object_word="cat",
    wn_module=wn,
    nlp_func=nlp,
    get_new_beams_fn=lambda G, s, t: get_new_beams_from_embeddings(
        G, s, t, wn_module=wn, model=embedding_model, beam_width=4
    ),
    top_k_branch_fn=lambda candidates, target, bw: top_k_branch_wrapper(
        candidates, target, beam_width=bw, model=embedding_model, wn_module=wn
    ),
    extract_subjects_from_gloss=extract_subjects_from_gloss,
    extract_objects_from_gloss=extract_objects_from_gloss,
    extract_verbs_from_gloss=extract_verbs_from_gloss,  # This parameter is now included
    beam_width=10,
    max_depth=10,
    max_self_intersection=5,
    max_results_per_pair=10,
    len_tolerance=5,
    relax_beam=True,
    max_sample_size=5  # Also added this parameter for controlling gloss token sampling
)

# Print the top results
for idx, r in enumerate(results[:10], start=1):
    print(f"Result #{idx}")
    print(" Predicate synset:", r["predicate_synset"])
    print(" Combined length:", r["combined_len"])
    print(" Combined cost:", r["combined_cost"])
    print(" Combined path:")
    print("  -> ".join(r["combined_path"]))
    print("-" * 60)

TypeError: find_connected_paths() got an unexpected keyword argument 'extract_verbs_from_gloss'

In [35]:
print(results)

[]


In [19]:
diagnostic_run(
    g=g,
    src_name="dog.n.01",
    tgt_name="cat.n.01",
    wn_module=wn,
    nlp_func=nlp,
    embedding_model=embedding_model,
    get_new_beams_fn_wrapped=lambda G,s,t: get_new_beams_from_embeddings(G,s,t, wn_module=wn, model=embedding_model, beam_width=6),
    build_gloss_seed_nodes_fn=lambda pred, wn_mod, nlpfn: build_gloss_seed_nodes_from_predicate(pred, wn_mod, nlpfn,
                                                                                              extract_subjects_fn=globals().get('extract_subjects_from_gloss'),
                                                                                              extract_objects_fn=globals().get('extract_objects_from_gloss'),
                                                                                              extract_verbs_fn=globals().get('extract_verbs_from_gloss'),
                                                                                              top_k_branch_fn=lambda cand, target, bw: top_k_branch_wrapper(cand, target, bw, model=embedding_model, wn_module=wn),
                                                                                              mode="subjects",
                                                                                              max_sample_size=6,
                                                                                              beam_width=6),
)


=== Diagnostic start ===
src: dog.n.01
tgt: cat.n.01

1) Node presence checks:
 g has src node? True
 g has tgt node? True

2) Connectivity check using networkx:
 Undirected connectivity between src and tgt? True
 Example undirected shortest path (hop count): 3 hops. path sample: ['dog.n.01', 'domestic_animal.n.01', 'domestic_cat.n.01', 'cat.n.01']

3) Node neighborhood:
 src out-neighbors (sample): ['canis.n.01', 'pack.n.06', 'flag.n.07', 'domestic_animal.n.01', 'canine.n.02', 'corgi.n.01', 'leonberg.n.01', 'cur.n.01', 'poodle.n.01', 'pug.n.01']
 tgt in-neighbors (sample): ['feline.n.01', 'domestic_cat.n.01', 'wildcat.n.03']

4) Embedding beam seeding (get_new_beams_fn):
 beams count: 6
[(('canine.n.02', 'hypernyms'),
  ('feline.n.01', 'hypernyms'),
  0.6772804990703681),
 (('domestic_animal.n.01', 'hypernyms'),
  ('domestic_cat.n.01', 'hyponyms'),
  0.6423323762382909),
 (('pooch.n.01', 'hyponyms'), ('feline.n.01', 'hypernyms'), 0.634097330718018),
 (('canine.n.02', 'hypernyms'),
  (

## Diagnostics

In [17]:
# Debug / diagnostic script
from pprint import pprint
import networkx as nx

def diagnostic_run(
    g, src_name, tgt_name, wn_module, nlp_func, embedding_model,
    get_new_beams_fn_wrapped, build_gloss_seed_nodes_fn,
    pairwise_class=PairwiseBidirectionalAStar
):
    print("=== Diagnostic start ===")
    print("src:", src_name)
    print("tgt:", tgt_name)
    print()

    # 1) Node presence
    print("1) Node presence checks:")
    print(" g has src node?", g.has_node(src_name))
    print(" g has tgt node?", g.has_node(tgt_name))
    if not g.has_node(src_name) or not g.has_node(tgt_name):
        print(" -> Node name mismatch likely. Check g.nodes() samples:")
        print("  sample nodes:", list(g.nodes)[:20])
        print("Stopping diagnostics early (node mismatch).")
        return

    # 2) Connectivity check (fast)
    print("\n2) Connectivity check using networkx:")
    try:
        has_path = nx.has_path(g.to_undirected(), src_name, tgt_name)
    except Exception:
        # graceful fallback for directed graphs if conversion fails
        has_path = nx.has_path(g, src_name, tgt_name) if nx.is_directed(g) else False
    print(" Undirected connectivity between src and tgt?", has_path)
    if not has_path:
        print(" -> Graph disconnected between src and tgt. Try increasing graph connectivity or check edges.")
    else:
        try:
            sp = nx.shortest_path(g.to_undirected(), src_name, tgt_name)
            print(" Example undirected shortest path (hop count):", len(sp)-1, "hops. path sample:", sp[:10])
        except Exception:
            pass

    # 3) print node degrees and first neighbors
    print("\n3) Node neighborhood:")
    try:
        nbrs_src = list(g.neighbors(src_name))[:10]
    except Exception:
        nbrs_src = []
    try:
        preds_tgt = list(g.predecessors(tgt_name))[:10]
    except Exception:
        preds_tgt = []
    print(" src out-neighbors (sample):", nbrs_src)
    print(" tgt in-neighbors (sample):", preds_tgt)

    # 4) Beam seeding via embeddings
    print("\n4) Embedding beam seeding (get_new_beams_fn):")
    try:
        beams = get_new_beams_fn_wrapped(g, src_name, tgt_name)
        print(" beams count:", len(beams))
        pprint(beams[:10])
    except Exception as e:
        print(" get_new_beams_fn raised exception:", e)
        beams = []

    # 5) Gloss seed extraction for predicate (if available)
    print("\n5) Gloss seeds (from predicate gloss):")
    try:
        # attempt to get a predicate synset if tgt is a predicate; otherwise try wn.synset(tgt)
        try:
            pred_syn = wn_module.synset(tgt_name)
        except Exception:
            pred_syn = None
        seeds = set()
        if pred_syn:
            seeds = build_gloss_seed_nodes_fn(pred_syn, wn_module, nlp_func)
        print(" gloss seeds count:", len(seeds))
        print(" seeds sample:", list(seeds)[:10])
    except Exception as e:
        print(" gloss seeding raised exception:", e)
        seeds = set()

    # 6) Run pairwise search with relaxed settings (wider beam, relax_beam True)
    print("\n6) Run PairwiseBidirectionalAStar with relaxed settings (beam_width=10, relax_beam=True, max_depth=10)")
    try:
        search = pairwise_class(
            g=g,
            src=src_name,
            tgt=tgt_name,
            get_new_beams_fn=lambda G, s, t: get_new_beams_fn_wrapped(G, s, t),
            gloss_seed_nodes=seeds,
            beam_width=10,
            max_depth=10,
            relax_beam=True
        )
        paths = search.find_paths(max_results=5, len_tolerance=2)
        print(" found paths (count):", len(paths))
        for p, cost in paths:
            print(" cost:", cost, " hops:", len(p)-1, " path:", p)
    except Exception as e:
        print(" pairwise search raised exception:", e)

    print("\n=== Diagnostic end ===")
