<a href="https://colab.research.google.com/github/IshtiSikder/Word-sense-disambiguation-with-BertWSD/blob/main/Word_Sense_Disambiguation_codes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Kindly select Runtime > Run all from above to run all the following cells successively, and please read the numbered text cells with the code cells below for descriptions about how the code works.

# We will build a pipeline that can take a sentence with an annotated target word as an input, and as output can create a MCQ question with the target word as the correct answer. The other options in the MCQ question (a.k.a distractors) will be chosen similar to the target word i.e if the target word is Red then other options will be White, Blue, Yellow etc. For the pipeline to be effectively doing this task, it will need to have Word Sense Disambiguation (WSD) capabilities which we will build into it. 

# When we will have completed building this pipeline, we will run the following two sentences through our pipeline to see if the pipeline can build meaningful MCQ questions out of them. Both sentences have the target word set as **mouse**, though the meaning is supposed to be different.

#a) "John bought a **mouse** for his computer."
#b) "John saw a **mouse** under his bed."

# 1. Let's first install some necessary modules/dependencies.

In [None]:
!pip install --quiet transformers==2.9.0
!pip install --quiet nltk==3.4.5
!pip install --quiet gradio==1.4.2

[K     |████████████████████████████████| 635 kB 3.9 MB/s 
[K     |████████████████████████████████| 5.6 MB 20.0 MB/s 
[K     |████████████████████████████████| 895 kB 46.6 MB/s 
[K     |████████████████████████████████| 1.2 MB 60.0 MB/s 
[K     |████████████████████████████████| 1.5 MB 4.1 MB/s 
[?25h  Building wheel for nltk (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 1.1 MB 4.3 MB/s 
[K     |████████████████████████████████| 210 kB 53.3 MB/s 
[K     |████████████████████████████████| 856 kB 41.7 MB/s 
[K     |████████████████████████████████| 3.6 MB 43.2 MB/s 
[K     |████████████████████████████████| 61 kB 407 kB/s 
[?25h  Building wheel for Flask-BasicAuth (setup.py) ... [?25l[?25hdone
  Building wheel for flask-cachebuster (setup.py) ... [?25l[?25hdone


# 2. We will now connect our personal Google Drive to Google Colab files by running the following cell. This will allow us to store the Pre-trained BERT-WSD (WSD = Word Sense Disambiguation) model we will be using in google drive and have access to it from Google Colab.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


# 3. Let's now download the pre-trained and open source BERT-WSD model we will be using from the following link and place it on the home page of our google drive (which we just connected to our colab environment).

# Download link: https://entuedu-my.sharepoint.com/personal/boonpeng001_e_ntu_edu_sg/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fboonpeng001%5Fe%5Fntu%5Fedu%5Fsg%2FDocuments%2FBERT%2DWSD%2Fmodel%2Fbert%5Fbase%2Daugmented%2Dbatch%5Fsize%3D128%2Dlr%3D2e%2D5%2Dmax%5Fgloss%3D6

# 4. Let's now extract the downloaded zip file by running the following cell. 

In [None]:
import os
import zipfile

bert_wsd_pytorch = "/content/gdrive/My Drive/bert_base-augmented-batch_size=128-lr=2e-5-max_gloss=6.zip"
extract_directory = "/content/gdrive/My Drive"

extracted_folder = bert_wsd_pytorch.replace(".zip","")

if not os.path.isdir(extracted_folder):
  with zipfile.ZipFile(bert_wsd_pytorch, 'r') as zip_ref:
      zip_ref.extractall(extract_directory)
else:
  print (extracted_folder," is extracted already")

/content/gdrive/My Drive/bert_base-augmented-batch_size=128-lr=2e-5-max_gloss=6  is extracted already


# 5. Let's initialize the BertWSD model that will utilize the annotated WordNet corpus. The WordNet corpus will create the distractors based on the target word we annotate in the input sentence.

In [None]:
import torch
import math
from transformers import BertModel, BertConfig, BertPreTrainedModel, BertTokenizer

class BertWSD(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

        self.ranking_linear = torch.nn.Linear(config.hidden_size, 1)

        self.init_weights()

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dir = "/content/gdrive/My Drive/bert_base-augmented-batch_size=128-lr=2e-5-max_gloss=6"


model = BertWSD.from_pretrained(model_dir)
tokenizer = BertTokenizer.from_pretrained(model_dir)
if '[TGT]' not in tokenizer.additional_special_tokens:
    tokenizer.add_special_tokens({'additional_special_tokens': ['[TGT]']})
    assert '[TGT]' in tokenizer.additional_special_tokens
    model.resize_token_embeddings(len(tokenizer))
    
model.to(DEVICE)
model.eval()

BertWSD(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30523, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    

In [None]:
import csv
import os
from collections import namedtuple

import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet as wn

import torch
from tqdm import tqdm

GlossSelectionRecord = namedtuple("GlossSelectionRecord", ["guid", "sentence", "sense_keys", "glosses", "targets"])
BertInput = namedtuple("BertInput", ["input_ids", "input_mask", "segment_ids", "label_id"])



def _create_features_from_records(records, max_seq_length, tokenizer, cls_token_at_end=False, pad_on_left=False,
                                  cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
                                  sequence_a_segment_id=0, sequence_b_segment_id=1,
                                  cls_token_segment_id=1, pad_token_segment_id=0,
                                  mask_padding_with_zero=True, disable_progress_bar=False):
    """ Convert records to list of features. Each feature is a list of sub-features where the first element is
        always the feature created from context-gloss pair while the rest of the elements are features created from
        context-example pairs (if available)
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
    """
    features = []
    for record in tqdm(records, disable=disable_progress_bar):
        tokens_a = tokenizer.tokenize(record.sentence)

        sequences = [(gloss, 1 if i in record.targets else 0) for i, gloss in enumerate(record.glosses)]

        pairs = []
        for seq, label in sequences:
            tokens_b = tokenizer.tokenize(seq)

            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)

            # The convention in BERT is:
            # (a) For sequence pairs:
            #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
            #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
            #
            # Where "type_ids" are used to indicate whether this is the first
            # sequence or the second sequence. The embedding vectors for `type=0` and
            # `type=1` were learned during pre-training and are added to the wordpiece
            # embedding vector (and position vector). This is not *strictly* necessary
            # since the [SEP] token unambiguously separates the sequences, but it makes
            # it easier for the model to learn the concept of sequences.
            #
            # For classification tasks, the first vector (corresponding to [CLS]) is
            # used as as the "sentence vector". Note that this only makes sense because
            # the entire model is fine-tuned.
            tokens = tokens_a + [sep_token]
            segment_ids = [sequence_a_segment_id] * len(tokens)

            tokens += tokens_b + [sep_token]
            segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)

            if cls_token_at_end:
                tokens = tokens + [cls_token]
                segment_ids = segment_ids + [cls_token_segment_id]
            else:
                tokens = [cls_token] + tokens
                segment_ids = [cls_token_segment_id] + segment_ids

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding_length = max_seq_length - len(input_ids)
            if pad_on_left:
                input_ids = ([pad_token] * padding_length) + input_ids
                input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
                segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
            else:
                input_ids = input_ids + ([pad_token] * padding_length)
                input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
                segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            pairs.append(
                BertInput(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label)
            )

        features.append(pairs)

    return features


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


# 6. We will now use the t5 transformer model to build a **getMCQ** function.

In [None]:
import re

import torch
from tabulate import tabulate
from torch.nn.functional import softmax
from tqdm import tqdm
from transformers import BertTokenizer

import time

from transformers import T5ForConditionalGeneration,T5Tokenizer

question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
question_tokenizer = T5Tokenizer.from_pretrained('t5-base')

MAX_SEQ_LENGTH = 128



def get_sense(sent):

  re_result = re.search(r"\[TGT\](.*)\[TGT\]", sent)
  if re_result is None:
      print("\nIncorrect input format. Please try again.")

  ambiguous_word = re_result.group(1).strip()




  results = dict()

  wn_pos = wn.NOUN
  for i, synset in enumerate(set(wn.synsets(ambiguous_word, pos=wn_pos))):
      results[synset] =  synset.definition()

  if len(results) ==0:
    return (None,None,ambiguous_word)

  # print (results)
  sense_keys=[]
  definitions=[]
  for sense_key, definition in results.items():
      sense_keys.append(sense_key)
      definitions.append(definition)


  record = GlossSelectionRecord("test", sent, sense_keys, definitions, [-1])

  features = _create_features_from_records([record], MAX_SEQ_LENGTH, tokenizer,
                                            cls_token=tokenizer.cls_token,
                                            sep_token=tokenizer.sep_token,
                                            cls_token_segment_id=1,
                                            pad_token_segment_id=0,
                                            disable_progress_bar=True)[0]

  with torch.no_grad():
      logits = torch.zeros(len(definitions), dtype=torch.double).to(DEVICE)
      for i, bert_input in tqdm(list(enumerate(features)), desc="Progress"):
          logits[i] = model.ranking_linear(
              model.bert(
                  input_ids=torch.tensor(bert_input.input_ids, dtype=torch.long).unsqueeze(0).to(DEVICE),
                  attention_mask=torch.tensor(bert_input.input_mask, dtype=torch.long).unsqueeze(0).to(DEVICE),
                  token_type_ids=torch.tensor(bert_input.segment_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)
              )[1]
          )
      scores = softmax(logits, dim=0)

      preds = (sorted(zip(sense_keys, definitions, scores), key=lambda x: x[-1], reverse=True))


  # print (preds)
  sense = preds[0][0]
  meaning = preds[0][1]
  return (sense,meaning,ambiguous_word)


# Distractors from Wordnet
def get_distractors_wordnet(syn,word):
    distractors=[]
    word= word.lower()
    orig_word = word
    if len(word.split())>0:
        word = word.replace(" ","_")
    hypernym = syn.hypernyms()
    if len(hypernym) == 0: 
        return distractors
    for item in hypernym[0].hyponyms():
        name = item.lemmas()[0].name()
        #print ("name ",name, " word",orig_word)
        if name == orig_word:
            continue
        name = name.replace("_"," ")
        name = " ".join(w.capitalize() for w in name.split())
        if name is not None and name not in distractors:
            distractors.append(name)
    return distractors


def get_question(sentence,answer):
  text = "context: {} answer: {} </s>".format(sentence,answer)
  max_len = 256
  encoding = question_tokenizer.encode_plus(text,max_length=max_len, pad_to_max_length=True, return_tensors="pt")

  input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]

  outs = question_model.generate(input_ids=input_ids,
                                  attention_mask=attention_mask,
                                  early_stopping=True,
                                  num_beams=5,
                                  num_return_sequences=1,
                                  no_repeat_ngram_size=2,
                                  max_length=200)


  dec = [question_tokenizer.decode(ids) for ids in outs]


  Question = dec[0].replace("question:","")
  Question= Question.strip()
  return Question


Downloading:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

In [None]:
import random

def getMCQs(sent):
  sentence_for_bert = sent.replace("**"," [TGT] ")
  sentence_for_bert = " ".join(sentence_for_bert.split())
  # try:
  sense,meaning,answer = get_sense(sentence_for_bert)
  if sense is not None:
    distractors = get_distractors_wordnet(sense,answer)
  else: 
    distractors = ["Word not found in Wordnet. So unable to extract distractors."]
  sentence_for_T5 = sent.replace("**"," ")
  sentence_for_T5 = " ".join(sentence_for_T5.split()) 
  ques = get_question(sentence_for_T5,answer)
  return ques,answer,distractors,meaning

# 7. Let's now try out the aforementioned two sentences with **mouse** as their target word.

In [None]:
sentence = "John bought a **mouse** for his computer."

question,answer,distractors,meaning = getMCQs(sentence)

new = distractors[:3]
                  
new.append(answer)

random.shuffle(new)

print("\n")

print("sentence:",sentence)
print ("question:",question)
for i in new:
  print("-->",i)
print ("answer:", answer)
print ("meaning:",meaning)

Progress: 100%|██████████| 4/4 [00:02<00:00,  1.57it/s]
  beam_id = beam_token_id // vocab_size




sentence: John bought a **mouse** for his computer.
question: What did John buy for his computer?
--> mouse
--> Defibrillator
--> Beeper
--> Answering Machine
answer: mouse
meaning: a hand-operated electronic device that controls the coordinates of a cursor on your computer screen as you move it around on a pad; on the bottom of the device is a ball that rolls on the surface of the pad


In [None]:
sentence = "John saw a **mouse** under his bed."

question,answer,distractors,meaning = getMCQs(sentence)

new = distractors[:3]
                  
new.append(answer)

random.shuffle(new)

print("\n")

print("sentence:",sentence)
print ("question:",question)
for i in new:
  print("-->",i)
print ("answer:", answer)
print ("meaning:",meaning)

Progress: 100%|██████████| 4/4 [00:02<00:00,  1.83it/s]
  beam_id = beam_token_id // vocab_size




sentence: John saw a **mouse** under his bed.
question: What animal did John see under his bed?
--> Abrocome
--> Beaver
--> mouse
--> Agouti
answer: mouse
meaning: any of numerous small rodents typically resembling diminutive rats having pointed snouts and small ears on elongated bodies with slender usually hairless tails
