In [1]:
import json
import faiss
import random
import numpy as np

## Loading and Testing

In [3]:
# Load dev-v2.0.json
with open('../dev-v2.0.json') as f:
    data = json.load(f)

In [3]:
# The data is divided as follows:
# First there is the version and the data
# The data is divided into topics, which are divided into paragraphs for each topic
# Each paragraph has a list of qas (questions and answers)

# Get relevant data as data['data'][i]['paragraphs'][j]['qas'][k]['question']
print(data['data'][0]['paragraphs'][0]['qas'][0])

{'question': 'In what country is Normandy located?', 'id': '56ddde6b9a695914005b9628', 'answers': [{'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}, {'text': 'France', 'answer_start': 159}], 'is_impossible': False}


In [26]:
# Combine the question and the answer text to get the form
# question: answer
text = "Question: " + data['data'][0]['paragraphs'][0]['qas'][0]['question'] + '\nAnswer: ' + data['data'][0]['paragraphs'][0]['qas'][0]['answers'][0]['text']
print(text)

Question: In what country is Normandy located?
Answer: France


## Creating FAISS index for textual embeddings

In [5]:
EMBEDDING_DIM = 16
NUM_RESULTS = 3

def random_embedding(text):
    "Return EMBEDDING_DIM dimensional random vector"
    return [random.random() for i in range(EMBEDDING_DIM)]

In [6]:
class EmbeddingDatabase:
    "Database which contains the text in a list and the embeddings in a faiss index"

    def __init__(self, EMBEDDING_DIM):
        self.texts = []
        self.index = faiss.IndexFlatL2(EMBEDDING_DIM)
        self.index_text = faiss.IndexIDMap(self.index)

    def add(self, text, embedding):
        "Add the text and the embedding to the database"
        self.texts.append(text)
        self.index_text.add_with_ids(np.array([embedding]), np.array([len(self.texts)-1]))

    def search(self, query, k=5):
        "Search for the k nearest neighbours of the query"
        D, I = self.index_text.search(np.array([query]), k)
        return [(self.texts[i], D[0][j]) for j, i in enumerate(I[0])]
    

In [7]:
def add_question_answer(question_dict, database):
    question = question_dict['question']
    if question_dict['is_impossible']:
        answer = "The answer is impossible"
    else:
        answer = question_dict['answers'][0]['text']
    
    text = "Question: " + question + '\nAnswer: ' + answer
    embedding = random_embedding(text)

    database.add(text, embedding)

In [8]:
# Test the database
db = EmbeddingDatabase(EMBEDDING_DIM)

# Add all the questions and answers from all paragraphs from all topics to the database
for topic in data['data']:
    for paragraph in topic['paragraphs']:
        for question_dict in paragraph['qas']:
            add_question_answer(question_dict, db)

In [9]:
def get_context(text, database):
    "Converts the text into an embedding and then gets the formatted text of the nearest neighbour"
    embedding = random_embedding(text)
    results = database.search(embedding, NUM_RESULTS)
    
    context = "Context:\n\n" + "\n\n".join(result[0] for result in results)

    return context

In [11]:
text = "Question: What is the name of the repository? \nAnswer: The name of the repository is squad"
context = get_context(text, db)

print(context)

Context:

Question: What time frame doesn't the Seven Years War cover?
Answer: The answer is impossible

Question: What is R-OOC-R?
Answer: The answer is impossible

Question: What percentage of married couples had children living with them?
Answer: The answer is impossible
