In [1]:
import nltk
from nltk.corpus import stopwords
import uuid
import json
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import linear_kernel
import traceback
import csv

import tfidf_functions

In [2]:
# Read in an entire story based on document id
def get_document_text(document_id):
    try:
        text=""
        try:
            f=open(f"../narrativeqa/tmp/{document_id}.content", "r", encoding="utf-8")
            text = f.read()
            f.close()
        except:
            f=open(f"../narrativeqa/tmp/{document_id}.content", "r", encoding="ISO-8859-1")
            text = f.read()
            f.close()
        return text
    except Exception as e:
        print(f"Error getting document {document_id}")
        print(f"Exception: {e}")        


In [None]:
#def split_document(text, num_characters):
    

In [42]:
# Split text into passages to serve as documents in tfidf
def split_document_and_tfidf_vectorize(text, num_characters=1500):
    passages = [text[i:i+num_characters] for i in range(0, len(text), num_characters)]

    #passages = text.split('\n\n')

    #passages = list(filter(None, passages))
    #print(len(passages))

    vectorizer = TfidfVectorizer(stop_words=set(stopwords.words("english")))
    try:
        tfidf = vectorizer.fit_transform(passages)
        return passages, tfidf, vectorizer
    except Exception:
        traceback.print_exc()
        print(passages)
        

    
    

In [30]:
class QAPair:
    passages = []
    
    def __init__(self, document_id, question, answer1, answer2):
        self.document_id = document_id
        self.question = question
        self.answer1 = answer1
        self.answer2 = answer2
        self.id = uuid.uuid4()


In [5]:
# Load all question answer pairs for the available documents
#document_id, set, question, answer1, answer2, question_tokenized, answer1_tokenized, answer2_tokenized.
def get_question_answer_pairs():
    document_questions = {}
    with open('../narrativeqa/qaps.csv', newline='') as csvfile:
        rows = csv.DictReader(csvfile, delimiter=',')
        for qpair in rows:
            document_id = qpair['document_id']

            if document_id not in document_questions.keys():
                document_questions[document_id] = []

            document_questions[document_id].append(QAPair(document_id, qpair['question_tokenized'], qpair['answer1'], qpair['answer2']))

    return document_questions



In [7]:
document_questions = get_question_answer_pairs()

In [51]:
# Get the top n passage indices in regards to a query within a document
# Based on cosine simliarity
def get_related_passage_indices(question, vectorizer, tfidf, num_passages_to_return=5):
    q_vec = vectorizer.transform([question])
    cosine_similarities = linear_kernel(q_vec, tfidf).flatten()

    related_docs_indices_a = cosine_similarities.argsort()[:-num_passages_to_return:-1]
    related_docs_indices = []
    for index in related_docs_indices_a:
        if abs(cosine_similarities[index]) > 0:
            related_docs_indices.append(index)
            
    #if len(related_docs_indices) == 0:
    #    print(f'empty question: {question}')
    #    print(f'empty qvec: {q_vec}')
    
    return related_docs_indices

# Get the top n passages in regards to a query
def get_related_passages(passages, related_docs_indices):
    related_passages = []
    for i in related_docs_indices:
        related_passages.append(passages[i])
    
    return related_passages

In [9]:
#TESTING

#q1 = [document_questions["0029bdbe75423337b551e42bb31f9a102785376f"][1].question]
#answer1 = document_questions["0029bdbe75423337b551e42bb31f9a102785376f"][1].answer1
#answer2 = document_questions["0029bdbe75423337b551e42bb31f9a102785376f"][1].answer2
#q_vec = vectorizer.transform(q1)
#cosine_similarities = linear_kernel(q_vec, tfidf).flatten()

#num_passages_to_return = 5
#related_docs_indices_a = cosine_similarities.argsort()[:-num_passages_to_return:-1]
#related_docs_indices = []
#or index in related_docs_indices_a:
#    if cosine_similarities[index] > 0:
#        related_docs_indices.append(index)
        
#print(f'Q: {q1}\n')
#for i in related_docs_indices:
#    print(f'index: {i}')
#    print(passages[i])
#    print()
#print('\nAnswers')
#print(answer1)
#print(answer2)

In [37]:
#   {
# .   title:
# .   document_id:
# .   paragraphs:[
#{                    "qas": [
#                        {
#                            "question": "In what country is Normandy located?",
#                            "id": "56ddde6b9a695914005b9628",
#                            "answers": [
#                                {
#                                    "text": "France",
#                                    "answer_start": 159
#                                },
#                                {
#                                    "text": "France",
#                                    "answer_start": 159
#                                },
#
#                            ],
#                            "is_impossible": false
#                        },}
#                       #context:

def convert_question_pair_to_squad_format(qa_pair):
    data = {
        "qas": [
            {
                "question": qa_pair.question,
                "id": str(qa_pair.id),
                "answers": [
                    {
                        "text": qa_pair.answer1
                    },
                    {
                        "text": qa_pair.answer2
                    }
                ]
            }
        ],
        "context": qa_pair.passages[0]
    }
    
    return data

In [11]:
def get_doc_start(text, start_search):
    doc_start = text.find(start_search)
    if doc_start == -1:
        start_search = "*** START "
        doc_start = text.find(start_search, 0)
        if doc_start == -1:
            start_search = "***START "
            doc_start = text.find(start_search, 0)
            if doc_start == -1:
                start_search = "<pre>"
                doc_start = text.find(start_search)
    return doc_start, start_search

def get_doc_end(text, end_search):
    doc_end = text.rfind(end_search)
    if doc_end == -1:
        end_search = "*** END"
        doc_end = text.rfind(end_search)
        if doc_end == -1:
            end_search = "***END"
            doc_end = text.rfind(end_search)
            if doc_end == -1:
                end_search = "</pre>"
                doc_end = text.rfind(end_search)
    return doc_end, end_search

In [None]:
#                   print('-------11111--------------')
#                        print('\ntext length 0')
#                        print(f'docid: {document_id}')
#                        print(f'labeled start: {x_split[8]}')
#                        print(f'labeled end: {x_split[9]}')
#                        print(f'dstart: {doc_start}')
#                        print(f'dend: {doc_end}')
#                        print(f'start search: {start_search}')
#                        print(f'end search: {end_search}')

In [54]:
# Loop through available documents, retrieve the top passages for each question/answer pair
# Write the returned passages to document_qa_passages directory for later use
# Pairs written as directionary of form {q: [passage, passage, etc]}

#document_id,set,kind,story_url,story_file_size,wiki_url,wiki_title,story_word_count,story_start,story_end

# {
# version:
# data: [

#}       
#]
#}
#]
#}
def get_and_write_qa_passages_as_squad(document_questions, max_stories=-1):
    with open('../narrativeqa/documents.csv') as f1:
        rows = csv.DictReader(f1, delimiter=',')
        i = 0
        s = 0
        q = 0
        document_id=""
        data = {}
        data['version']="1.0"
        data['data'] = []
        for doc in rows:
            try:
                i = i + 1
                if i == 1:
                    continue

                book_data = {}
                document_id = doc['document_id']

                book_data['title'] = doc['wiki_title']
                book_data['document_id'] = document_id
                book_data['paragraphs'] = []


                text = get_document_text(document_id)
                text = ' '.join(text.split())

                doc_start, start_search = get_doc_start(text, doc['story_start'])
                doc_end, end_search = get_doc_end(text, doc['story_end'])

                if doc_start != -1:
                    text = text[int(doc_start):int(doc_end)]

                passages, tfidf, vectorizer = split_document_and_tfidf_vectorize(text)
                passages_to_write = {}

                for qa_pair in document_questions[document_id]:
                    q = q + 1
                    related_indices = get_related_passage_indices(qa_pair.question, vectorizer, tfidf, num_passages_to_return=5)
                    related_passages = get_related_passages(passages, related_indices)
                    qa_pair.passages = related_passages

                    if len(qa_pair.passages) == 0:
                        #print(f'skipped pair: {document_id}')
                        s = s + 1
                        continue

                    book_data['paragraphs'].append(convert_question_pair_to_squad_format(qa_pair))
                    passages_to_write[qa_pair.question] = related_passages

                json_question_pairs = json.dumps(passages_to_write)
                fq = open(f"./document_qa_passages/{document_id}.q_passages", "w")
                fq.write(json_question_pairs)
                fq.close()

                data['data'].append(book_data)

                if max_stories != -1 and i > max_stories:
                    break

            except Exception:
                traceback.print_exc()
                print(f"Error processing qa pairs and passages for document {document_id}. Story no {i}")
                break

        squad_json_format_qa = json.dumps(data)
        fq = open(f"./squad_document_qa_passages/train.data", "w")
        fq.write(squad_json_format_qa)
        fq.close()
        
        print(f'total question pairs: {q}')
        print(f'skipped pairs: {s}')


In [55]:
get_and_write_qa_passages_as_squad(document_questions, max_stories=-1)

total question pairs: 46736
skipped pairs: 845


In [None]:
#TODO
#clear html
#try different combinations of characters
#break up by paragraphs?
    #heuristic to aim for size?

#probability of passage  being correct?

#catalog sparknotes urls