<a href="https://colab.research.google.com/github/IsaacFigNewton/SMIED/blob/adding-semantic-decomposition/BeamSemantic_Decomposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Config

## Install dependencies

## Import, config stuff

In [1]:
"""
Semantic decomposition of ("cat", "eats", "mouse") using WordNet + spaCy + depth-limited GBFS.
- Uses spaCy to parse verb synset glosses and detect subject/object dependencies.
- If both subject and object tokens are present, branches directly toward original triple synsets.
- Otherwise falls back to WordNet relations.
"""
import nltk
import spacy
from nltk.corpus import wordnet as wn
from heapq import heappush, heappop
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import gensim.downloader as api
from collections import deque
import heapq
import itertools
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple
import networkx as nx

In [2]:
nltk.download('wordnet')

# Load spaCy English model for dependency parsing
nlp = spacy.load("en_core_web_sm")

# Download required NLTK data (run once)
nltk.download('wordnet')
nltk.download('omw-1.4')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


True

In [3]:
embedding_model = api.load('word2vec-google-news-300')



In [4]:
# Type aliases
SynsetName = str  # e.g., "dog.n.01"
Path = List[SynsetName]
BeamElement = Tuple[Tuple[SynsetName, str], Tuple[SynsetName, str], float]
GetNewBeamsFn = Callable[[nx.DiGraph, SynsetName, SynsetName], List[BeamElement]]
TopKBranchFn = Callable[[List[List], object, int], List[BeamElement]]

# Helpers

In [5]:
def wn_to_nx():
    # Initialize directed graph
    G = nx.DiGraph()

    synset_rels = {
        # holonyms
        "part_holonyms": lambda x: x.part_holonyms(),
        "substance_holonyms": lambda x: x.substance_holonyms(),
        "member_holonyms": lambda x: x.member_holonyms(),

        # meronyms
        "part_meronyms": lambda x: x.part_meronyms(),
        "substance_meronyms": lambda x: x.substance_meronyms(),
        "member_meronyms": lambda x: x.member_meronyms(),

        # other
        "hypernyms": lambda x: x.hypernyms(),
        "hyponyms": lambda x: x.hyponyms(),
        "entailments": lambda x: x.entailments(),
        "causes": lambda x: x.causes(),
        "also_sees": lambda x: x.also_sees(),
        "verb_groups": lambda x: x.verb_groups(),
    }

    # add nodes (synsets) and all their edges (lexical relations) to the nx graph
    for synset in wn.all_synsets():
        for rel_name, rel_func in synset_rels.items():
            for target in rel_func(synset):
                G.add_edge(
                    synset.name(),
                    target.name(),
                    relation = rel_name[:-1]
                )
    return G

In [6]:
def get_all_neighbors(synset: wn.synset) -> List[wn.synset]:
    """Get all neighbors of a synset based on its POS."""
    neighbors = []

    # Add hypernyms and hyponyms
    neighbors.extend(synset.hypernyms())
    neighbors.extend(synset.hyponyms())

    # Add POS-specific neighbors
    if synset.pos() == 'n':
        neighbors.extend(get_noun_neighbors(synset))
    else:
        neighbors.extend(get_verb_neighbors(synset))

    return neighbors


def get_noun_neighbors(syn: wn.synset):
    """Get neighbors for a noun synset."""
    nbrs = set()
    nbrs.update(syn.part_meronyms())
    nbrs.update(syn.substance_meronyms())
    nbrs.update(syn.member_meronyms())
    nbrs.update(syn.part_holonyms())
    nbrs.update(syn.substance_holonyms())
    nbrs.update(syn.member_holonyms())
    return list(nbrs)


def get_verb_neighbors(syn: wn.synset):
    """Get neighbors for a verb synset."""
    nbrs = set()
    nbrs.update(syn.entailments())
    nbrs.update(syn.causes())
    nbrs.update(syn.also_sees())
    nbrs.update(syn.verb_groups())
    return list(nbrs)

# Embedding Helpers

In [7]:
# ============================================================================
# Embedding-based Helper Functions
# ============================================================================

def get_synset_embedding_centroid(synset, model) -> np.ndarray:
    """
    Given a wn.synset, compute centroid (mean) of embeddings for lemmas.
    Returns empty np.array if nothing found.
    """
    try:
        lemmas = [lemma.name().lower().replace("_", " ") for lemma in synset.lemmas()]
        embeddings = []
        for lemma in lemmas:
            if lemma in model:
                embeddings.append(np.asarray(model[lemma], dtype=float))
            elif lemma.replace(" ", "_") in model:
                embeddings.append(np.asarray(model[lemma.replace(" ", "_")], dtype=float))
            elif " " in lemma:
                # try individual words
                words = lemma.split()
                word_embs = [np.asarray(model[w], dtype=float) for w in words if w in model]
                if word_embs:
                    embeddings.append(np.mean(word_embs, axis=0))
        if not embeddings:
            return np.array([])  # empty
        return np.mean(embeddings, axis=0)
    except Exception as e:
        # defensive: return empty arr on any error
        return np.array([])


def embed_lexical_relations(synset, model) -> Dict[str, List[Tuple[SynsetName, np.ndarray]]]:
    """
    Return map: lexical_rel_name -> list of (synset_name, centroid ndarray)
    Filters out relations whose centroid is empty.
    """
    def _rel_centroids(get_attr):
        try:
            items = []
            for s in get_attr(synset):
                cent = get_synset_embedding_centroid(s, model)
                if cent.size > 0:
                    items.append((s.name(), cent))
            return items
        except Exception:
            return []

    return {
        "part_holonyms": _rel_centroids(lambda x: x.part_holonyms()),
        "substance_holonyms": _rel_centroids(lambda x: x.substance_holonyms()),
        "member_holonyms": _rel_centroids(lambda x: x.member_holonyms()),
        "part_meronyms": _rel_centroids(lambda x: x.part_meronyms()),
        "substance_meronyms": _rel_centroids(lambda x: x.substance_meronyms()),
        "member_meronyms": _rel_centroids(lambda x: x.member_meronyms()),
        "hypernyms": _rel_centroids(lambda x: x.hypernyms()),
        "hyponyms": _rel_centroids(lambda x: x.hyponyms()),
        "entailments": _rel_centroids(lambda x: x.entailments()),
        "causes": _rel_centroids(lambda x: x.causes()),
        "also_sees": _rel_centroids(lambda x: x.also_sees()),
        "verb_groups": _rel_centroids(lambda x: x.verb_groups()),
    }


def get_embedding_similarities(rel_embs_1: List[Tuple[str, np.ndarray]],
                              rel_embs_2: List[Tuple[str, np.ndarray]]) -> np.ndarray:
    """
    Return cosine similarity matrix (m x n) for lists of (name, centroid).
    If either is empty, returns empty (0xN or Mx0) array.
    """
    if not rel_embs_1 or not rel_embs_2:
        return np.zeros((0, 0))

    e1 = np.array([x[1] for x in rel_embs_1], dtype=float)  # (m,d)
    e2 = np.array([x[1] for x in rel_embs_2], dtype=float)  # (n,d)

    # avoid divide-by-zero: replace zero norms with eps
    e1_norms = np.linalg.norm(e1, axis=1, keepdims=True)
    e2_norms = np.linalg.norm(e2, axis=1, keepdims=True)
    e1_norms[e1_norms == 0] = 1e-8
    e2_norms[e2_norms == 0] = 1e-8

    e1u = e1 / e1_norms
    e2u = e2 / e2_norms

    sims = np.dot(e1u, e2u.T)
    return sims


def get_top_k_aligned_lex_rel_pairs(
    src_tgt_rel_map: Dict[str, str],
    src_emb_dict: Dict[str, List[Tuple[SynsetName, np.ndarray]]],
    tgt_emb_dict: Dict[str, List[Tuple[SynsetName, np.ndarray]]],
    beam_width: int = 3,
) -> List[BeamElement]:
    """
    src_tgt_rel_map: mapping from relation name in src to relation name in tgt,
      e.g., {'hypernyms': 'hyponyms', ...}

    Returns list of ((src_syn_name, src_rel), (tgt_syn_name, tgt_rel), similarity)
    """
    rel_sims = []
    for e1_rel, e2_rel in src_tgt_rel_map.items():
        e1_list = src_emb_dict.get(e1_rel, [])
        e2_list = tgt_emb_dict.get(e2_rel, [])
        if not e1_list or not e2_list:
            continue
        sims = get_embedding_similarities(e1_list, e2_list)  # shape (m,n)
        if sims.size == 0:
            continue
        for i in range(sims.shape[0]):
            for j in range(sims.shape[1]):
                try:
                    rel_sims.append(((e1_list[i][0], e1_rel), (e2_list[j][0], e2_rel), float(sims[i, j])))
                except IndexError:
                    continue

    # sort and return top-k
    rel_sims.sort(key=lambda x: x[2], reverse=True)
    if beam_width <= 0:
        return rel_sims
    return rel_sims[:beam_width]


# ============================================================================
# Refactored: Embedding-based synset branch ranking
# ============================================================================

def get_top_k_synset_branch_pairs(
    candidates: List[List],  # List of lists of synsets
    target_synsets,  # Single synset or list of synsets
    beam_width: int = 3,
    model=None,  # embedding model
    wn_module=None  # WordNet module
) -> List[BeamElement]:
    """
    Refactored to use embedding-based alignment instead of get_synset_relatedness.

    Given a list of candidate synset lists and target synset(s),
    return the k synset-relation pairs most similar to the target.

    Returns: List of ((synset_name, lexical_rel), (target_syn_name, lexical_rel), similarity)
    """
    if model is None or wn_module is None:
        return []

    # Handle target synsets (single or list)
    if not isinstance(target_synsets, (list, tuple)):
        target_synsets = [target_synsets]

    if not target_synsets:
        return []

    # Relation maps for alignment
    asymm_map = {
        "hypernyms": "hyponyms",
        "hyponyms": "hypernyms",
        "part_meronyms": "part_holonyms",
        "member_meronyms": "member_holonyms",
        "substance_meronyms": "substance_holonyms",
        "entailments": "causes",
        "causes": "entailments",
    }
    symm_map = {
        "hypernyms": "hypernyms",
        "hyponyms": "hyponyms",
        "part_meronyms": "part_meronyms",
        "member_meronyms": "member_meronyms",
        "also_sees": "also_sees",
        "verb_groups": "verb_groups",
    }

    all_results = []

    # Process each target synset
    for target_syn in target_synsets:
        if target_syn is None:
            continue

        # Precompute target embeddings
        tgt_emb_dict = embed_lexical_relations(target_syn, model)

        # Process each candidate list
        for synset_list in candidates:
            if not synset_list:
                continue

            for synset in synset_list:
                if synset is None:
                    continue

                try:
                    # Compute candidate embeddings
                    src_emb_dict = embed_lexical_relations(synset, model)

                    # Get aligned pairs from asymmetric relations
                    asymm_pairs = get_top_k_aligned_lex_rel_pairs(
                        asymm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width
                    )
                    all_results.extend(asymm_pairs)

                    # Get aligned pairs from symmetric relations
                    symm_pairs = get_top_k_aligned_lex_rel_pairs(
                        symm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width
                    )
                    all_results.extend(symm_pairs)

                except Exception:
                    continue

    # Sort by similarity and return top-k
    all_results.sort(key=lambda x: x[2], reverse=True)
    return all_results[:beam_width]

In [None]:
import numpy as np
from typing import Dict, List, Tuple, Optional, Iterable, Set, Callable
import networkx as nx

# Type aliases
SynsetName = str
BeamElement = Tuple[Tuple[SynsetName, str], Tuple[SynsetName, str], float]
GetNewBeamsFn = Callable[[nx.DiGraph, SynsetName, SynsetName], List[BeamElement]]
TopKBranchFn = Callable[[List[List], object, int], List[BeamElement]]

# functions moved to EmbeddingHelper.py
# These functions are now imported from EmbeddingHelper.py
# from EmbeddingHelper import (
#     get_top_k_synset_branch_pairs,
#     get_top_k_aligned_lex_rel_pairs,
#     get_embedding_similarities,
#     get_synset_embedding_centroid,
#     embed_lexical_relations,
#     get_all_neighbors,
#     get_noun_neighbors,
#     get_verb_neighbors,
#     wn_to_nx,
# )

# Beam Construction

In [None]:
#  Note: beam construction and search functions have been moved to BeamSearch.py
# from BeamSearch import (
#     get_new_beams,
#     get_top_k_branches,
#     get_top_k_branches_depth_limited,
#     get_top_k_branches_depth_limited_lexical,
#     get_top_k_branches_depth_limited_embedding,
#     get_top_k_branches_depth_limited_embedding_lexical,
#     get_top_k_branches_depth_limited_embedding_lexical_wn,
#     get_top_k_branches_depth_limited_embedding_lexical_wn_lexical,
#     get_top_k_branches_depth_limited_embedding_lexical_wn_embedding,
# )

# Pathing

## PairwiseBidirectionalAStar

In [None]:
# Note: PairWiseBidirectionalAStar has been moved to PairWiseBidirectionalAStar.py
# from PairWiseBidirectionalAStar import PairWiseBidirectionalAStar

## Helpers

In [29]:
# -------------------------
# Refactored top_k_branch_wrapper using new embedding functions
# -------------------------
def top_k_branch_wrapper(
    candidates_lists: List[List],
    target_synset_or_list,
    beam_width: int = 3,
    model=None,  # embedding model
    wn_module=None  # WordNet module
) -> List[BeamElement]:
    """
    Adaptation so that find_connected_paths can call a branch-ranking function
    that uses the new embedding-based alignment functions.

    Args:
      candidates_lists: list of lists of wn.synset objects (from gloss tokens)
      target_synset_or_list: either a single wn.synset or a list of target synsets
      beam_width: number of top pairs to return
      model: embedding model (e.g., Word2Vec, GloVe)
      wn_module: WordNet module
    Returns:
      List of beam elements ((synset_name, rel), (target_syn_name, rel), sim)
    """
    if model is None or wn_module is None:
        return []

    results = []

    # Choose a single representative target synset to align against.
    # If a list is provided, pick the first (you can change this strategy).
    if isinstance(target_synset_or_list, (list, tuple)):
        if not target_synset_or_list:
            return []
        target_syn = target_synset_or_list[0]
    else:
        target_syn = target_synset_or_list

    if target_syn is None:
        return []

    # Import the embedding functions (assuming they're available)
    from embedding_helpers import (
        embed_lexical_relations,
        get_top_k_aligned_lex_rel_pairs
    )

    # Precompute target synset relation embeddings
    tgt_emb_dict = embed_lexical_relations(target_syn, model)

    # Relation maps (same defaults used by the embedding-beam adapter)
    asymm_map = {
        "hypernyms": "hyponyms",
        "hyponyms": "hypernyms",
        "part_meronyms": "part_holonyms",
        "member_meronyms": "member_holonyms",
        "substance_meronyms": "substance_holonyms",
        "entailments": "causes",
        "causes": "entailments",
    }
    symm_map = {
        "hypernyms": "hypernyms",
        "hyponyms": "hyponyms",
        "part_meronyms": "part_meronyms",
        "member_meronyms": "member_meronyms",
        "also_sees": "also_sees",
        "verb_groups": "verb_groups",
    }

    # For each candidate synset in every candidate list, compute its lexical-relation embeddings
    # and align to the target's relation embeddings using get_top_k_aligned_lex_rel_pairs
    for cand_list in candidates_lists:
        for cand_syn in cand_list:
            try:
                src_emb_dict = embed_lexical_relations(cand_syn, model)
            except Exception:
                continue

            # get top pairs from both maps
            try:
                asymm_results = get_top_k_aligned_lex_rel_pairs(
                    asymm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width
                )
                results.extend(asymm_results)
            except Exception:
                pass

            try:
                symm_results = get_top_k_aligned_lex_rel_pairs(
                    symm_map, src_emb_dict, tgt_emb_dict, beam_width=beam_width
                )
                results.extend(symm_results)
            except Exception:
                pass

    # Sort aggregated results by similarity and return top beam_width
    results.sort(key=lambda x: x[2], reverse=True)
    return results[:beam_width]


# -------------------------
# High-level triple search (subject -> predicate -> object)
# -------------------------
def find_connected_paths(
    g: nx.DiGraph,
    subject_word: str,
    predicate_word: str,
    object_word: str,
    wn_module,                # your WordNet interface (e.g., nltk.corpus.wordnet or custom)
    nlp_func,                 # spaCy NLP call (callable that takes a string -> doc)
    get_new_beams_fn=None,
    top_k_branch_fn: Optional[TopKBranchFn] = None,
    extract_subjects_from_gloss=None,
    extract_objects_from_gloss=None,
    beam_width: int = 3,
    max_depth: int = 8,
    max_self_intersection: int = 5,
    max_results_per_pair: int = 3,
    len_tolerance: int = 1,
    relax_beam: bool = False,
):
    """
    Find connected subject->predicate->object paths.

    Strategy:
      - Get synsets for subject (nouns), predicate (verbs), object (nouns).
      - For each candidate predicate synset:
          - Build gloss seeds for subject->predicate side (use predicate gloss to extract candidate subject tokens)
          - Build gloss seeds for predicate->object side (use predicate gloss to extract object tokens)
          - Run PairwiseBidirectionalAStar twice (subj->pred and pred->obj), requesting multiple paths each
          - Combine pairs of returned paths that share the predicate synset and meet intersection constraints
      - Return sorted list of connected path triples
    """

    # get word synsets
    subject_synsets = wn_module.synsets(subject_word, pos=wn_module.NOUN)
    predicate_synsets = wn_module.synsets(predicate_word, pos=wn_module.VERB)
    object_synsets = wn_module.synsets(object_word, pos=wn_module.NOUN)

    results = []

    # helper: safe top_k_branch function fallback (if not provided)
    def _default_top_k_branch_fn(candidates_lists, target_synset, beam_width_inner=3):
        # fallback: return empty (user should inject a real one)
        return []

    top_k_branch_fn_used = top_k_branch_fn or _default_top_k_branch_fn

    for pred_syn in predicate_synsets:
        pred_name = pred_syn.name()

        # ------------------------
        # Build gloss seeds for subj->pred
        # ------------------------
        pred_gloss_doc = nlp_func(pred_syn.definition())

        # Extract subject tokens from predicate gloss
        active_subject_tokens = []
        if extract_subjects_from_gloss:
            try:
                active_subject_tokens, _ = extract_subjects_from_gloss(pred_gloss_doc)
            except Exception:
                active_subject_tokens = []

        # Convert tokens to candidate synset lists and get top_k branches
        subject_candidate_synsets = []
        for t in active_subject_tokens:
            try:
                synsets = wn_module.synsets(t.text, pos=wn_module.NOUN)
                if synsets:
                    subject_candidate_synsets.append(synsets)
            except Exception:
                pass

        # Get top-k synsets that best align with subject synsets
        subj_gloss_seed_nodes = set()
        if subject_candidate_synsets and top_k_branch_fn_used:
            try:
                subj_top_k = top_k_branch_fn_used(subject_candidate_synsets, subject_synsets, beam_width)
                # Extract synset names from the results
                for beam_elem in subj_top_k:
                    s_pair = beam_elem[0]  # (synset_name, rel)
                    if isinstance(s_pair[0], str):
                        subj_gloss_seed_nodes.add(s_pair[0])
                    elif hasattr(s_pair[0], 'name'):
                        subj_gloss_seed_nodes.add(s_pair[0].name())
            except Exception:
                pass

        # ------------------------
        # Build gloss seeds for pred->obj
        # ------------------------
        objects_tokens = []
        if extract_objects_from_gloss:
            try:
                objects_tokens = extract_objects_from_gloss(pred_gloss_doc)
            except Exception:
                objects_tokens = []

        # Convert tokens to candidate synset lists
        object_candidate_synsets = []
        for t in objects_tokens:
            try:
                synsets = wn_module.synsets(t.text, pos=wn_module.NOUN)
                if synsets:
                    object_candidate_synsets.append(synsets)
            except Exception:
                pass

        # Get top-k synsets that best align with object synsets
        obj_gloss_seed_nodes = set()
        if object_candidate_synsets and top_k_branch_fn_used:
            try:
                obj_top_k = top_k_branch_fn_used(object_candidate_synsets, object_synsets, beam_width)
                # Extract synset names from the results
                for beam_elem in obj_top_k:
                    s_pair = beam_elem[0]  # (synset_name, rel)
                    if isinstance(s_pair[0], str):
                        obj_gloss_seed_nodes.add(s_pair[0])
                    elif hasattr(s_pair[0], 'name'):
                        obj_gloss_seed_nodes.add(s_pair[0].name())
            except Exception:
                pass

        # Now run pairwise searches:
        # For subject->predicate: src = each subject synset, tgt = pred_syn (pred_name)
        subject_paths_map = {}  # subj_syn_name -> list of (path, cost)
        for subj_syn in subject_synsets:
            # Check if cross-POS search
            if subj_syn.pos() != pred_syn.pos():
                # Cross-POS: rely more on gloss seeds, relax beam constraints
                pair_search = PairwiseBidirectionalAStar(
                    g=g,
                    src=subj_syn.name(),
                    tgt=pred_name,
                    get_new_beams_fn=None,  # Don't use beam function for cross-POS
                    gloss_seed_nodes=subj_gloss_seed_nodes,  # Use gloss seeds
                    beam_width=beam_width,
                    max_depth=max_depth,
                    relax_beam=True  # Allow exploration outside beams
                )
            else:
                # Same-POS: use normal beam search with gloss enrichment
                pair_search = PairwiseBidirectionalAStar(
                    g=g,
                    src=subj_syn.name(),
                    tgt=pred_name,
                    get_new_beams_fn=get_new_beams_fn,
                    gloss_seed_nodes=subj_gloss_seed_nodes,
                    beam_width=beam_width,
                    max_depth=max_depth,
                    relax_beam=relax_beam
                )

            subj_paths = pair_search.find_paths(
                max_results=max_results_per_pair,
                len_tolerance=len_tolerance
            )
            if subj_paths:
                subject_paths_map[subj_syn.name()] = subj_paths

        # For predicate->object: src = pred_name, tgt = each object synset
        object_paths_map = {}  # obj_syn_name -> list of (path, cost)
        for obj_syn in object_synsets:
            # Check if cross-POS search
            if pred_syn.pos() != obj_syn.pos():
                # Cross-POS: rely more on gloss seeds, relax beam constraints
                pair_search = PairwiseBidirectionalAStar(
                    g=g,
                    src=pred_name,
                    tgt=obj_syn.name(),
                    get_new_beams_fn=None,  # Don't use beam function for cross-POS
                    gloss_seed_nodes=obj_gloss_seed_nodes,  # Use gloss seeds
                    beam_width=beam_width,
                    max_depth=max_depth,
                    relax_beam=True  # Allow exploration outside beams
                )
            else:
                # Same-POS: use normal beam search with gloss enrichment
                pair_search = PairwiseBidirectionalAStar(
                    g=g,
                    src=pred_name,
                    tgt=obj_syn.name(),
                    get_new_beams_fn=get_new_beams_fn,
                    gloss_seed_nodes=obj_gloss_seed_nodes,
                    beam_width=beam_width,
                    max_depth=max_depth,
                    relax_beam=relax_beam
                )

            obj_paths = pair_search.find_paths(
                max_results=max_results_per_pair,
                len_tolerance=len_tolerance
            )
            if obj_paths:
                object_paths_map[obj_syn.name()] = obj_paths

        # Combine: for any subj_path and obj_path that both go through `pred_name`,
        # produce combined results if intersection is low enough.
        for subj_syn_name, subj_paths in subject_paths_map.items():
            for subj_path, subj_cost in subj_paths:
                for obj_syn_name, obj_paths in object_paths_map.items():
                    for obj_path, obj_cost in obj_paths:
                        # Make sure paths connect through the predicate
                        if subj_path[-1] == pred_name and obj_path[0] == pred_name:
                            # Check intersection to avoid tautological results
                            intersection_size = len(set(subj_path).intersection(set(obj_path)))
                            if intersection_size <= max_self_intersection:
                                # Combined path: subject_path + object_path[1:] (drop duplicate predicate)
                                combined_path = subj_path + obj_path[1:]
                                combined_cost = subj_cost + obj_cost
                                combined_len = len(combined_path)
                                results.append({
                                    "predicate_synset": pred_name,
                                    "subject_path": subj_path,
                                    "object_path": obj_path,
                                    "combined_path": combined_path,
                                    "combined_cost": combined_cost,
                                    "combined_len": combined_len,
                                })

    # sort results by combined_len then cost
    results = sorted(results, key=lambda r: (r["combined_len"], r["combined_cost"]))
    return results

## Gloss Parsing

In [None]:
# Note: Gloss parsing functions have been moved to GlossParser.py
# from GlossParser import (
#     extract_subjects_from_gloss,
#     extract_objects_from_gloss,
# )

In [None]:
# Note: old cross-POS search functions have been moved to CrossPOSSearch.py

In [None]:
# Note: old wrapper functions have been moved to SemanticDecomposer.py
# from SemanticDecomposer import (
#     top_k_branch_wrapper,
#     find_connected_paths,
# )

# Testing

In [33]:
g = wn_to_nx()

## Example

In [40]:
# ------------------------------
# Example call of find_connected_paths
# ------------------------------
results = find_connected_paths(
    g=g,
    subject_word="dog",
    predicate_word="chase",
    object_word="cat",
    wn_module=wn,
    nlp_func=nlp,
    get_new_beams_fn=lambda G, s, t: get_new_beams_from_embeddings(
        G, s, t, wn_module=wn, model=embedding_model, beam_width=4
    ),
    top_k_branch_fn=lambda candidates, target, bw: top_k_branch_wrapper(
        candidates, target, beam_width=bw, model=embedding_model, wn_module=wn
    ),
    extract_subjects_from_gloss=extract_subjects_from_gloss,
    extract_objects_from_gloss=extract_objects_from_gloss,
    extract_verbs_from_gloss=extract_verbs_from_gloss,  # This parameter is now included
    beam_width=10,
    max_depth=10,
    max_self_intersection=5,
    max_results_per_pair=10,
    len_tolerance=5,
    relax_beam=True,
    max_sample_size=5  # Also added this parameter for controlling gloss token sampling
)

# Print the top results
for idx, r in enumerate(results[:10], start=1):
    print(f"Result #{idx}")
    print(" Predicate synset:", r["predicate_synset"])
    print(" Combined length:", r["combined_len"])
    print(" Combined cost:", r["combined_cost"])
    print(" Combined path:")
    print("  -> ".join(r["combined_path"]))
    print("-" * 60)

TypeError: find_connected_paths() got an unexpected keyword argument 'extract_verbs_from_gloss'

In [35]:
print(results)

[]


In [19]:
diagnostic_run(
    g=g,
    src_name="dog.n.01",
    tgt_name="cat.n.01",
    wn_module=wn,
    nlp_func=nlp,
    embedding_model=embedding_model,
    get_new_beams_fn_wrapped=lambda G,s,t: get_new_beams_from_embeddings(G,s,t, wn_module=wn, model=embedding_model, beam_width=6),
    build_gloss_seed_nodes_fn=lambda pred, wn_mod, nlpfn: build_gloss_seed_nodes_from_predicate(pred, wn_mod, nlpfn,
                                                                                              extract_subjects_fn=globals().get('extract_subjects_from_gloss'),
                                                                                              extract_objects_fn=globals().get('extract_objects_from_gloss'),
                                                                                              extract_verbs_fn=globals().get('extract_verbs_from_gloss'),
                                                                                              top_k_branch_fn=lambda cand, target, bw: top_k_branch_wrapper(cand, target, bw, model=embedding_model, wn_module=wn),
                                                                                              mode="subjects",
                                                                                              max_sample_size=6,
                                                                                              beam_width=6),
)


=== Diagnostic start ===
src: dog.n.01
tgt: cat.n.01

1) Node presence checks:
 g has src node? True
 g has tgt node? True

2) Connectivity check using networkx:
 Undirected connectivity between src and tgt? True
 Example undirected shortest path (hop count): 3 hops. path sample: ['dog.n.01', 'domestic_animal.n.01', 'domestic_cat.n.01', 'cat.n.01']

3) Node neighborhood:
 src out-neighbors (sample): ['canis.n.01', 'pack.n.06', 'flag.n.07', 'domestic_animal.n.01', 'canine.n.02', 'corgi.n.01', 'leonberg.n.01', 'cur.n.01', 'poodle.n.01', 'pug.n.01']
 tgt in-neighbors (sample): ['feline.n.01', 'domestic_cat.n.01', 'wildcat.n.03']

4) Embedding beam seeding (get_new_beams_fn):
 beams count: 6
[(('canine.n.02', 'hypernyms'),
  ('feline.n.01', 'hypernyms'),
  0.6772804990703681),
 (('domestic_animal.n.01', 'hypernyms'),
  ('domestic_cat.n.01', 'hyponyms'),
  0.6423323762382909),
 (('pooch.n.01', 'hyponyms'), ('feline.n.01', 'hypernyms'), 0.634097330718018),
 (('canine.n.02', 'hypernyms'),
  (

## Diagnostics

In [17]:
# Debug / diagnostic script
from pprint import pprint
import networkx as nx

def diagnostic_run(
    g, src_name, tgt_name, wn_module, nlp_func, embedding_model,
    get_new_beams_fn_wrapped, build_gloss_seed_nodes_fn,
    pairwise_class=PairwiseBidirectionalAStar
):
    print("=== Diagnostic start ===")
    print("src:", src_name)
    print("tgt:", tgt_name)
    print()

    # 1) Node presence
    print("1) Node presence checks:")
    print(" g has src node?", g.has_node(src_name))
    print(" g has tgt node?", g.has_node(tgt_name))
    if not g.has_node(src_name) or not g.has_node(tgt_name):
        print(" -> Node name mismatch likely. Check g.nodes() samples:")
        print("  sample nodes:", list(g.nodes)[:20])
        print("Stopping diagnostics early (node mismatch).")
        return

    # 2) Connectivity check (fast)
    print("\n2) Connectivity check using networkx:")
    try:
        has_path = nx.has_path(g.to_undirected(), src_name, tgt_name)
    except Exception:
        # graceful fallback for directed graphs if conversion fails
        has_path = nx.has_path(g, src_name, tgt_name) if nx.is_directed(g) else False
    print(" Undirected connectivity between src and tgt?", has_path)
    if not has_path:
        print(" -> Graph disconnected between src and tgt. Try increasing graph connectivity or check edges.")
    else:
        try:
            sp = nx.shortest_path(g.to_undirected(), src_name, tgt_name)
            print(" Example undirected shortest path (hop count):", len(sp)-1, "hops. path sample:", sp[:10])
        except Exception:
            pass

    # 3) print node degrees and first neighbors
    print("\n3) Node neighborhood:")
    try:
        nbrs_src = list(g.neighbors(src_name))[:10]
    except Exception:
        nbrs_src = []
    try:
        preds_tgt = list(g.predecessors(tgt_name))[:10]
    except Exception:
        preds_tgt = []
    print(" src out-neighbors (sample):", nbrs_src)
    print(" tgt in-neighbors (sample):", preds_tgt)

    # 4) Beam seeding via embeddings
    print("\n4) Embedding beam seeding (get_new_beams_fn):")
    try:
        beams = get_new_beams_fn_wrapped(g, src_name, tgt_name)
        print(" beams count:", len(beams))
        pprint(beams[:10])
    except Exception as e:
        print(" get_new_beams_fn raised exception:", e)
        beams = []

    # 5) Gloss seed extraction for predicate (if available)
    print("\n5) Gloss seeds (from predicate gloss):")
    try:
        # attempt to get a predicate synset if tgt is a predicate; otherwise try wn.synset(tgt)
        try:
            pred_syn = wn_module.synset(tgt_name)
        except Exception:
            pred_syn = None
        seeds = set()
        if pred_syn:
            seeds = build_gloss_seed_nodes_fn(pred_syn, wn_module, nlp_func)
        print(" gloss seeds count:", len(seeds))
        print(" seeds sample:", list(seeds)[:10])
    except Exception as e:
        print(" gloss seeding raised exception:", e)
        seeds = set()

    # 6) Run pairwise search with relaxed settings (wider beam, relax_beam True)
    print("\n6) Run PairwiseBidirectionalAStar with relaxed settings (beam_width=10, relax_beam=True, max_depth=10)")
    try:
        search = pairwise_class(
            g=g,
            src=src_name,
            tgt=tgt_name,
            get_new_beams_fn=lambda G, s, t: get_new_beams_fn_wrapped(G, s, t),
            gloss_seed_nodes=seeds,
            beam_width=10,
            max_depth=10,
            relax_beam=True
        )
        paths = search.find_paths(max_results=5, len_tolerance=2)
        print(" found paths (count):", len(paths))
        for p, cost in paths:
            print(" cost:", cost, " hops:", len(p)-1, " path:", p)
    except Exception as e:
        print(" pairwise search raised exception:", e)

    print("\n=== Diagnostic end ===")
