# TriviaQA

In [None]:
!git clone https://github.com/aaronkossler/triviaqa.git

## Import Dataset

In [None]:
%%bash

pip install datasets

In [None]:
from datasets import load_dataset

trivia_qa_wikipedia = load_dataset('trivia_qa', name="rc.wikipedia")

In [None]:
train_split = trivia_qa_wikipedia["train"].train_test_split(shuffle=False, train_size=7900)
# train_split = trivia_qa_wikipedia["train"].train_test_split(shuffle=False, train_size=10)
validation = train_split["train"]
train = train_split["test"]
test = trivia_qa_wikipedia["validation"]

In [None]:
import json
import os

# Convert the evaluation set to the desired format
data = []
for item in test:
    answer = {
        "Aliases": item["answer"]["aliases"],
        "MatchedWikiEntityName": item["answer"]["matched_wiki_entity_name"],
        "NormalizedAliases": item["answer"]["normalized_aliases"],
        "NormalizedMatchedWikiEntityName": item["answer"]["normalized_matched_wiki_entity_name"],
        "NormalizedValue": item["answer"]["normalized_value"],
        "Type": item["answer"]["type"],
        "Value": item["answer"]["value"],
    }
    entity_pages = [
        {
            "DocSource": item["entity_pages"]["doc_source"][index],
            "Filename": item["entity_pages"]["filename"][index],
            "Title": item["entity_pages"]["title"][index],
        }
        for index in range(len(item["entity_pages"]["filename"]))
    ]
    question = item["question"]
    question_id = item["question_id"]
    question_source = item["question_source"]
    search_results = []
    data_item = {
        "Answer": answer,
        "EntityPages": entity_pages,
        "Question": question,
        "QuestionId": question_id,
        "QuestionSource": question_source,
        "SearchResults": search_results,
    }
    data.append(data_item)

output = {
    "Data": data,
    "Domain": "Wikipedia",
    "VerifiedEval": False,
    "Version": 1.0,
}

# Write the output to a JSON file
if not os.path.exists("triviaqa/sets"):
    os.makedirs("triviaqa/sets")

with open("triviaqa/sets/evaluation.json", "w") as f:
    json.dump(output, f)

## Preprocessing

## Model Prediction

In [None]:
%%bash

pip install --upgrade pip
pip install farm-haystack[colab,inference]

In [None]:
import logging

logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.INFO)

In [None]:
from haystack.document_stores import InMemoryDocumentStore

documents = {}

for row in test:
    document_store = InMemoryDocumentStore(use_bm25=True)

    for article in row["entity_pages"]["wiki_context"]:
        document = {
            "content": article,
            "meta": {
                "question_id": row["question_id"]
            },
        }
        document_store.write_documents([document])

    documents[row['question_id']] = document_store

In [None]:
from haystack.nodes import FARMReader

reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)

In [None]:
from haystack.nodes import BM25Retriever
from haystack.pipelines import ExtractiveQAPipeline

predictions = {}
for entry in test:
    retriever = BM25Retriever(document_store=documents[entry['question_id']])
    pipe = ExtractiveQAPipeline(reader, retriever)
    prediction = pipe.run(
        query=entry["question"],
        params={"Retriever": {"top_k": 1}, "Reader": {"top_k": 1}})
    predictions[entry['question_id']] = prediction["answers"][0].answer

[1;30;43mDie letzten 5000 Zeilen der Streamingausgabe wurden abgeschnitten.[0m
Inferencing Samples: 100%|██████████| 1/1 [00:00<00:00,  2.25 Batches/s]
Inferencing Samples: 100%|██████████| 1/1 [00:00<00:00,  1.72 Batches/s]
Inferencing Samples: 100%|██████████| 1/1 [00:00<00:00,  1.75 Batches/s]
Inferencing Samples: 100%|██████████| 1/1 [00:00<00:00,  2.26 Batches/s]
Inferencing Samples: 100%|██████████| 1/1 [00:00<00:00,  2.81 Batches/s]
Inferencing Samples: 100%|██████████| 2/2 [00:00<00:00,  2.15 Batches/s]
Inferencing Samples: 100%|██████████| 2/2 [00:01<00:00,  1.50 Batches/s]
Inferencing Samples: 100%|██████████| 2/2 [00:01<00:00,  1.40 Batches/s]
Inferencing Samples: 100%|██████████| 1/1 [00:00<00:00,  2.31 Batches/s]
Inferencing Samples: 100%|██████████| 1/1 [00:00<00:00,  5.13 Batches/s]
Inferencing Samples: 100%|██████████| 4/4 [00:02<00:00,  1.60 Batches/s]
Inferencing Samples: 100%|██████████| 1/1 [00:00<00:00,  2.86 Batches/s]
Inferencing Samples: 100%|██████████| 2/2 [

In [None]:
if not os.path.exists("triviaqa/predictions"):
        os.makedirs("triviaqa/predictions")

# Convert the dictionary to a JSON string
json_string = json.dumps(predictions)

# Write the JSON string to a file
with open("triviaqa/predictions/test_predictions.json", "w") as f:
    f.write(json_string)

## Evaluation

In [None]:
import sys
sys.path.append("./triviaqa")

In [None]:
from triviaqa.evaluation.triviaqa_evaluation import evaluate_triviaqa
from triviaqa.utils.dataset_utils import *
from triviaqa.utils.utils import read_json

In [None]:
dataset_file = 'triviaqa/sets/evaluation.json'
prediction_file = 'triviaqa/predictions/test_predictions.json'

expected_version = 1.0
dataset_json = read_triviaqa_data(dataset_file)
if dataset_json['Version'] != expected_version:
    print('Evaluation expects v-{} , but got dataset with v-{}'.format(expected_version,dataset_json['Version']),
          file=sys.stderr)
key_to_ground_truth = get_key_to_ground_truth(dataset_json)
predictions = read_json(prediction_file)
eval_dict = evaluate_triviaqa(key_to_ground_truth, predictions)

em=0: Phantom of the Opera ['sunset boulevard', 'sunset bulevard', 'west sunset boulevard', 'sunset blvd']
em=0: Sir Robert Walpole ['henry campbell bannerman', 'sir henry campbell bannerman', 'campbell bannerman']
em=0: Phyllis Hyman ['exiles', 'voluntary exile', 'forced exile', 'banish', 'self exile', 'exile politics and government', 'exile in greek tragedy', 'sent into exile', 'banishment', 'transported for life', 'exile', 'internal exile', 'exile and banishment']
em=0: breast cancer ['aids related cancer', 'sporadic cancer', 'cancer disease', 'malignant tumors', 'cancers', 'carcinophobia', 'cancer', 'cancer diagnosis', 'malignant neoplastic disease', 'malignant neoplasm', 'tumour virus', 'cancer medicine', 'deaths by cancer', 'malignant tumour', 'epithelial cancers', 'solid cancer', 'cancerous', 'borderline cancer', 'invasive cancer', 'anti cancer', 'cancer pathology', 'cancer signs', 'cancer aromatase', 'cancer therapy', 'financial toxicity', 'cancerophobia', 'cancer en cuirasse',

In [None]:
print(eval_dict)

{'exact_match': 47.59164268735143, 'f1': 55.40320738832585, 'common': 7993, 'denominator': 7993, 'pred_len': 7993, 'gold_len': 7993}
