In [None]:
!pip install --quiet transformers==2.9.0
!pip install --quiet nltk==3.4.5

In [None]:
# connect your personal google drive to store the trained model
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
import torch
import math
from transformers import BertModel, BertConfig, BertPreTrainedModel, BertTokenizer

class BertWSD(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

        self.ranking_linear = torch.nn.Linear(config.hidden_size, 1)

        self.init_weights()

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_directory = "/content/gdrive/My Drive/bert_base-augmented-batch_size=128-lr=2e-5-max_gloss=6"


model = BertWSD.from_pretrained(model_directory)
tokenizer = BertTokenizer.from_pretrained(model_directory)
# add new special token
if '[TGT]' not in tokenizer.additional_special_tokens:
    tokenizer.add_special_tokens({'additional_special_tokens': ['[TGT]']})
    assert '[TGT]' in tokenizer.additional_special_tokens
    model.resize_token_embeddings(len(tokenizer))
    
model.to(DEVICE)
model.eval()

BertWSD(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30523, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    

In [None]:
import csv
import os
from collections import namedtuple

import nltk
nltk.download('wordnet')

import torch
from tqdm import tqdm

GlossSelectionRecord = namedtuple("GlossSelectionRecord", ["guid", "sentence", "sense_keys", "glosses", "targets"])
BertInput = namedtuple("BertInput", ["input_ids", "input_mask", "segment_ids", "label_id"])

def _create_features_from_records(records, max_seq_length, tokenizer, cls_token_at_end=False, pad_on_left=False,
                                  cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
                                  sequence_a_segment_id=0, sequence_b_segment_id=1,
                                  cls_token_segment_id=1, pad_token_segment_id=0,
                                  mask_padding_with_zero=True, disable_progress_bar=False):
    """ Convert records to list of features. Each feature is a list of sub-features where the first element is
        always the feature created from context-gloss pair while the rest of the elements are features created from
        context-example pairs (if available)
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
    """
    features = []
    for record in tqdm(records, disable=disable_progress_bar):
        tokens_a = tokenizer.tokenize(record.sentence)

        sequences = [(gloss, 1 if i in record.targets else 0) for i, gloss in enumerate(record.glosses)]

        pairs = []
        for seq, label in sequences:
            tokens_b = tokenizer.tokenize(seq)

            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)

            # The convention in BERT is:
            # (a) For sequence pairs:
            #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
            #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
            #
            # Where "type_ids" are used to indicate whether this is the first
            # sequence or the second sequence. The embedding vectors for `type=0` and
            # `type=1` were learned during pre-training and are added to the wordpiece
            # embedding vector (and position vector). This is not *strictly* necessary
            # since the [SEP] token unambiguously separates the sequences, but it makes
            # it easier for the model to learn the concept of sequences.
            #
            # For classification tasks, the first vector (corresponding to [CLS]) is
            # used as as the "sentence vector". Note that this only makes sense because
            # the entire model is fine-tuned.
            tokens = tokens_a + [sep_token]
            segment_ids = [sequence_a_segment_id] * len(tokens)

            tokens += tokens_b + [sep_token]
            segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)

            if cls_token_at_end:
                tokens = tokens + [cls_token]
                segment_ids = segment_ids + [cls_token_segment_id]
            else:
                tokens = [cls_token] + tokens
                segment_ids = [cls_token_segment_id] + segment_ids

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding_length = max_seq_length - len(input_ids)
            if pad_on_left:
                input_ids = ([pad_token] * padding_length) + input_ids
                input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
                segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
            else:
                input_ids = input_ids + ([pad_token] * padding_length)
                input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
                segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            pairs.append(
                BertInput(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label)
            )

        features.append(pairs)

    return features

def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet as wordnet

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
import re
import torch
from tabulate import tabulate
from torch.nn.functional import softmax
from tqdm import tqdm
from transformers import BertTokenizer
import time

MAX_SEQ_LENGTH = 128

class DistractorGeneration:
  def __init__(self, sentence, targetWord):
    self.sentence = sentence
    self.targetWord = targetWord
    self.formattedSentence = ""

  def generate_distractors_wordnet(self, synset, word):
    distractors = []
    target_word = word.lower()
    hypernyms_of_target_word = synset.hypernyms()
    hyponyms_of_target_word = hypernyms_of_target_word[0].hyponyms()
    if len(hypernyms_of_target_word) == 0:
      print('hypernym not found. Unable generate distractors')
      return []
    for item in hyponyms_of_target_word:
      name = item.lemmas()[0].name()
      if name != target_word:
        distractors.append(name)
    return distractors

  def get_predictions(self, sentence):
    re_result = re.search(r"\[TGT\](.*)\[TGT\]", sentence)
    if re_result is None:
        print("\nIncorrect input format. Please try again.")

    ambiguous_word = re_result.group(1).strip()

    results = dict()

    part_of_speech = wordnet.NOUN
    for i, synset in enumerate(set(wordnet.synsets(ambiguous_word, pos=part_of_speech))):
        results[synset] =  synset.definition()

    if len(results) ==0:
      return (None,None,ambiguous_word)

    sense_keys=[]
    definitions=[]
    for sense_key, definition in results.items():
        sense_keys.append(sense_key)
        definitions.append(definition)

    record = GlossSelectionRecord("test", sentence, sense_keys, definitions, [-1])

    features = _create_features_from_records([record], MAX_SEQ_LENGTH, tokenizer,
                                              cls_token=tokenizer.cls_token,
                                              sep_token=tokenizer.sep_token,
                                              cls_token_segment_id=1,
                                              pad_token_segment_id=0,
                                              disable_progress_bar=True)[0]

    with torch.no_grad():
        logits = torch.zeros(len(definitions), dtype=torch.double).to(DEVICE)
        # for i, bert_input in tqdm(list(enumerate(features)), desc="Progress"):
        for i, bert_input in list(enumerate(features)):
            logits[i] = model.ranking_linear(
                model.bert(
                    input_ids=torch.tensor(bert_input.input_ids, dtype=torch.long).unsqueeze(0).to(DEVICE),
                    attention_mask=torch.tensor(bert_input.input_mask, dtype=torch.long).unsqueeze(0).to(DEVICE),
                    token_type_ids=torch.tensor(bert_input.segment_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)
                )[1]
            )
        scores = softmax(logits, dim=0)
        #print(scores)

        preds = (sorted(zip(sense_keys, definitions, scores), key=lambda x: x[-1], reverse=True))
        #print(preds)

    sense = preds[0][0]
    meaning = preds[0][1]
    score = preds[0][2]
    for i in range(len(preds)):
      print("meaning: " , preds[i][1] , "\n score: " ,preds[i][2].item())
      print()
    return (sense,meaning,ambiguous_word)
  
  def format_sentence_for_bert(self):
    if f' {self.targetWord} ' in self.sentence:
      self.formattedSentence = self.sentence.replace(self.targetWord, f'[TGT]{self.targetWord}[TGT]')
      print(self.formattedSentence)
      print()
      return self.formattedSentence
    else:
      print('target word does not exist')
      return self.formattedSentence

In [None]:
#sentence = "John is annoyed by a cricket chirping in his room"
sentence = "John's enjoys playing cricket in summer"
targetword = "cricket"
distractorGenerator = DistractorGeneration(sentence, targetword)
formattedSentence = distractorGenerator.format_sentence_for_bert()
sense,meaning,answer = distractorGenerator.get_predictions(formattedSentence)
distractors = distractorGenerator.generate_distractors_wordnet(sense, answer)
print("\n -----------------------------------------------------------------------------------------------\n")
print(distractors)

John's enjoys playing [TGT]cricket[TGT] in summer

meaning:  a game played with a ball and bat by two teams of 11 players; teams take turns trying to score runs 
 score:  0.9997342482658158

meaning:  leaping insect; male makes chirping noises by rubbing the forewings together 
 score:  0.0002657517341842314


 -----------------------------------------------------------------------------------------------

['ball_game', 'field_hockey', 'football', 'hurling', 'lacrosse', 'polo', 'pushball', 'ultimate_frisbee']
