<a href="https://colab.research.google.com/github/IsaacFigNewton/SMIED/blob/main/BFS_Semantic_Decomposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Config

In [1]:
!pip install gensim

Collecting gensim
  Downloading gensim-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting numpy<2.0,>=1.18.5 (from gensim)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy<1.14.0,>=1.7.0 (from gensim)
  Downloading scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Downloading gensim-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.6/26.6 MB[0m [31m50.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
[2K   [90m━━━━━━━━━━━

## Import, config stuff

In [7]:
"""
Semantic decomposition of ("cat", "eats", "mouse") using WordNet + spaCy + depth-limited GBFS.
- Uses spaCy to parse verb synset glosses and detect subject/object dependencies.
- If both subject and object tokens are present, branches directly toward original triple synsets.
- Otherwise falls back to WordNet relations.
"""
from typing import Tuple, List, Dict, Optional
import nltk
import spacy
from nltk.corpus import wordnet as wn
from heapq import heappush, heappop
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import gensim.downloader as api
from collections import deque
import heapq
import itertools
from typing import Dict, List, Optional, Set, Tuple, Callable
import networkx as nx

In [2]:
nltk.download('wordnet')

# Load spaCy English model for dependency parsing
nlp = spacy.load("en_core_web_sm")

# Download required NLTK data (run once)
nltk.download('wordnet')
nltk.download('omw-1.4')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


True

In [3]:
embedding_model = api.load('word2vec-google-news-300')



# Helpers

In [24]:
def wn_to_nx():
    # Initialize directed graph
    G = nx.DiGraph()

    synset_rels = {
        # holonyms
        "part_holonyms": lambda x: x.part_holonyms(),
        "substance_holonyms": lambda x: x.substance_holonyms(),
        "member_holonyms": lambda x: x.member_holonyms(),

        # meronyms
        "part_meronyms": lambda x: x.part_meronyms(),
        "substance_meronyms": lambda x: x.substance_meronyms(),
        "member_meronyms": lambda x: x.member_meronyms(),

        # other
        "hypernyms": lambda x: x.hypernyms(),
        "hyponyms": lambda x: x.hyponyms(),
        "entailments": lambda x: x.entailments(),
        "causes": lambda x: x.causes(),
        "also_sees": lambda x: x.also_sees(),
        "verb_groups": lambda x: x.verb_groups(),
    }

    # add nodes (synsets) and all their edges (lexical relations) to the nx graph
    for synset in wn.all_synsets():
        for rel_name, rel_func in synset_rels.items():
            for target in rel_func(synset):
                G.add_edge(
                    synset.name(),
                    target.name(),
                    relation = rel_name[:-1]
                )
    return G

In [8]:
def get_all_neighbors(synset: wn.synset) -> List[wn.synset]:
    """Get all neighbors of a synset based on its POS."""
    neighbors = []

    # Add hypernyms and hyponyms
    neighbors.extend(synset.hypernyms())
    neighbors.extend(synset.hyponyms())

    # Add POS-specific neighbors
    if synset.pos() == 'n':
        neighbors.extend(get_noun_neighbors(synset))
    else:
        neighbors.extend(get_verb_neighbors(synset))

    return neighbors


def get_noun_neighbors(syn: wn.synset):
    """Get neighbors for a noun synset."""
    nbrs = set()
    nbrs.update(syn.part_meronyms())
    nbrs.update(syn.substance_meronyms())
    nbrs.update(syn.member_meronyms())
    nbrs.update(syn.part_holonyms())
    nbrs.update(syn.substance_holonyms())
    nbrs.update(syn.member_holonyms())
    return list(nbrs)


def get_verb_neighbors(syn: wn.synset):
    """Get neighbors for a verb synset."""
    nbrs = set()
    nbrs.update(syn.entailments())
    nbrs.update(syn.causes())
    nbrs.update(syn.also_sees())
    nbrs.update(syn.verb_groups())
    return list(nbrs)

# Embedding Helpers

In [36]:
def get_synset_embedding_centroid(synset:wn.synset, model=embedding_model):
    """
    Get the centroid (mean) of token embeddings for all lemmas in a synset.

    Args:
        synset: WordNet Synset object (e.g., 'dog.n.01')
        model: Loaded token model

    Returns:
        numpy array representing the centroid, or an empty numpy array if no lemmas found
    """
    try:
        # Get all lemma names from the synset
        lemmas = [lemma.name().lower().replace('_', ' ') for lemma in synset.lemmas()]
        # Collect embeddings for lemmas that exist in the model
        embeddings = []
        found_lemmas = []

        for lemma in lemmas:
            # Try the lemma as-is first
            if lemma in model:
                embeddings.append(model[lemma])
                found_lemmas.append(lemma)
            # Try with underscores replaced by spaces (for multi-word terms)
            elif lemma.replace(' ', '_') in model:
                embeddings.append(model[lemma.replace(' ', '_')])
                found_lemmas.append(lemma)
            # Try individual words if it's a multi-word term
            elif ' ' in lemma:
                words = lemma.split()
                word_embeddings = []
                for word in words:
                    if word in model:
                        word_embeddings.append(model[word])
                if word_embeddings:
                    # Average the embeddings of individual words
                    embeddings.append(np.mean(word_embeddings, axis=0))
                    found_lemmas.append(lemma)

        if not embeddings:
            # print(f"Warning: No lemmas from {synset.name()} found in token model")
            # print(f"  Attempted lemmas: {lemmas}")
            return np.array([])

        # print(f"Synset {synset.name()}: Found {len(found_lemmas)}/{len(lemmas)} lemmas in model")
        # print(f"  Found: {found_lemmas}")

        # Return the mean of all embeddings
        return np.mean(embeddings, axis=0)

    except Exception as e:
        print(f"Error processing synset {synset.name()}: {e}")
        return np.array([]) # Return empty array on error


def embed_lexical_relations(synset: wn.synset, model) -> Dict[str, List[Tuple[str, np.ndarray]]]:
    """
    Args:
        synset: WordNet Synset object (e.g., 'dog.n.01')

    Returns:
        Dict of embeddings for lexical relations of the synset.
    """
    def _rel_centroids(get_attr) -> List[Tuple[str, np.ndarray]]:
      try:
        centroids = [
            (s.name(), get_synset_embedding_centroid(s))
            for s in get_attr(synset)
        ]
        # Filter out empty arrays and return the mean
        return [c for c in centroids if c[1].size > 0]
      except Exception as e:
        print(f"Error processing relation for synset {synset.name()}: {e}")
        return np.array([])

    return {
        # holonyms
        "part_holonyms": _rel_centroids(lambda x: x.part_holonyms()),
        "substance_holonyms": _rel_centroids(lambda x: x.substance_holonyms()),
        "member_holonyms": _rel_centroids(lambda x: x.member_holonyms()),

        # meronyms
        "part_meronyms": _rel_centroids(lambda x: x.part_meronyms()),
        "substance_meronyms": _rel_centroids(lambda x: x.substance_meronyms()),
        "member_meronyms": _rel_centroids(lambda x: x.member_meronyms()),

        # other
        "hypernyms": _rel_centroids(lambda x: x.hypernyms()),
        "hyponyms": _rel_centroids(lambda x: x.hyponyms()),
        "entailments": _rel_centroids(lambda x: x.entailments()),
        "causes": _rel_centroids(lambda x: x.causes()),
        "also_sees": _rel_centroids(lambda x: x.also_sees()),
        "verb_groups": _rel_centroids(lambda x: x.verb_groups()),
    }


def get_embedding_similarities(
      rel_embs_1: List[Tuple[str, np.ndarray]],
      rel_embs_2: List[Tuple[str, np.ndarray]],
    ) -> np.ndarray:
    """
    get similarities of all possible pairings
      between elements of the asymmetric lexical categories
    """
    # convert the lists of centroid embeddings into rank 2 tensors
    # e1_rel_embs: shape (m, d)
    embs_1 = np.array([e[1] for e in rel_embs_1])
    # e2_rel_embs: shape (n, d)
    embs_2 = np.array([e[1] for e in rel_embs_2])

    # Normalize each embedding vector to unit length
    e1_norm = embs_1 / np.linalg.norm(embs_1, axis=1, keepdims=True)
    e2_norm = embs_2 / np.linalg.norm(embs_2, axis=1, keepdims=True)

    # Compute cosine similarities via dot product
    cosine_sims = np.dot(e1_norm, e2_norm.T)  # shape (m, n)
    return cosine_sims


def get_top_k_aligned_lex_rel_pairs(
      src_tgt_rel_map:Dict[str, List[Tuple[str, np.ndarray]]],
      src_emb_dict:Dict[str, List[Tuple[str, np.ndarray]]],
      tgt_emb_dict:Dict[str, List[Tuple[str, np.ndarray]]],
      model=embedding_model,
      beam_width:int=-1
    ) -> List[Tuple[
        Tuple[str, str],
        Tuple[str, str],
        float
    ]]:
    """
    Get the k most similar symmetric/asymmetric lexical relationship pairs
      based on the relations' associated synset embeddings.

    Returns a list of tuples of the form:
      (
        (synset1, lexical_rel),
        (synset2, lexical_rel),
        relatedness
      )
    """

    # check similarity of relations
    #   i.e. similarity of synset1's meronyms to synset2's meronyms
    rel_sims = list()
    for e1_rel, e2_rel in src_tgt_rel_map.items():
        # get embedding lists for e1_rel and e2_rel in the associated embedding dicts
        e1_rel_syn_embs = src_emb_dict.get(e1_rel)
        e2_rel_syn_embs = tgt_emb_dict.get(e2_rel)

        # if there are >0 embeddings in each list
        if len(e1_rel_syn_embs) > 0 and len(e2_rel_syn_embs) > 0:
            # get similarities of all possible pairings
            sims = get_embedding_similarities(e1_rel_syn_embs, e2_rel_syn_embs)
            # add tuples relating the base synset to the neighbors
            #   and their similarities to the asymm_rel_sims list
            for i, j in np.ndindex(sims.shape):
                try:
                  rel_sims.append(
                      (
                          (e1_rel_syn_embs[i][0], e1_rel),
                          (e2_rel_syn_embs[j][0], e2_rel),
                          sims[i, j]
                      )
                  )
                except IndexError:
                  raise IndexError(f"IndexError: i={i}, j={j} \nShape of sims: {sims.shape} \nLength of e1_rel_syn_embs: {len(e1_rel_syn_embs)} \nLength of e2_rel_syn_embs: {len(e2_rel_syn_embs)}")

    # it doesn't matter which one is contains/includes/etc. the other,
    #   as long as they're closer than antonyms or unrelated terms
    # i.e. a good hyponym-hypernym pair is just as important
    #   as a good hypernym-hyponym pair
    beam = sorted(
        rel_sims,
        key=lambda x: x[2],
        reverse=True
    )[:beam_width]
    return beam

# Beam Construction

In [21]:
asymmetric_pairs_map = {
    # holonyms
    "part_holonyms": "part_meronyms",
    "substance_holonyms": "substance_meronyms",
    "member_holonyms": "member_meronyms",

    # meronyms
    "part_meronyms": "part_holonyms",
    "substance_meronyms": "substance_holonyms",
    "member_meronyms": "member_holonyms",

    # other
    "hypernyms": "hyponyms",
    "hyponyms": "hyponyms"
}


symmetric_pairs_map = {
    # holonyms
    "part_holonyms": "part_holonyms",
    "substance_holonyms": "substance_holonyms",
    "member_holonyms": "member_holonyms",

    # meronyms
    "part_meronyms": "part_meronyms",
    "substance_meronyms": "substance_meronyms",
    "member_meronyms": "member_meronyms",

    # other
    "hypernyms": "hypernyms",
    "hyponyms": "hyponyms",
    "entailments": "entailments",
    "causes": "causes",
    "also_sees": "also_sees",
    "verb_groups": "verb_groups"
}

In [28]:
def get_new_beams(
      g: nx.DiGraph,
      src: str,
      tgt: str,
      model=embedding_model,
      beam_width=3
    ) -> List[Tuple[
        Tuple[str, str],
        Tuple[str, str],
        float
    ]]:
    """
    Get the k closest pairs of lexical relations between 2 synsets.

    Args:
        src: WordNet Synset object (e.g., 'dog.n.01')
        tgt: WordNet Synset object (e.g., 'cat.n.01')
        model: token model (if None, will load default)
        beam_width: max number of pairs to return

    Returns:
        List of tuples of the form:
          (
            (synset1, lexical_rel),
            (synset2, lexical_rel),
            relatedness
          )
    """

    # Build a map of each synset's associated lexical relations
    #   and the centroids of their associated synsets
    src_lex_rel_embs = embed_lexical_relations(wn.synset(src), model)
    tgt_lex_rel_embs = embed_lexical_relations(wn.synset(tgt), model)

    # ensure the edges in the nx graph align with those in the embedding maps
    src_neighbors = {n for n in g.neighbors(src)}
    for rel, synset_list in src_lex_rel_embs.items():
      if not all(s[0] in src_neighbors for s in synset_list):
        raise ValueError(f"Not all lexical properties of {src} ({[s[0] for s in synset_list]}) in graph for relation {rel}")
    tgt_neighbors = {n for n in g.neighbors(tgt)}
    for rel, synset_list in tgt_lex_rel_embs.items():
      if not all(s[0] in tgt_neighbors for s in synset_list):
        raise ValueError(f"Not all lexical properties of {tgt} ({[s[0] for s in synset_list]}) in graph for relation {rel}")
    # in the future, get neighbor relation in node metadata with g.adj[n]

    # Get the asymmetric lexical relation pairings,
    #   sorted in descending order of embedding similarity
    #   e.x. similarity of synset1's hypernyms to synset2's hypernyms
    asymm_lex_rel_sims = get_top_k_aligned_lex_rel_pairs(
        asymmetric_pairs_map,
        src_lex_rel_embs,
        tgt_lex_rel_embs,
        model,
        beam_width
    )
    # Get the symmetric lexical relation pairings,
    #   sorted in descending order of embedding similarity
    #   e.x. similarity of synset1's hypernyms to synset2's hypernyms
    symm_lex_rel_sims = get_top_k_aligned_lex_rel_pairs(
        symmetric_pairs_map,
        src_lex_rel_embs,
        tgt_lex_rel_embs,
        model,
        beam_width
    )
    combined = asymm_lex_rel_sims + symm_lex_rel_sims
    beam = sorted(combined, key=lambda x: x[2], reverse=True)[:beam_width]
    return beam

# Pathing

In [14]:
# Type aliases
BeamElement = Tuple[Tuple[str, str], Tuple[str, str], float]
GetNewBeamsFn = Callable[[nx.DiGraph, str, str], List[BeamElement]]


class BidirectionalAStar:
    def __init__(
        self,
        g: nx.DiGraph,
        src: str,
        tgt: str,
        get_new_beams_fn: GetNewBeamsFn = get_new_beams,
        beam_width: int = 10,
        relax_beam: bool = False,
    ):
        """
        Beam-constrained bidirectional A* encapsulated in a class.

        Args:
            g: directed graph (nx.DiGraph)
            src: source node id
            tgt: target node id
            get_new_beams_fn: function(g, src, tgt) -> List[((src_node, rel),(tgt_node, rel), sim)]
            beam_width: passed through to get_new_beams if that function uses it internally
            relax_beam: if True, allow exploring nodes outside the beam sets
        """
        self.g = g
        self.src = src
        self.tgt = tgt
        self.get_new_beams_fn = get_new_beams_fn
        self.beam_width = beam_width
        self.relax_beam = relax_beam

        # Will be filled during setup/search
        self.src_allowed: Set[str] = set()
        self.tgt_allowed: Set[str] = set()
        self.h_f: Dict[str, float] = {}
        self.h_b: Dict[str, float] = {}

        # Search state
        self.g_f: Dict[str, float] = {}
        self.g_b: Dict[str, float] = {}
        self.parent_f: Dict[str, Optional[str]] = {}
        self.parent_b: Dict[str, Optional[str]] = {}
        self.open_f = []
        self.open_b = []
        self.closed_f: Set[str] = set()
        self.closed_b: Set[str] = set()
        self._counter = itertools.count()

        # Results
        self.meet_node: Optional[str] = None
        self.path_cost: Optional[float] = None

    # ----------------------
    # Setup / heuristics
    # ----------------------
    def build_allowed_sets_and_heuristics(self):
        """
        Calls get_new_beams_fn and builds:
          - src_allowed, tgt_allowed (sets of node ids)
          - h_f, h_b : heuristic maps (node -> float)
        """
        beams = self.get_new_beams_fn(self.g, self.src, self.tgt)
        # beams are expected to be [((src_node, rel),(tgt_node, rel), sim), ...]
        src_pairs = [b[0] for b in beams]
        tgt_pairs = [b[1] for b in beams]

        self.src_allowed = {p[0] for p in src_pairs}
        self.tgt_allowed = {p[0] for p in tgt_pairs}
        # always allow src and tgt explicitly
        self.src_allowed.add(self.src)
        self.tgt_allowed.add(self.tgt)

        self.h_f = {}
        self.h_b = {}
        for s_pair, t_pair, sim in beams:
            s_node = s_pair[0]
            t_node = t_pair[0]
            h_val = max(0.0, 1.0 - float(sim))
            if s_node not in self.h_f or h_val < self.h_f[s_node]:
                self.h_f[s_node] = h_val
            if t_node not in self.h_b or h_val < self.h_b[t_node]:
                self.h_b[t_node] = h_val

    # ----------------------
    # Initialization
    # ----------------------
    def _init_search_state(self):
        """Initialize open queues, g-scores, parents, closed sets and counters."""
        self._counter = itertools.count()
        self.open_f = []
        self.open_b = []
        self.g_f = {self.src: 0.0}
        self.g_b = {self.tgt: 0.0}
        self.parent_f = {self.src: None}
        self.parent_b = {self.tgt: None}
        self.closed_f = set()
        self.closed_b = set()

        heapq.heappush(self.open_f, (self.h_f.get(self.src, 0.0), next(self._counter), self.src))
        heapq.heappush(self.open_b, (self.h_b.get(self.tgt, 0.0), next(self._counter), self.tgt))

    # ----------------------
    # Utilities
    # ----------------------
    def _edge_weight(self, u: str, v: str) -> float:
        data = self.g.get_edge_data(u, v, default={})
        try:
            return float(data.get("weight", 1.0))
        except Exception:
            return 1.0

    def _allowed_forward(self, node: str) -> bool:
        return self.relax_beam or node in self.src_allowed or node in self.tgt_allowed or node == self.tgt or node == self.src

    def _allowed_backward(self, node: str) -> bool:
        return self.relax_beam or node in self.tgt_allowed or node in self.src_allowed or node == self.src or node == self.tgt

    # ----------------------
    # Expansion methods
    # ----------------------
    def _expand_forward(self) -> Optional[str]:
        """
        Pop one entry from open_f, expand its outgoing neighbors.
        Return meeting node if found (node in closed_b), else None.
        """
        while self.open_f:
            _, _, current = heapq.heappop(self.open_f)
            if current in self.closed_f:
                continue
            self.closed_f.add(current)

            if current in self.closed_b:
                return current

            for nbr in self.g.neighbors(current):
                if not self._allowed_forward(nbr):
                    continue
                tentative = self.g_f[current] + self._edge_weight(current, nbr)
                if tentative < self.g_f.get(nbr, float("inf")):
                    self.g_f[nbr] = tentative
                    self.parent_f[nbr] = current
                    f_score = tentative + self.h_f.get(nbr, 0.0)
                    heapq.heappush(self.open_f, (f_score, next(self._counter), nbr))
                    if nbr in self.closed_b:
                        return nbr
            # we expanded one node; break to let the other side move if needed
            return None
        return None

    def _expand_backward(self) -> Optional[str]:
        """
        Pop one entry from open_b, expand its predecessors (backward search).
        Return meeting node if found, else None.
        """
        while self.open_b:
            _, _, current = heapq.heappop(self.open_b)
            if current in self.closed_b:
                continue
            self.closed_b.add(current)

            if current in self.closed_f:
                return current

            for nbr in self.g.predecessors(current):
                if not self._allowed_backward(nbr):
                    continue
                tentative = self.g_b[current] + self._edge_weight(nbr, current)
                if tentative < self.g_b.get(nbr, float("inf")):
                    self.g_b[nbr] = tentative
                    self.parent_b[nbr] = current
                    f_score = tentative + self.h_b.get(nbr, 0.0)
                    heapq.heappush(self.open_b, (f_score, next(self._counter), nbr))
                    if nbr in self.closed_f:
                        return nbr
            return None
        return None

    # ----------------------
    # Path reconstruction
    # ----------------------
    def _reconstruct_path(self, meet: str) -> List[str]:
        """
        Build path: src -> ... -> meet -> ... -> tgt using parent_f and parent_b.
        parent_f maps node -> predecessor (toward src),
        parent_b maps node -> successor (toward tgt).
        """
        # forward: src ... meet
        path_f: List[str] = []
        n = meet
        while n is not None:
            path_f.append(n)
            n = self.parent_f.get(n)
        path_f.reverse()  # src -> ... -> meet

        # backward: nodes after meet toward tgt
        path_b: List[str] = []
        n = self.parent_b.get(meet)
        while n is not None:
            path_b.append(n)
            n = self.parent_b.get(n)

        return path_f + path_b

    # ----------------------
    # Main search
    # ----------------------
    def search(self) -> Optional[List[str]]:
        """
        Run the bidirectional beam-constrained A*. Returns path (list of node ids) or None.
        Also sets .meet_node and .path_cost when a path is found.
        """
        if self.src == self.tgt:
            self.meet_node = self.src
            self.path_cost = 0.0
            return [self.src]

        # prepare heuristics and allowed sets
        self.build_allowed_sets_and_heuristics()
        self._init_search_state()

        while self.open_f and self.open_b:
            top_f = self.open_f[0][0] if self.open_f else float("inf")
            top_b = self.open_b[0][0] if self.open_b else float("inf")

            if top_f <= top_b:
                meet = self._expand_forward()
            else:
                meet = self._expand_backward()

            if meet is not None:
                self.meet_node = meet
                # compute path and cost
                path = self._reconstruct_path(meet)
                cost_f = self.g_f.get(meet, float("inf"))
                cost_b = self.g_b.get(meet, float("inf"))
                # When meet is not exactly same node in both maps, attempt to compute more precise cost:
                # if there exists a node in both g_f and g_b, we could choose minimum g_f[n] + g_b[n] across intersection.
                # For now use the meeting node costs as reported.
                self.path_cost = None if (cost_f == float("inf") or cost_b == float("inf")) else (cost_f + cost_b)
                return path

        # no path found
        return None


In [15]:
# # ============================================================================
# # Core Path Finding Functions
# # ============================================================================

# def path_syn_to_syn(start_synset, end_synset, max_depth=6):
#     """
#     Find shortest path between synsets of the same POS using bidirectional BFS.
#     Returns a list of synsets forming the path, or None if no path found.
#     """

#     if not (start_synset.pos() == end_synset.pos() and start_synset.pos() in {'n', 'v'}):
#       raise ValueError(f"{start_synset.name()} POS tag != {end_synset.name()}. Synsets must be of the same POS (noun or verb).")

#     # Handle the trivial case where start and end are the same
#     if start_synset.name() == end_synset.name():
#         return [start_synset]

#     # Initialize two search frontiers
#     forward_queue = deque([(start_synset, 0)])
#     forward_visited = {start_synset.name(): [start_synset]}

#     backward_queue = deque([(end_synset, 0)])
#     backward_visited = {end_synset.name(): [end_synset]}

#     def expand_frontier(queue, visited_from_this_side, visited_from_other_side, is_forward):
#         """Expand one step of the search frontier."""
#         if not queue:
#             return None

#         curr_synset, depth = queue.popleft()

#         if depth >= (max_depth + 1) // 2:
#             return None

#         path_to_current = visited_from_this_side[curr_synset.name()]

#         for neighbor in get_all_neighbors(curr_synset):
#             neighbor_name = neighbor.name()

#             if neighbor_name in visited_from_this_side:
#                 continue

#             if is_forward:
#                 new_path = path_to_current + [neighbor]
#             else:
#                 new_path = [neighbor] + path_to_current

#             if neighbor_name in visited_from_other_side:
#                 other_path = visited_from_other_side[neighbor_name]

#                 if is_forward:
#                     full_path = path_to_current + other_path
#                 else:
#                     full_path = other_path + path_to_current

#                 return full_path

#             visited_from_this_side[neighbor_name] = new_path
#             queue.append((neighbor, depth + 1))

#         return None

#     # Alternate between forward and backward search
#     while forward_queue or backward_queue:
#         if forward_queue:
#             result = expand_frontier(forward_queue, forward_visited, backward_visited, True)
#             if result:
#                 return result

#         if backward_queue:
#             result = expand_frontier(backward_queue, backward_visited, forward_visited, False)
#             if result:
#                 return result

#     return None


# # ============================================================================
# # Gloss Analysis Helper Functions
# # ============================================================================

# def extract_subjects_from_gloss(gloss_doc):
#     """Extract subject tokens from a parsed gloss."""
#     subjects = []

#     # Direct subjects
#     subjects.extend([tok for tok in gloss_doc if tok.dep_ == "nsubj"])

#     # Passive subjects (which are actually objects semantically)
#     # Skip these for actor identification
#     passive_subjects = [tok for tok in gloss_doc if tok.dep_ == "nsubjpass"]

#     # Filter out passive subjects from the main list
#     subjects = [s for s in subjects if s not in passive_subjects]

#     return subjects, passive_subjects


# def extract_objects_from_gloss(gloss_doc):
#     """Extract various types of object tokens from a parsed gloss."""
#     objs = []

#     # Indirect objects
#     iobjs = [tok for tok in gloss_doc if tok.dep_ == "iobj"]
#     objs.extend(iobjs)

#     # Direct objects
#     # Only include if there were no indirect objects,
#     #   crude, but good for MVP
#     if not iobjs:
#         objs.extend([tok for tok in gloss_doc if tok.dep_ == "dobj"])

#     # Prepositional objects
#     objs.extend([tok for tok in gloss_doc if tok.dep_ == "pobj"])

#     # General objects
#     objs.extend([tok for tok in gloss_doc if tok.dep_ == "obj"])

#     # Check for noun chunks related to root verb
#     root_verbs = [tok for tok in gloss_doc if tok.dep_ == "ROOT" and tok.pos_ == "VERB"]
#     if root_verbs and not objs:
#         for noun_chunk in gloss_doc.noun_chunks:
#             if any(token.head == root_verbs[0] for token in noun_chunk):
#                 objs.append(noun_chunk.root)

#     return objs


# def extract_verbs_from_gloss(gloss_doc, include_passive=False):
#     """Extract verb tokens from a parsed gloss."""
#     verbs = [tok for tok in gloss_doc if tok.pos_ == "VERB"]

#     if include_passive:
#         # Past participles used as adjectives or in relative clauses
#         passive_verbs = [tok for tok in gloss_doc if
#                         tok.tag_ in ["VBN", "VBD"] and
#                         tok.dep_ in ["acl", "relcl", "amod"]]
#         verbs.extend(passive_verbs)

#     return verbs


# def find_instrumental_verbs(gloss_doc):
#     """Find verbs associated with instrumental use (e.g., 'used for')."""
#     instrumental_verbs = []

#     if "used" in gloss_doc.text.lower():
#         for i, token in enumerate(gloss_doc):
#             if token.text.lower() == "used":
#                 # Check tokens after "used"
#                 for j in range(i+1, min(i+4, len(gloss_doc))):
#                     if gloss_doc[j].pos_ == "VERB":
#                         instrumental_verbs.append(gloss_doc[j])

#     return instrumental_verbs


# # ============================================================================
# # Cross-POS Path Finding Functions
# # ============================================================================
# def get_top_k_synset_branch_pairs(
#       candidates: List[List[wn.synset]],
#       target_synset: wn.synset,
#       beam_width=3
#     ) -> List[Tuple[
#         Tuple[wn.synset, str],
#         Tuple[wn.synset, str],
#         float
#     ]]:
#     """
#     Given a list of candidate tokens and a target synset,


#     Return the k synset subrelation pairs most similar to the target of the form:
#       ((synset, lexical_rel), (name, lexical_rel), relatedness)
#     """
#     top_k_asymm_branches = list()
#     top_k_symm_branches = list()
#     beam = list()
#     # for each list of possible synsets for a candidate token
#     for synsets in candidates:
#         # # filter to subjects based on whether they reside in the same sub-category
#         # #   where the subcategory != 'entity.n.01' or a similar top-level
#         # synsets = [
#         #     s for s in synsets
#         #     if s.root_hypernyms() != s.lowest_common_hypernyms(target_synset)
#         # ]
#         # # if there are synsets left for the candidate token after pruning
#         if synsets:
#             # # if the target is a verb,
#             # #   filter out any synsets with no lemma frames matching the target
#             # #   frame patterns: (Somebody [v] something), (Somebody [v]), ...
#             # if target_synset.pos() == 'v':
#             #     synsets = [
#             #         s for s in synsets
#             #         if any(
#             #             frame in s.frame_ids()
#             #             for frame in target_synset.frame_ids()
#             #         )
#             #     ]
#             for synset in synsets:
#               beam += get_synset_relatedness(synset, target_synset)
#               beam += get_synset_relatedness(synset, target_synset)
#     beam = sorted(
#         beam,
#         key=lambda x: x[2],
#         reverse=True
#     )[:beam_width]
#     return beam


# def find_subject_to_predicate_path(
#       subject_synset: wn.synset,
#       predicate_synset: wn.synset,
#       max_depth:int,
#       visited=set(),
#       max_sample_size=5,
#     ):
#     """Find path from subject (noun) to predicate (verb)."""
#     if subject_synset.name() in visited or predicate_synset.name() in visited:
#       return None

#     paths = []
#     print()
#     print(f"Finding path from {subject_synset.name()} to {predicate_synset.name()}")

#     # Strategy 1: Look for active subjects in verb's gloss
#     pred_gloss_doc = nlp(predicate_synset.definition())
#     # passive subjects are semantically equivalent to objects
#     active_subjects, _ = extract_subjects_from_gloss(pred_gloss_doc)
#     # convert spacy tokens to lists of synsets
#     subjects = [wn.synsets(s.text, pos=subject_synset.pos()) for s in active_subjects]
#     # of the remaining subjects, get the most similar
#     top_k = get_top_k_synset_branches(active_subjects[:max_sample_size], subject_synset)
#     if top_k:
#       print(f"Found best matches for {subject_synset.name()}: {top_k} using strategy 1")
#       for matched_synset, _ in top_k:
#         path = path_syn_to_syn(subject_synset, matched_synset, max_depth-1)
#         if path:
#             paths.append(path + [predicate_synset])

#     # Strategy 2: Look for verbs in the noun's gloss
#     subj_gloss_doc = nlp(subject_synset.definition())
#     verbs = extract_verbs_from_gloss(subj_gloss_doc, include_passive=False)
#     # convert spacy tokens to lists of synsets
#     verbs = [wn.synsets(v.text, pos=predicate_synset.pos()) for v in verbs]
#     # of the remaining subjects, get the most similar
#     top_k = get_top_k_synset_branches(verbs[:max_sample_size], predicate_synset)
#     if top_k:
#       print(f"Found best matches for {predicate_synset.name()}: {top_k} using strategy 2")
#       for matched_synset, _ in top_k:
#         path = path_syn_to_syn(matched_synset, predicate_synset, max_depth-1)
#         if path:
#             paths.append([subject_synset] + path)

#     # Strategy 3: Explore the 3 most promising pairs of neighbors
#     subject_neighbors = get_all_neighbors(subject_synset)
#     predicate_neighbors = get_all_neighbors(predicate_synset)
#     top_k = get_k_closest_synset_pairs(subject_neighbors, predicate_neighbors)
#     if top_k:
#       print(f"Most promising pairs for bidirectional exploration: {top_k} using strategy 3")
#       for s, p, _ in top_k:
#         visited.add(subject_synset.name())
#         visited.add(predicate_synset.name())
#         path = find_subject_to_predicate_path(s, p, max_depth-1, visited)
#         if path:
#             paths.append([subject_synset] + path + [predicate_synset])


#     # Return shortest path if any found
#     return min(paths, key=len) if paths else None


# def find_predicate_to_object_path(
#       predicate_synset: wn.synset,
#       object_synset: wn.synset,
#       max_depth:int,
#       visited=set(),
#       max_sample_size=5,
#     ):
#     """Find path from predicate (verb) to object (noun)."""

#     if predicate_synset.name() in visited or object_synset.name() in visited:
#       return None

#     paths = []
#     print()
#     print(f"Finding path from {predicate_synset.name()} to {object_synset.name()}")

#     # === Strategy 1: Objects in predicate gloss (incl. passive subjects) ===
#     pred_gloss_doc = nlp(predicate_synset.definition())
#     objects = extract_objects_from_gloss(pred_gloss_doc)
#     _, passive_subjects = extract_subjects_from_gloss(pred_gloss_doc)
#     objects.extend(passive_subjects)
#     # convert spacy tokens to lists of synsets
#     objects = [wn.synsets(o.text, pos=object_synset.pos()) for o in objects]
#     top_k = get_top_k_synset_branches(objects[:max_sample_size], object_synset)
#     if top_k:
#       print(f"Found best matches for {object_synset.name()}: {top_k} using strategy 1")
#       for matched_synset, _ in top_k:
#         path = path_syn_to_syn(matched_synset, object_synset, max_depth-1)
#         if path:
#             paths.append([predicate_synset] + path)

#     # === Strategy 2: Verbs in object's gloss ===
#     obj_gloss_doc = nlp(object_synset.definition())
#     verbs = extract_verbs_from_gloss(obj_gloss_doc, include_passive=True)
#     # Use instrumental verbs in object's gloss as backup
#     verbs.extend(find_instrumental_verbs(obj_gloss_doc))
#     # convert spacy tokens to lists of synsets
#     verbs = [wn.synsets(v.text, pos=predicate_synset.pos()) for v in verbs]
#     top_k = get_top_k_synset_branches(verbs[:max_sample_size], predicate_synset)
#     if top_k:
#       print(f"Found best matches for {predicate_synset.name()}: {top_k} using strategy 2")
#       for matched_synset, _ in top_k:
#         path = path_syn_to_syn(predicate_synset, matched_synset, max_depth-1)
#         if path:
#             paths.append(path + [object_synset])

#     # Strategy 3: Explore the 3 most promising neighbors
#     predicate_neighbors = get_all_neighbors(predicate_synset)
#     object_neighbors = get_all_neighbors(object_synset)
#     top_k = get_k_closest_synset_pairs(predicate_neighbors, object_neighbors)
#     if top_k:
#       print(f"Most promising pairs for bidirectional exploration: {top_k} using strategy 3")
#       for p, o, _ in top_k:
#         visited.add(predicate_synset.name())
#         visited.add(object_synset.name())
#         path = find_predicate_to_object_path(p, o, max_depth-1, visited)
#         if path:
#             paths.append([predicate_synset] + path + [object_synset])


#     # Return shortest path if any found
#     return min(paths, key=len) if paths else None


# # ============================================================================
# # Main Connected Path Finding Function
# # ============================================================================

# def find_connected_shortest_paths(
#       subject_word,
#       predicate_word,
#       object_word,
#       max_depth=10,
#       max_self_intersection=5
#     ):
#     """
#     Find shortest connected paths from subject through predicate to object.
#     Ensures that the same predicate synset connects both paths.
#     """

#     # Get synsets for each word
#     subject_synsets = wn.synsets(subject_word, pos=wn.NOUN)
#     predicate_synsets = wn.synsets(predicate_word, pos=wn.VERB)
#     object_synsets = wn.synsets(object_word, pos=wn.NOUN)

#     best_combined_path_length = float('inf')
#     best_subject_path = None
#     best_object_path = None
#     best_predicate = None

#     # Try each predicate synset as the connector
#     for pred in predicate_synsets:
#         # Find paths from all subjects to this specific predicate
#         subject_paths = []
#         for subj in subject_synsets:
#             path = find_subject_to_predicate_path(subj, pred, max_depth)
#             if path:
#                 subject_paths.append(path)

#         # Find paths from this specific predicate to all objects
#         object_paths = []
#         for obj in object_synsets:
#             path = find_predicate_to_object_path(pred, obj, max_depth)
#             if path:
#                 object_paths.append(path)

#         # If we have both paths through this predicate, check if it's the best
#         if subject_paths and object_paths:
#             # find pairs of paths that don't intersect with eachother
#             #   i.e. burglar > break_in > attack > strike > shoot > strike > attack > woman
#             #   would not be allowed, since tautological statements are uninformative
#             valid_pairs = list()
#             for subj_path in subject_paths:
#               for obj_path in object_paths:
#                 if len(set(subj_path).intersection(set(obj_path))) <= max_self_intersection:
#                   valid_pairs.append((
#                       subj_path,
#                       obj_path,
#                       # Calculate combined length (subtract 1 to avoid counting predicate twice)
#                       len(subj_path) + len(obj_path) - 1
#                   ))

#             if not valid_pairs:
#               print(f"No valid pairs of subj, obj paths found for {pred.name()}")
#               break

#             shortest_comb_path = min(valid_pairs, key=lambda x: x[2])

#             if shortest_comb_path[2] < best_combined_path_length:
#                 best_combined_path_length = shortest_comb_path[2]
#                 best_subject_path = shortest_comb_path[0]
#                 best_object_path = shortest_comb_path[1]
#                 best_predicate = pred

#     return best_subject_path, best_object_path, best_predicate


# # ============================================================================
# # Display Functions
# # ============================================================================

# def show_path(label, path):
#     """Pretty print a path of synsets."""
#     if path:
#         print(f"{label}:")
#         print(" -> ".join(f"{s.name()} ({s.definition()})" for s in path))
#         print(f"Path length: {len(path)}")
#         print()
#     else:
#         print(f"{label}: No path found")
#         print()


# def show_connected_paths(subject_path, object_path, predicate):
#     """Display the connected paths with their shared predicate."""
#     if subject_path and object_path and predicate:
#         print("=" * 70)
#         print(f"CONNECTED PATH through predicate: {predicate.name()}")
#         print("=" * 70)

#         show_path("Subject -> Predicate path", subject_path)
#         show_path("Predicate -> Object path", object_path)

#         # Show the complete connected path
#         complete_path = subject_path + object_path[1:]  # Avoid duplicating the predicate
#         print("Complete connected path:")
#         print(" -> ".join(f"{s.name()}" for s in complete_path))
#         print(f"Total path length: {len(complete_path)}")
#         print()
#     else:
#         print("No connected path found through any predicate synset.")


# Testing

In [25]:
g = wn_to_nx()

In [59]:
wn.synset("person.n.01").hypernyms()

[Synset('causal_agent.n.01'), Synset('organism.n.01')]

In [64]:
solver = BidirectionalAStar(
    g,
    "dog.n.01",
    "one.n.02",
    get_new_beams_fn=get_new_beams,
    relax_beam=True
)
path = solver.search()
if path:
    print("found path:", path)
    print("meet:", solver.meet_node, "cost:", solver.path_cost)
else:
    print("no path found")

found path: ['dog.n.01', 'pack.n.06', 'animal_group.n.01', 'biological_group.n.01', 'group.n.01', 'abstraction.n.06', 'measure.n.02', 'playing_period.n.01', 'chukker.n.01', 'part.n.09', 'whole.n.01', 'unit.n.04', 'one.n.02']
meet: playing_period.n.01 cost: 12.0
