# Import Libraries and Read Datasets

In [13]:
!pip install langchain
!pip install sentence-transformers
!pip install faiss-cpu
!pip install langchainhub
!pip install tqdm
!pip install --upgrade pip
!pip install farm-haystack[colab,inference]#==1.22

zsh:1: no matches found: farm-haystack[colab,inference]#==1.22


In [20]:
import string
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings

def retrieve_wiki_headers_and_paragraphs(context, langchain=False):
  data = context.split("\n\n")
  current_header = "General"
  results = []

  for part in data:
    if part[:-1] not in string.punctuation and len(part.split()) < 10:
      current_header = part
    else:
      results.append((current_header, part))

  if not langchain:
    return results
  else:
    return [item[0] + " - " + item[1] for item in results]

In [21]:
def rag_get_context(question, context, log=False):
    paragraphs = retrieve_wiki_headers_and_paragraphs(context, langchain=True)
    vectorstore = FAISS.from_texts(texts=paragraphs, embedding=HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2"))
    retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 1}, return_parents=False)
    par = retriever.get_relevant_documents(question)

    return par[0].page_content

In [22]:
import json
import os
def convert_to_json(input, output_filename):
# Convert the evaluation set to the desired format
    data = []
    for item in input:
        answer = {
            "Aliases": item["answer"]["aliases"],
            "MatchedWikiEntityName": item["answer"]["matched_wiki_entity_name"],
            "NormalizedAliases": item["answer"]["normalized_aliases"],
            "NormalizedMatchedWikiEntityName": item["answer"]["normalized_matched_wiki_entity_name"],
            "NormalizedValue": item["answer"]["normalized_value"],
            "Type": item["answer"]["type"],
            "Value": item["answer"]["value"],
        }
        entity_pages = [
            {
                "DocSource": item["entity_pages"]["doc_source"][index],
                "Filename": item["entity_pages"]["filename"][index],
                "Title": item["entity_pages"]["title"][index],
            }
            for index in range(len(item["entity_pages"]["filename"]))
        ]
        question = item["question"]
        question_id = item["question_id"]
        question_source = item["question_source"]
        search_results = []
        data_item = {
            "Answer": answer,
            "EntityPages": entity_pages,
            "Question": question,
            "QuestionId": question_id,
            "QuestionSource": question_source,
            "SearchResults": search_results,
        }
        data.append(data_item)
    
    output = {
        "Data": data,
        "Domain": "Wikipedia",
        "VerifiedEval": False,
        "Version": 1.0,
    }
    
    # Write the output to a JSON file
    if not os.path.exists("triviaqa/sets"):
        os.makedirs("triviaqa/sets")
    
    with open(f"triviaqa/sets/{output_filename}.json", "w") as f:
        json.dump(output, f)
        
    return output

In [31]:
# # Script based on https://github.com/mandarjoshi90/triviaqa/blob/master/utils/convert_to_squad_format.py
# # We include functions that are modified from https://github.com/mandarjoshi90/triviaqa/tree/master/utils
# # cite: https://github.com/mandarjoshi90/triviaqa/

import os, re, json, nltk
from tqdm import tqdm
from langchain.document_loaders.telegram import text_to_docs


def add_triple_data(datum, page, domain):
    qad = {'Source': domain}
    for key in ['QuestionId', 'Question', 'Answer']:
        qad[key] = datum[key]
    for key in page:
        qad[key] = page[key]
    return qad


def get_qad_triples(data):
    qad_triples = []
    for datum in data['Data']:
        for key in ['EntityPages', 'SearchResults']:
            for page in datum.get(key, []):
                qad = add_triple_data(datum, page, key)
                qad_triples.append(qad)
    return qad_triples

# from utils.utils import get_file_contents

def get_file_contents(filename, encoding='utf-8'):
    with open(filename, encoding=encoding) as f:
        content = f.read()
    return content

# from utils.dataset_utils import read_triviaqa_data, get_question_doc_string

def read_triviaqa_data(qajson):
    data = read_json(qajson)
    # read only documents and questions that are a part of clean data set
    if data['VerifiedEval']:
        clean_data = []
        for datum in data['Data']:
            if datum['QuestionPartOfVerifiedEval']:
                if data['Domain'] == 'Web':
                    datum = read_clean_part(datum)
                clean_data.append(datum)
        data['Data'] = clean_data
    return data

def get_question_doc_string(qid, doc_name):
    return '{}--{}'.format(qid, doc_name)

def read_clean_part(datum):
    for key in ['EntityPages', 'SearchResults']:
        new_page_list = []
        for page in datum.get(key, []):
            if page['DocPartOfVerifiedEval']:
                new_page_list.append(page)
        datum[key] = new_page_list
    assert len(datum['EntityPages']) + len(datum['SearchResults']) > 0
    return datum

def read_json(filename, encoding='utf-8'):
    contents = get_file_contents(filename, encoding=encoding)
    return json.loads(contents)

def select_relevant_portion(text):
    paras = text.split('\n')
    selected = []
    done = False
    for para in paras:
        sents = sent_tokenize.tokenize(para)
        for sent in sents:
            words = nltk.word_tokenize(sent)
            for word in words:
                selected.append(word)
                if len(selected) >= 800:
                    done = True
                    break
            if done:
                break
        if done:
            break
        selected.append('\n')
    st = ' '.join(selected).strip()
    return st


#-------------------------------------------------------



def answer_index_in_document(answer, document):
    answer_list = answer['Aliases'] + answer['NormalizedAliases']
    for answer_string_in_doc in answer_list:
        index = document.find(answer_string_in_doc)
        if index != -1:
            return answer_string_in_doc, index
    return answer['NormalizedValue'], -1



def triviaqa_to_squad_format(triviaqa_file, data_dir, output_file):
    processed_question_ids = set()
    triviaqa_json = read_triviaqa_data(triviaqa_file)
    qad_triples = get_qad_triples(triviaqa_json)

    def generate_data():
        for triviaqa_example in tqdm(qad_triples, desc="Processing QAD Triples"):
            question_text = triviaqa_example['Question']
            if triviaqa_example['QuestionId'] in processed_question_ids:
                continue
            processed_question_ids.add(triviaqa_example['QuestionId'])
            
            # Remove text between square brackets
            text = re.sub(r'\[.*?\]', '',get_file_contents(os.path.join(data_dir, triviaqa_example['Filename']), encoding='utf-8'))
            # Remove links with http/https
            text = re.sub(r'https?://\S+', '', text)
            
            context_str = select_relevant_portion(text)
            
            # context = rag_get_context(question_text, text, log=True)
            # if context:
            #     context_str = f'{context[0][0]}\n{context[0][1]}'
            #     print(f"PARAGRAPH: {context_str}")
            # else:
            #     context_str=""
            #     print("No Context")
            
            para = {'context': context_str, 'qas': [{'question': question_text, 'answers': []}]}
            data = {'paragraphs': [para]}
            qa = para['qas'][0]
            qa['id'] = get_question_doc_string(triviaqa_example['QuestionId'], triviaqa_example['Filename'])
            qa['is_impossible'] = True
            ans_string, index = answer_index_in_document(triviaqa_example['Answer'], context_str)

            if index != -1:
                qa['answers'].append({'text': ans_string, 'answer_start': index})
                qa['is_impossible'] = False
        
            
            yield data

    triviaqa_as_squad = {'data': list(generate_data()), 'version': '2.0'}

    with open(output_file, 'w', encoding='utf-8') as outfile:
        json.dump(triviaqa_as_squad, outfile, indent=2, sort_keys=True, ensure_ascii=False)

# Clear unused variables
    del triviaqa_json
    del triviaqa_as_squad
    del qad_triples

sent_tokenize = nltk.data.load('tokenizers/punkt/english.pickle')

Get the trivia QA dataset

In [12]:
!cd BERT/
!brew install wget
!wget "https://nlp.cs.washington.edu/triviaqa/data/triviaqa-rc.tar.gz"

Running `brew update --auto-update`...
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/portable-ruby/portable-ruby/blobs/sha256:d783cbeb6e6ef0d71c0b442317b54554370decd6fac66bf2d4938c07a63f67be[0m
######################################################################### 100.0%                2.3%
[34m==>[0m [1mPouring portable-ruby-3.1.4.arm64_big_sur.bottle.tar.gz[0m
[34m==>[0m [1mAuto-updated Homebrew![0m
Updated 3 taps (homebrew/services, homebrew/core and homebrew/cask).
[34m==>[0m [1mNew Formulae[0m
abi3audit                                pmix
action-validator                         postgresql@16
ain                                      presenterm
ali                                      pter
amass                                    python-abseil
ansible@8                                python-anytree
apkleaks                                 python-argcomplete
appstream                                python-asn1crypto
argc                     

In [18]:
%%bash
mkdir ./triviaqa
mkdir ./SQuadformatted
tar -C ./triviaqa -zxf triviaqa-rc.tar.gz

mkdir: ./triviaqa: File exists
mkdir: ./SQuadformatted: File exists


In [40]:
from datasets import load_dataset

trivia_qa_wikipedia = load_dataset('trivia_qa', name="rc.wikipedia")

training_split = trivia_qa_wikipedia["train"].train_test_split(shuffle=False, train_size=7900)
validation = convert_to_json(training_split["train"], output_filename="validation")
train = convert_to_json(training_split["test"], output_filename="training")

Convert

In [38]:
triviaqa_to_squad_format(triviaqa_file="triviaqa/sets/training.json", data_dir='triviaqa/evidence/wikipedia', output_file='SQuadformatted/triviaqa_train.json') # 7901 and following

Processing QAD Triples:   0%|          | 243/97386 [00:00<03:02, 532.30it/s]


KeyboardInterrupt: 

In [41]:
triviaqa_to_squad_format('triviaqa/sets/validation.json', 'triviaqa/evidence/wikipedia', 'SQuadformatted/triviaqa_validation.json')

Processing QAD Triples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 13261/13261 [00:23<00:00, 573.33it/s]


### Training

In [46]:
from haystack.nodes import FARMReader

reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", use_gpu=True)
epochs = 1

In [47]:
import logging
from tqdm import tqdm

logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.WARNING)
logging.getLogger("haystack").setLevel(logging.WARNING)

reader.train(data_dir="SQuadformatted", train_filename="triviaqa_train.json", use_gpu=True, n_epochs=epochs, save_dir="my_model/")
#reader.train(data_dir="SQuadformatted", train_filename="triviaqa_train.json", dev_filename="triviaqa_validation.json", checkpoint_every=500, checkpoint_root_dir="checkpoints", evaluate_every=5000, use_gpu=False, n_epochs=epochs, save_dir="my_model/")# "{epoch}"  + "/")     


Preprocessing dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 106/106 [01:07<00:00,  1.58 Dicts/s]
Train epoch 0/0 (Cur. train loss: 2.7010):   0%|          | 62/33859 [00:38<5:49:16,  1.61it/s]


KeyboardInterrupt: 

In [None]:
question=f"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use one to two words or numbers maximum and keep the answer concise. question: {question}"

## context-tools

In [49]:
from datasets import load_dataset
trivia_qa_wikipedia = load_dataset('trivia_qa', name="rc.wikipedia")
test = trivia_qa_wikipedia["validation"]

In [50]:
def build_context(item):
    texts = []
    for text in item["entity_pages"]["wiki_context"]:
      texts.append(text)

    context = " ".join(texts)
    return context

## Prediction

In [53]:
from haystack.nodes import FARMReader
model_name="distilbert-base-uncased-distilled-squad"
reader = FARMReader(model_name_or_path=f"{model_name}", use_gpu=False)

In [54]:
import logging
import os, json
from tqdm import tqdm
from haystack.schema import Document
from haystack.nodes import BM25Retriever

predictions = {}
count=0
failed= []

logging.basicConfig(format="%(levelname)s - %(name)s -  %(message)s", level=logging.ERROR)

for entry in tqdm(test, desc="Processing Predictions"):
    try:
        context = build_context(entry)
        rag_context = rag_get_context(entry["question"], context, log=True)
        #print(rag_context)
        prediction = reader.predict(query=entry["question"], documents=[Document(rag_context)])
        predictions[entry['question_id']] = prediction['answers'][0].answer
        #print(f"question: {entry['question_id']}, answer: {prediction['answers'][0].answer}")
        #print("_"*25)
    except KeyboardInterrupt as error:
        json_string = json.dumps(predictions)
        if not os.path.exists("predictions"):
            os.makedirs("predictions")
        with open(f"predictions/validation_predictions_{model_name}.json", "w") as f:
            f.write(json_string)
        print("saved")
        raise error
    except Exception as error:
        print(f"Failure for question {entry['question_id']} ({type(error).__name__}: {error})")
        failed.append(entry['question_id'])
        
#print(f"FAILED: {failed}")
json_string = json.dumps(predictions)

if not os.path.exists("predictions"):
        os.makedirs("predictions")

with open(f"predictions/validation_predictions_{model_name}.json", "w") as f:
    f.write(json_string)

Processing Predictions:   0%|          | 0/7993 [00:00<?, ?it/s]
Inferencing Samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 14.21 Batches/s]
Processing Predictions:   0%|          | 1/7993 [00:01<3:12:21,  1.44s/it]
Inferencing Samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 18.57 Batches/s]
Processing Predictions:   0%|          | 2/7993 [00:03<3:48:43,  1.72s/it]
Inferencing Samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 19.22 Batches/s]
Processing Predictions:   0%|          | 3/7993 [00:03<2:23:24,  1.08s/it]
Inferencing Samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 19.64 Batches/s]
Processing Predictions:   0%|          | 4/7993 [00:04<2:13:24,  1.00s/it]
Inferencing Samples: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00, 20.12 Batches/s]
Processing Predictions:   0%|          | 5/7993 [00:05<2:14:20,  1.01s/it][E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E th

saved


KeyboardInterrupt: 

## Evaluation

In [21]:
import sys
sys.path.append("../")

In [23]:
from evaluation.triviaqa_evaluation import evaluate_triviaqa
from utils.dataset_utils import *
from utils.utils import read_json

In [24]:
dataset_file = 'triviaqa/sets/evaluation.json'
prediction_file = f'predictions/validation_predictions_{model_name}.json'

expected_version = 1.0
dataset_json = read_triviaqa_data(dataset_file)
if dataset_json['Version'] != expected_version:
    print('Evaluation expects v-{} , but got dataset with v-{}'.format(expected_version,dataset_json['Version']),
          file=sys.stderr)
key_to_ground_truth = get_key_to_ground_truth(dataset_json)
predictions = read_json(prediction_file)
eval_dict = evaluate_triviaqa(key_to_ground_truth, predictions)

em=0: The Sound of Music ['sunset boulevard', 'sunset bulevard', 'west sunset boulevard', 'sunset blvd']
em=0: David Cameron ['henry campbell bannerman', 'sir henry campbell bannerman', 'campbell bannerman']
em=0: her mother ['lauren becall', 'loren bacall', 'lauren becal', 'lauren bacall', 'betty j perske', 'betty perske', 'betty joan perske', 'bacall', 'betty joan perski']
em=0: red ['greenishness', 'color green', '0 255 0', 'green color', 'green', 'rgb 0 255 0', 'avacado color', 'greenest', 'list of terms associated with color green', 'greenness', 'greenishly', 'colour green', 'pastel green', 'green colour']
em=0: Def Leppard ['richard marx', 'richard noel marx']
em=0: The Atlantics ['screaming abdabs', 'megadeaths', 'clive metcalfe', 'pink floyd band', 'pink flowd', 'meggadeaths', 'pi5', 'pink floyd', 'screaming ab dabs', 'grey floyd', 'architectural abdabs', 'pink flod', 'pinkfloyd', 'pik floyd', 'pink floyd sound', 'tea set', 'pink floid', 't set', 'pink floyd trivia', 'notable o

Missed question sfq_18519 will receive score 0.


em=0: zamkowych ['bram stoker s dracula 1992 film', 'bram stoker s dracula', 'bram stoker s dracula film', 'bram stokers dracula', 'dracula 1992 film']
em=0: French ['filerbuster', 'filibusting', 'filibustuster', 'filibuster in senate', 'fillabuster', 'fillibuster', 'talked out', 'filibusters', 'filibustering', 'filibuster', 'filibuster legislative tactic', 'filibustered', 'filabuster']
em=0: the pope ['recent pope', 'timeline of papacy', 'popes by nationality', 'list of roman catholic popes', 'german pope', 'list of german popes', 'list of bishops of rome', 'pope s', 'list of popes', 'list of popes by nationality', 'syrian popes', 'list of roman catholic popes by nationality', 'list of vatican monarchs', 'african popes', 'popes', 'list of syrian popes', 'list of popes of rome', 'list of catholic popes']
em=0: Stockholm syndrome ['stockholm sweden', 'stochholm', 'stockohlm', 'un locode sesto', 'stockolm', 'sthlm', 'science in stockholm', 'sockholm', 'estocolmo', 'weather in stockholm',

In [25]:
print(eval_dict)

{'exact_match': 32.51595145752533, 'f1': 41.96009135655252, 'common': 7992, 'denominator': 7993, 'pred_len': 7992, 'gold_len': 7993}
