##### Implementation of two baseline algorithms for Word Sense Disambiguation -- the most common sense and the plain lesk

In [None]:
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')

from nltk.corpus import wordnet as wn
from loader import *
from dict_utilities import *

import numpy as np
from numpy.linalg import norm

import gensim.downloader as api

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


#### The most common sense algorithm for WSD -- The sense which comes up first in the list of senses.

In [None]:
SEMCOR_DATA_FILE = './semcor.data.xml'
SEMCOR_LABELLED = './semcor.gold.key.txt'
SENSEVAL_2_DATA_FILE = './senseval2.data.xml'
SENSEVAL_2_LABELLED = './senseval2.gold.key.txt'
SENSEVAL_3_DATA_FILE = './senseval3.data.xml'
SENSEVAL_3_LABELLED = './senseval3.gold.key.txt'

In [None]:
semcor_lemmas = load_instances(SEMCOR_DATA_FILE)
senseval_2_lemmas = load_instances(SENSEVAL_2_DATA_FILE)
senseval_3_lemmas = load_instances(SENSEVAL_3_DATA_FILE)

In [None]:
def most_common_sense(lemma):
  '''
  Returns the first sense of the input lemma
  '''
  ## Get the synsets of the lemma and synset keys
  
  all_synset_keys = []
  for synset in wn.synsets(lemma):
    this_synset = []
    for key in synset.lemmas():
        this_synset.append(key.key())
    all_synset_keys.append(this_synset)

  ## Return the first key in the list
  #print("The predicted sense of the lemma {} is: {}".format(lemma, all_synset_keys[0]))

  return all_synset_keys[0]

#### Run the algorithm on labelled lemmas and find the accuracy

In [None]:
def get_labels(LABEL_FILE):
  """
  Reads the labels/annotations of the lemmas and returns in dictionary form
  """
  labels = {}
  for line in open(LABEL_FILE):
    if len(line) <= 1: continue
    lemma_id_label = line.strip().split(" ")

    labels[lemma_id_label[0]] = lemma_id_label[1:]

  return labels

In [None]:
def eval_common_sense(lemmas, labels):
  """
  Finds the accuracy of the common sense algorithms on one of the given datasets:
  semcor, senseval2, senseval3
  """
  correct_count = 0
  total = len(labels)

  for lemma_id, label in labels.items():
    pred_label = most_common_sense(lemmas[lemma_id].lemma)
    correct_label = labels[lemma_id][0]

    for prediction in pred_label:
      if correct_label == prediction:
        correct_count += 1
        break

  return (correct_count/total)*100

In [None]:
## Read the semcor labelled lemmas
semcor_labels = get_labels(SEMCOR_LABELLED)
senseval_2_labels = get_labels(SENSEVAL_2_LABELLED)
senseval_3_labels = get_labels(SENSEVAL_3_LABELLED)

In [None]:
## Evaluate most common sense algorithm
print("Accuracy on Semcor dataset:", eval_common_sense(semcor_lemmas, semcor_labels))
print("Accuracy on Senseval2 dataset:", eval_common_sense(senseval_2_lemmas, senseval_2_labels))
print("Accuracy on Senseval3 dataset:", eval_common_sense(senseval_3_lemmas, senseval_3_labels))

Accuracy on Semcor dataset: 48.674105009821446
Accuracy on Senseval2 dataset: 49.25503943908852
Accuracy on Senseval3 dataset: 47.945945945945944


#### Implementation of the plain lesk algorithm -- the one that finds the overlap of lemmas between context and the gloss.

In [None]:
def plain_lesk(lemma, context):
  """
  Implementation of the plain lesk algorithm. Return the predicted sense
  for the input lemma in the given context.
  """
  ## Get the glosses of the lemma from the wordnet

  synsets = []
  for synset in wn.synsets(lemma):
    synsets.append(synset)

  overlap_count = 0
  predicted_synset = synsets[0] ## Just initialize it, updated in this for loop
  for synset in synsets:
    ## Find the overlap between the gloss and context
    overlap = list(set(synset.definition()).intersection(set(context)))
    if len(overlap)> overlap_count:
      ## Save the synset which has the most overlap till now
      predicted_synset = synset
      overlap_count = len(overlap)

  ## Get the sense keys of the synset which has the most overlaps 
  predicted_keys = []
  for key in predicted_synset.lemmas():
    predicted_keys.append(key.key())

  return predicted_keys

In [None]:
def eval_plain_lesk(lemmas, labels):
  """
  Evaluation of the plain lesk algorithm
  """
  correct_count = 0
  total = len(labels)
  for lemma_id, label in labels.items():
    pred_label = plain_lesk(lemmas[lemma_id].lemma, lemmas[lemma_id].context)

    correct_label = labels[lemma_id][0]

    for prediction in pred_label:
      if correct_label == prediction:
        correct_count += 1
        break

  return (correct_count/total)*100

In [None]:
## Evaluate the plain lesk algorithm
print("Accuracy on Semcor dataset:", eval_plain_lesk(semcor_lemmas, semcor_labels))
print("Accuracy on Senseval2 dataset:", eval_plain_lesk(senseval_2_lemmas, senseval_2_labels))
print("Accuracy on Senseval3 dataset:", eval_plain_lesk(senseval_3_lemmas, senseval_3_labels))

Accuracy on Semcor dataset: 44.2283530057159
Accuracy on Senseval2 dataset: 44.8729184925504
Accuracy on Senseval3 dataset: 43.027027027027025


### Accuracies of baselines on all three datasets summarized here:

| Algorithm      | Semcor | Senseval2 | Senseval3
| ----------- | ----------- |----------- |----------- |
| Most Common Sense      | 48.67       | 49.25 | 47.94
| The Plain Lesk   | 44.23       | 44.87 | 43.03