In [5]:
"""
Libraries
"""

!pip install rank_bm25 -q

import numpy as np
import pandas as pd 
from pathlib import Path, PurePath

import nltk
from nltk.corpus import stopwords
import re
import string
import torch

from rank_bm25 import BM25Okapi # Search engine

In [16]:
"""
Load metadata df
"""

input_dir = PurePath('CORD-19-research-challenge')
metadata_path = input_dir / 'metadata.csv'
metadata_df = pd.read_csv(metadata_path, low_memory=False)

In [18]:
metadata_df = metadata_df.dropna(subset=['abstract', 'title']) \
                            .reset_index(drop=True)

In [28]:
raw_search_str = metadata_df.abstract.fillna('') + ' ' \
                            + metadata_df.title.fillna('')

In [32]:
tt = raw_search_str.to_frame()
tt.columns = ['terms']
tt.index = metadata_df.index
mybm25 = BM25Okapi(tt.terms.tolist())

In [35]:
mybm25

<rank_bm25.BM25Okapi at 0x1a23c55d30>

In [37]:
mybm25.get_scores('covid')

array([25.153851  , 25.39474883, 25.31214667, ..., 25.49566284,
       22.84587602, 25.48895717])

In [39]:
mysearch_terms = 'what is covid' 
myscores = mybm25.get_scores(mysearch_terms)

# sort by scores
myind = np.argsort(myscores)[::-1][:10] 

# select top results and returns
results = metadata_df.iloc[myind][metadata_df.columns]
results['score'] = myscores[myind]
results = results[results.score > 0]

In [40]:
results

Unnamed: 0,cord_uid,sha,source_x,title,doi,pmcid,pubmed_id,license,abstract,publish_time,authors,journal,Microsoft Academic Paper ID,WHO #Covidence,has_full_text,full_text_file,url,score
24725,tl8vp8o7,ef8d5f283816c67a1af2a88c454c165b04946775,Elsevier,Intestinal antibody response after vaccination...,10.1016/0165-2427(86)90087-5,,3006327.0,els-covid,Abstract The intestinal and systemic antibody ...,1986-01-31,"Van Zaane, Dick; Ijzerman, Johan; De Leeuw, Pe...",Veterinary Immunology and Immunopathology,,,True,custom_license,https://doi.org/10.1016/0165-2427(86)90087-5,67.303364
17636,ar702fxq,,PMC,Carnivore Parvovirus Ecology in the Serengeti ...,10.1128/jvi.02220-18,PMC6580958,30996096.0,unk,Carnivore parvoviruses infect wild and domesti...,2019-04-17,"Calatayud, Olga; Esperón, Fernando; Cleaveland...",Journal of Virology,,,False,,https://jvi.asm.org/content/jvi/93/13/e02220-1...,67.249887
25015,a6ia8kxf,9bb776ccd5016928c925a331e21c2d98813db6a8,Elsevier,Efficacy of an inactivated oil-adjuvanted rota...,10.1016/0264-410x(89)90241-7,,2551102.0,els-covid,Abstract We have assessed the potency of an in...,1989-06-30,"Bellinzoni, R.C.; Blackhall, J.; Baro, N.; Auz...",Vaccine,,,True,custom_license,https://doi.org/10.1016/0264-410x(89)90241-7,67.249641
12729,ydssboge,6dab7c67db93625257a0e3aa8a8d88f22573d1fe,PMC,Acute phase response in bovine coronavirus pos...,10.1186/s13028-019-0471-3,PMC6659199,31345246.0,cc-by,Bovine coronavirus (BCoV) is associated with s...,2019-07-25,"Chae, Jeong-Byoung; Park, Jinho; Jung, Suk-Han...",Acta Vet Scand,,,True,comm_use_subset,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6...,67.247517
30646,00t28df6,dd8738920d2fdd2f606f249d47665b32dc33f1ea,Elsevier,Bovine Neonatal Pancytopenia: Is this alloimmu...,10.1016/j.vaccine.2011.05.012,,21605614.0,els-covid,Abstract Bovine Neonatal Pancytopenia (BNP) is...,2011-07-18,"Bastian, Max; Holsteg, Mark; Hanke-Robinson, H...",Vaccine,,,True,custom_license,https://doi.org/10.1016/j.vaccine.2011.05.012,67.243239
19411,2gsy750k,be105f13ce0785aa5c3f21cc1e27a803fdb5b79b,PMC,Immunogenicity and Protective Efficacy in Mice...,10.1089/vim.2010.0028,PMC2967819,20883165.0,unk,The immunogenicity and efficacy of β-propiolac...,2010-10-01,"Roberts, Anjeanette; Lamirande, Elaine W.; Vog...",Viral Immunology,,,True,custom_license,http://europepmc.org/articles/pmc2967819?pdf=r...,67.229453
10004,xzgfwu8a,7461fe0adbb9a865f8a79e994c6cd8b8ebdd4e78,PMC,Would it be legally justified to impose vaccin...,10.1186/s13584-017-0182-z,PMC5661933,29084599.0,cc-by,BACKGROUND: The detection of wild poliovirus i...,2017-10-30,"Kamin-Friedman, Shelly",Isr J Health Policy Res,,,True,comm_use_subset,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5...,67.22154
31634,u7glxpum,dcecb3ef49f22680bf9e96142f00142f3b1d0873,Elsevier,An evolutionary vaccination game in the modifi...,10.1016/j.physa.2015.09.073,,,els-covid,"Abstract In this paper, we explore an evolutio...",2016-02-01,"Han, Dun; Sun, Mei",Physica A: Statistical Mechanics and its Appli...,,,True,custom_license,https://doi.org/10.1016/j.physa.2015.09.073,67.218345
27138,q5vqiivm,fd993078c0b3ffca161716125696ca01b106e173,Elsevier,Attempted immunisation of cats against feline ...,10.1016/s0034-5288(18)30970-6,,,els-covid,Specific pathogen free kittens were vaccinated...,1988-11-30,"STODDART, C.A.; BARLOUGH, J.E.; BALDWIN, C.A.;...",Research in Veterinary Science,,,True,custom_license,https://doi.org/10.1016/s0034-5288(18)30970-6,67.209602
12424,t1anr8gd,3df56c2e76799309cc7bdbc5ef8f968a1569a08c,PMC,Fecal Viral Diversity of Captive and Wild Tasm...,10.1128/jvi.00205-19,PMC6532096,30867308.0,cc-by,The Tasmanian devil is an endangered carnivoro...,2019-05-15,"Chong, Rowena; Shi, Mang; Grueber, Catherine E...",J Virol,,,True,comm_use_subset,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6...,67.206913


In [43]:
english_stopwords = list(set(stopwords.words('english')))

class CovidSearchEngine:
    """
    Simple CovidSearchEngine.
    """
    
    def remove_special_character(self, text):
        #Remove special characters from text string
        return text.translate(str.maketrans('', '', string.punctuation))

    def tokenize(self, text):
        # tokenize text
        words = nltk.word_tokenize(text)
        return list(set([word for word in words 
                         if len(word) > 1
                         and not word in english_stopwords
                         and not word.isnumeric() 
                        ])
                   )
    
    def preprocess(self, text):
        # Clean and tokenize text input
        return self.tokenize(self.remove_special_character(text.lower()))


    def __init__(self, corpus: pd.DataFrame):
        self.corpus = corpus
        self.columns = corpus.columns
        
        raw_search_str = self.corpus.abstract.fillna('') + ' ' \
                            + self.corpus.title.fillna('')
        
        self.index = raw_search_str.apply(self.preprocess).to_frame()
        self.index.columns = ['terms']
        self.index.index = self.corpus.index
        self.bm25 = BM25Okapi(self.index.terms.tolist())
    
    def search(self, query, num):
        """
        Return top `num` results that better match the query
        """
        # obtain scores
        search_terms = self.preprocess(query) 
        doc_scores = self.bm25.get_scores(search_terms)
        
        # sort by scores
        ind = np.argsort(doc_scores)[::-1][:num] 
        
        # select top results and returns
        results = self.corpus.iloc[ind][self.columns]
        results['score'] = doc_scores[ind]
        results = results[results.score > 0]
        return results.reset_index()

In [46]:
cse = CovidSearchEngine(metadata_df)

In [51]:
"""
Download pre-trained QA model
"""

import torch
from transformers import BertTokenizer
from transformers import BertForQuestionAnswering

torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

BERT_SQUAD = 'bert-large-uncased-whole-word-masking-finetuned-squad'

model = BertForQuestionAnswering.from_pretrained(BERT_SQUAD)
tokenizer = BertTokenizer.from_pretrained(BERT_SQUAD)

model = model.to(torch_device)
model.eval()

print()

100%|██████████| 398/398 [00:00<00:00, 64590.17B/s]
100%|██████████| 1340675298/1340675298 [03:10<00:00, 7042502.94B/s] 
100%|██████████| 231508/231508 [00:00<00:00, 466672.11B/s]





In [59]:
tokenizer.encode_plus

<bound method PreTrainedTokenizer.encode_plus of <transformers.tokenization_bert.BertTokenizer object at 0x1a3725f438>>

In [63]:
def answer_question(question, context):
    # anser question given question and context
    encoded_dict = tokenizer.encode_plus(
                        question, context,
                        add_special_tokens = True,
                        max_length = 256,
                        pad_to_max_length = True,
                        return_tensors = 'pt'
                   )
    
    input_ids = encoded_dict['input_ids'].to(torch_device)
    token_type_ids = encoded_dict['token_type_ids'].to(torch_device)
    
    start_scores, end_scores = model(input_ids, token_type_ids=token_type_ids)

    all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    start_index = torch.argmax(start_scores)
    end_index = torch.argmax(end_scores)
    
    answer = tokenizer.convert_tokens_to_string(all_tokens[start_index:end_index+1])
    answer = answer.replace('[CLS]', '')
    return answer

In [64]:
# adapted from https://www.kaggle.com/dirktheeng/anserini-bert-squad-for-semantic-corpus-search

covid_kaggle_questions = {
"data":[
          {
              "task": "What is known about transmission, incubation, and environmental stability?",
              "questions": [
                  "Is the virus transmitted by aerisol, droplets, food, close contact, fecal matter, or water?",
                  "How long is the incubation period for the virus?",
                  "Can the virus be transmitted asymptomatically or during the incubation period?",
                  "How does weather, heat, and humidity affect the tramsmission of 2019-nCoV?",
                  "How long can the 2019-nCoV virus remain viable on common surfaces?"
              ]
          },
          {
              "task": "What do we know about COVID-19 risk factors?",
              "questions": [
                  "What risk factors contribute to the severity of 2019-nCoV?",
                  "How does hypertension affect patients?",
                  "How does heart disease affect patients?",
                  "How does copd affect patients?",
                  "How does smoking affect patients?",
                  "How does pregnancy affect patients?",
                  "What is the fatality rate of 2019-nCoV?",
                  "What public health policies prevent or control the spread of 2019-nCoV?"
              ]
          },
          {
              "task": "What do we know about virus genetics, origin, and evolution?",
              "questions": [
                  "Can animals transmit 2019-nCoV?",
                  "What animal did 2019-nCoV come from?",
                  "What real-time genomic tracking tools exist?",
                  "What geographic variations are there in the genome of 2019-nCoV?",
                  "What effors are being done in asia to prevent further outbreaks?"
              ]
          },
          {
              "task": "What do we know about vaccines and therapeutics?",
              "questions": [
                  "What drugs or therapies are being investigated?",
                  "Are anti-inflammatory drugs recommended?"
              ]
          },
          {
              "task": "What do we know about non-pharmaceutical interventions?",
              "questions": [
                  "Which non-pharmaceutical interventions limit tramsission?",
                  "What are most important barriers to compliance?"
              ]
          },
          {
              "task": "What has been published about medical care?",
              "questions": [
                  "How does extracorporeal membrane oxygenation affect 2019-nCoV patients?",
                  "What telemedicine and cybercare methods are most effective?",
                  "How is artificial intelligence being used in real time health delivery?",
                  "What adjunctive or supportive methods can help patients?"
              ]
          },
          {
              "task": "What do we know about diagnostics and surveillance?",
              "questions": [
                  "What diagnostic tests (tools) exist or are being developed to detect 2019-nCoV?"
              ]
          },
          {
              "task": "Other interesting questions",
              "questions": [
                  "What is the immune system response to 2019-nCoV?",
                  "Can personal protective equipment prevent the transmission of 2019-nCoV?",
                  "Can 2019-nCoV infect patients a second time?"
              ]
          }
   ]
}

In [65]:
NUM_CONTEXT_FOR_EACH_QUESTION = 10


def get_all_context(query, num_results):
    # Return ^num_results' papers that better match the query
    
    papers_df = cse.search(query, num_results)
    return papers_df['abstract'].str.replace("Abstract", "").tolist()


def get_all_answers(question, all_contexts):
    # Ask the same question to all contexts (all papers)
    
    all_answers = []
    
    for context in all_contexts:
        all_answers.append(answer_question(question, context))
    return all_answers


def create_output_results(question, 
                          all_contexts, 
                          all_answers, 
                          summary_answer='', 
                          summary_context=''):
    # Return results in json format
    
    def find_start_end_index_substring(context, answer):   
        search_re = re.search(re.escape(answer.lower()), context.lower())
        if search_re:
            return search_re.start(), search_re.end()
        else:
            return 0, len(context)
        
    output = {}
    output['question'] = question
    output['summary_answer'] = summary_answer
    output['summary_context'] = summary_context
    results = []
    for c, a in zip(all_contexts, all_answers):

        span = {}
        span['context'] = c
        span['answer'] = a
        span['start_index'], span['end_index'] = find_start_end_index_substring(c,a)

        results.append(span)
    
    output['results'] = results
        
    return output

    
def get_results(question, 
                summarize=False, 
                num_results=NUM_CONTEXT_FOR_EACH_QUESTION,
                verbose=True):
    # Get results

    all_contexts = get_all_context(question, num_results)
    
    all_answers = get_all_answers(question, all_contexts)
    
    if summarize:
        # NotImplementedYet
        summary_answer = get_summary(all_answers)
        summary_context = get_summary(all_contexts)
    
    return create_output_results(question, 
                                 all_contexts, 
                                 all_answers)


In [66]:
all_tasks = []

for i, t in enumerate(covid_kaggle_questions['data']):
    print("Answering questions to task {}. ...".format(i+1))
    answers_to_question = []
    for q in t['questions']:
            answers_to_question.append(get_results(q, verbose=False))
    task = {}
    task['task'] = t['task']
    task['questions'] = answers_to_question
    
    all_tasks.append(task)

all_answers = {}
all_answers['data'] = all_tasks

Answering questions to task 1. ...


TypeError: _tokenize() got an unexpected keyword argument 'pad_to_max_length'