In [None]:
import nltk
nltk.download('wordnet')

import json
import networkx as nx
from nltk.corpus import wordnet as wn
from transformers import AutoTokenizer

import inflect

# Setup

In [2]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

vocab = tokenizer.get_vocab()
vocab_set = set(vocab.keys())

p = inflect.engine()

## nouns

In [3]:
def get_all_hyponym_lemmas(synset):
    hyponyms = synset.hyponyms()
    lemmas = set()
    for hyponym in hyponyms:
        lemmas.update(lemma.name() for lemma in hyponym.lemmas())
        lemmas.update(get_all_hyponym_lemmas(hyponym))  # Recursively get lemmas from hyponyms,
    
    return lemmas

In [4]:
all_noun_synsets = list(wn.all_synsets(pos=wn.NOUN))
noun_lemmas = {}
for s in all_noun_synsets:
    lemmas = get_all_hyponym_lemmas(s)
    # add and remove space bc of how gemma vocab works
    lemmas = vocab_set.intersection({"▁" + l for l in lemmas})
    noun_lemmas[s.name()] = {l[1:] for l in lemmas}

large_nouns = {k: v for k, v in noun_lemmas.items() if len(v) > 5}

In [5]:
print(len(all_noun_synsets))
print(len(large_nouns))

82115
2573


In [6]:
# Construct the hypernym inclusion graph among large categories
G_noun = nx.DiGraph()

nodes = list(large_nouns.keys())
for key in nodes:
    for path in wn.synset(key).hypernym_paths():
        # ancestors included in the cleaned set
        ancestors = [s.name() for s in path if s.name() in nodes]
        if len(ancestors) > 1:
            G_noun.add_edge(ancestors[-2],key) # first entry is itself
        else:
            print(f"no ancestors for {key}")

G_noun = nx.DiGraph(G_noun.subgraph(nodes))

no ancestors for entity.n.01


In [7]:
# if a node has only one child, and that child has only one parent, merge the two nodes
def merge_nodes(G, lemma_dict):
    topological_sorted_nodes = list(reversed(list(nx.topological_sort(G))))
    for node in topological_sorted_nodes:
        children = list(G.successors(node))
        if len(children) == 1:
            child = children[0]
            parent_lemmas_not_in_child = lemma_dict[node] - lemma_dict[child]
            if len(list(G.predecessors(child))) == 1 or len(parent_lemmas_not_in_child) <6:
                grandchildren = list(G.successors(child))
                
                if len(parent_lemmas_not_in_child) > 1:
                    if len(grandchildren) > 0:
                        lemma_dict[node + '.other'] = parent_lemmas_not_in_child
                        G.add_edge(node, node + '.other')

                # del synset_lemmas[child]
                for grandchild in grandchildren:
                    G.add_edge(node, grandchild)
                G.remove_node(child)
                print(f"merged {node} and {child}")

merge_nodes(G_noun, large_nouns)
large_nouns = {k: v for k, v in large_nouns.items() if k in G_noun.nodes()}

merged bovine.n.01 and cattle.n.01
merged horse.n.01 and saddle_horse.n.01
merged percoid_fish.n.01 and scombroid.n.01
merged equine.n.01 and horse.n.01
merged cat.n.01 and domestic_cat.n.01
merged spiny-finned_fish.n.01 and percoid_fish.n.01
merged soft-finned_fish.n.01 and cypriniform_fish.n.01
merged globulin.n.01 and gamma_globulin.n.01
merged hominid.n.01 and homo.n.02
merged even-toed_ungulate.n.01 and ruminant.n.01
merged odd-toed_ungulate.n.01 and equine.n.01
merged canine.n.02 and dog.n.01
merged simple_protein.n.01 and globulin.n.01
merged malignant_tumor.n.01 and cancer.n.01
merged tolerance.n.03 and lenience.n.03
merged fruit_tree.n.01 and citrus.n.02
merged primate.n.02 and hominid.n.01
merged saurian.n.01 and lizard.n.01
merged waterfowl.n.01 and anseriform_bird.n.01
merged domestic_fowl.n.01 and chicken.n.02
merged bony_fish.n.01 and teleost_fish.n.01
merged commissioned_officer.n.01 and commissioned_military_officer.n.01
merged doctor.n.01 and specialist.n.02
merged car

In [9]:
# make a gemma specific version
def _noun_to_gemma_vocab_elements(word):
    word = word.lower()
    plural = p.plural(word)
    add_cap_and_plural = [word, word.capitalize(), plural, plural.capitalize()]
    add_space = ["▁" + w for w in add_cap_and_plural]
    return vocab_set.intersection(add_space)

### save the data
# with open('data/noun_synsets_wordnet_gemma.json', 'w') as f:
#     for synset, lemmas in large_nouns.items():
#         gemma_words = []
#         for w in lemmas:
#             gemma_words.extend(_noun_to_gemma_vocab_elements(w))

#         f.write(json.dumps({synset: gemma_words}) + "\n")
        
# nx.write_adjlist(G_noun, "data/noun_synsets_wordnet_hypernym_graph.adjlist")

#### After this, we utilized the GPT-4 API to obtain cleaner features. The code is omitted here.

## verbs

In [10]:
all_verb_synsets = list(wn.all_synsets(pos=wn.VERB))

G_verb = nx.DiGraph()
verb_leaves = [s.name() for s in all_verb_synsets if len(s.hyponyms()) == 0]
for leaf in verb_leaves:
    for path in wn.synset(leaf).hypernym_paths():
        for i in range(len(path) - 1):
            G_verb.add_edge(path[i].name(), path[i+1].name())

verb_lemmas = {}
sorted_verb_synsets = list(reversed(list(nx.topological_sort(G_verb))))
for s in sorted_verb_synsets:
    # add and remove "▁" bc of how gemma vocab works
    verb_lemmas[s] = vocab_set.intersection(["▁" + lemma.name() for lemma in wn.synset(s).lemmas()])
    verb_lemmas[s] = {lemma[1:] for lemma in verb_lemmas[s]}
    for child in G_verb.successors(s):
        verb_lemmas[s].update(verb_lemmas[child])

large_verbs = {k: v for k, v in verb_lemmas.items() if len(v) > 5}

In [11]:
G_verb = nx.DiGraph(G_verb.subgraph(list(large_verbs.keys())))
merge_nodes(G_verb, verb_lemmas)
large_verbs = {k: v for k, v in large_verbs.items() if k in G_verb.nodes()}

merged provoke.v.03 and entice.v.01
merged challenge.v.02 and provoke.v.03
merged set_forth.v.01 and describe.v.01
merged mean.v.01 and typify.v.02
merged invite.v.04 and challenge.v.02
merged back.v.01 and guarantee.v.04
merged broach.v.01 and cover.v.05
merged rede.v.02 and urge.v.01
merged elaborate.v.01 and set_forth.v.01
merged suppress.v.01 and hush.v.02
merged approve.v.01 and back.v.01
merged deceive.v.01 and cheat.v.03
merged cheat.v.01 and overcharge.v.01
merged argue.v.02 and oppose.v.01
merged clarify.v.01 and elaborate.v.01
merged deceive.v.02 and gull.v.02
merged overstate.v.01 and boast.v.01
merged repeat.v.01 and sum_up.v.01
merged bespeak.v.01 and bode.v.01
merged impart.v.01 and convey.v.01
merged oppose.v.06 and protest.v.02
merged recognize.v.09 and honor.v.01
merged appeal.v.02 and plead.v.01
merged ask.v.02 and request.v.02
merged allocate.v.01 and award.v.01
merged house.v.02 and lodge.v.04
merged ask.v.05 and question.v.03
merged converse.v.01 and argue.v.02
mer

In [12]:
def _verb_to_gemma_vocab_elements(verb):
    #  add in regular verb conjugations to expand vocab
 
    verb = verb.lower()

    # some weird wordnet bugs, filter it
    if len(verb) < 3 and verb not in ["be", "do", "go"]:
        return set()

    # Conjugating to past tense
    if verb.endswith('e'):
        past = verb + 'd'
    else:
        past = verb + 'ed'
    
    # Conjugating to present participle
    if verb.endswith('e'):
        present_participle = verb[:-1] + 'ing'
    else:
        present_participle = verb + 'ing'
    
    # 3psg 
    if verb.endswith(("o", "ch", "s", "sh", "x", "z")):
        third_person = verb + 'es'
    elif verb.endswith("y") and verb[-2] not in "aeiou":
        third_person = verb[:-1] + 'ies'
    else:
        third_person = verb + 's'


    tenses = [verb, past, present_participle, third_person]

    caps = [w.capitalize() for w in tenses]
    
    # Add underscore prefix to each word to match vocab tokens
    add_space = ["▁" + w for w in (caps + tenses)]

    # Return intersection with a hypothetical vocabulary set
    return vocab_set.intersection(add_space)

### save the data
# with open('data/verb_synsets_wordnet_gemma.json', 'w') as f:
#     for synset, lemmas in large_verbs.items():
#         gemma_words = []
#         for w in lemmas:
#             gemma_words.extend(_verb_to_gemma_vocab_elements(w))

#         f.write(json.dumps({synset: gemma_words}) + "\n")
        
# nx.write_adjlist(G_verb, "data/verb_synsets_wordnet_hypernym_graph.adjlist")