In [None]:
import nltk
nltk.download('wordnet')

import json
import networkx as nx
from nltk.corpus import wordnet as wn
from transformers import AutoTokenizer

import inflect

# Setup

In [2]:
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
# MODEL_NAME = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

vocab = tokenizer.get_vocab()
vocab_set = set(vocab.keys())

p = inflect.engine()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
def get_all_hyponym_lemmas(synset):
    hyponyms = synset.hyponyms()
    lemmas = set()
    for hyponym in hyponyms:
        lemmas.update(lemma.name() for lemma in hyponym.lemmas())
        lemmas.update(get_all_hyponym_lemmas(hyponym))  # Recursively get lemmas from hyponyms,
    
    return lemmas

In [4]:
all_noun_synsets = list(wn.all_synsets(pos=wn.NOUN))
noun_lemmas = {}
for s in all_noun_synsets:
    lemmas = get_all_hyponym_lemmas(s)
    # add and remove space bc of how gemma vocab works
    lemmas = vocab_set.intersection({"Ġ" + l for l in lemmas})
    noun_lemmas[s.name()] = {l[1:] for l in lemmas}

large_nouns = {k: v for k, v in noun_lemmas.items() if len(v) > 5}

In [5]:
print(len(all_noun_synsets))
print(len(large_nouns))

82115
2016


In [6]:
# Construct the hypernym inclusion graph among large categories
G_noun = nx.DiGraph()

nodes = list(large_nouns.keys())
for key in nodes:
    for path in wn.synset(key).hypernym_paths():
        # ancestors included in the cleaned set
        ancestors = [s.name() for s in path if s.name() in nodes]
        if len(ancestors) > 1:
            G_noun.add_edge(ancestors[-2],key) # first entry is itself
        else:
            print(f"no ancestors for {key}")

G_noun = nx.DiGraph(G_noun.subgraph(nodes))

# list(G.successors('reptile.n.01'))

no ancestors for entity.n.01


In [None]:
# if a node has only one child, and that child has only one parent, merge the two nodes
def merge_nodes(G, lemma_dict):
    topological_sorted_nodes = list(reversed(list(nx.topological_sort(G))))
    for node in topological_sorted_nodes:
        children = list(G.successors(node))
        if len(children) == 1:
            child = children[0]
            parent_lemmas_not_in_child = lemma_dict[node] - lemma_dict[child]
            if len(list(G.predecessors(child))) == 1 or len(parent_lemmas_not_in_child) <6:
                grandchildren = list(G.successors(child))
                
                if len(parent_lemmas_not_in_child) > 1:
                    if len(grandchildren) > 0:
                        lemma_dict[node + '.other'] = parent_lemmas_not_in_child
                        G.add_edge(node, node + '.other')

                # del synset_lemmas[child]
                for grandchild in grandchildren:
                    G.add_edge(node, grandchild)
                G.remove_node(child)
                print(f"merged {node} and {child}")

merge_nodes(G_noun, large_nouns)
large_nouns = {k: v for k, v in large_nouns.items() if k in G_noun.nodes()}

In [8]:
nx.is_weakly_connected(G_noun)

True

In [9]:
# make a gemma specific version
def _noun_to_gemma_vocab_elements(word):
    word = word.lower()
    plural = p.plural(word)
    add_cap_and_plural = [word, word.capitalize(), plural, plural.capitalize()]
    add_space = ["Ġ" + w for w in add_cap_and_plural]
    return vocab_set.intersection(add_space)

with open('data/noun_synsets_wordnet_llama.json', 'w') as f:
    for synset, lemmas in large_nouns.items():
        llama_words = []
        for w in lemmas:
            llama_words.extend(_noun_to_gemma_vocab_elements(w))

        f.write(json.dumps({synset: llama_words}) + "\n")
        
nx.write_adjlist(G_noun, "data/noun_synsets_wordnet_hypernym_graph_llama.adjlist")

In [10]:
cats = {}
with open('data/noun_synsets_wordnet_llama.json', 'r') as f:
    for line in f:
        cats.update(json.loads(line))
G = nx.read_adjlist("data/noun_synsets_wordnet_hypernym_graph_llama.adjlist", create_using=nx.DiGraph())

cats = {k: list(set(v)) for k, v in cats.items() if len(set(v)) > 50}
G = nx.DiGraph(G.subgraph(cats.keys()))

reversed_nodes = list(reversed(list(nx.topological_sort(G))))
for node in reversed_nodes:
    children = list(G.successors(node))
    if len(children) == 1:
        child = children[0]
        parent_lemmas_not_in_child = set(cats[node]) - set(cats[child])
        if len(list(G.predecessors(child))) == 1 or len(parent_lemmas_not_in_child) <5:
            grandchildren = list(G.successors(child))
            for grandchild in grandchildren:
                G.add_edge(node, grandchild)
            G.remove_node(child)

G = nx.DiGraph(G.subgraph(cats.keys()))
sorted_keys = list(nx.topological_sort(G))
cats = {k: cats[k] for k in sorted_keys}

## verbs

In [11]:
all_verb_synsets = list(wn.all_synsets(pos=wn.VERB))

G_verb = nx.DiGraph()
verb_leaves = [s.name() for s in all_verb_synsets if len(s.hyponyms()) == 0]
for leaf in verb_leaves:
    for path in wn.synset(leaf).hypernym_paths():
        for i in range(len(path) - 1):
            G_verb.add_edge(path[i].name(), path[i+1].name())

verb_lemmas = {}
sorted_verb_synsets = list(reversed(list(nx.topological_sort(G_verb))))
for s in sorted_verb_synsets:
    # add and remove "▁" bc of how gemma vocab works
    verb_lemmas[s] = vocab_set.intersection(["Ġ" + lemma.name() for lemma in wn.synset(s).lemmas()])
    verb_lemmas[s] = {lemma[1:] for lemma in verb_lemmas[s]}
    for child in G_verb.successors(s):
        verb_lemmas[s].update(verb_lemmas[child])

large_verbs = {k: v for k, v in verb_lemmas.items() if len(v) > 5}

In [None]:
G_verb = nx.DiGraph(G_verb.subgraph(list(large_verbs.keys())))
merge_nodes(G_verb, verb_lemmas)
large_verbs = {k: v for k, v in large_verbs.items() if k in G_verb.nodes()}

In [13]:
def _verb_to_gemma_vocab_elements(verb):
    #  add in regular verb conjugations to expand vocab
 
    verb = verb.lower()

    # some weird wordnet bugs, filter it
    if len(verb) < 3 and verb not in ["be", "do", "go"]:
        return set()

    # Conjugating to past tense
    if verb.endswith('e'):
        past = verb + 'd'
    else:
        past = verb + 'ed'
    
    # Conjugating to present participle
    if verb.endswith('e'):
        present_participle = verb[:-1] + 'ing'
    else:
        present_participle = verb + 'ing'
    
    # 3psg 
    if verb.endswith(("o", "ch", "s", "sh", "x", "z")):
        third_person = verb + 'es'
    elif verb.endswith("y") and verb[-2] not in "aeiou":
        third_person = verb[:-1] + 'ies'
    else:
        third_person = verb + 's'


    tenses = [verb, past, present_participle, third_person]

    caps = [w.capitalize() for w in tenses]
    
    # Add underscore prefix to each word to match vocab tokens
    add_space = ["Ġ" + w for w in (caps + tenses)]

    # Return intersection with a hypothetical vocabulary set
    return vocab_set.intersection(add_space)

with open('data/verb_synsets_wordnet_llama.json', 'w') as f:
    for synset, lemmas in large_verbs.items():
        gemma_words = []
        for w in lemmas:
            gemma_words.extend(_verb_to_gemma_vocab_elements(w))

        f.write(json.dumps({synset: gemma_words}) + "\n")
        
nx.write_adjlist(G_verb, "data/verb_synsets_wordnet_hypernym_graph_llama.adjlist")

In [14]:
cats = {}
with open('data/verb_synsets_wordnet_llama.json', 'r') as f:
    for line in f:
        cats.update(json.loads(line))
G = nx.read_adjlist("data/verb_synsets_wordnet_hypernym_graph_llama.adjlist", create_using=nx.DiGraph())

cats = {k: list(set(v)) for k, v in cats.items() if len(set(v)) > 50}
G = nx.DiGraph(G.subgraph(cats.keys()))

reversed_nodes = list(reversed(list(nx.topological_sort(G))))
for node in reversed_nodes:
    children = list(G.successors(node))
    if len(children) == 1:
        child = children[0]
        parent_lemmas_not_in_child = set(cats[node]) - set(cats[child])
        if len(list(G.predecessors(child))) == 1 or len(parent_lemmas_not_in_child) <5:
            grandchildren = list(G.successors(child))
            for grandchild in grandchildren:
                G.add_edge(node, grandchild)
            G.remove_node(child)

G = nx.DiGraph(G.subgraph(cats.keys()))
sorted_keys = list(nx.topological_sort(G))
cats = {k: cats[k] for k in sorted_keys}