In [None]:
!pip install -U spacy[cuda112]
!python -m spacy download en_core_web_trf
!pip install datasets

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from collections import defaultdict
import random
from os import path
import pickle

import spacy
import spacy_transformers

import datasets

from tqdm.auto import tqdm
import pandas as pd

## Loading Dataset

In [None]:
squad = datasets.load_dataset("squad", split="validation")

Downloading builder script:   0%|          | 0.00/5.27k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.36k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.67k [00:00<?, ?B/s]

Downloading and preparing dataset squad/plain_text to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/8.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453. Subsequent calls will reuse this data.


## Spacy Configuration

In [None]:
spacy.prefer_gpu()

True

## Load NER model

In [None]:
nlp = spacy.load('en_core_web_trf', disable=["tagger", "parser", "attribute_ruler", "lemmatizer"])

## Find Named Entities For Given Text

In [None]:
def get_text_nes(text):
    global nlp
    ner_tags = nlp(text, disable=["tagger", "parser", "attribute_ruler", "lemmatizer"])
    ner_result = list()

    for e in ner_tags.ents:
        ner_result.append({'word':e.text, 'begin':e.start_char, 'end':e.end_char, 'type':e.label_})
    
    return ner_result

## Generating Distractor Entities

In [None]:
def generate_replacement_entities(index):
    global predefined_entity_lists
    global predicted_answers
    answer = predicted_answers[index]['prediction_text']
    context = squad[index]['context']

    answer_nes = get_text_nes(answer)
    context_nes = get_text_nes(context)

    # print(answer_nes, context_nes)

    entity_type = None
    if len(answer_nes) == 0:
        for entity in context_nes:
            if entity['word'] == answer:
                entity_type = entity['type']
                entity_to_replace = entity['word']
            # What if there are no matches
    else:
        rand_num = random.randint(0, len(answer_nes)-1)
        entity_type = answer_nes[rand_num]['type']
        entity_to_replace = answer_nes[rand_num]['word']
    
    if entity_type == None:
        return None, None
    elif entity_type == 'CARDINAL' or entity_type == 'DATE':
        if entity_to_replace.isdigit():
            replacement_entities = list()
            while len(replacement_entities) < 3:
                rand_num = random.randint(1, 10) * (-1 if (random.randint(0, 10) % 2) else 1)
                replacement_entities.append(str(int(entity_to_replace) + rand_num))
            return replacement_entities, entity_to_replace
    

    context = context.replace(entity_to_replace, '')
    context_nes = get_text_nes(context)

    filtered_entities = defaultdict(list)
    for entity in context_nes:
        if entity['word'] not in filtered_entities[entity['type']]:
            filtered_entities[entity['type']].append(entity['word'])
    
    replacement_entities = filtered_entities.get(entity_type, None)
    # print("First get: ", replacement_entities)

    if replacement_entities != None and entity_to_replace in replacement_entities:
        replacement_entities.remove(entity_to_replace)
    
    if replacement_entities != None and len(replacement_entities) == 0:
        return None, entity_to_replace

    if replacement_entities == None or len(replacement_entities) < 3:
        replacement_entities = predefined_entity_lists[squad[index]['title']].get(entity_type, None)

        if replacement_entities != None and entity_to_replace in replacement_entities:
            replacement_entities.remove(entity_to_replace)
    
    return replacement_entities, entity_to_replace

## Generate Distractors

In [None]:
with open('/content/drive/MyDrive/QuestionAnsweringModels/predicted_answers.pkl', 'rb') as f:
    predicted_answers = pickle.load(f)

In [None]:
def generate_predefined_entity_lists():
    predefined_entity_lists = dict()

    for topic in tqdm(set(squad['title'])):
        samples = pd.DataFrame(squad.filter(lambda x: x["title"] == topic)).sample(frac=1)
        unique_contexts = samples['context'].unique()
        filtered_entities = defaultdict(list)

        i = 0
        while i < len(unique_contexts):
            context_nes = get_text_nes(unique_contexts[i])
            for entity in context_nes:
                if entity['word'] not in filtered_entities[entity['type']]:
                    filtered_entities[entity['type']].append(entity['word'])
            i += 1
        predefined_entity_lists[topic] = filtered_entities
    return predefined_entity_lists

In [None]:
%%time
predefined_entity_lists = generate_predefined_entity_lists()

In [None]:
with open('/content/drive/MyDrive/QuestionAnsweringOutputs/predefined_entity_lists.pkl', 'wb') as f:
    pickle.dump(predefined_entity_lists, f)

In [None]:
with open('/content/drive/MyDrive/QuestionAnsweringModels/predefined_entity_lists.pkl', 'rb') as f:
    predefined_entity_lists = pickle.load(f)

In [None]:
def generate_distractors(predicted_answers):
    distractors = list()
    for i in tqdm(range(len(squad['context']))):
        replacement_entities, entity_to_replace = generate_replacement_entities(i)
        wrong_answers = list()
        if replacement_entities != None:
            for replacement_entity in replacement_entities:
                if entity_to_replace not in replacement_entity:
                    wrong_answers.append(predicted_answers[i]['prediction_text'].replace(entity_to_replace, replacement_entity))
        distractors.append({squad[i]['id']:wrong_answers})
    return distractors

In [None]:
distractors = generate_distractors(predicted_answers)

  0%|          | 0/10570 [00:00<?, ?it/s]

In [None]:
with open('/content/drive/MyDrive/QuestionAnsweringModels/distractors.pkl', 'wb') as f:
    pickle.dump(distractors, f)