<a href="https://colab.research.google.com/github/IsaacFigNewton/SMIED/blob/main/BFS_Semantic_Decomposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [27]:
"""
Semantic decomposition of ("cat", "eats", "mouse") using WordNet + spaCy + depth-limited GBFS.
- Uses spaCy to parse verb synset glosses and detect subject/object dependencies.
- If both subject and object tokens are present, branches directly toward original triple synsets.
- Otherwise falls back to WordNet relations.
"""

import nltk
import spacy
from nltk.corpus import wordnet as wn
from nltk.wsd import lesk
from heapq import heappush, heappop

In [2]:
nltk.download('wordnet')

# Load spaCy English model for dependency parsing
nlp = spacy.load("en_core_web_sm")

[nltk_data] Downloading package wordnet to /root/nltk_data...


In [35]:
def path_syn_to_syn(start_synset, end_synset, max_depth=6):
    """
    Get neighbors for a synset.
    """

    assert start_synset.pos() == end_synset.pos() and start_synset.pos() in {'n', 'v'}

    # DFS
    queue = [(start_synset, 0)]
    while queue:
        curr_synset, curr_depth = queue.pop()
        # if a path has been found
        if curr_synset.name() == end_synset.name():
            return [curr_synset.name()]
        # if no path was found
        elif curr_depth > max_depth:
            return None
        else:
            queue.extend(curr_synset.hypernyms())
            queue.extend(curr_synset.hyponyms())
            if start_synset.pos() == 'n':
                neighbors = list(get_noun_neighbors(curr_synset))
                neighbors_depths = [(n, curr_depth+1) for n in neighbors]
                queue.extend(neighbors_depths)
            else:
                neighbors = list(get_verb_neighbors(curr_synset))
                neighbors_depths = [(n, curr_depth+1) for n in neighbors]
                queue.extend(neighbors_depths)



def get_noun_neighbors(syn):
    """
    Get neighbors for a synset.
    """
    nbrs = set()
    nbrs.update(syn.part_meronyms())
    nbrs.update(syn.substance_meronyms())
    nbrs.update(syn.member_meronyms())
    nbrs.update(syn.part_holonyms())
    nbrs.update(syn.substance_holonyms())
    nbrs.update(syn.member_holonyms())
    return nbrs

def get_verb_neighbors(syn):
    """
    Get neighbors for a synset.
    """
    nbrs = set()
    nbrs.update(syn.entailments())
    nbrs.update(syn.causes())
    nbrs.update(syn.also_sees())
    nbrs.update(syn.verb_groups())
    return nbrs


def bfs(start_synset=None, end_synset=None, max_depth=6):
    """
    Get neighbors for a synset.
    If start_synset is a verb and end_synset is a noun,
        parse its gloss with spaCy and check for nsubj/dobj tokens.
    If both found, branch toward subject/object synsets directly.

    If start_synset is a noun and end_synset is a verb,
        parse its gloss with SpaCy and check for verb relations.
    If found, branch toward verb hypernyms/hyponyms/etc.

    Otherwise, fall back to WordNet relations.
    """

    assert start_synset.pos() != end_synset.pos() and start_synset.pos() in {'n', 'v'}

    if start_synset and end_synset:
        # if it's a subject-predicate pairing
        if start_synset.pos() == 'n':
            # get gloss of predicate
            pred_gloss_doc = nlp(end_synset.definition())
            subjs = [tok for tok in pred_gloss_doc if tok.dep_ == "nsubj"]
            preds = [tok for tok in pred_gloss_doc if tok.pos_ == "VERB"]
            objs = [tok for tok in pred_gloss_doc if tok.dep_ == "dobj"]
            if len(preds)>0 and len(subjs)>0:
                return [start_synset] + path_syn_to_syn(
                    start_synset,
                    lesk(pred_gloss_doc.text, subjs[0], pos='n'),
                    max_depth=max_depth,
                )

        # if it's a predicate-object pairing
        if start_synset.pos() == 'v':
            # get gloss of predicate
            pred_gloss_doc = nlp(start_synset.definition())
            subjs = [tok for tok in pred_gloss_doc if tok.dep_ == "nsubj"]
            preds = [tok for tok in pred_gloss_doc if tok.pos_ == "VERB"]
            objs = [tok for tok in pred_gloss_doc if tok.dep_ == "dobj"]
            if len(subjs)>0 and len(objs)>0:
                return [start_synset] + path_syn_to_syn(
                    lesk(pred_gloss_doc.text, objs[0], pos='n'),
                    end_synset,
                    max_depth=max_depth,
                )

# Gather synsets
cat_synsets = wn.synsets("cat", pos=wn.NOUN)
mouse_synsets = wn.synsets("mouse", pos=wn.NOUN)
eat_synsets = wn.synsets("eat", pos=wn.VERB)

# Run BFS from subject and object
cat_paths = list()
mouse_paths = list()
for eat in eat_synsets:
  for cat in cat_synsets:
    cat_path = bfs(
        cat,
        eat,
        max_depth=10
    )
    if cat_path:
      cat_paths.append(cat_path)
  for mouse in mouse_synsets:
    mouse_path = bfs(
        eat,
        mouse,
        max_depth=10
    )
    if mouse_path:
      mouse_paths.append(mouse_path)

# get the shortest cat and mouse paths
cat_path = min(cat_paths, key=len)
mouse_path = min(mouse_paths, key=len)

# Pretty print
def show_path(label, path):
    if path:
        print(f"{label}:")
        print(" -> ".join(f"{s.name()} ({s.definition()})" for s in path))
        print()

show_path("Path from 'cat' to 'eat' (actor slot)", cat_path)
show_path("Path from 'mouse' to 'eat' (object slot)", mouse_path)

ValueError: min() iterable argument is empty