# TriviaQA

In [None]:
!git clone https://github.com/aaronkossler/triviaqa.git

## Import Dataset

In [None]:
%%bash

pip install datasets

In [3]:
from datasets import load_dataset

trivia_qa_wikipedia = load_dataset('trivia_qa', name="rc.wikipedia")

In [39]:
# train_split = trivia_qa_wikipedia["train"].train_test_split(shuffle=False, train_size=7900)
train_split = trivia_qa_wikipedia["train"].train_test_split(shuffle=False, train_size=5)
evaluation = train_split["train"]
train = train_split["test"]
test = trivia_qa_wikipedia["validation"]

In [40]:
import json
import os

# Convert the evaluation set to the desired format
data = []
for item in evaluation:
    answer = {
        "Aliases": item["answer"]["aliases"],
        "MatchedWikiEntityName": item["answer"]["matched_wiki_entity_name"],
        "NormalizedAliases": item["answer"]["normalized_aliases"],
        "NormalizedMatchedWikiEntityName": item["answer"]["normalized_matched_wiki_entity_name"],
        "NormalizedValue": item["answer"]["normalized_value"],
        "Type": item["answer"]["type"],
        "Value": item["answer"]["value"],
    }
    entity_pages = [
        {
            "DocSource": item["entity_pages"]["doc_source"][index],
            "Filename": item["entity_pages"]["filename"][index],
            "Title": item["entity_pages"]["title"][index],
        }
        for index in range(len(item["entity_pages"]["filename"]))
    ]
    question = item["question"]
    question_id = item["question_id"]
    question_source = item["question_source"]
    search_results = []
    data_item = {
        "Answer": answer,
        "EntityPages": entity_pages,
        "Question": question,
        "QuestionId": question_id,
        "QuestionSource": question_source,
        "SearchResults": search_results,
    }
    data.append(data_item)

output = {
    "Data": data,
    "Domain": "Web",
    "VerifiedEval": False,
    "Version": 1.0,
}

# Write the output to a JSON file
if not os.path.exists("triviaqa/sets"):
    os.makedirs("triviaqa/sets")

with open("triviaqa/sets/evaluation.json", "w") as f:
    json.dump(output, f)

## Model Training

In [None]:
%%bash

pip install --upgrade pip
pip install farm-haystack[colab,inference]

In [8]:
import logging

logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)

In [None]:
from haystack.document_stores import InMemoryDocumentStore

documents = {}

for row in evaluation:
    for index, item in enumerate(row["entity_pages"]["filename"]):
        document_store = InMemoryDocumentStore(use_bm25=True)
        document = {
            "content": row["entity_pages"]["wiki_context"][index],
            "meta": {
                "question_id": row["question_id"],
                "filename": row["entity_pages"]["filename"][index]
            },
        }
        document_store.write_documents([document])
        documents[f"{row['question_id']}--{item}"] = document_store


In [None]:
from haystack.nodes import FARMReader

reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)

In [None]:
from haystack.nodes import BM25Retriever
from haystack.pipelines import ExtractiveQAPipeline

predictions = {}
for entry in evaluation:
    for item in entry["entity_pages"]["filename"]:
        retriever = BM25Retriever(document_store=documents[f"{entry['question_id']}--{item}"])
        pipe = ExtractiveQAPipeline(reader, retriever)
        prediction = pipe.run(
            query=entry["question"],
            params={"Retriever": {"top_k": 1}, "Reader": {"top_k": 1}})
        predictions[f"{entry['question_id']}--{item}"] = prediction["answers"][0].answer

In [44]:
if not os.path.exists("triviaqa/predictions"):
        os.makedirs("triviaqa/predictions")

# Convert the dictionary to a JSON string
json_string = json.dumps(predictions)

# Write the JSON string to a file
with open("triviaqa/predictions/evaluation_predictions.json", "w") as f:
    f.write(json_string)

## Evaluation

In [14]:
import sys
sys.path.append("./triviaqa")

In [34]:
from triviaqa.evaluation.triviaqa_evaluation import evaluate_triviaqa
from triviaqa.utils.dataset_utils import *
from triviaqa.utils.utils import read_json

In [None]:
dataset_file = 'triviaqa/sets/evaluation.json'
prediction_file = 'triviaqa/predictions/evaluation_predictions.json'

expected_version = 1.0
dataset_json = read_triviaqa_data(dataset_file)
if dataset_json['Version'] != expected_version:
    print('Evaluation expects v-{} , but got dataset with v-{}'.format(expected_version,dataset_json['Version']),
          file=sys.stderr)
key_to_ground_truth = get_key_to_ground_truth(dataset_json)
predictions = read_json(prediction_file)
eval_dict = evaluate_triviaqa(key_to_ground_truth, predictions)

In [46]:
print(eval_dict)

{'exact_match': 33.333333333333336, 'f1': 38.888888888888886, 'common': 9, 'denominator': 9, 'pred_len': 9, 'gold_len': 9}
