In [11]:
import json
import faiss
import random
import numpy as np
from transformers import BertModel, BertTokenizer
from sentence_transformers import SentenceTransformer
import torch

from tqdm.notebook import tqdm

## Loading and Testing

In [3]:
# Load dev-v2.0.json
with open('../data/dev-v2.0.json') as f:
    data = json.load(f)

# Load train-v2.0.json
with open('../data/train-v2.0.json') as f:
    train = json.load(f)

In [4]:
# The data is divided as follows:
# First there is the version and the data
# The data is divided into topics, which are divided into paragraphs for each topic
# Each paragraph has a list of qas (questions and answers)

# Get relevant data as data['data'][i]['paragraphs'][j]['qas'][k]['question']
print(data['data'][0]['paragraphs'][0]['qas'][0])
print(train['data'][0]['paragraphs'][0]['qas'][0])

{'question': 'In what country is Normandy located?', 'id': '56ddde6b9a695914005b9628', 'answers': [{'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}], 'is_impossible': False}
{'question': 'When did Beyonce start becoming popular?', 'id': '56be85543aeaaa14008c9063', 'answers': [{'text': 'in the late 1990s', 'answer_start': 269}], 'is_impossible': False}


In [63]:
# Combine the question and the answer text to get the form
# question: answer
text = "Question: " + data['data'][0]['paragraphs'][0]['qas'][0]['question'] + '\nAnswer: ' + data['data'][0]['paragraphs'][0]['qas'][0]['answers'][0]['text']
print(text)

Question: In what country is Normandy located?
Answer: France


## Embeddings

In [33]:
# Load pre-trained model tokenizer 
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# # Load pre-trained model
# model = BertModel.from_pretrained('bert-base-uncased')

model = SentenceTransformer('sentence-transformers/paraphrase-albert-base-v2')

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.74k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/827 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/46.7M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/464 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/245 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [6]:
"Creates the embedings for the given text, returning a list with 768 elements"
"This should be used if huggingface's transformers are used"
def create_embedings(tokenizer, model, text):  
  input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=True)

  with torch.no_grad():
    outputs = model(**input_ids)

  "It is considered as sentence embeddings the weights of the [CLS] token (first token) of the last layer"
  embedding = outputs.last_hidden_state[:, 0, :]
  return embedding

In [34]:
BATCH_SIZE = 256

# Collect all questions and their corresponding data items and paragraphs
all_questions = []
data_paragraphs = []
for data_item in train['data']:
    for paragraph in data_item['paragraphs']:
        questions = [qa['question'] for qa in paragraph['qas']]
        all_questions.extend(questions)
        data_paragraphs.extend([(data_item, paragraph)] * len(questions))

# Encode all questions in a batch
embeddings = model.encode(all_questions, batch_size=BATCH_SIZE, show_progress_bar=True)
question_data = []

# Iterate over the embeddings and the corresponding data items and paragraphs
for embedding, (data_item, paragraph) in zip(embeddings, data_paragraphs):
    for i, qa in enumerate(paragraph['qas']):
        # Setup the tuple based on 'is_impossible' flag
        if qa['is_impossible']:
            qa_object = (qa['question'], None, embedding)
        else:
            qa_object = (qa['question'], qa['answers'][0]['text'], embedding)

        question_data.append(qa_object)

Batches:   0%|          | 0/510 [00:00<?, ?it/s]

## Creating FAISS index for textual embeddings

In [21]:
EMBEDDING_DIM = model.get_sentence_embedding_dimension()
NUM_RESULTS = 3

In [22]:
class EmbeddingDatabase:
    "Database which contains the text in a list and the embeddings in a faiss index"

    def __init__(self, EMBEDDING_DIM):
        self.texts = []
        self.index = faiss.IndexFlatL2(EMBEDDING_DIM)
        self.index_text = faiss.IndexIDMap(self.index)

    def add(self, text, embedding):
        "Add the text and the embedding to the database"
        self.texts.append(text)
        self.index_text.add_with_ids(np.array([embedding]), np.array([len(self.texts)-1]))

    def search(self, query, k=5):
        "Search for the k nearest neighbours of the query"
        D, I = self.index_text.search(np.array([query]), k)
        return [(self.texts[i], D[0][j]) for j, i in enumerate(I[0])]


In [23]:
def add_question_answer(qa_object, database):
    question, answer, embedding = qa_object

    if answer is None:
        answer = ""

    text = "Question: " + question + '\nAnswer: ' + answer

    database.add(text, embedding)

In [25]:
# Test the database
db = EmbeddingDatabase(EMBEDDING_DIM)

# Add all the questions and answers from all paragraphs from all topics to the database
for qa_object in tqdm(question_data, desc="Adding questions and answers to database"):
    add_question_answer(qa_object, db)

Adding questions and answers to database:   0%|          | 0/1074011 [00:00<?, ?it/s]

In [26]:
def get_context(text, database):
    "Converts the text into an embedding and then gets the formatted text of the nearest neighbour"
    embedding = model.encode(text)
    results = database.search(embedding, NUM_RESULTS)

    context = "Context:\n\n" + "\n\n".join(result[0] for result in results)

    return context

In [32]:
text = "Who is Rutherford?"
context = get_context(text, db)

print(context)

Context:

Question: When did Rutherford introduce the new name for the Society?
Answer: July 26, 1931

Question: What biblical passage was the name Jehovah's witnesses based on?
Answer: Isaiah 43:10

Question: What system did Rutherford eliminate in 1932?
Answer: locally elected elders
