<a href="https://colab.research.google.com/github/DAlkemade/bert-for-fever/blob/master/L101_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/drive')
!pip install -r "/content/drive/My Drive/Overig/requirements.txt"

In [0]:
import os

In [0]:
WORK_DIR = '/content/drive/My Drive/Overig'
db_file = os.path.join(WORK_DIR, 'fever.db')
index_file = os.path.join(WORK_DIR, 'fever-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz')
in_file = os.path.join(WORK_DIR, 'test.jsonl')
out_file = os.path.join(WORK_DIR, 'test_baseline_pages.sentences.p5.s5.jsonl')
EMPTY_TOKEN = 'EMPTY'
TESTING = False
PREDICT_SENTENCES = False
MAX_DOCS = 50
MAX_SENTS = 5
print(db_file)
print(index_file)
print(in_file)
print(out_file)

In [0]:
from drqa.retriever import DocDB, utils


# Copied FeverDocDB from Papelo repo
class FEVERDocumentDatabase(DocDB):

    def __init__(self,path=None):
        super().__init__(path)

    def get_doc_lines(self, doc_id):
        """Fetch the raw text of the doc for 'doc_id'."""
        cursor = self.connection.cursor()
        cursor.execute(
            "SELECT lines FROM documents WHERE id = ?",
            (utils.normalize(doc_id),)
        )
        result = cursor.fetchone()
        cursor.close()
        return result if result is None else result[0]

    def get_non_empty_doc_ids(self):
        """Fetch all ids of docs stored in the db."""
        cursor = self.connection.cursor()
        cursor.execute("SELECT id FROM documents WHERE length(trim(text)) > 0")
        results = [r[0] for r in cursor.fetchall()]
        cursor.close()
        return results

In [0]:
from allennlp.common import Registrable
from fever.reader.document_database import FEVERDocumentDatabase


class RetrievalMethod(Registrable):

    def __init__(self,database:FEVERDocumentDatabase):
        self.database = database

    def get_sentences_for_claim(self,claim_text,include_text=False):
        pass

In [0]:
def parse_doc(doc_raw):
    """
    Parse a list of lines from a raw document text, with the index in the list
    correponding to the line index in the data entries
    """
    new = []
    #   lines = doc_raw.split("\n")
    for line in doc_raw:
        # print('Line: {}'.format(line))
        line = line.split("\t")
    #   TODO: THIS MIGHT DROP PARTS OF SENTENCES AFTER A TAB
        if len(line[1]) > 1:
            new.append(line[1])
        else:
            new.append(EMPTY_TOKEN)
    return new

In [0]:
import math
from drqa import retriever
from drqascripts.retriever.build_tfidf_lines import OnlineTfidfDocRanker

@RetrievalMethod.register("top_docs")
class TopNDocsTopNSents(RetrievalMethod):

    class RankArgs:
        def __init__(self):
            self.ngram = 2
            self.hash_size = int(math.pow(2,24))
            self.tokenizer = "simple"
            self.num_workers = None

    def __init__(self, database, index, n_docs, n_sents):
        super().__init__(database)
        self.n_docs = n_docs
        self.n_sents = n_sents
        print("Retrieve tfidf indices")
        self.ranker = retriever.get_class('tfidf')(tfidf_path=index)
        print("Retrieved tfidf indices")
        self.onlineranker_args = self.RankArgs()

    def get_docs_for_claim(self, claim_text):
        doc_names, doc_scores = self.ranker.closest_docs(claim_text, self.n_docs)
        return zip(doc_names, doc_scores)

    def tf_idf_sim(self, claim, lines, freqs=None):
        tfidf = OnlineTfidfDocRanker(self.onlineranker_args, [line["sentence"] for line in lines], freqs)
        line_ids, scores = tfidf.closest_docs(claim,self.n_sents)
        ret_lines = []
        for idx, line in enumerate(line_ids):
            ret_lines.append(lines[line])
            ret_lines[-1]["score"] = scores[idx]
        return ret_lines
    
    def get_only_docs_for_claim(self, claim_text):
        pages = self.get_docs_for_claim(claim_text)
        sorted_p = list(sorted(pages, reverse=True, key=lambda elem: elem[1]))
        pages = [p[0] for p in sorted_p[:self.n_docs]]
        return pages    

    def get_sentences_for_claim(self,claim_text,include_text=False):
        pages = self.get_docs_for_claim(claim_text)
        sorted_p = list(sorted(pages, reverse=True, key=lambda elem: elem[1]))
        pages = [p[0] for p in sorted_p[:self.n_docs]]
        p_lines = []
        for page in pages:
            lines = self.database.get_doc_lines(page)
            lines = parse_doc(lines)

            p_lines.extend(zip(lines, [page] * len(lines), range(len(lines))))
        lines = []
        for p_line in p_lines:
            lines.append({
                "sentence": p_line[0],
                "page": p_line[1],
                "line_on_page": p_line[2]
            })
        scores = self.tf_idf_sim(claim_text, lines)

        if include_text:
            return scores

        return [(s["page"], s["line_on_page"]) for s in scores]

    


In [0]:
import argparse
import json
from multiprocessing.pool import ThreadPool
from tqdm import tqdm

import multiprocessing

CORES = multiprocessing.cpu_count()


def process_line(method,line, args):
    if PREDICT_SENTENCES:
        sents = method.get_sentences_for_claim(line["claim"])
        pages = list(set(map(lambda sent:sent[0],sents)))
        line["predicted_pages"] = pages
        line["predicted_sentences"] = sents
        return line
    else:
        pages = list(method.get_only_docs_for_claim(line["claim"]))
        line["predicted_pages"] = pages
        return line


def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def get_map_function(parallel):
    return p.imap if parallel else map


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--database', type=str, help='/path/to/saved/db.db')
    parser.add_argument('--index', type=str, help='/path/to/saved/db.db')
    parser.add_argument('--in-file', type=str, help='/path/to/saved/db.db')
    parser.add_argument('--out-file', type=str, help='/path/to/saved/db.db')
    parser.add_argument('--max-page',type=int)
    parser.add_argument('--max-sent',type=int)
    parser.add_argument('--cuda-device',type=int,default=-1)
    parser.add_argument('--parallel',type=str2bool,default=True)
    parser.add_argument('--threads', type=int, default=None)
    sequence = f'--database /content/drive/My Drive/Overig/fever.db --index /content/drive/My Drive/Overig/fever-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz --in-file /content/drive/My Drive/Overig/dev.jsonl --out-file /content/drive/My Drive/Overig/dev.sentences.p5.s5.jsonl --max-page 5 --max-sent 5'
    sequence_list = ['--database', db_file, '--index', index_file, '--in-file', in_file, '--out-file', out_file, '--max-page', str(MAX_DOCS), '--max-sent', str(MAX_SENTS)]
    print(sequence_list)
    args = parser.parse_args(sequence_list) 

    

In [0]:
print('Create database object')
database = FEVERDocumentDatabase(args.database)
print("Create TopNDocsTopNSents object")
method = TopNDocsTopNSents(database, args.index, args.max_page, args.max_sent)

In [0]:
old_preds = []
completed_claim_ids = []

# with open('/content/drive/My Drive/Overig/train_baseline_pages.sentences.p5.s5.jsonl',"r") as prev_prediction_file:

#     for line in prev_prediction_file:
#         old_preds.append(json.loads(line))
#     
#     for pred in old_preds:
#         completed_claim_ids.append(pred['id'])

len(completed_claim_ids)
    

In [0]:
print("Using {} cores".format(args.threads))
print("Start processing")

In [0]:
processed = dict()
with open(args.in_file,"r") as in_file, open(args.out_file, "w+") as out_file:
    lines = []
    for line in in_file:
        lines.append(json.loads(line))
    # print(f'Length OG set: {len(lines)}')
    
    if TESTING:
        lines = lines[:10]
    uncompleted_lines = []
    for line in lines:
        if not line['id'] in completed_claim_ids:
            uncompleted_lines.append(line)
    for pred in old_preds:
        out_file.write(json.dumps(pred) + "\n")
    # print(f'\nLength uncompleted: {len(uncompleted_lines)}')

    with ThreadPool(args.threads) as p:
        for line in tqdm(get_map_function(args.parallel)(lambda line: process_line(method, line, args), uncompleted_lines),
                            total=len(uncompleted_lines)):
            out_file.write(json.dumps(line) + "\n")
