## WSD using BERT Masked Language Model
This notebook explores the a part of the idea proposed by Ajit Rakasekharan in his blog post 
[Examining BERT raw embeddings.](https://towardsdatascience.com/examining-berts-raw-embeddings-fd905cb22df7) 

The idea is that examining the predictions of a masked language model for a masked ambiguous word can yield insights into the semantic meaning of the ambiguous word.

We use the HuggingFace BERT for Masked LM with weights from a bert-base-cased pre-trained model for our experiment.

We mask the ambiguous word (here we have used bank for our test) in sentences, and then send them through a BERT MLM model. Output is an array of logits for each position of the input sequence. So assuming a sentence with T tokens and a vocabulary size of V, the predictions of the MLM is (1, T, V) where 1 is the batch size (1 input sentence at a time in our experiment).

In order to find the top k predictions, the logits for the masked position is softmaxed and the top k values chosen.



## Prepare your environment

As always, we highly recommend that you install all packages with a virtual environment manager, like [venv](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/) or [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html), to prevent version conflicts of different packages.  

### Masked LM Model and Tokenizer 
[tutorial](https://huggingface.co/docs/transformers/tasks/language_modeling)  
Task is to predict words that are masked using BERT, so we will use BERTMaskedLM model and BERTTokenizer and use the pre-trained bert-base-uncased model.

In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 7.0 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 57.6 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 55.8 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 transformers-4.24.0


In [None]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForMaskedLM

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForMaskedLM.from_pretrained('bert-base-cased', return_dict=True)

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


We are going to use the pre-trained BERT language model in inference mode only.

The tokenizer tokenizes the input sequence and pads it with the [CLS] and [SEP] tokens.

The output produced by the model has two components, loss and logits. The logits component has shape (1, number_of_tokens, vocab_size) where the leading 1 represents the single input sentence.

We will identify the logits corresponding to the position of our masked token, identify the top 5 vocabulary words predicted for that position, and return the softmax probabilities for each of the top 5 predicted words.

In [None]:
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
outputs = model(**inputs)

In [None]:
tokenizer.convert_ids_to_tokens(inputs.input_ids[0])


['[CLS]', 'The', 'capital', 'of', 'France', 'is', '[MASK]', '.', '[SEP]']

In [None]:
outputs

MaskedLMOutput(loss=None, logits=tensor([[[ -7.1545,  -6.9931,  -7.1826,  ...,  -5.9124,  -5.6733,  -5.9854],
         [ -8.0190,  -8.1319,  -8.0509,  ...,  -6.5679,  -6.4058,  -6.8998],
         [ -4.9772,  -6.1781,  -6.0669,  ...,  -5.6362,  -4.6603,  -5.1241],
         ...,
         [ -3.4420,  -3.2557,  -3.5733,  ...,  -2.4606,  -2.6495,  -3.1952],
         [-10.5890, -10.4621, -11.7181,  ...,  -7.4646,  -9.9543,  -8.3927],
         [-14.8900, -14.8873, -14.4569,  ..., -11.6588, -13.0151, -11.6073]]],
       grad_fn=<ViewBackward0>), hidden_states=None, attentions=None)

In [None]:
def get_mask_index(input_ids, tokenizer):
    x = input_ids[0]
    is_masked = torch.where(x == tokenizer.mask_token_id, x, 0)
    mask_idx = torch.nonzero(is_masked)
    return mask_idx.item()

mask_idx = get_mask_index(inputs.input_ids, tokenizer)
mask_idx

6

In [None]:
def get_top_k_predictions(pred_logits, mask_idx, top_k):
    probs = torch.nn.functional.softmax(pred_logits[0, mask_idx, :], dim=-1)
    top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
    top_k_pct_weights = [100 * x.item() for x in top_k_weights]
    top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices)
    return list(zip(top_k_tokens, top_k_pct_weights))


get_top_k_predictions(outputs.logits, mask_idx, 5)

[('Paris', 44.46818828582764),
 ('Lyon', 9.396008402109146),
 ('Toulouse', 8.23453962802887),
 ('Lille', 7.515150308609009),
 ('Marseille', 5.692291632294655)]

### WSD Test Sentences
We take our pair of sentences for disambiguating the word bank and mask them, and extract the top 20 predictions from the pre-trained BERT MLM model.

As expected, the first set of predictions predominantly point to some sort of financial institution, whereas the second set of predictions predominantly point to some geographical formation around bodies of water.

In [None]:
sentences = [
  "Go to the [MASK] and deposit your pay check.",
  "Jim and Janet went down to the river [MASK] to admire the swans."
]

In [None]:
def get_predictions(sentence, tokenizer, model):
    inputs = tokenizer(sentence, return_tensors="pt")
    outputs = model(**inputs)
    mask_idx = get_mask_index(inputs.input_ids, tokenizer)
    top_preds = get_top_k_predictions(outputs.logits, mask_idx, 20)
    return top_preds

In [None]:
get_predictions(sentences[0], tokenizer, model)

[('bank', 70.31395435333252),
 ('office', 10.280592739582062),
 ('register', 1.7452014610171318),
 ('store', 1.628476195037365),
 ('bathroom', 0.9394762106239796),
 ('library', 0.8934843353927135),
 ('desk', 0.8724356070160866),
 ('counter', 0.7977331057190895),
 ('hotel', 0.5163723137229681),
 ('lobby', 0.49569709226489067),
 ('kitchen', 0.3637074725702405),
 ('garage', 0.34799198620021343),
 ('door', 0.3412732621654868),
 ('car', 0.33113667741417885),
 ('house', 0.26490529999136925),
 ('airport', 0.25470301043242216),
 ('elevator', 0.2491131192073226),
 ('back', 0.24807637091726065),
 ('computer', 0.24019568227231503),
 ('banks', 0.2349143149331212)]

In [None]:
get_predictions(sentences[1], tokenizer, model)

[('##bank', 32.602110505104065),
 ('below', 13.03199827671051),
 ('bank', 11.940895020961761),
 (',', 5.626505985856056),
 ('##boat', 3.1638897955417633),
 ('##front', 2.733229286968708),
 ('basin', 1.6210535541176796),
 ('##bed', 1.2178422883152962),
 ('together', 1.1841707862913609),
 ('bed', 0.9657179936766624),
 ('again', 0.8369828574359417),
 ('deck', 0.8356167003512383),
 ('valley', 0.7271416950970888),
 ('mouth', 0.7227536290884018),
 ('boat', 0.7151056081056595),
 ('pier', 0.6493288092315197),
 ('house', 0.6301583256572485),
 ('banks', 0.5700568202883005),
 ('pool', 0.5345712415874004),
 ('Thames', 0.4995541647076607)]

## Assignment
In this week's assignment, you are tasked with processing SemCor data and feed the data into BERT masked-LM. After that, use the predictions to find the most likely sense of the target word using WordNet similarity.

### Data Preprocessing 
You can find a sample of SemCor dataset [here](https://drive.google.com/file/d/1inmv3rUcGrtiS4VQwTMsT9HF-iL8jc5V/view?usp=sharing) and load the data using the following methods.

In [None]:
import json
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet as wn
sents = []
tokens = []
wn_id = []
lemmatizer = WordNetLemmatizer()

with open('/content/drive/MyDrive/graduate/nlp/week11/semcor.sample.jsonl') as f:
    for line in f:
        data = json.loads(line)
        sents.append(data['sent'])
        tokens.append(data['tokens'])
        wn_id.append(data['wnid'])


In [None]:
print(sents[10])
print(tokens[10])
print(wn_id[10])

implementation of georgia 's automobile title law was also recommended by the outgoing jury . 
['implementation', 'of', 'georgia', "'s", 'automobile', 'title', 'law', 'was', 'also', 'recommended', 'by', 'the', 'outgoing', 'jury', '.']
['implementation%1:04:01::', 0, 'georgia%1:15:00::', 0, 'automobile%1:06:00::', 'title%1:10:04::', 'law%1:10:00::', 0, 'also%4:02:00::', 'recommend%2:32:01::', 0, 0, 'outgoing%3:00:00::', 'jury%1:14:00::', 0]


In [None]:
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


True

In [None]:
# The WordNet ID can be converted to NLTK Lemma using the following function

wn.lemma_from_key('implementation%1:04:01::')

Lemma('execution.n.06.implementation')

### TODO 
Please implement a method to convert the data to BERT Masked-LM format and keep track of the headword. Store the data into the following lists

word[i] = 'implementation'  
ground_truth[i] = 'implementation%1:04:01::'  
sent[i] = "[MASK] of georgia 's automobile title law was also recommended by the outgoing jury ."  



In [None]:
sent_ex = "I saw a huge saw in a saw"
token_ex =['I','saw','a','huge','saw','in','a','saw']
tags_ex = [0,'saw%2:32:00::',0,'huge%2:32:00::','saw%2:32:00::',0,0,'saw%2:32:00::']

star_str = sent_ex
print(sent_ex)
print('--------------')
for index, tags in enumerate(tags_ex):
  if tags != 0:
    start_pos = star_str.index(token_ex[index])
    end_pos = start_pos + len(token_ex[index])
    new_str = sent_ex[:start_pos] + '[mask]' + sent_ex[end_pos:]
    star_str = star_str[:start_pos] + '*' * len(token_ex[index]) + star_str[end_pos:]
    print('start_pos:',start_pos,' end_pos:',end_pos)
    print(new_str)
    print(star_str)
    print('------------')

I saw a huge saw in a saw
--------------
start_pos: 2  end_pos: 5
I [mask] a huge saw in a saw
I *** a huge saw in a saw
------------
start_pos: 8  end_pos: 12
I saw a [mask] saw in a saw
I *** a **** saw in a saw
------------
start_pos: 13  end_pos: 16
I saw a huge [mask] in a saw
I *** a **** *** in a saw
------------
start_pos: 22  end_pos: 25
I saw a huge saw in a [mask]
I *** a **** *** in a ***
------------


In [None]:
mask_token = []
ground_truth = []
sent = []

count = 0

for sentence,token,tags in zip(sents,tokens,wn_id):
  sent_replace_with_star = sentence
  for index, sense in enumerate(tags):
    if sense != 0:
      start_pos = sent_replace_with_star.index(token[index])
      end_pos = start_pos + len(token[index])
      mask_sent = sentence[:start_pos] + '[MASK]' + sentence[end_pos:]
      sent_replace_with_star = sent_replace_with_star[:start_pos] + '*' * len(token[index]) + sent_replace_with_star[end_pos:]
      mask_token.append(token[index])
      ground_truth.append(sense)
      sent.append(mask_sent)

In [None]:
print(mask_token[:3])
print(len(mask_token))
print(ground_truth[:3])
print(len(ground_truth))
print(sent[:3])
print(len(sent))

['said', 'friday', 'investigation']
1042
['say%2:32:00::', 'friday%1:28:00::', 'investigation%1:09:00::']
1042
['the fulton_county_grand_jury [MASK] friday an investigation of atlanta \'s recent primary_election produced " no evidence " that any irregularities took_place . ', 'the fulton_county_grand_jury said [MASK] an investigation of atlanta \'s recent primary_election produced " no evidence " that any irregularities took_place . ', 'the fulton_county_grand_jury said friday an [MASK] of atlanta \'s recent primary_election produced " no evidence " that any irregularities took_place . ']
1042


#### Identify the top 5 predictions other than the headword using Masked-LM 
1. Use get_predictions to get the predicted words
2. Use lemmatizer to lemmatize the prediction
3. Remove headword
4. Keep top 5 unique predictions

In [None]:
def get_mask_index(input_ids, tokenizer):
    x = input_ids[0]
    is_masked = torch.where(x == tokenizer.mask_token_id, x, 0)
    mask_idx = torch.nonzero(is_masked)
    return mask_idx.item()

def get_top_k_predictions(pred_logits, mask_idx, top_k):
    probs = torch.nn.functional.softmax(pred_logits[0, mask_idx, :], dim=-1)
    top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
    top_k_pct_weights = [100 * x.item() for x in top_k_weights]
    top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices)
    return list(zip(top_k_tokens, top_k_pct_weights))

def get_predictions(sentence, tokenizer, model):
    inputs = tokenizer(sentence, return_tensors="pt")
    outputs = model(**inputs)
    mask_idx = get_mask_index(inputs.input_ids, tokenizer)
    top_preds = get_top_k_predictions(outputs.logits, mask_idx, 20)
    return top_preds

In [None]:
import string

candidate = []
candidate_lemmas = []
pred_words_with_score = {}
for index,sentence in enumerate(sent):
  predict_words = get_predictions(sentence, tokenizer, model)
  for predict_word in predict_words:
    if predict_word[0] != mask_token[index] and predict_word[0] not in string.punctuation:
      pred_word_lemma = lemmatizer.lemmatize(predict_word[0])
      if pred_word_lemma not in pred_words_with_score:
        pred_words_with_score[pred_word_lemma] = predict_word[1]
      else:
        new_score = pred_words_with_score[pred_word_lemma] + predict_word[1]
        pred_words_with_score.update({pred_word_lemma: new_score})
  sorted_pred_words_with_score = sorted(pred_words_with_score.items(), key=lambda x:x[1],reverse=True)
  
  for candidate_length,(word,value) in enumerate(sorted_pred_words_with_score):
    if candidate_length >= 5:
      break
    candidate.append(word)
  candidate_lemmas.append(candidate)
  candidate = []
  pred_words_with_score = {}


len(candidate_lemmas)



1042

In [None]:
for i in range(3):
  print(sent[i])
  print(mask_token[i])
  print(ground_truth[i])
  print(candidate_lemmas[i])
  print('--------------------------')

the fulton_county_grand_jury [MASK] friday an investigation of atlanta 's recent primary_election produced " no evidence " that any irregularities took_place . 
said
say%2:32:00::
['found', 'reported', 'told', 'stated', 'announced']
--------------------------
the fulton_county_grand_jury said [MASK] an investigation of atlanta 's recent primary_election produced " no evidence " that any irregularities took_place . 
friday
friday%1:28:00::
['that', 'after', 'in', 'during', 'of']
--------------------------
the fulton_county_grand_jury said friday an [MASK] of atlanta 's recent primary_election produced " no evidence " that any irregularities took_place . 
investigation
investigation%1:09:00::
['analysis', 'examination', 'audit', 'evaluation', 'inspection']
--------------------------


example:  
candidate_lemmas = ['office', 'register', 'store', 'bathroom', 'library']


Identify the most similar sense of headword with relation to the 5 unique candidates



In [None]:
from itertools import product
headword_list = []
candidate_word_list = []
predicted_sense = []


for idx,head_word in enumerate(mask_token):
  headword_list = head_word.split()
  candidate_word_list = candidate_lemmas[idx]

  headword_syns = set(ss for word in headword_list for ss in wn.synsets(word))
  candidate_word_syns = set(ss for word in candidate_word_list for ss in wn.synsets(word))  
  try:
    best = max((wn.wup_similarity(s1, s2) or 0, s1, s2) for s1, s2 in product(headword_syns, candidate_word_syns))
    predicted_sense.append(best)
  except:
    predicted_sense.append('')
  
len(predicted_sense)

1042

In [None]:
for i in range(3):
  print(sent[i])
  print(mask_token[i])
  print(ground_truth[i])
  print(candidate_lemmas[i])
  print(predicted_sense[i])
  print('--------------------------')

  

the fulton_county_grand_jury [MASK] friday an investigation of atlanta 's recent primary_election produced " no evidence " that any irregularities took_place . 
said
say%2:32:00::
['found', 'reported', 'told', 'stated', 'announced']
(1.0, Synset('state.v.01'), Synset('state.v.01'))
--------------------------
the fulton_county_grand_jury said [MASK] an investigation of atlanta 's recent primary_election produced " no evidence " that any irregularities took_place . 
friday
friday%1:28:00::
['that', 'after', 'in', 'during', 'of']
(0.375, Synset('friday.n.01'), Synset('inch.n.01'))
--------------------------
the fulton_county_grand_jury said friday an [MASK] of atlanta 's recent primary_election produced " no evidence " that any irregularities took_place . 
investigation
investigation%1:09:00::
['analysis', 'examination', 'audit', 'evaluation', 'inspection']
(0.9411764705882353, Synset('investigation.n.02'), Synset('examination.n.05'))
--------------------------


For evaluation purpose, for i = 50, please run the process and print out the following:  
1. word[50]
2. ground_truth[50] (in synset or lemma)
3. sent[50]
4. candidate_lemmas
5. predicted_sense (in synset or lemma)    

Also, please print out the accuracy of the process over our dataset

In [None]:
i = 50
print(f'word[{i}]:  ',mask_token[i])
print(f'ground_truth[{i}]:  ',ground_truth[i])
print(f'sent[{i}]:  ',sent[i])
print(f'candidate_lemmas:  ',candidate_lemmas[i])
print(f'predicted_sense: ',predicted_sense[i][1])

word[50]:   size
ground_truth[50]:   size%1:07:00::
sent[50]:   " only a relative handful of such reports was received " , the jury said , " considering the widespread interest in the election , the number of voters and the [MASK] of_this city " . 
candidate_lemmas:   ['population', 'status', 'reputation', 'character', 'state']
predicted_sense:  Synset('size.n.04')


In [None]:
score = []

for idx, result in enumerate(predicted_sense):
  if result != '':
    try:
      ground_truth_lemma = wn.lemma_from_key(ground_truth[idx])
      ground_truth_synset = ground_truth_lemma.synset()
      score_with_ground_truth = wn.wup_similarity(ground_truth_synset, result[1])
      score.append((score_with_ground_truth,ground_truth_synset,result[1]))
    except:
      ground_truth_replace =  ground_truth[idx].split('%3')
      g_t_replace = ground_truth_replace[0] + '%5' + ground_truth_replace[1]
      ground_truth_lemma = wn.lemma_from_key(g_t_replace)
      ground_truth_synset = ground_truth_lemma.synset()
      score_with_ground_truth = wn.wup_similarity(ground_truth_synset, result[1])
      score.append((score_with_ground_truth,ground_truth_synset,result[1]))
      
  else:
    score.append(0)


In [None]:
correct_count = 0
for result in score:
  if result != 0 :
    if result[1] == result[2]:
      correct_count+=1
  
  

accuracy = correct_count / len(score)
accuracy

0.3397312859884837

## TA's Note

Congratulations, you made it to the end of the tutorial! Make sure you make an appointment to show your work and turn in your finished assignment before next week's lesson. We will ask you to run your code, so double check that everything is working and that your model is saved. Don't worry if you didn't pass the evaluation requirements, you'll still get partial points for trying.