In [69]:
USE_SUMMARY = False
FIND_PDFS = False

import os
os.environ["JAVA_HOME"] = "/Library/Java/JavaVirtualMachines/jdk-11.0.2.jdk/Contents/Home"
from pyserini.search import pysearch

minDate = '2020/04/02'
luceneDir = 'lucene-index-covid-2020-04-03/'

import tensorflow as tf
import tensorflow_hub as hub

import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
from transformers import BartTokenizer, BartForConditionalGeneration
import numpy as np
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

QA_MODEL = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
QA_TOKENIZER = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
QA_MODEL.to(torch_device)
QA_MODEL.eval()

if USE_SUMMARY:
    SUMMARY_TOKENIZER = BartTokenizer.from_pretrained('bart-large-cnn')
    SUMMARY_MODEL = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
    SUMMARY_MODEL.to(torch_device)
    SUMMARY_MODEL.eval()
    
query = 'Which non-pharmaceutical interventions limit tramsission'
keywords = '2019-nCoV, SARS-CoV-2, COVID-19, non-pharmaceutical interventions, npi'

In [70]:
import json
searcher = pysearch.SimpleSearcher(luceneDir)
hits = searcher.search(query + '. ' + keywords)
n_hits = len(hits)
## collect the relevant data in a hit dictionary
hit_dictionary = {}
for i in range(0, n_hits):
    doc_json = json.loads(hits[i].raw)
    idx = str(hits[i].docid)
    hit_dictionary[idx] = doc_json
    hit_dictionary[idx]['title'] = hits[i].lucene_document.get("title")
    hit_dictionary[idx]['authors'] = hits[i].lucene_document.get("authors")
    hit_dictionary[idx]['doi'] = hits[i].lucene_document.get("doi")

## scrub the abstracts in prep for BERT-SQuAD
for idx,v in hit_dictionary.items():
    abs_dirty = v['abstract']
    # looks like the abstract value can be an empty list
    v['abstract_paragraphs'] = []
    v['abstract_full'] = ''

    if abs_dirty:
        # looks like if it is a list, then the only entry is a dictionary wher text is in 'text' key
        # looks like it is broken up by paragraph if it is in that form.  lets make lists for every paragraph
        # and a new entry that is full abstract text as both could be valuable for BERT derrived QA


        if isinstance(abs_dirty, list):
            for p in abs_dirty:
                v['abstract_paragraphs'].append(p['text'])
                v['abstract_full'] += p['text'] + ' \n\n'

        # looks like in some cases the abstract can be straight up text so we can actually leave that alone
        if isinstance(abs_dirty, str):
            v['abstract_paragraphs'].append(abs_dirty)
            v['abstract_full'] += abs_dirty + ' \n\n'

In [8]:
test_txt = hit_dictionary['we62087x']['abstract_full']
document = test_txt
question = query

In [13]:
question

'Which non-pharmaceutical interventions limit tramsission'

In [12]:
test_txt

'Human infections with a novel coronavirus (SARS-CoV-2) were first identified via syndromic surveillance in December of 2019 in Wuhan China. Since identification, infections (coronavirus disease-2019; COVID-19) caused by this novel pathogen have spread globally, with more than 180,000 confirmed cases as of March 16, 2020. Effective public health interventions, including social distancing, contact tracing, and isolation/quarantine rely on the rapid and accurate identification of confirmed cases. However, testing capacity (having sufficient tests and laboratory throughput) to support these non-pharmaceutical interventions remains a challenge for containment and mitigation of COVID-19 infections. We undertook a sentinel event strategy (where single health events signal emerging trends) to estimate the incidence of COVID-19 in the US. Data from a recent national conference, the Conservative Political Action Conference, (CPAC) near Washington, DC and from the outbreak in Wuhan, China were u

In [102]:
question_len = len(QA_TOKENIZER.encode(question))
token_ls = QA_TOKENIZER.encode(document)
def split_doc(text, question, tokenizer, overlap_rate):
    token_ls = tokenizer.encode(text)
    question_ls = tokenizer.encode(question)
    question_len = len(question_ls)
#     print(question_len)
    piece_length = 500 - question_len
    ret = []
    start_idx = 0
    while True:
        start_idx = max(0, start_idx - int((overlap_rate-1)*piece_length))
        end_idx = start_idx + piece_length
        content = tokenizer.decode(token_ls[start_idx: end_idx])
        content = content.replace('[CLS]', '').replace('[SEP]', '').strip()
        ret.append(content)
        start_idx = end_idx
        if start_idx > len(token_ls):
            break
#     print([len(tokenizer.encode(con)) for con in ret])
    return ret
    
        
a = split_doc(test_txt, question, QA_TOKENIZER, overlap_rate=1.1)
txt = QA_TOKENIZER.encode(question, a[0])
# len(a[2].split())

In [103]:
def bert_pred(ipt_ids, tokenizer):
    tokens = tokenizer.convert_ids_to_tokens(ipt_ids)
    sep_index = ipt_ids.index(QA_TOKENIZER.sep_token_id)
    num_seg_a = sep_index + 1
    num_seg_b = len(ipt_ids) - num_seg_a
    segment_ids = [0]*num_seg_a + [1]*num_seg_b
    assert len(segment_ids) == len(ipt_ids)
    n_ids = len(segment_ids)
    assert n_ids < 512
    print(len(ipt_ids))
    start_scores, end_scores = QA_MODEL(torch.tensor([ipt_ids]).to(torch_device), 
                                 token_type_ids=torch.tensor([segment_ids]).to(torch_device))
    # Plan 1:
#     answer_start, answer_end = torch.argmax(start_scores), \
#                                    torch.argmax(end_scores)
    # Plan 2:

    return start_scores, end_scores, {'tokens': tokens, 'sep_idx': sep_index}

start_scores, end_scores, _ = bert_pred(txt, QA_TOKENIZER)


344


In [104]:
def reconstruct_text(tokens, start=0, stop=-1):
    tokens = tokens[start: stop]
    if '[SEP]' in tokens:
        sepind = tokens.index('[SEP]')
        tokens = tokens[sepind+1:]
    txt = ' '.join(tokens)
    txt = txt.replace(' ##', '')
    txt = txt.replace('##', '')
    txt = txt.strip()
    txt = " ".join(txt.split())
    txt = txt.replace(' .', '.')
    txt = txt.replace('( ', '(')
    txt = txt.replace(' )', ')')
    txt = txt.replace(' - ', '-')
    txt_list = txt.split(' , ')
    txt = ''
    nTxtL = len(txt_list)
    if nTxtL == 1:
        return txt_list[0]
    newList =[]
    for i,t in enumerate(txt_list):
        if i < nTxtL -1:
            if t[-1].isdigit() and txt_list[i+1][0].isdigit():
                newList += [t,',']
            else:
                newList += [t, ', ']
        else:
            newList += [t]
    answer = ''.join(newList) 
    if answer.startswith('. ') or answer.startswith(', '):
            answer = answer[2:]
    return answer

In [105]:
# plan 1, we get the best prediction for each piece
# plan 2, we get the best a few predictions for whole article.
def make_bert_squad_prediction(document, question):
    overlap_rate = 1.1
    doc_pieces = split_doc(document, question, QA_TOKENIZER, overlap_rate)
    input_ids = [QA_TOKENIZER.encode(question, dp) for dp in doc_pieces] 
    print([len(i) for i in input_ids])
    answers = []
    confidences = []
    for ipt_id in input_ids:
        start_scores, end_scores, info = bert_pred(ipt_id, QA_TOKENIZER)
        sep_index = info['sep_idx']
        start_scores, end_scores = start_scores[:, sep_index:], \
                               end_scores[:, sep_index:]
        tokens_wo_question = info['tokens'][sep_index:]

        answer_start, answer_end = torch.argmax(start_scores), \
                                   torch.argmax(end_scores)
        tokens = QA_TOKENIZER.convert_ids_to_tokens(ipt_id)
        answer = reconstruct_text(tokens_wo_question, answer_start, answer_end+1)
        total_score = start_scores[0,answer_start].item()+\
                      end_scores[0,answer_end].item()
        answers.append(answer)
        confidences.append(total_score)
    max_conf = max(confidences)
    argmax_conf = np.argmax(confidences)
    max_answer = answers[argmax_conf]
    return {'answer': answer,
            'confidence': max_conf,
            'text': document}
    
        

In [106]:
from tqdm import tqdm
from collections import OrderedDict
def search_abstracts(hit_dictionary, question):
    result = OrderedDict()
    for k,v in tqdm(hit_dictionary.items()):
        abstract = v['abstract_full']
        if abstract:
            ans = make_bert_squad_prediction(abstract, question)
            if ans['answer']: result[k]=ans
    c_ls = np.array([result[key]['confidence'] for key in result])

    if len(c_ls) != 0:
        max_score = c_ls.max()
        total = 0.0
        exp_scores = np.exp(c_ls - max_score)
        for i,k in enumerate(result):
            result[k]['confidence'] = exp_scores[i]
            
    ret = {}
    for k in result:
        c = result[k]['confidence']
        ret[c] = result[k].copy()
    return ret

In [107]:
answers = search_abstracts(hit_dictionary, query)

  0%|          | 0/10 [00:00<?, ?it/s]

[344]
344


 10%|█         | 1/10 [00:01<00:09,  1.09s/it]

[500, 62]
500
62


 20%|██        | 2/10 [00:03<00:10,  1.37s/it]

[240]
240


 30%|███       | 3/10 [00:03<00:08,  1.17s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (634 > 512). Running this sequence through the model will result in indexing errors


[500, 209]
500
209


 40%|████      | 4/10 [00:06<00:09,  1.54s/it]

[394]
394


 50%|█████     | 5/10 [00:07<00:07,  1.46s/it]

[374]
374


 60%|██████    | 6/10 [00:08<00:05,  1.38s/it]

[284]
284


 70%|███████   | 7/10 [00:09<00:03,  1.22s/it]

[406]
406


 80%|████████  | 8/10 [00:10<00:02,  1.24s/it]

[197]
197


 90%|█████████ | 9/10 [00:11<00:01,  1.04s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (692 > 512). Running this sequence through the model will result in indexing errors


[500, 264]
500
264


100%|██████████| 10/10 [00:13<00:00,  1.40s/it]


In [89]:
answers

{0.0007431192870043687: {'answer': 'medrxiv a license to display the preprint in perpetuity. is the (which was not peer-reviewed',
  'confidence': 0.0007431192870043687,
  'text': 'The COVID-19 outbreak containment strategies in China based on non-pharmaceutical interventions (NPIs) appear to be effective. Quantitative research is still needed however to assess the efficacy of different candidate NPIs and their timings to guide ongoing and future responses to epidemics of this emerging disease across the World. \n\nWe built a travel network-based susceptible-exposed-infectious-removed (SEIR) model to simulate the outbreak across cities in mainland China. We used epidemiological parameters estimated for the early stage of outbreak in Wuhan to parameterise the transmission before NPIs were implemented. To quantify the relative effect of various NPIs, daily changes of delay from illness onset to the first reported case in each county were used as a proxy for the improvement of case identi

In [108]:
from BERT_func import BERT_SQUAD_QA

In [109]:
QA_model = BERT_SQUAD_QA(QA_TOKENIZER, QA_MODEL)
print(QA_model.search_abstracts(hit_dictionary, query))

  0%|          | 0/10 [00:00<?, ?it/s]

[344]
344


 10%|█         | 1/10 [00:04<00:38,  4.24s/it]

[500, 62]
500
62


 20%|██        | 2/10 [00:06<00:28,  3.56s/it]

[240]
240


 30%|███       | 3/10 [00:06<00:18,  2.69s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (634 > 512). Running this sequence through the model will result in indexing errors


[500, 209]
500
209


 40%|████      | 4/10 [00:09<00:15,  2.65s/it]

[394]
394


 50%|█████     | 5/10 [00:10<00:11,  2.31s/it]

[374]
374


 60%|██████    | 6/10 [00:12<00:08,  2.04s/it]

[284]
284


 70%|███████   | 7/10 [00:13<00:05,  1.72s/it]

[406]
406


 80%|████████  | 8/10 [00:14<00:03,  1.66s/it]

[197]
197


 90%|█████████ | 9/10 [00:15<00:01,  1.36s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (692 > 512). Running this sequence through the model will result in indexing errors


[500, 264]
500
264


100%|██████████| 10/10 [00:18<00:00,  1.82s/it]

{0.0007431192870043687: {'answer': 'medrxiv a license to display the preprint in perpetuity. is the (which was not peer-reviewed', 'confidence': 0.0007431192870043687, 'text': 'The COVID-19 outbreak containment strategies in China based on non-pharmaceutical interventions (NPIs) appear to be effective. Quantitative research is still needed however to assess the efficacy of different candidate NPIs and their timings to guide ongoing and future responses to epidemics of this emerging disease across the World. \n\nWe built a travel network-based susceptible-exposed-infectious-removed (SEIR) model to simulate the outbreak across cities in mainland China. We used epidemiological parameters estimated for the early stage of outbreak in Wuhan to parameterise the transmission before NPIs were implemented. To quantify the relative effect of various NPIs, daily changes of delay from illness onset to the first reported case in each county were used as a proxy for the improvement of case identifica




In [None]:
workingPath = './kaggle/working'
import pandas as pd
if FIND_PDFS:
    from metapub import UrlReverse
    from metapub import FindIt
from IPython.core.display import display, HTML

#from summarizer import Summarizer
#summarizerModel = Summarizer()
def displayResults(hit_dictionary, answers, question):
    
    question_HTML = '<div style="font-family: Times New Roman; font-size: 28px; padding-bottom:28px"><b>Query</b>: '+question+'</div>'
    #all_HTML_txt = question_HTML
    confidence = list(answers.keys())
    confidence.sort(reverse=True)
    
    confidence = list(answers.keys())
    confidence.sort(reverse=True)
    

    for c in confidence:
        if c>0 and c <= 1 and len(answers[c]['answer']) != 0:
            if 'idx' not in  answers[c]:
                continue
            rowData = []
            idx = answers[c]['idx']
            title = hit_dictionary[idx]['title']
            authors = hit_dictionary[idx]['authors'] + ' et al.'
            doi = '<a href="https://doi.org/'+hit_dictionary[idx]['doi']+'" target="_blank">' + title +'</a>'

            
            full_abs = answers[c]['abstract_bert']
            bert_ans = answers[c]['answer']
            
            
            split_abs = full_abs.split(bert_ans)
            sentance_beginning = split_abs[0][split_abs[0].rfind('.')+1:]
            if len(split_abs) == 1:
                sentance_end_pos = len(full_abs)
                sentance_end =''
            else:
                sentance_end_pos = split_abs[1].find('. ')+1
                if sentance_end_pos == 0:
                    sentance_end = split_abs[1]
                else:
                    sentance_end = split_abs[1][:sentance_end_pos]
                
            #sentance_full = sentance_beginning + bert_ans+ sentance_end
            answers[c]['full_answer'] = sentance_beginning+bert_ans+sentance_end
            answers[c]['sentence_beginning'] = sentance_beginning
            answers[c]['sentence_end'] = sentance_end
            answers[c]['title'] = title
            answers[c]['doi'] = doi
        else:
            answers.pop(c)
    
    
    ## now rerank based on semantic similarity of the answers to the question
    cList = list(answers.keys())
    allAnswers = [answers[c]['full_answer'] for c in cList]
    
    messages = [question]+allAnswers
    
    encoding_matrix = embed_fn(messages)
    similarity_matrix = np.inner(encoding_matrix, encoding_matrix)
    rankings = similarity_matrix[1:,0]
    
    for i,c in enumerate(cList):
        answers[rankings[i]] = answers.pop(c)

    ## now form pandas dv
    confidence = list(answers.keys())
    confidence.sort(reverse=True)
    pandasData = []
    ranked_aswers = []
    for c in confidence:
        rowData=[]
        title = answers[c]['title']
        doi = answers[c]['doi']
        idx = answers[c]['idx']
        rowData += [idx]            
        sentance_html = '<div>' +answers[c]['sentence_beginning'] + " <font color='red'>"+answers[c]['answer']+"</font> "+answers[c]['sentence_end']+'</div>'
        
        rowData += [sentance_html, c, doi]
        pandasData.append(rowData)
        ranked_aswers.append(' '.join([answers[c]['full_answer']]))
    
    if FIND_PDFS:
        pdata2 = []
        for rowData in pandasData:
            rd = rowData
            idx = rowData[0]
            if str(idx).startswith('pm_'):
                pmid = idx[3:]
            else:
                try:
                    test = UrlReverse('https://doi.org/'+hit_dictionary[idx]['doi'])
                    if test is not None:
                        pmid = test.pmid
                    else:
                        pmid = None
                except:
                    pmid = None
            pdfLink = None
            if pmid is not None:
                try:
                    pdfLink = FindIt(str(pmid))
                except:
                    pdfLink = None
            if pdfLink is not None:
                pdfLink = pdfLink.url

            if pdfLink is None:

                rd += ['Not Available']
            else:
                rd += ['<a href="'+pdfLink+'" target="_blank">PDF Link</a>']
            pdata2.append(rowData)
    else:
        pdata2 = pandasData
        
    
    display(HTML(question_HTML))
    
    if USE_SUMMARY:
        ## try generating an exacutive summary with extractive summarizer
        allAnswersTxt = ' '.join(ranked_aswers[:6]).replace('\n','')
    #    exec_sum = summarizerModel(allAnswersTxt, min_length=1, max_length=500)    
     #   execSum_HTML = '<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:18px"><b>BERT Extractive Summary:</b>: '+exec_sum+'</div>'

        answers_input_ids = SUMMARY_TOKENIZER.batch_encode_plus([allAnswersTxt], return_tensors='pt', max_length=1024)['input_ids'].to(torch_device)
        summary_ids = SUMMARY_MODEL.generate(answers_input_ids,
                                               num_beams=10,
                                               length_penalty=1.2,
                                               max_length=1024,
                                               min_length=64,
                                               no_repeat_ngram_size=4)

        exec_sum = SUMMARY_TOKENIZER.decode(summary_ids.squeeze(), skip_special_tokens=True)
        execSum_HTML = '<div style="font-family: Times New Roman; font-size: 18px; margin-bottom:1pt"><b>BART Abstractive Summary:</b>: '+exec_sum+'</div>'
        display(HTML(execSum_HTML))
        warning_HTML = '<div style="font-family: Times New Roman; font-size: 12px; padding-bottom:12px; color:#CCCC00; margin-top:1pt"> Warning this is an autogenerated summary based on semantic search of abstracts, always examine the sources before accepting this conclusion.  If the evidence only mentions topic in passing or the evidence is not clear, the summary will likely not clearly answer the question.</div>'
        display(HTML(warning_HTML))

#    display(HTML('<div style="font-family: Times New Roman; font-size: 18px; padding-bottom:18px"><b>Body of Evidence:</b></div>'))
    
    if FIND_PDFS:
        df = pd.DataFrame(pdata2, columns = ['Lucene ID', 'BERT-SQuAD Answer with Highlights', 'Confidence', 'Title/Link','PDF Link'])
    else:
        df = pd.DataFrame(pdata2, columns = ['Lucene ID', 'BERT-SQuAD Answer with Highlights', 'Confidence', 'Title/Link'])
        
    display(HTML(df.to_html(render_links=True, escape=False)))
    
displayResults(hit_dictionary, answers, query)