# Overlap based WSD using Lesk's Algorithm with Word2Vec Embeddings

## Install Dependencies
Run this section only once

In [1]:
# Run this only once
import nltk
nltk.download("punkt")
nltk.download("wordnet")
nltk.download("semcor") # downloads the .zip file, but doesn't unzip it
nltk.download("stopwords")
nltk.download("averaged_perceptron_tagger")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.
[nltk_data] Downloading package semcor to /root/nltk_data...
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


True

In [2]:
# Unzip SemCor 3.0
! unzip -q /root/nltk_data/corpora/semcor.zip -d /root/nltk_data/corpora # after this, data in /root/nltk_data/corpora/semcor

In [3]:
# Download pre-trained word2vec embeddings
! wget -c "https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz" # 1.53 GB

--2021-09-08 11:08:16--  https://s3.amazonaws.com/dl4j-distribution/GoogleNews-vectors-negative300.bin.gz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.140.46
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.140.46|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1647046227 (1.5G) [application/x-gzip]
Saving to: ‘GoogleNews-vectors-negative300.bin.gz’


2021-09-08 11:08:53 (42.6 MB/s) - ‘GoogleNews-vectors-negative300.bin.gz’ saved [1647046227/1647046227]



In [4]:
! pip install num2words

Collecting num2words
  Downloading num2words-0.5.10-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 2.2 MB/s 
Installing collected packages: num2words
Successfully installed num2words-0.5.10


In [5]:
# Import w2v here itself as it takes time to load
from gensim.models import KeyedVectors
W2V = KeyedVectors.load_word2vec_format("GoogleNews-vectors-negative300.bin.gz", binary = True)

## Start
To re-run, run from this cell

In [6]:
%reset_selective -f ^(?!W2V).*$ # clear everything except W2V

## Imports

In [7]:
import nltk
from nltk import word_tokenize
from nltk.corpus import semcor # corpus reader: https://github.com/nltk/nltk/blob/develop/nltk/corpus/reader/semcor.py
from nltk.corpus import stopwords
from nltk.corpus import wordnet as wn
from nltk.stem import WordNetLemmatizer

import random
import numpy as np
from tqdm.notebook import tqdm
from string import punctuation
from num2words import num2words
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

## Constants

In [8]:
# Custom stopwords
EXTRA_SW = [
    "''",
    "'s",
    "``"
]

SW = stopwords.words("english")
SW += [p for p in punctuation]
SW += EXTRA_SW

In [9]:
lemmatizer = WordNetLemmatizer()

## Functions

In [10]:
def cosineSimilarity(a, b):
    cs = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    return cs

In [11]:
def isNumber(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

In [12]:
def n2w(w):
    # converts given n to word form if n is numeric
    if isNumber(w) and w.lower() != "infinity" and w.lower() != "nan":
        w = num2words(w)
    return w

In [13]:
def lemmatize(w, tag):
    if tag is None:
        return lemmatizer.lemmatize(w)
    else:
        return lemmatizer.lemmatize(w, tag)

In [14]:
def clean(tokens):
    tagged = nltk.pos_tag(tokens)
    lemmatized = [lemmatize(w, treebank2wn(tag)) for w, tag in tagged]
    cleaned = [n2w(w) for w in lemmatized if w.lower() not in SW]
    return cleaned

In [15]:
def getVec(w):
    # Returns (300,) shaped numpy array
    try:
        v = W2V[w]
        return v
    except KeyError:
        return None # ignore words not in vocab

In [16]:
def syn2sense(syn):
    # get the sense (= lemma.postag.num) for a given synset
    s = syn.name()
    # n = ".".join(s.split(".")[-2:]) # n.01 and v.01 are different senses (eg: ash.n.01, ash.v.01)
    return s

In [17]:
def treebank2wn(ttag):
    if ttag.startswith("J"):
        return wn.ADJ
    elif ttag.startswith("V"):
        return wn.VERB
    elif ttag.startswith("N"):
        return wn.NOUN
    elif ttag.startswith("R"):
        return wn.ADV
    else:
        return None

In [18]:
def sent2vec(tokens):

    v = 0
    n = 0

    for w in tokens:

        # Check if w is a named entity (TODO: use wordnet NE tag directly instead of below approach)
        tkns = word_tokenize(w)

        if len(tkns) > 1:
            for t in tkns:
                vt = getVec(t)
                if vt is not None:
                    n += 1
                    v += vt
        else:
            vw = getVec(w)
            if vw is not None:
                n += 1
                v += vw

    if n == 0: # when tokens is empty or no token in word2vec
        v = None
    else:
        v /= n

    return v

In [19]:
def parse(d):
    # d (nltk.corpus.reader.semcor.SemcorSentence) : can have lists as elements or nltk.tree.Tree

    tokens = []
    senses = []

    for e in d:

        if isinstance(e, nltk.tree.Tree):

            # e.label() returns a nltk.corpus.reader.wordnet.Lemma object or simply a string (of the form word.pos.num)
            lemma = e.label()
            
            if isinstance(lemma, nltk.corpus.reader.wordnet.Lemma):
                synset = lemma.synset() # nltk.corpus.reader.wordnet.Synset
                sense = syn2sense(synset)
            else:
                sense = None # ignore all tagged senses which aren't in WN (i.e. present as Lemma)
            
            le = len(e)
            if le == 1:
                w = e[0]
                if isinstance(w, nltk.tree.Tree) or isinstance(w, list):
                    # ignore w.label()
                    lw = len(w)
                    w = " ".join([w[i] for i in range(lw)])
            else:
                w = " ".join([e[i] for i in range(le)])

        elif isinstance(e, list):
            w = e[0]
            sense = None

        else:
            invtype = type(e)
            raise TypeError(f"Invalid type: {invtype}")

        if w:
            tokens.append(w)
            senses.append(sense)

    return tokens, senses

In [20]:
def getCandidates(w, tag):
    # Get candidate sense vectors and labels of a word w

    w = w.replace(".", "") # "Sept." becomes "Sept"
    w = w.replace("-", "") # re-elected becomes "relected"

    # Handle ngrams (like "united states")
    tkns = word_tokenize(w)
    if len(tkns) > 1:
        tagged = nltk.pos_tag(tkns)
        tags = [treebank2wn(p[1]) for p in tagged]
        ltkns = [lemmatize(w, t) for w, t in zip(tkns, tags)]
        w = "_".join(ltkns)

    syns = wn.synsets(w, tag)

    if len(syns) == 0:
        w = "_".join(tkns) # cases where lemmatization doesn't help ("agreed upon")
        syns = wn.synsets(w, tag)

    sense_vectors = []
    sense_labels = []

    for syn in syns:

        label = syn2sense(syn)

        defn = syn.definition() # TODO: Implement the extended Lesk algorithm that uses related synsets as well

        defn = defn.replace("_", " ")
        defn = defn.replace("-", " ")

        tkns = word_tokenize(defn)
        if len(tkns) == 0:
            raise ValueError(f"0 tokens found: {defn}")

        clnd = clean(tkns)
        if len(clnd) < 2:
            clnd = tkns # don't remove stopwords if the sentence is almost entirely made up of them

        sv = sent2vec(clnd)

        if sv is None:
            print(f"Empty sense vector. Word: {w}, Definition: {defn}, Cleaned: {clnd}. Using a random vector as sense.")
            sv = np.random.rand(300,)
        
        sense_vectors.append(sv)
        sense_labels.append(label)

    return sense_vectors, sense_labels # returns empty lists if no synsets found

## Main

In [21]:
data = semcor.tagged_sents(tag = "sem") # 37176 sentences, 224716 tagged words, 34189 unique senses 

In [22]:
n_total = 0
n_correct = 0
n_samples = 0

true = []
pred = []

for d in data:

    try:

        tokens, senses = parse(d)
        n_tokens = len(tokens)

        # Tag and lemmatize tokens, don't remove stopwords here
        tagged = nltk.pos_tag(tokens)
        tags = [treebank2wn(p[1]) for p in tagged]
        tokens = [lemmatize(w, tag) for w, tag in zip(tokens, tags)]

        for i in range(n_tokens):

            w = tokens[i]
            tag = tags[i]
            s_true = senses[i]

            if not isinstance(w, str):
                raise TypeError(f"Invalid type: {type(w)} : {w} : {tokens}")

            # Don't predict for words that aren't sense-tagged
            if s_true is None:
                continue

            # Get context for w (all words in the sentence except w)
            context = tokens.copy()
            del context[i] # more efficient than .pop(i)

            # Remove stopwords and punctuation from context to reduce #elements in the context
            # These don't contribute much to the semantic overlap anyways
            cleaned = clean(context)
            if len(cleaned) < 2:
                cleaned = context # if almost all words are stopwords, don't remove any

            # Get context vector by average w2v vectors for each word
            cv = sent2vec(cleaned)

            if cv is None:
                print(f"Empty context vector. Word: {w}, Cleaned: {cleaned}, Tokens: {tokens}. Using a random vector as context.")
                cv = np.random.rand(300,)

            # Get WordNet candidate senses
            sense_vectors, sense_labels = getCandidates(w, tag)
            n_candidates = len(sense_labels)

            s_pred = None
            if n_candidates == 0:
                # Try without pos tag
                sense_vectors, sense_labels = getCandidates(w, None)
                n_candidates = len(sense_labels)
                if n_candidates == 0:
                    # print(f"No synsets found. Word: {w}, Sense: {s_true}") # don't print, too many NE's in the data
                    s_pred = random.choice(["group.n.01", "person.n.01", "location.n.01"]) # most likely an NE
            
            # Use cosine similarity to get the best senses
            best = -1 
            for j in range(n_candidates):
                sv = sense_vectors[j]
                cs = cosineSimilarity(cv, sv)
                if cs > best:
                    best = cs
                    s_pred = sense_labels[j]

            if s_true == s_pred:
                n_correct += 1
            n_total += 1

            true.append(s_true)
            pred.append(s_pred)

    except Exception as e:
        print(f"Error at: {n_samples}")
        print(str(e))
        raise ValueError("Error")

    n_samples += 1

    if n_samples%200 == 0:
        print(f"{n_samples} sentences processed")
        acc = (n_correct/n_total)*100
        print(f"Accuracy: {acc:.4f}")
        print()

200 sentences processed
Accuracy: 44.2681

400 sentences processed
Accuracy: 41.6832

600 sentences processed
Accuracy: 40.8078

800 sentences processed
Accuracy: 39.9124

1000 sentences processed
Accuracy: 40.2434

1200 sentences processed
Accuracy: 39.9983

1400 sentences processed
Accuracy: 40.0856

1600 sentences processed
Accuracy: 40.5734

1800 sentences processed
Accuracy: 40.5844

2000 sentences processed
Accuracy: 40.5544

2200 sentences processed
Accuracy: 40.5462

Empty context vector. Word: Cancer, Cleaned: ['``', "''", '!'], Tokens: ['``', 'Cancer', "''", '!']. Using a random vector as context.
Empty context vector. Word: By no means, Cleaned: ['.'], Tokens: ['By no means', '.']. Using a random vector as context.
Empty context vector. Word: For instance, Cleaned: [':'], Tokens: ['For instance', ':']. Using a random vector as context.
Empty context vector. Word: Death, Cleaned: ['!'], Tokens: ['Death', '!']. Using a random vector as context.
2400 sentences processed
Accurac

In [23]:
pred_sense_set = set(pred)
true_sense_set = set(true)
all_senses = sorted(list(true_sense_set.union(pred_sense_set)))
not_predicted = true_sense_set - pred_sense_set
extra_predicted = pred_sense_set - true_sense_set

In [30]:
acc = accuracy_score(true, pred)
prec = precision_score(true, pred, average = "macro")
rec = recall_score(true, pred, average = "macro")
f1 = f1_score(true, pred, average = "macro")

print(f"Accuracy: {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall: {rec:.4f}")
print(f"F1-Score: {f1:.4f}")

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3786
Precision: 0.4136
Recall: 0.4009
F1-Score: 0.3776


In [25]:
def predict(sent):

    senses = []
    tokens = word_tokenize(sent)
    # Tag and lemmatize tokens, don't remove stopwords here
    tagged = nltk.pos_tag(tokens)
    tags = [treebank2wn(p[1]) for p in tagged]
    tokens = [lemmatize(w, tag) for w, tag in zip(tokens, tags)]
    n_tokens = len(tokens)

    for i in range(n_tokens):

        w = tokens[i]
        tag = tags[i]

        # Get context for w (all words in the sentence except w)
        context = tokens.copy()
        del context[i] # more efficient than .pop(i)

        # Get context vector by average w2v vectors for each word
        cv = sent2vec(context)

        if cv is None:
            print(f"Empty context vector. Word: {w}, Tokens: {tokens}. Using a random vector as context.")
            cv = np.random.rand(300,)

        # Get WordNet candidate senses
        sense_vectors, sense_labels = getCandidates(w, tag)
        n_candidates = len(sense_labels)

        s_pred = None
        if n_candidates == 0:
            # Try without pos tag
            sense_vectors, sense_labels = getCandidates(w, None)
            n_candidates = len(sense_labels)
            if n_candidates == 0:
                print(f"No synsets found: {w}")
                s_pred = None

        # Use cosine similarity to get the best senses
        best = -1 
        for j in range(n_candidates):
            sv = sense_vectors[j]
            cs = cosineSimilarity(cv, sv)
            if cs > best:
                best = cs
                s_pred = sense_labels[j]

        senses.append(s_pred)

    return senses

In [26]:
sents = [
    "On combustion of coal we get ash", 
    "The bank is located in the city near the river",
    "The stolen credit cards were found near the river bank",
    "The user had to kill the computer process",
    "The trees near Nuclear Power Plant were cut down"
]

for sent in sents:
    senses = predict(sent)
    for s in senses:
        if s is not None:
            print(s, ":", wn.synset(s).definition())
    print()

No synsets found: of
No synsets found: we
on.r.03 : in a state required for something to function or be effective
combustion.n.01 : a process in which a substance reacts with oxygen to give heat and light
ember.n.01 : a hot fragment of wood or coal that is left from a fire and is glowing or smoldering
get.v.01 : come into the possession of something concrete or abstract
ash.n.01 : the residue that remains when something is burned

No synsets found: The
No synsets found: the
No synsets found: the
bank.n.07 : a slope in the turn of a road or track; the outside is higher than the inside in order to reduce the effects of centrifugal force
be.v.03 : occupy a certain position or area; be somewhere
situate.v.01 : determine or indicate the place, site, or limits of, as if by an instrument or by a survey
in.r.01 : to or toward the inside of
city.n.01 : a large and densely populated urban area; may include several independent administrative districts
near.r.01 : near in time or place or relation

## References
- [NLTK Trees](https://stackoverflow.com/questions/62472606/get-the-type-of-a-nltk-tree)
- [NLTK WordNet Lemma](https://www.nltk.org/_modules/nltk/corpus/reader/wordnet.html)  
`Lemma` attributes, accessible via methods with the same name:
    - name: The canonical name of this lemma.
    - synset: The synset that this lemma belongs to.
    - syntactic_marker: For adjectives, the WordNet string identifying the
        syntactic position relative modified noun. See:
        https://wordnet.princeton.edu/documentation/wninput5wn
        For all other parts of speech, this attribute is None.
    - count: The frequency of this lemma in wordnet.