<a href="https://colab.research.google.com/github/IsaacFigNewton/SMIED/blob/main/BFS_Semantic_Decomposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
Semantic decomposition of ("cat", "eats", "mouse") using WordNet + spaCy + depth-limited GBFS.
- Uses spaCy to parse verb synset glosses and detect subject/object dependencies.
- If both subject and object tokens are present, branches directly toward original triple synsets.
- Otherwise falls back to WordNet relations.
"""

import nltk
import spacy
from nltk.corpus import wordnet as wn
from nltk.wsd import lesk
from heapq import heappush, heappop

In [2]:
nltk.download('wordnet')

# Load spaCy English model for dependency parsing
nlp = spacy.load("en_core_web_sm")

[nltk_data] Downloading package wordnet to /root/nltk_data...


In [9]:
from collections import deque
import nltk
from nltk.corpus import wordnet as wn
from nltk.wsd import lesk
import spacy

# Initialize spaCy (assuming you have it loaded)
# nlp = spacy.load("en_core_web_sm")

# ============================================================================
# Core Path Finding Functions
# ============================================================================

def path_syn_to_syn(start_synset, end_synset, max_depth=6):
    """
    Find shortest path between synsets of the same POS using bidirectional BFS.
    Returns a list of synsets forming the path, or None if no path found.
    """

    assert start_synset.pos() == end_synset.pos() and start_synset.pos() in {'n', 'v'}

    # Handle the trivial case where start and end are the same
    if start_synset.name() == end_synset.name():
        return [start_synset]

    # Initialize two search frontiers
    forward_queue = deque([(start_synset, 0)])
    forward_visited = {start_synset.name(): [start_synset]}

    backward_queue = deque([(end_synset, 0)])
    backward_visited = {end_synset.name(): [end_synset]}

    def expand_frontier(queue, visited_from_this_side, visited_from_other_side, is_forward):
        """Expand one step of the search frontier."""
        if not queue:
            return None

        curr_synset, depth = queue.popleft()

        if depth >= (max_depth + 1) // 2:
            return None

        path_to_current = visited_from_this_side[curr_synset.name()]

        for neighbor in get_all_neighbors(curr_synset):
            neighbor_name = neighbor.name()

            if neighbor_name in visited_from_this_side:
                continue

            if is_forward:
                new_path = path_to_current + [neighbor]
            else:
                new_path = [neighbor] + path_to_current

            if neighbor_name in visited_from_other_side:
                other_path = visited_from_other_side[neighbor_name]

                if is_forward:
                    full_path = path_to_current + other_path
                else:
                    full_path = other_path + path_to_current

                return full_path

            visited_from_this_side[neighbor_name] = new_path
            queue.append((neighbor, depth + 1))

        return None

    # Alternate between forward and backward search
    while forward_queue or backward_queue:
        if forward_queue:
            result = expand_frontier(forward_queue, forward_visited, backward_visited, True)
            if result:
                return result

        if backward_queue:
            result = expand_frontier(backward_queue, backward_visited, forward_visited, False)
            if result:
                return result

    return None


def get_all_neighbors(synset):
    """Get all neighbors of a synset based on its POS."""
    neighbors = []

    # Add hypernyms and hyponyms
    neighbors.extend(synset.hypernyms())
    neighbors.extend(synset.hyponyms())

    # Add POS-specific neighbors
    if synset.pos() == 'n':
        neighbors.extend(get_noun_neighbors(synset))
    else:
        neighbors.extend(get_verb_neighbors(synset))

    return neighbors


def get_noun_neighbors(syn):
    """Get neighbors for a noun synset."""
    nbrs = set()
    nbrs.update(syn.part_meronyms())
    nbrs.update(syn.substance_meronyms())
    nbrs.update(syn.member_meronyms())
    nbrs.update(syn.part_holonyms())
    nbrs.update(syn.substance_holonyms())
    nbrs.update(syn.member_holonyms())
    return list(nbrs)


def get_verb_neighbors(syn):
    """Get neighbors for a verb synset."""
    nbrs = set()
    nbrs.update(syn.entailments())
    nbrs.update(syn.causes())
    nbrs.update(syn.also_sees())
    nbrs.update(syn.verb_groups())
    return list(nbrs)


# ============================================================================
# Gloss Analysis Helper Functions
# ============================================================================

def extract_subjects_from_gloss(gloss_doc):
    """Extract subject tokens from a parsed gloss."""
    subjects = []

    # Direct subjects
    subjects.extend([tok for tok in gloss_doc if tok.dep_ == "nsubj"])

    # Passive subjects (which are actually objects semantically)
    # Skip these for actor identification
    passive_subjects = [tok for tok in gloss_doc if tok.dep_ == "nsubjpass"]

    # Filter out passive subjects from the main list
    subjects = [s for s in subjects if s not in passive_subjects]

    return subjects, passive_subjects


def extract_objects_from_gloss(gloss_doc):
    """Extract various types of object tokens from a parsed gloss."""
    objs = []

    # Direct objects
    objs.extend([tok for tok in gloss_doc if tok.dep_ == "dobj"])
    # Prepositional objects
    objs.extend([tok for tok in gloss_doc if tok.dep_ == "pobj"])
    # Indirect objects
    objs.extend([tok for tok in gloss_doc if tok.dep_ == "iobj"])
    # General objects
    objs.extend([tok for tok in gloss_doc if tok.dep_ == "obj"])

    # Check for noun chunks related to root verb
    root_verbs = [tok for tok in gloss_doc if tok.dep_ == "ROOT" and tok.pos_ == "VERB"]
    if root_verbs and not objs:
        for noun_chunk in gloss_doc.noun_chunks:
            if any(token.head == root_verbs[0] for token in noun_chunk):
                objs.append(noun_chunk.root)

    return objs


def extract_verbs_from_gloss(gloss_doc, include_passive=False):
    """Extract verb tokens from a parsed gloss."""
    verbs = [tok for tok in gloss_doc if tok.pos_ == "VERB"]

    if include_passive:
        # Past participles used as adjectives or in relative clauses
        passive_verbs = [tok for tok in gloss_doc if
                        tok.tag_ in ["VBN", "VBD"] and
                        tok.dep_ in ["acl", "relcl", "amod"]]
        verbs.extend(passive_verbs)

    return verbs


def find_instrumental_verbs(gloss_doc):
    """Find verbs associated with instrumental use (e.g., 'used for')."""
    instrumental_verbs = []

    if "used" in gloss_doc.text.lower():
        for i, token in enumerate(gloss_doc):
            if token.text.lower() == "used":
                # Check tokens after "used"
                for j in range(i+1, min(i+4, len(gloss_doc))):
                    if gloss_doc[j].pos_ == "VERB":
                        instrumental_verbs.append(gloss_doc[j])

    return instrumental_verbs


# ============================================================================
# Cross-POS Path Finding Functions
# ============================================================================

def find_subject_to_predicate_path(subject_synset, predicate_synset, max_depth=6):
    """Find path from subject (noun) to predicate (verb)."""
    paths = []

    # Strategy 1: Look for active subjects in verb's gloss
    pred_gloss_doc = nlp(predicate_synset.definition())
    active_subjects, passive_subjects = extract_subjects_from_gloss(pred_gloss_doc)

    # Try active subjects (true actors)
    for subj in active_subjects[:3]:
        try:
            subject_synset_from_gloss = lesk(pred_gloss_doc.text, subj.text, pos='n')
            if subject_synset_from_gloss:
                path = path_syn_to_syn(subject_synset, subject_synset_from_gloss, max_depth)
                if path:
                    paths.append(path + [predicate_synset])
        except:
            continue

    # Strategy 2: Look for verbs in the noun's gloss
    subj_gloss_doc = nlp(subject_synset.definition())
    verbs = extract_verbs_from_gloss(subj_gloss_doc, include_passive=False)

    for verb in verbs[:3]:
        try:
            verb_synset_from_gloss = lesk(subj_gloss_doc.text, verb.text, pos='v')
            if verb_synset_from_gloss:
                path = path_syn_to_syn(verb_synset_from_gloss, predicate_synset, max_depth)
                if path:
                    paths.append([subject_synset] + path)
        except:
            continue

    # Return shortest path if any found
    return min(paths, key=len) if paths else None


def find_predicate_to_object_path(predicate_synset, object_synset, max_depth=6):
    """Find path from predicate (verb) to object (noun)."""
    paths = []

    # Strategy 1: Look for objects in verb's gloss
    pred_gloss_doc = nlp(predicate_synset.definition())
    objects = extract_objects_from_gloss(pred_gloss_doc)

    # Also check for passive subjects (which are semantic objects)
    _, passive_subjects = extract_subjects_from_gloss(pred_gloss_doc)
    objects.extend(passive_subjects)

    for obj in objects[:3]:
        try:
            object_synset_from_gloss = lesk(pred_gloss_doc.text, obj.text, pos='n')
            if object_synset_from_gloss:
                path = path_syn_to_syn(object_synset_from_gloss, object_synset, max_depth)
                if path:
                    paths.append([predicate_synset] + path)
        except:
            continue

    # Strategy 2: Look for verbs in object's gloss
    obj_gloss_doc = nlp(object_synset.definition())
    verbs = extract_verbs_from_gloss(obj_gloss_doc, include_passive=True)

    for verb in verbs[:3]:
        try:
            verb_synset_from_gloss = lesk(obj_gloss_doc.text, verb.text, pos='v')
            if verb_synset_from_gloss:
                path = path_syn_to_syn(predicate_synset, verb_synset_from_gloss, max_depth)
                if path:
                    paths.append(path + [object_synset])
        except:
            continue

    # Strategy 3: Check for instrumental relationships
    instrumental_verbs = find_instrumental_verbs(obj_gloss_doc)
    for verb in instrumental_verbs[:2]:
        try:
            verb_synset = lesk(obj_gloss_doc.text, verb.text, pos='v')
            if verb_synset:
                path = path_syn_to_syn(predicate_synset, verb_synset, max_depth)
                if path:
                    paths.append(path + [object_synset])
        except:
            continue

    # Return shortest path if any found
    return min(paths, key=len) if paths else None


# ============================================================================
# Main Connected Path Finding Function
# ============================================================================

def find_connected_shortest_paths(subject_word, predicate_word, object_word, max_depth=10):
    """
    Find shortest connected paths from subject through predicate to object.
    Ensures that the same predicate synset connects both paths.
    """

    # Get synsets for each word
    subject_synsets = wn.synsets(subject_word, pos=wn.NOUN)
    predicate_synsets = wn.synsets(predicate_word, pos=wn.VERB)
    object_synsets = wn.synsets(object_word, pos=wn.NOUN)

    best_combined_path_length = float('inf')
    best_subject_path = None
    best_object_path = None
    best_predicate = None

    # Try each predicate synset as the connector
    for pred in predicate_synsets:
        # Find paths from all subjects to this specific predicate
        subject_paths = []
        for subj in subject_synsets:
            path = find_subject_to_predicate_path(subj, pred, max_depth)
            if path:
                subject_paths.append(path)

        # Find paths from this specific predicate to all objects
        object_paths = []
        for obj in object_synsets:
            path = find_predicate_to_object_path(pred, obj, max_depth)
            if path:
                object_paths.append(path)

        # If we have both paths through this predicate, check if it's the best
        if subject_paths and object_paths:
            shortest_subj_path = min(subject_paths, key=len)
            shortest_obj_path = min(object_paths, key=len)

            # Calculate combined length (subtract 1 to avoid counting predicate twice)
            combined_length = len(shortest_subj_path) + len(shortest_obj_path) - 1

            if combined_length < best_combined_path_length:
                best_combined_path_length = combined_length
                best_subject_path = shortest_subj_path
                best_object_path = shortest_obj_path
                best_predicate = pred

    return best_subject_path, best_object_path, best_predicate


# ============================================================================
# Display Functions
# ============================================================================

def show_path(label, path):
    """Pretty print a path of synsets."""
    if path:
        print(f"{label}:")
        print(" -> ".join(f"{s.name()} ({s.definition()})" for s in path))
        print(f"Path length: {len(path)}")
        print()
    else:
        print(f"{label}: No path found")
        print()


def show_connected_paths(subject_path, object_path, predicate):
    """Display the connected paths with their shared predicate."""
    if subject_path and object_path and predicate:
        print("=" * 70)
        print(f"CONNECTED PATH through predicate: {predicate.name()}")
        print("=" * 70)

        show_path("Subject -> Predicate path", subject_path)
        show_path("Predicate -> Object path", object_path)

        # Show the complete connected path
        complete_path = subject_path + object_path[1:]  # Avoid duplicating the predicate
        print("Complete connected path:")
        print(" -> ".join(f"{s.name()}" for s in complete_path))
        print(f"Total path length: {len(complete_path)}")
        print()
    else:
        print("No connected path found through any predicate synset.")


In [11]:
nlp = spacy.load("en_core_web_sm")

# Find shortest connected paths
subject_path, object_path, connecting_predicate = find_connected_shortest_paths(
    "cat", "eat", "mouse", max_depth=10
)

# Display results
show_connected_paths(subject_path, object_path, connecting_predicate)

CONNECTED PATH through predicate: feed.v.06
Subject -> Predicate path:
kat.n.01 (the leaves of the shrub Catha edulis which are chewed like tobacco or used to make tea; has the effect of a euphoric stimulant) -> chew.v.01 (chew (food); to bite and grind with the teeth) -> eat.v.01 (take in solid food) -> feed.v.06 (take in food; used of animals only)
Path length: 4

Predicate -> Object path:
feed.v.06 (take in food; used of animals only) -> animal.n.01 (a living organism characterized by voluntary movement) -> organism.n.01 (a living thing that has (or can develop) the ability to act or function independently) -> person.n.01 (a human being) -> mouse.n.03 (person who is quiet or timid)
Path length: 5

Complete connected path:
kat.n.01 -> chew.v.01 -> eat.v.01 -> feed.v.06 -> animal.n.01 -> organism.n.01 -> person.n.01 -> mouse.n.03
Total path length: 8

