In [None]:
import nltk
import random
nltk.download('senseval')
from nltk.corpus import senseval

[nltk_data] Downloading package senseval to /root/nltk_data...
[nltk_data]   Package senseval is already up-to-date!


In [None]:
def senses_func(word):
    """
    This takes a target word from senseval-2 (find out what the possible
    are by running senseval.fileides()), and it returns the list of possible 
    senses for the word
    """
    return list(set(i.senses[0] for i in senseval.instances(word)))

In [None]:
from collections import defaultdict
from nltk.classify import accuracy
nltk.download('stopwords')
stop_words = nltk.corpus.stopwords.words('english')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [None]:
import string

In [None]:
# Some helper functions we'll need to train our model

def extract_vocab_frequency(instances, target, stop_words=stop_words, vocab_size=300):
    """
    Given a list of senseval instances, return a list of the n most frequent words that
    appears in its context (i.e., the sentence with the target word in), output is in order
    of frequency and includes also the number of instances in which that key appears in the
    context of instances.
    """
    fd = nltk.FreqDist()
    #print(fd)
    for i in instances:
        (target, suffix) = i.word.split('-')               #suffix = pos, target = target word
        words = (c[0] for c in i.context if not c[0] == target)
        for word in (set(words) - set(stop_words)) - set(string.punctuation):
            fd[word] += 1
    #print(fd)
    return fd.most_common()[:vocab_size+1]

In [None]:
def extract_vocab(instances, target, stop_words=stop_words, vocab_size=300):
    return [w for w, f in extract_vocab_frequency(instances,target,stop_words,vocab_size)]

In [None]:
# Feature extraction

def wsd_context_features(instance, window_size=3):
    features = {}
    ind = instance.position
    con = instance.context
    for i in range(max(0, ind-window_size), ind):
        j = ind-i
        features[con[i][0]] = True

    for i in range(ind+1, min(ind+window_size+1, len(con))):
        j = i-ind
        features[con[i][0]] = True

    return features

In [None]:
def wsd_word_features(instance, vocab):
    """
    Create a feature set where every key returns False unless it occurs in the
    instance's context
    """
    features = defaultdict(lambda:False)
    #features['alwayson'] = True
    #cur_words = [w for (w, pos) in i.context]
    try:
      # 
        for (w, pos) in instance.context:
            if w in vocab:                            #vocab == most frequent words
                features[w] = True
    except ValueError:
        pass
    return features

In [None]:
def sense_instances(instances, sense):
    """
    This returns the list of instances in instances that have the same sense
    """
    return [instance for instance in instances if instance.senses[0]==sense]

In [None]:
def theta_k(train_corpus, context, sense, vocab, window_size):
  instances = sense_instances(train_corpus, sense)
  features = wsd_context_features(context, window_size=window_size)

  nr_occ = []
  for feature in features.keys():
    nr_occ_single_f = 0
    for instance in instances:
      context = instance.context
      for tuple_ in context:
        if len(tuple_) == 2 and tuple_[0] == feature:
          nr_occ_single_f += 1
    nr_occ.append(nr_occ_single_f)

  nr_occ_total = np.sum(nr_occ)

  nr_occ = np.array(nr_occ) / nr_occ_total  

  return np.prod(nr_occ)

In [None]:
def alpha_k(corpus, target_word, sense):
  nr_occ_sense = 0
  nr_occ_w = 0
  for instance in corpus:
    if instance.senses[0] == sense:
      nr_occ_sense += 1
  return nr_occ_sense/len(corpus)

In [None]:
import numpy as np

In [None]:
_inst_cache = {}

In [None]:
def extract_data(target_word, stop_words=stop_words, vocab_size=300):
  corpus = senseval.instances(target_word+".pos")
  senses = senses_func(target_word+".pos")
  vocab = extract_vocab(corpus, target_word, stop_words=stop_words, vocab_size=vocab_size)   #most frequent words in the context
  return corpus, senses, vocab

In [None]:
def data_split(target_word):
    print("Reading data...")
    global _inst_cache

    if target_word not in _inst_cache:
        _inst_cache[target_word] = [(i, i.senses[0]) for i in senseval.instances(target_word+".pos")]
    events = _inst_cache[target_word][:]

    # Split the instances into a training and test set
    n = len(events)
    random.seed(334)
    random.shuffle(events)
    training_data = events[:int(0.5 * n)]
    print("Training data size: ", len(training_data))
    validation_data = events[int(0.5 * n):int(0.85 * n)]
    print("Validation data size: ", len(validation_data))
    test_data = events[int(0.85 * n):n]
    print("Test data size: ", len(test_data))

    return training_data, validation_data, test_data

In [None]:
def wsd_classifier(corpus, target_word, senses, vocab, data, window_size, example_errors=False):
    #probabilities of each sense in the corpus:
    prob_senses = []
    for sense in senses:
      prob_sense = alpha_k(corpus, target_word, sense)
      prob_senses.append(prob_sense)

    
    probs = []
    for i in range(len(data)):
      probs_aux = []
      for j, sense in enumerate(senses):
        theta = theta_k(list(zip(*data))[0], data[i][0], sense, vocab, window_size=window_size)
        probs_aux.append(np.log(prob_senses[j]*theta))
      probs.append(probs_aux)

    predictions = []
    for probs_aux in probs:
      predictions.append(np.argmax(probs_aux))

    
    #Accuracy:
    correct = 0
    for i, pred in enumerate(predictions):
      if pred == senses.index(data[i][1]):
        correct += 1

    accuracy = correct / len(predictions) 

    print("Accuracy: {}".format(accuracy))
    
    gold = [data[i][1] for i in range(len(data))]
    derived = [senses[j] for j in predictions]
    cm = nltk.ConfusionMatrix(gold, derived)
    print(cm) 

    k = 0
    if example_errors:
      for i, der in enumerate(derived):
        if k < 5:
          if der != data[i][1]:
            print("Context: {}".format(data[i][0].context))
            print("Correct sense: {}".format(data[i][1]))
            print("Predicted sense: {}".format(derived[i]))
            print("\n")
            k += 1
        else:
          break
    return accuracy  

In [None]:
def test_words(target_word, stage, vocab_size, window_size, example_errors=False):
  train_data, validation_data, test_data = data_split(target_word)
  corpus, senses, vocab = extract_data(target_word, vocab_size=vocab_size)

  if stage == "train":
    wsd_classifier(corpus, target_word, senses, vocab, train_data, window_size, example_errors)
  elif stage == "validation":
    wsd_classifier(corpus, target_word, senses, vocab, validation_data, window_size, example_errors)
  else:
    wsd_classifier(corpus, target_word, senses, vocab, test_data, window_size, example_errors=True)

Experimenting with window size and vocabulary size:

In [None]:
test_words("line", "validation", vocab_size=300, window_size=3)

Reading data...
Training data size:  2073
Validation data size:  1451
Test data size:  622


  


Accuracy: 0.9131633356305996
          |           f             |
          |       d   o             |
          |       i   r       p     |
          |       v   m       r     |
          |       i   a   p   o     |
          |   c   s   t   h   d   t |
          |   o   i   i   o   u   e |
          |   r   o   o   n   c   x |
          |   d   n   n   e   t   t |
----------+-------------------------+
     cord |<118>  .   .   1   9   1 |
 division |   1<125>  .   .   1   . |
formation |   1   2 <97>  5   6   3 |
    phone |   4   1   2<116> 18   5 |
  product |   2   3  23   5<752>  6 |
     text |   1   .   3   3  20<117>|
----------+-------------------------+
(row = reference; col = test)



In [None]:
test_words("line", "validation", vocab_size=300, window_size=7)

Reading data...
Training data size:  2073
Validation data size:  1451
Test data size:  622


  


Accuracy: 0.9703652653342523
          |           f             |
          |       d   o             |
          |       i   r       p     |
          |       v   m       r     |
          |       i   a   p   o     |
          |   c   s   t   h   d   t |
          |   o   i   i   o   u   e |
          |   r   o   o   n   c   x |
          |   d   n   n   e   t   t |
----------+-------------------------+
     cord |<129>  .   .   .   .   . |
 division |   .<126>  1   .   .   . |
formation |   .   1<113>  .   .   . |
    phone |   1   .   .<141>  1   3 |
  product |   .   1  29   2<758>  1 |
     text |   .   2   .   .   1<141>|
----------+-------------------------+
(row = reference; col = test)



In [None]:
test_words("line", "validation", vocab_size=300, window_size=10)

Reading data...
Training data size:  2073
Validation data size:  1451
Test data size:  622


  


Accuracy: 0.9710544452101999
          |           f             |
          |       d   o             |
          |       i   r       p     |
          |       v   m       r     |
          |       i   a   p   o     |
          |   c   s   t   h   d   t |
          |   o   i   i   o   u   e |
          |   r   o   o   n   c   x |
          |   d   n   n   e   t   t |
----------+-------------------------+
     cord |<129>  .   .   .   .   . |
 division |   .<126>  1   .   .   . |
formation |   .   .<114>  .   .   . |
    phone |   1   .   1<143>  .   1 |
  product |   .   1  35   1<754>  . |
     text |   .   .   .   .   1<143>|
----------+-------------------------+
(row = reference; col = test)



In [None]:
test_words("line", "validation", vocab_size=250, window_size=10)

Reading data...
Training data size:  2073
Validation data size:  1451
Test data size:  622


  


Accuracy: 0.9710544452101999
          |           f             |
          |       d   o             |
          |       i   r       p     |
          |       v   m       r     |
          |       i   a   p   o     |
          |   c   s   t   h   d   t |
          |   o   i   i   o   u   e |
          |   r   o   o   n   c   x |
          |   d   n   n   e   t   t |
----------+-------------------------+
     cord |<129>  .   .   .   .   . |
 division |   .<126>  1   .   .   . |
formation |   .   .<114>  .   .   . |
    phone |   1   .   1<143>  .   1 |
  product |   .   1  35   1<754>  . |
     text |   .   .   .   .   1<143>|
----------+-------------------------+
(row = reference; col = test)



In [None]:
test_words("hard", "validation", vocab_size=300, window_size=3)

Reading data...
Training data size:  2166
Validation data size:  1517
Test data size:  650


  


Accuracy: 0.930784442979565
      |    H    H    H |
      |    A    A    A |
      |    R    R    R |
      |    D    D    D |
      |    1    2    3 |
------+----------------+
HARD1 |<1149>  45   27 |
HARD2 |   27 <140>   . |
HARD3 |    6    . <123>|
------+----------------+
(row = reference; col = test)



In [None]:
test_words("hard", "validation", vocab_size=300, window_size=7)

Reading data...
Training data size:  2166
Validation data size:  1517
Test data size:  650


  


Accuracy: 0.977587343441002
      |    H    H    H |
      |    A    A    A |
      |    R    R    R |
      |    D    D    D |
      |    1    2    3 |
------+----------------+
HARD1 |<1194>  20    7 |
HARD2 |    6 <161>   . |
HARD3 |    1    . <128>|
------+----------------+
(row = reference; col = test)



In [None]:
test_words("hard", "validation", vocab_size=300, window_size=13)

Reading data...
Training data size:  2166
Validation data size:  1517
Test data size:  650


  


Accuracy: 0.995385629531971
      |    H    H    H |
      |    A    A    A |
      |    R    R    R |
      |    D    D    D |
      |    1    2    3 |
------+----------------+
HARD1 |<1216>   5    . |
HARD2 |    1 <166>   . |
HARD3 |    1    . <128>|
------+----------------+
(row = reference; col = test)



In [None]:
test_words("serve", "validation", vocab_size=300, window_size=3)

Reading data...
Training data size:  2189
Validation data size:  1532
Test data size:  657


  


Accuracy: 0.9445169712793734
        |   S   S         |
        |   E   E   S   S |
        |   R   R   E   E |
        |   V   V   R   R |
        |   E   E   V   V |
        |   1   1   E   E |
        |   0   2   2   6 |
--------+-----------------+
SERVE10 |<611>  2  11   1 |
SERVE12 |  12<432> 14   5 |
 SERVE2 |  14   7<263>  1 |
 SERVE6 |   3   8   7<141>|
--------+-----------------+
(row = reference; col = test)



In [None]:
test_words("serve", "validation", vocab_size=300, window_size=7)

Reading data...
Training data size:  2189
Validation data size:  1532
Test data size:  657


  


Accuracy: 0.9849869451697127
        |   S   S         |
        |   E   E   S   S |
        |   R   R   E   E |
        |   V   V   R   R |
        |   E   E   V   V |
        |   1   1   E   E |
        |   0   2   2   6 |
--------+-----------------+
SERVE10 |<617>  .   8   . |
SERVE12 |   2<453>  6   2 |
 SERVE2 |   1   .<284>  . |
 SERVE6 |   .   .   4<155>|
--------+-----------------+
(row = reference; col = test)



In [None]:
test_words("interest", "validation", vocab_size=300, window_size=3)

Reading data...
Training data size:  1184
Validation data size:  828
Test data size:  356


  


Accuracy: 0.9118357487922706
           |   i   i   i   i   i   i |
           |   n   n   n   n   n   n |
           |   t   t   t   t   t   t |
           |   e   e   e   e   e   e |
           |   r   r   r   r   r   r |
           |   e   e   e   e   e   e |
           |   s   s   s   s   s   s |
           |   t   t   t   t   t   t |
           |   _   _   _   _   _   _ |
           |   1   2   3   4   5   6 |
-----------+-------------------------+
interest_1 |<110>  4   .   .   2   7 |
interest_2 |   .  <4>  .   .   .   . |
interest_3 |   .   2 <17>  1   .   . |
interest_4 |   .   1   . <55>  .   2 |
interest_5 |   3   1   .   1<155> 10 |
interest_6 |   1  24   4   8   2<414>|
-----------+-------------------------+
(row = reference; col = test)



In [None]:
test_words("interest", "validation", vocab_size=300, window_size=7)

Reading data...
Training data size:  1184
Validation data size:  828
Test data size:  356


  


Accuracy: 0.9927536231884058
           |   i   i   i   i   i   i |
           |   n   n   n   n   n   n |
           |   t   t   t   t   t   t |
           |   e   e   e   e   e   e |
           |   r   r   r   r   r   r |
           |   e   e   e   e   e   e |
           |   s   s   s   s   s   s |
           |   t   t   t   t   t   t |
           |   _   _   _   _   _   _ |
           |   1   2   3   4   5   6 |
-----------+-------------------------+
interest_1 |<123>  .   .   .   .   . |
interest_2 |   .  <4>  .   .   .   . |
interest_3 |   .   . <20>  .   .   . |
interest_4 |   .   .   . <56>  .   2 |
interest_5 |   .   .   .   .<169>  1 |
interest_6 |   .   2   1   .   .<450>|
-----------+-------------------------+
(row = reference; col = test)



Testing:

In [None]:
test_words("line", "test", vocab_size=300, window_size=10)

Reading data...
Training data size:  2073
Validation data size:  1451
Test data size:  622


  


Accuracy: 0.9678456591639871
          |           f             |
          |       d   o             |
          |       i   r       p     |
          |       v   m       r     |
          |       i   a   p   o     |
          |   c   s   t   h   d   t |
          |   o   i   i   o   u   e |
          |   r   o   o   n   c   x |
          |   d   n   n   e   t   t |
----------+-------------------------+
     cord | <58>  .   .   .   .   . |
 division |   . <63>  .   .   .   . |
formation |   .   . <57>  .   .   . |
    phone |   1   .   1 <61>  .   . |
  product |   .   1  17   .<314>  . |
     text |   .   .   .   .   . <49>|
----------+-------------------------+
(row = reference; col = test)

Context: [('international', 'NNP'), ('business', 'NNP'), ('machines', 'NNP'), ('corp', 'NNP'), ('.', '.'), ('is', 'VBZ'), ('using', 'VBG'), ('the', 'DT'), ('huge', 'JJ'), ('computer', 'NN'), ('trade', 'NN'), ('show', 'NN'), ('here', 'RB'), ('this', 'DT'), ('week', 'NN'), ('to', 'TO'), ('try', 

In [None]:
test_words("hard", "test", vocab_size=300, window_size=13)

Reading data...
Training data size:  2166
Validation data size:  1517
Test data size:  650


  


Accuracy: 0.9984615384615385
      |   H   H   H |
      |   A   A   A |
      |   R   R   R |
      |   D   D   D |
      |   1   2   3 |
------+-------------+
HARD1 |<525>  .   . |
HARD2 |   1 <71>  . |
HARD3 |   .   . <53>|
------+-------------+
(row = reference; col = test)

Context: [('but', 'CC'), ('what', 'WP'), ('does', 'VBZ'), ('he', 'PRP'), ('get', 'VBP'), ('for', 'IN'), ('his', 'PRP$'), ('hard', 'JJ'), ('work', 'NN'), ('?', '.')]
Correct sense: HARD2
Predicted sense: HARD1




In [None]:
test_words("serve", "test", vocab_size=300, window_size=7)

Reading data...
Training data size:  2189
Validation data size:  1532
Test data size:  657


  


Accuracy: 0.9923896499238964
        |   S   S         |
        |   E   E   S   S |
        |   R   R   E   E |
        |   V   V   R   R |
        |   E   E   V   V |
        |   1   1   E   E |
        |   0   2   2   6 |
--------+-----------------+
SERVE10 |<288>  .   5   . |
SERVE12 |   .<177>  .   . |
 SERVE2 |   .   .<118>  . |
 SERVE6 |   .   .   . <69>|
--------+-----------------+
(row = reference; col = test)

Context: [('cover', 'NNP'), ('the', 'DT'), ('pot', 'NN'), ('again', 'RB'), (',', ','), ('turn', 'VB'), ('flame', 'NN'), ('down', 'RB'), ('to', 'TO'), ('low', 'JJ'), (',', ','), ('and', 'CC'), ('continue', 'VB'), ('to', 'TO'), ('simmer', 'VB'), ('for', 'IN'), ('1', 'CD'), 'FRASL', ('2', 'CD'), ('hour', 'NN'), ('more', 'JJR'), ('.', '.'), ('fourth', 'NNP'), (':', ':'), ('just', 'RB'), ('before', 'IN'), ('serving', 'VBG'), (',', ','), ('remove', 'VB'), ('1', 'CD'), 'FRASL', ('2', 'CD'), ('cup', 'NN'), ('broth', 'NN'), ('to', 'TO'), ('a', 'DT'), ('teacup', 'NN'), (',', ',')

In [None]:
test_words("interest", "test", vocab_size=300, window_size=7)

Reading data...
Training data size:  1184
Validation data size:  828
Test data size:  356


  


Accuracy: 0.9859550561797753
           |   i   i   i   i   i   i |
           |   n   n   n   n   n   n |
           |   t   t   t   t   t   t |
           |   e   e   e   e   e   e |
           |   r   r   r   r   r   r |
           |   e   e   e   e   e   e |
           |   s   s   s   s   s   s |
           |   t   t   t   t   t   t |
           |   _   _   _   _   _   _ |
           |   1   2   3   4   5   6 |
-----------+-------------------------+
interest_1 | <50>  .   .   .   .   1 |
interest_2 |   .  <3>  .   .   .   . |
interest_3 |   .   .  <7>  .   .   . |
interest_4 |   .   .   . <29>  .   . |
interest_5 |   .   1   .   . <88>  1 |
interest_6 |   .   2   .   .   .<174>|
-----------+-------------------------+
(row = reference; col = test)

Context: [('one', 'CD'), ('suit', 'NN'), (',', ','), ('filed', 'VBN'), ('by', 'IN'), ('more', 'JJR'), ('than', 'IN'), ('three', 'CD'), ('dozen', 'NN'), ('investors', 'NNS'), (',', ','), ('charges', 'VBZ'), ('that', 'DT'), ('mr', 'NN'), ('