<font size="5" >üîçKnow-Corona : COVID-19 Open Research Dataset Challenge </font>

### Loading metadata dataframe

In [1]:
!pip install rank_bm25 -q

import numpy as np
import pandas as pd 
from pathlib import Path, PurePath

import nltk
from nltk.corpus import stopwords
import re
import string
import torch

from rank_bm25 import BM25Okapi

In [4]:
"""
Load metadata
"""

data_dir = PurePath('../data')
metadata_path = data_dir / 'covid_19_metadata.csv'
metadata_df = pd.read_csv(metadata_path,
                               dtype={'Microsoft Academic Paper ID': str, 'pubmed_id': str})
metadata_df = metadata_df.dropna(subset=['abstract', 'title']).reset_index(drop=True)

### Covid Search Engine

In [6]:
from rank_bm25 import BM25Okapi

# adapted from https://www.kaggle.com/dgunning/building-a-cord19-research-engine-with-bm25
english_stopwords = list(set(stopwords.words('english')))

class CovidSearchEngine:
    """
    Simple CovidSearchEngine.
    
    Usage:
    
    cse = CovidSearchEngine(metadata_df) # metadata_df is a pandas dataframe with 'title' and 'abstract' columns 
    search_results = cse.search("What is coronavirus", num=10) # Return `num` top-results
    """
    
    def remove_special_character(self, text):
        """
        Remove all special character from text string
        """
        return text.translate(str.maketrans('', '', string.punctuation))

    def tokenize(self, text):
        """
        Tokenize with NLTK

        Rules:
            - drop all words of 1 and 2 characters
            - drop all stopwords
            - drop all numbers
        """
        words = nltk.word_tokenize(text)
        return list(set([word for word in words 
                         if len(word) > 1
                         and not word in english_stopwords
                         and not word.isnumeric() 
                        ])
                   )
    
    def preprocess(self, text):
        """
        Clean and tokenize text input
        """
        return self.tokenize(self.remove_special_character(text.lower()))


    def __init__(self, corpus: pd.DataFrame):
        self.corpus = corpus
        self.columns = corpus.columns
        raw_search_str = self.corpus.abstract.fillna('') + ' ' + self.corpus.title.fillna('')
        self.index = raw_search_str.apply(self.preprocess).to_frame()
        self.index.columns = ['terms']
        self.index.index = self.corpus.index
        self.bm25 = BM25Okapi(self.index.terms.tolist())
    
    def search(self, query, num):
        """
        Return top `num` results that better match the query
        """
        search_terms = self.preprocess(query) 
        doc_scores = self.bm25.get_scores(search_terms) # get scores
        
        ind = np.argsort(doc_scores)[::-1][:num] # sort results
        
        results = self.corpus.iloc[ind][self.columns] # Initialize results_df
        results['score'] = doc_scores[ind] # Insert 'score' column
        results = results[results.score > 0]
        return results.reset_index()
    
cse = CovidSearchEngine(metadata_df) # Covid Search Engine

### Question-Answering system (BioBert)

In [10]:
%%time

"""
LIBRARIES
"""

import torch
from transformers import BertTokenizer
from transformers import BertForQuestionAnswering
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

"""
SETTINGS
"""

NUM_CONTEXT_FOR_EACH_QUESTION = 10


"""
Transformers
"""

torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

print("Code running on: {}".format(torch_device) )

model = AutoModelForQuestionAnswering.from_pretrained('ktrapeznikov/biobert_v1.1_pubmed_squad_v2')
tokenizer = AutoTokenizer.from_pretrained('ktrapeznikov/biobert_v1.1_pubmed_squad_v2')

model = model.to(torch_device)
model.eval()

def answer_question(question, context):
    """
    Answer questions
    """
    encoded_dict = tokenizer.encode_plus(
                        question, context, # Sentence to encode.
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        max_length = 256,  # Pad & truncate all sentences.
                        pad_to_max_length = True,
                        return_attention_mask = True,   # Construct attn. masks.
                        return_tensors = 'pt'     # Return pytorch tensors.
                   )
    
    input_ids = encoded_dict['input_ids'].to(torch_device)
    token_type_ids = encoded_dict['token_type_ids'].to(torch_device) # segments
    
    start_scores, end_scores = model(input_ids, token_type_ids=token_type_ids)

    all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    answer = tokenizer.convert_tokens_to_string(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
    
    answer = answer.replace('[CLS]', '')
    
    return answer



from transformers import BartTokenizer, BartModel

tokenizer_summarize = BartTokenizer.from_pretrained('bart-large')
model_summarize = BartModel.from_pretrained('bart-large').to(torch_device)


model_summarize.to(torch_device)
# Set the model in evaluation mode to deactivate the DropOut modules
model_summarize.eval()

def get_summary(text):
    """
    Get summary
    """
    
    answers_input_ids = tokenizer_summarize.batch_encode_plus(
        [text], return_tensors='pt', max_length=1024
    )['input_ids']
    
    answers_input_ids = answers_input_ids.to(torch_device)
    
    summary_ids = model_summarize.generate(answers_input_ids,
                                           num_beams=4,
                                           max_length=5,
                                           early_stopping=True
                                          )
        
    return tokenizer_summarize.decode(summary_ids.squeeze(), skip_special_tokens=True, clean_up_tokenization_spaces=False)

    
"""
Main 
"""



def create_output_results(question, all_contexts, all_answers, summary_answer, summary_context):
    """
    Return a dictionary of the form
    
    {
        question: 'what is coronavirus',
        results: [
            {
                'context': 'coronavirus is an infectious disease caused by',
                'answer': 'infectious disease'
                'start_index': 18
                'end_index': 36
            },
            {
                ...
            }
        ]
    }
    
    Start and end index are useful to find the position of the answer in the context  
    """
    
    def find_start_end_index_substring(context, answer):   
        search_re = re.search(re.escape(answer.lower()), context.lower())
        if search_re:
            return search_re.start(), search_re.end()
        else:
            return 0, len(context)
        
    output = {}
    output['question'] = question
    output['summary_answer'] = summary_answer
    output['summary_context'] = summary_context
    results = []
    for c, a in zip(all_contexts, all_answers):

        span = {}
        span['context'] = c
        span['answer'] = a
        span['start_index'], span['end_index'] = find_start_end_index_substring(c,a)

        results.append(span)
    
    output['results'] = results
        
    return output


def get_all_context(query, num_results):
    """
    Search in the metadata dataframe and return the first `num` results that better match the query 
    """
    
    papers_df = cse.search(query, num_results)
    return papers_df['abstract'].str.replace("Abstract", "").tolist()


def get_all_answers(question, all_context):
    """
    Return a list of all answers, given a question and a list of context
    """    
    
    all_answers = []
    
    for context in all_context:
        all_answers.append(answer_question(question, context))
    return all_answers

    
def get_results(question, summarize=False, num_results=NUM_CONTEXT_FOR_EACH_QUESTION, verbose=True):
    """
    Return dict object containg a list of all context and answers related to the (sub)question
    """
    
    if verbose:
        print("Getting context ...")
    all_contexts = get_all_context(question, num_results)
    
    if verbose:
        print("Answering to all questions ...")
    all_answers = get_all_answers(question, all_contexts)
    
    summary_answer = ''
    summary_context = ''
    if verbose and summarize:
        print("Adding summary ...")
    if summarize:
        summary_answer = get_summary(all_answers)
        summary_context = get_summary(all_contexts)
    
    if verbose:
        print("output.")
    
    return create_output_results(question, all_contexts, all_answers, summary_answer, summary_context)

Code running on: cpu


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1264.0, style=ProgressStyle(description‚Ä¶




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1625270110.0, style=ProgressStyle(descr‚Ä¶


Wall time: 7min 28s


### Dict object to store all Kaggle CORD-19 tasks

In [11]:
# adapted from https://www.kaggle.com/dirktheeng/anserini-bert-squad-for-semantic-corpus-search

covid_kaggle_questions = {
"data":[
          {
              "task": "What is known about transmission, incubation, and environmental stability?",
              "questions": [
                  "Is the virus transmitted by aerisol, droplets, food, close contact, fecal matter, or water?",
                  "How long is the incubation period for the virus?",
                  "Can the virus be transmitted asymptomatically or during the incubation period?",
                  "How does weather, heat, and humidity affect the tramsmission of 2019-nCoV?",
                  "How long can the 2019-nCoV virus remain viable on common surfaces?"
              ]
          },
          {
              "task": "What do we know about COVID-19 risk factors?",
              "questions": [
                  "What risk factors contribute to the severity of 2019-nCoV?",
                  "How does hypertension affect patients?",
                  "How does heart disease affect patients?",
                  "How does copd affect patients?",
                  "How does smoking affect patients?",
                  "How does pregnancy affect patients?",
                  "What is the fatality rate of 2019-nCoV?",
                  "What public health policies prevent or control the spread of 2019-nCoV?"
              ]
          },
          {
              "task": "What do we know about virus genetics, origin, and evolution?",
              "questions": [
                  "Can animals transmit 2019-nCoV?",
                  "What animal did 2019-nCoV come from?",
                  "What real-time genomic tracking tools exist?",
                  "What geographic variations are there in the genome of 2019-nCoV?",
                  "What effors are being done in asia to prevent further outbreaks?"
              ]
          },
          {
              "task": "What do we know about vaccines and therapeutics?",
              "questions": [
                  "What drugs or therapies are being investigated?",
                  "Are anti-inflammatory drugs recommended?"
              ]
          },
          {
              "task": "What do we know about non-pharmaceutical interventions?",
              "questions": [
                  "Which non-pharmaceutical interventions limit tramsission?",
                  "What are most important barriers to compliance?"
              ]
          },
          {
              "task": "What has been published about medical care?",
              "questions": [
                  "How does extracorporeal membrane oxygenation affect 2019-nCoV patients?",
                  "What telemedicine and cybercare methods are most effective?",
                  "How is artificial intelligence being used in real time health delivery?",
                  "What adjunctive or supportive methods can help patients?"
              ]
          },
          {
              "task": "What do we know about diagnostics and surveillance?",
              "questions": [
                  "What diagnostic tests (tools) exist or are being developed to detect 2019-nCoV?"
              ]
          },
          {
              "task": "Other interesting questions",
              "questions": [
                  "What is the immune system response to 2019-nCoV?",
                  "Can personal protective equipment prevent the transmission of 2019-nCoV?",
                  "Can 2019-nCoV infect patients a second time?"
              ]
          }
   ]
}

### Answer to all questions

Store it in the `all_answers` dataframe.

In [12]:
all_tasks = []


for i, t in enumerate(covid_kaggle_questions['data']):
    print("Answering question to task {}. ...".format(i+1))
    answers_to_question = []
    for q in t['questions']:
            answers_to_question.append(get_results(q, verbose=False))
    task = {}
    task['task'] = t['task']
    task['questions'] = answers_to_question
    
    all_tasks.append(task)

all_answers = {}
all_answers['data'] = all_tasks


Answering question to task 1. ...
Answering question to task 2. ...
Answering question to task 3. ...
Answering question to task 4. ...
Answering question to task 5. ...
Answering question to task 6. ...
Answering question to task 7. ...
Answering question to task 8. ...


###  Display questions, context and answers

In [25]:
# Adapted from https://jbesomi.github.io/Korono/

from IPython.display import display, Markdown, Latex, HTML

def layout_style():
    
    
    style = """
        div {
            color: black;
        }
        
        .single_answer {
            border-left: 3px solid #dc7b15;
            padding-left: 10px;
            font-family: Arial;
            font-size: 16px;
            color: #777777;
            margin-left: 5px;

        }
        
        .answer{
            color: #dc7b15;
        }
        
        .question_title {
            color: grey;
            display: block;
            text-transform: none;
        }
               
        div.output_scroll { 
            height: auto; 
        }
    
    """
    
    return "<style>" + style + "</style>"

def dm(x): display(Markdown(x))
def dh(x): display(HTML(layout_style() + x))
    
def display_task(task):
    m("## " + task['task'])
    
#display_task(task1['data'][0])


def display_single_context(context, start_index, end_index):
    
    before_answer = context[:start_index]
    answer = context[start_index:end_index]
    after_answer = context[end_index:]

    content = before_answer + "<span class='answer'>" + answer + "</span>" + after_answer

    return dh("""<div class="single_answer">{}</div>""".format(content))

def display_question_title(question):
    return dh("<h2 class='question_title'>{}</h2>".format(question.capitalize()))

def answer_not_found(context, start_index, end_index):
    return (start_index == 0 and len(context) == end_index) or (start_index == 0 and end_index == 0)
def display_all_context(index, question):
    
    display_question_title(str(index + 1) + ". " + question['question'].capitalize())
    
    # display context
    for i in question['results']:
        if answer_not_found(i['context'], i['start_index'], i['end_index']):
            continue # skip not found questions
        display_single_context(i['context'], i['start_index'], i['end_index'])

def display_task_title(index, task):
    task_title = "Task " + str(index) + ": " + task
    return dh("<h1 class='task_title'>{}</h1>".format(task_title))

def display_single_task(index, task):
    
    display_task_title(index, task['task'])
    
    for i, question in enumerate(task['questions']):
        display_all_context(i, question)

task = 1
display_single_task(task, all_tasks[task-1])

In [14]:
task = 2
display_single_task(task, all_tasks[task-1])

In [15]:
task = 3
display_single_task(task, all_tasks[task-1])

In [16]:
task = 4
display_single_task(task, all_tasks[task-1])

In [17]:
task = 5
display_single_task(task, all_tasks[task-1])

In [18]:
task = 6
display_single_task(task, all_tasks[task-1])

In [19]:
task = 7
display_single_task(task, all_tasks[task-1])

In [20]:
task = 8
display_single_task(task, all_tasks[task-1])

### Export Results

In [22]:
output_path = data_dir / "covid_kaggle_answer_from_biobert_qa.json"

In [24]:
import json
with open(output_path, "w") as f:
    json.dump(all_answers, f)