<a href="https://colab.research.google.com/github/IsaacFigNewton/SMIED/blob/main/Experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from typing import List, Any, Dict, Tuple, Set
import nltk
nltk.download('framenet_v17')
nltk.download('wordnet')
from nltk.corpus import framenet as fn
from nltk.corpus import wordnet as wn

[nltk_data] Downloading package framenet_v17 to /root/nltk_data...
[nltk_data]   Unzipping corpora/framenet_v17.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...


# Get overlapping hypernym paths

In [2]:
def overlapping_hypernym_paths(syn1, syn2) -> List[Any]:
    lchs = syn1.lowest_common_hypernyms(syn2)
    print("LCHs:", [lch.name() for lch in lchs])
    common_paths = []

    for p1 in syn1.hypernym_paths():
        for p2 in syn2.hypernym_paths():
            if any(lch in p1 for lch in lchs) and any(lch in p2 for lch in lchs):
              # truncate the paths until they've got one of the lchs
              while p1 and p2 and p1[0] not in lchs:
                  last_lch = p1[0]
                  p1 = p1[1:]
                  p2 = p2[1:]
              # get the shared lch path
              common_paths.append(p1[::-1] + p2[1:])

    return common_paths

In [3]:
# Example:
cat = wn.synset('cat.n.01')
dog = wn.synset('dog.n.01')
print()
overlaps = overlapping_hypernym_paths(cat, dog)
for path in overlaps:
    print(" → ".join(s.name() for s in path))
print()
overlaps = overlapping_hypernym_paths(dog, cat)
for path in overlaps:
    print(" → ".join(s.name() for s in path))


LCHs: ['carnivore.n.01']
cat.n.01 → feline.n.01 → carnivore.n.01 → canine.n.02 → dog.n.01

LCHs: ['carnivore.n.01']
dog.n.01 → canine.n.02 → carnivore.n.01 → feline.n.01 → cat.n.01


# Get frame info from predicate token

In [54]:
import spacy
from spacy.tokens import Token, Doc
from spacy import displacy
nlp = spacy.load("en_core_web_sm")

## Get a dict of all dependency schemas matching a synset's frames/usage

In [55]:
import spacy
from spacy.tokens import Token

# Register a custom extension for original POS if you want to store it
if not Token.has_extension("orig_tag"):
    Token.set_extension("orig_tag", default=None)

@spacy.language.Language.component("custom_pos_modifier")
def custom_pos_modifier(doc):
    mapping = {"something": "NN", "someone": "NN"}  # Use suitable tag string
    for token in doc:
        # Save original tag
        token._.orig_tag = token.tag_

        if token.text.lower() not in mapping and token.pos_ == "NOUN":
            token.tag_ = "VB"  # Set fine-grained tag to verb
        # Else, you could override mapping if needed

    return doc

arg_schema_nlp = spacy.load("en_core_web_sm")
arg_schema_nlp.add_pipe("custom_pos_modifier", after="tagger")
print(arg_schema_nlp.pipe_names)

['tok2vec', 'tagger', 'custom_pos_modifier', 'parser', 'attribute_ruler', 'lemmatizer', 'ner']


In [56]:
def _flatten_tok_deps(tok: Token) -> Set[Token]:
    deps = set()
    for child in tok.children:
        deps.add(child)
        deps.update(_flatten_tok_deps(child))
    return deps

# TODO: retain original arg structure and add other syntactic/lexical info for better SRL
def _flattened_fn_arg_schema(doc: Doc|Token) -> Dict[str, str]:
      # used as fallback if FE parsing fails later on
      schema_atom_synset_maps = {
          "something": "entity.n.01",
          "someone": "causal_agent.n.01",
      }

      # if it's a token, just flatten its children and call that set a doc
      if isinstance(doc, Token):
          doc: Set[Token] = _flatten_tok_deps(doc)

      arg_schema = dict()
      for tok in doc:
          if tok.lower in schema_atom_synset_maps.keys():
              arg_schema[tok.dep_] = schema_atom_synset_maps[tok.lower]
          else:
              arg_schema[tok.dep_] = tok.lemma_
      return arg_schema

In [57]:
vect_order = ["subject", "object"]  #, "theme"]
arg_schema_dep_req_map = {
    "subject": {"nsubj", "nsubjpass"},
    "object": {"dobj", "dative"},
    # "theme": {"iobj", "pobj"}
}

In [58]:
def _get_dep_reqs(doc: Doc|Token) -> Tuple[bool, bool]:
      arg_schema_reqs = _flattened_fn_arg_schema(doc)

      # restructure arg_reqs based to only evaluate core dependencies
      #   TODO: add more relationships for finer-grained SRL
      return tuple([
          len(arg_schema_dep_req_map[k].intersection(set(arg_schema_reqs.keys()))) > 0
          for k in vect_order
      ])

In [59]:
def _get_f_id_arg_struct_dict(syn: wn.synset) -> Dict[int, Dict[str, bool]]:
      syn_frame_ids_strs: Dict[int, Doc] = dict()
      for lemma in syn.lemmas():

          # get all FrameNet frame IDs for this lemma
          for i, f_id in enumerate(lemma.frame_ids()):

              # get the argument structure for this frame as a string
              f_str = lemma.frame_strings()[i]
              # parse f_str into an argument structure vector
              #   use tuple to make it hashable
              f_arg_structure = f_str.split(' ')

              # remove any extra arguments beyond subject, object, theme
              if f_arg_structure[-1] != f_arg_structure[-1].lower():
                  f_arg_structure = f_arg_structure[:-1]

              # create a spacy doc for the argument structure template
              doc = arg_schema_nlp(' '.join(f_arg_structure))
              # # display dependency parse tree for debugging
              # displacy.render(doc, style='dep')

              syn_frame_ids_strs[f_id] = _get_dep_reqs(doc)

      return syn_frame_ids_strs

In [60]:
f_id_docs = _get_f_id_arg_struct_dict(wn.synset('spin.v.01'))

## Get Wordnet frames structures that match the structure of the dependency parse of the original predicate token

In [61]:
def _get_candidate_wn_frames(pred_tok: spacy.tokens.Token) -> Dict[str, Set[Tuple[bool, bool]]]:
      # since wordnet frames only support a max of 2 args, even for ditransitive verbs,
      #   truncate to just subj, obj roles
      original_tok_dep_reqs = _get_dep_reqs(pred_tok)
      print(f"Original pred_tok dependency structure requirements: {original_tok_dep_reqs}")

      # get candidate WordNet synsets for the predicate (verbs)
      pred_lemma = pred_tok.lemma_.lower()
      # get a dict of synset names and synset objects for quick lookup
      pred_synsets = wn.synsets(pred_lemma, pos=wn.VERB)
      pred_synsets = {syn.name(): syn for syn in pred_synsets}
      # create a dict of synset names and lists of their possible frames,
      #   indexed by argument structure (num args of each type)
      pred_frames: Dict[str, Set[Tuple[bool, bool]]] = dict()

      # get all FrameNet frames associated with this synset
      for s_name, syn in pred_synsets.items():
          for frame_id, arg_struct_tuple in _get_f_id_arg_struct_dict(syn).items():
              # ensure the proposed tuple matches the requirements of the original token
              if arg_struct_tuple == original_tok_dep_reqs:
                  if s_name not in pred_frames:
                      pred_frames[s_name] = set()
                  # add wordnet's frame requirements as an entry in the synset's nested dict
                  pred_frames[s_name].add(arg_struct_tuple)

      return pred_frames

In [62]:
import pprint
test_doc = nlp("I struck the cobra with my axe.")
displacy.render(test_doc, style='dep')
pprint.pprint(_get_candidate_wn_frames(test_doc[1]))

Original pred_tok dependency structure requirements: (True, True)
{'assume.v.05': {(True, True)},
 'come_to.v.03': {(True, True)},
 'fall_upon.v.01': {(True, True)},
 'hit.v.12': {(True, True)},
 'mint.v.01': {(True, True)},
 'strickle.v.02': {(True, True)},
 'strike.v.01': {(True, True)},
 'strike.v.04': {(True, True)},
 'strike.v.10': {(True, True)},
 'strike.v.13': {(True, True)},
 'strike.v.14': {(True, True)},
 'strike.v.21': {(True, True)}}
