## Preparing Dataset

In [7]:
import argparse
import csv
import random
import re
from pathlib import Path
from xml.etree.ElementTree import ElementTree
import nltk 
from tqdm import tqdm

In [8]:
nltk.download('wordnet')

[nltk_data] Downloading package wordnet to /home/abhishek/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [3]:
from nltk.corpus import wordnet as wn 
POS = {'NOUN':wn.NOUN, 'VERB':wn.VERB, 'ADJ':wn.ADJ, 'ADV':wn.ADV}

def getInfo(type,pos,lemma):
    res = dict()
    word_pos = POS[pos] if pos is not None else None
    morpho = wn._morphy(lemma, pos=word_pos) if pos is not None else []

    for synset in tqdm(set(wn.synsets(lemma,pos=word_pos))):
        key = None
        for lem in synset.lemmas():
            if lem.name().lower() == lemma.lower():
                key = lem.key()
                break
            elif lem.name().lower() in morpho:
                key = lem.key()
            
        assert key is not None
        res[key] = synset.definition() if type == 'def' else synset.examples()
    
    return res

def get_glosses(pos,lemma):
    return getInfo('def',pos,lemma) 

def getexample(pos,lemma):
    return getInfo('ex',pos,lemma)

def getAllWordnetLemmaNames():
    res = []
    for pos, pos_name in POS.items():
        for synset in wn.synsets(pos=pos_name):
            res.append((pos,wn.all_lemma_names(pos=pos_name)))

    return res 

In [None]:
xml_file = './SemCor/semcor.data.xml'
gold_txt_file = './SemCor/semcor.gold.key.txt'
output_file = './SemCor/semcor_data.csv'
max_glossKey = 4

print("Creating dataset...")
root = ElementTree(file=xml_file).getroot()
with open(output_file,'w',encoding='utf-8') as f:
    writer = csv.writer(f)
    writer.writerow(['id','sentence','sense_keys','glosses','target_words'])

    def write_to_csv(id_,sentence_,lemma_,pos_,gold_keys_):
        sense_i = get_glosses(pos_,lemma_)
        # print(sense_i)
        gloss_sense_pairs = list()
        for i in gold_keys_:
            gloss_sense_pairs.append((i,sense_i[i]))
            del sense_i[i]
        rem = max_glossKey - len(gloss_sense_pairs)
        if len(sense_i) > rem :
            gloss_sense_pairs.extend(random.sample(list(sense_i.items()),rem))
        elif len(sense_i) > 0:
            gloss_sense_pairs.extend(list(sense_i.items()))

        random.shuffle(gloss_sense_pairs)
        glosses = [i[1] for i in gloss_sense_pairs]
        sense_keys = [i[0] for i in gloss_sense_pairs]

        target_words = [sense_keys.index(i) for i in gold_keys_]
        writer.writerow([id_,sentence_,sense_keys,glosses,target_words])

    with open(gold_txt_file,'r',encoding='utf-8') as g:
        for dc in tqdm(root):
            for sentence in dc:
                instances = list()
                tokens = list()
                for token in sentence:
                    tokens.append(token.text)
                    if token.tag == 'instance':
                        strt_index = len(tokens) -1 
                        end_index = strt_index + 1
                        instances.append((token.attrib['id'],strt_index,end_index,token.attrib['lemma'],token.attrib['pos']))
                # print(instances)
                
                for id_,start,end,lemma,pos in instances:
                    gold_key = g.readline().strip().split()
                    gold = gold_key[1:]
                    assert id_ == gold_key[0]
                    sentence_ = ' '.join(
                        tokens[:start] + ['[TGT]'] + tokens[start:end] + ['[TGT]'] + tokens[end:]
                    )
                    write_to_csv(id_,sentence_,lemma,pos,gold)



print("Done!")


In [13]:
# read semeval2015.gold.txt 
with open("./senseval3.gold.key.txt") as f:
    lines = f.readlines()
    lines = [line.strip() for line in lines]

# read semeval2015.test.txt
with open("./senseval3_predictions.txt") as f:
    test_lines = f.readlines()
    test_lines = [line.strip() for line in test_lines]

In [14]:

count = 0
for i in range(len(lines)):
    # split the line to get the id and the gloss
    line = lines[i].split()
    # print(line)
    test_line = test_lines[i].split()
    for j in range(len(line[1:])):
        if test_line[1] == line[j+1]:
            count += 1
            break

print("Accuracy:",count/len(lines))
# print(lines)

Accuracy: 0.7005405405405405


In [6]:
count = 0
for i in range(len(lines)):
    # split the line to get the id and the gloss
    line = lines[i].split()
    # print(line)
    test_line = test_lines[i].split()
    for j in range(len(line[1:])):
        if test_line[1] == line[j+1]:
            count += 1
            break

print("Accuracy:", count/len(lines))


Accuracy: 0.7354198262787812


In [8]:
import argparse
import re

import torch
from tabulate import tabulate
from torch.nn.functional import softmax
from tqdm import tqdm
from transformers import BertTokenizer
from createFeatures import GlossSelectionRecord, _create_features_from_records
from modelBERT import BERT_for_WSD


MAX_SEQ_LENGTH = 128
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_predictions(model, tokenizer, sentence):
    re_result = re.search(r"\[TGT\](.*)\[TGT\]", sentence)
    if re_result is None:
        print("Incorrect input format")
        return

    ambiguous_word = re_result.group(1).strip()
    sense_keys = []
    definitions = []
    for sense_key, definition in get_glosses(None,ambiguous_word).items():
        sense_keys.append(sense_key)
        definitions.append(definition)

    record = GlossSelectionRecord(
        "test", sentence, sense_keys, definitions, [-1])
    features = _create_features_from_records([record], MAX_SEQ_LENGTH, tokenizer,
                                             cls_token=tokenizer.cls_token,
                                             sep_token=tokenizer.sep_token,
                                             cls_token_segment_id=1,
                                             pad_token_segment_id=0,
                                             disable_progress_bar=True)[0]

    with torch.no_grad():
        logits = torch.zeros(len(definitions), dtype=torch.double).to(device)
        for i, bert_input in tqdm(list(enumerate(features)), desc="Progress"):
            logits[i] = model.ranking_linear(
                model.bert(
                    input_ids=torch.tensor(
                        bert_input.input_ids, dtype=torch.long).unsqueeze(0).to(device),
                    attention_mask=torch.tensor(
                        bert_input.input_mask, dtype=torch.long).unsqueeze(0).to(device),
                    token_type_ids=torch.tensor(
                        bert_input.segment_ids, dtype=torch.long).unsqueeze(0).to(device)
                )[1]
            )
        scores = softmax(logits, dim=0)

    return sorted(zip(sense_keys, definitions, scores), key=lambda x: x[-1], reverse=True)


model = BERT_for_WSD.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model.to(device)
model.eval()

while True:
    sentence = input(
        "\nEnter a sentence with an ambiguous word surrounded by [TGT] tokens\n> ")
    predictions = get_predictions(model, tokenizer, sentence)
    if predictions:
        print("\nPredictions:")
        print(tabulate(
            [[f"{i+1}.", key, gloss, f"{score:.5f}"]
             for i, (key, gloss, score) in enumerate(predictions)],
            headers=["No.", "Sense key", "Definition", "Score"])
        )


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BERT_for_WSD: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BERT_for_WSD from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BERT_for_WSD from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BERT_for_WSD were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['ranking_linear.bias'


Predictions:
  No.  Sense key     Definition                                                                             Score
-----  ------------  -----------------------------------------------------------------------------------  -------
    1  be%1:27:00::  a light strong brittle grey toxic bivalent metallic element                          0.08185
    2  be%2:41:00::  work in a specific place, with a specific subject, or in a specific function         0.08056
    3  be%2:42:05::  occupy a certain position or area; be somewhere                                      0.079
    4  be%2:42:04::  happen, occur, take place                                                            0.07449
    5  be%2:42:13::  to remain unmolested, undisturbed, or uninterrupted -- used only in infinitive form  0.07264
    6  be%2:42:00::  have an existence, be extant                                                         0.07008
    7  be%2:42:03::  have the quality of being; (copula, used with an adject