### TODO

03.04.2020
- Use 'bert-large-nli-mean-tokens'.

06.04.2020
- Add the lower ranking of some keywords (like 'diabetes').
- Explore how synonyms impact sentence embeddings space search.

---

### Context

**Dataset**

Human curated WHO papers + query* on PMC / bioRxiv / medRxiv.

**Query**

- "COVID-19"
- OR Coronavirus
- OR "Corona virus"
- OR "2019-nCoV"
- OR "SARS-CoV"
- OR "MERS-CoV"
- OR “Severe Acute Respiratory Syndrome”
- OR “Middle East Respiratory Syndrome” 

---

## Imports

In [None]:
import textwrap
import hashlib
import time
import sqlite3
from pathlib import Path
import json
import logging
from functools import partial
import datetime

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import tensorflow_hub as hub
import ipywidgets as widgets
from IPython.core.display import HTML

import sent2vec
import nltk
from nltk import word_tokenize
from nltk.corpus import stopwords
from string import punctuation
from sentence_transformers import SentenceTransformer

from sklearn.metrics.pairwise import cosine_similarity

import pdfkit

In [None]:
nltk.download('punkt')

## Definitions

In [None]:
# main_dir = Path("/raid/covid19_kaggle-data")

# data_path = main_dir / "v6"
# sql_db_path = main_dir / "cord19q" / "articles.sqlite"
# pafe_path = main_dir / "pafe"

In [None]:
data_path = Path("/raid/sschmidt/covid/data/2020-04-08")
cord_path = data_path / "CORD-19-research-challenge"
databases_path = data_path / "databases"
embeddings_path = data_path / "embeddings"
assets_path = Path("/raid/sschmidt/covid/assets")

assert data_path.exists()
assert cord_path.exists()
assert databases_path.exists()
assert embeddings_path.exists()
assert assets_path.exists()

In [None]:
class Color:
    PURPLE = '\033[95m'
    CYAN = '\033[96m'
    DARKCYAN = '\033[36m'
    BLUE = '\033[94m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    END = '\033[0m'

In [None]:
for var in dir(Color):
    if not var.startswith('__') and var != 'END':
        c = getattr(Color, var)
        print(c + f"This is {var}" + Color.END)

In [None]:
print(Color.BOLD + Color.PURPLE + "This is a test" + Color.END)

## Build SQL Database

In [None]:
# !pip install --user git+https://github.com/neuml/cord19q

In [None]:
# Install scispacy model
# !pip install --user https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.2.4/en_core_sci_sm-0.2.4.tar.gz

In [None]:
# import spacy
# spacy.load('en_core_sci_sm')

In [None]:
# from cord19q.etl.execute import Execute as Etl

# Build SQLite database for metadata.csv and json full text files
# Etl.run(str(cord_path), str(databases_path))

## Load Data: SQL, JSON, Metadata

In [None]:
db = sqlite3.connect(str(databases_path / "articles.sqlite"))

In [None]:
df_metadata_original = pd.read_csv(cord_path / "metadata.csv")
df_metadata_original.head(2)

Remove rows with no title and no SHA

In [None]:
mask_useless = df_metadata_original['title'].isna() & df_metadata_original['sha'].isna()
df_metadata = df_metadata_original[~mask_useless]

Generate fake SHAs

In [None]:
mask = df_metadata['sha'].isna()
df_metadata.loc[mask, 'sha'] = df_metadata.loc[mask, 'title'].apply(
    lambda text: hashlib.sha1(str(text).encode("utf-8")).hexdigest())
df_metadata.head(2)

Load JSON Files

In [None]:
n_json = len(list(data_path.rglob("*.json")))
json_files = []

for f in tqdm(data_path.rglob("*.json"), total=n_json):
    json_files.append(json.load(open(f)))

Fill in missing titles from the metadata

In [None]:
for json_file in tqdm(json_files):
    if json_file['metadata']['title'] == '':
        sha = json_file['paper_id']
        idx = np.where(df_metadata['sha'] == sha)[0]
        if len(idx) > 0:
            new_title = df_metadata['title'].iloc[idx[0]]
            json_file['metadata']['title'] = new_title

Create a dictionary with JSON files based on their SHAs

In [None]:
json_files_d = {
    json_file['paper_id']: json_file
    for json_file in json_files
}

## Load Models

In [None]:
%%time

# Load USE
use_version = 5
use = hub.load(f"https://tfhub.dev/google/universal-sentence-encoder-large/{use_version}")

In [None]:
%%time

# Load SBERT
sbert = SentenceTransformer('bert-base-nli-mean-tokens')

In [None]:
nltk.download('stopwords')

Source: https://github.com/ncbi-nlp/BioSentVec

In [None]:
%%time

# Load BioSentVec
bsv = sent2vec.Sent2vecModel()
bsv.load_model(str(assets_path / 'BioSentVec_PubMed_MIMICIII-bigram_d700.bin'))

bsv_stopwords = set(stopwords.words('english'))

def bsv_preprocess(text):
    text = text.replace('/', ' / ')
    text = text.replace('.-', ' .- ')
    text = text.replace('.', ' . ')
    text = text.replace('\'', ' \' ')
    text = text.lower()
    tokens = [token for token in word_tokenize(text)
              if token not in punctuation and token not in bsv_stopwords]
    return ' '.join(tokens)

## Preprocessing of Sentences

In [None]:
synonyms_dict = dict()
with open(assets_path / 'synonyms_list.txt', 'r', encoding='utf-8-sig') as f:
    for l in [l_.strip().lower() for l_ in f]:
        if l:
            w = [l_.strip() for l_ in l.split('=')]
            synonyms_dict[w[0]] = w[1:]

del synonyms_dict['sars']

synonyms_index = {x.lower(): k.lower() for k,v in synonyms_dict.items() for x in v}

def sent_preprocessing(sentences, 
                      synonyms_index):
    """Preprocessing of the sentences. (Lower + Split + Replace Synonym)
    
    Parameters
    ----------
    sentences : List[str]
        List of N strings.
    synonyms_index: dict
        Dictionary containing as key the synonym term and as values the reference of this term.
    """
    
    return [" ".join(synonyms_index.get(y, y) for y in word_tokenize(x.lower()))
            for x in sentences]

In [None]:
def embed_sentences(sentences, embedding_name, embedding_model):
    if embedding_name == 'USE':
        return embedding_model(sentences).numpy()
    
    elif embedding_name == 'SBERT':
        return np.stack(embedding_model.encode(sentences), axis=0)
    
    elif embedding_name == 'BSV':
        preprocessed = [bsv_preprocess(x) for x in sentences]
        return embedding_model.embed_sentences(preprocessed)
        
    else:
        raise NotImplementedError(f'Embedding {repr(embedding_name)} not '
                                  f'available!')

In [None]:
EMBEDDINGS_NAMES = ['USE', 'SBERT', 'BSV']

In [None]:
embeddings = np.load(embeddings_path / 'sentence_embeddings.npz')

In [None]:
embeddings_syns = np.load(embeddings_path / 'sentence_embeddings_merged_synonyms.npz')

## Actual Widget

In [None]:
logger = logging.getLogger("My logger")
logger.setLevel(logging.WARNING)

In [None]:
def find_paragraph(uid, sentence, db):
    """Find the paragraph corresponding to the given sentece
    
    Parameters
    ----------
    uid : int
        The identifier of the given sentence
    sentence: str
        The sentence to highlight
    db: sqlite3.Connection
        The database connection
        
    Returns
    -------
    formatted_paragraph : str
        The paragraph containing `sentence`
    """
    
    sha, where_from = db.execute(f'SELECT Article, Name FROM sections WHERE Id = {uid}').fetchall()[0]
    logger.debug(f"uid = {uid}")
    logger.debug(f"sha = {sha}")
    logger.debug(f"where_from = {where_from}")
    logger.debug(f"sentence = {sentence}")
    if sha in list(df_metadata['sha']) and where_from in ['TITLE', 'ABSTRACT']:
        df_row = df_metadata[df_metadata['sha'] == sha].iloc[0]
        if sentence in df_row['title']:
            paragraph = df_row['title']
        elif sentence in df_row['abstract']:
            paragraph = df_row['abstract']
        else:
            raise ValueError("Sentence not found in title nor in abstract")
    elif sha in json_files_d:
        json_file = json_files_d[sha]
        if sentence in json_file['metadata']['title']:
            paragraph = json_file['metadata']['title']
        else:
            for text_chunk in json_file['abstract'] + json_file['body_text']:
                paragraph = text_chunk['text']
                if sentence in paragraph:
                    break
            else:
                raise ValueError("sentence not found in body_text and abstract")
    else:
        raise ValueError("SHA not found")
        
    return paragraph

In [None]:
def highlight_in_paragraph(paragraph, sentence, width=80, indent=0, color=Color.BOLD + Color.PURPLE):
    """Highlight a given sentence in the paragraph
    
    Parameters
    ----------
    uid : int
        The identifier of the given sentence
    sentence: str
        The sentence to highlight
    width : int
        The width to which to wrapt the returned paragraph
    indent : int
        The indentation for the lines in the returned apragraph
    color : str
        The color to use for the highlight encoded as an ANSI
        escape code
    
    Returns
    -------
    formatted_paragraph : str
        The paragraph containing `sentence` with the sentence highlighted
        in color
    """
    
    start = paragraph.index(sentence)
    end = start + len(sentence)
    hightlighted_paragraph = ''.join([
        paragraph[:start],
        '<font color="purple"> <b>' + paragraph[start:end] + '</b> </font>',
        paragraph[end:]
    ])
    wrapped_lines = textwrap.wrap(hightlighted_paragraph, width=width)
    wrapped_lines = [' ' * indent + line for line in wrapped_lines]
    formatted_paragraph = '\n'.join(wrapped_lines)
    
    return formatted_paragraph

In [None]:
uid = 81135
sentence = "This agent binds towards the pocket entrance, but fails to occupy the end of the pocket (Chapman et al., 1991) ."

paragraph = find_paragraph(uid, sentence, db)
print(paragraph)
# print(highlight_in_paragraph(paragraph, sentence, width=80, indent=10))

In [None]:
html_report = ""

def investigate():
    
    def pdf_button_on_click(b):
        
        print("Saving the results to a pdf file.")
#         print(wtext_str_exclusion.value)        
    
        formatted_html_report =  "<h1> Parameters </h1>"
        formatted_html_report += f"""<ul>
                                        <li> Model: {wselect_model.value} </li>
                                        <li> Merge synonyms enabled: {wcheck.value} </li>
                                        <li> Query: {wtext_query.value} </li>
                                        <li> Deprioritised text: {wtext_exclusion.value} </li>
                                        <li> Deprioritised strength: {deprioritization_toggles.value} </li>
                                        <li> Excluded text: {wtext_str_exclusion.value} </li> 
                                     </ul>
                                 """
        formatted_html_report += f"<h1> Results </h1> {html_report}"
        pdfkit.from_string(formatted_html_report, f"report_{datetime.datetime.now()}.pdf")
    
    def investigate_on_click(b):    
        
        global html_report
        html_report = ""
        wout.clear_output()
        with wout:
            print()
            t0 = time.time()
            
            if wcheck.value:
                query_value = sent_preprocessing([wtext_query.value], synonyms_index)
                exclu_value = sent_preprocessing([wtext_exclusion.value], synonyms_index)                
            else:
                query_value = [wtext_query.value]
                exclu_value = [wtext_exclusion.value]
                                    
            print('Embedding query...    ', end=' ')
            embedding_query = embed_sentences(query_value, 
                                              wselect_model.value, 
                                              eval(wselect_model.value.lower()))
            print(f'{time.time()-t0:.2f} s.')
            
            if exclu_value[0]:
                print('Embedding exclusion...    ', end=' ')
                embedding_exclu = embed_sentences(exclu_value, 
                                                  wselect_model.value, 
                                                  eval(wselect_model.value.lower()))
                print(f'{time.time()-t0:.2f} s.')                
            
            print('Computing similarities...', end=' ')
            # For scalability, we will replace this part with FAISS, as in the other part of the code base.
            if wcheck.value:
                arr = embeddings_syns[wselect_model.value]
            else:
                arr = embeddings[wselect_model.value]
            uids, embedding_docs = arr[:, 0], arr[:, 1:]
            similarities_query = cosine_similarity(X=embedding_query, Y=embedding_docs).squeeze()

            if exclu_value[0]:
                similarities_exclu = cosine_similarity(X=embedding_exclu, Y=embedding_docs).squeeze()
            else:
                similarities_exclu = np.zeros_like(similarities_query)
                            
            deprioritizations ={
                'None': (1, 0),
                'Weak': (0.9, 0.1),
                'Mild': (0.8, 0.3),
                'Strong': (0.5, 0.5),
                'Stronger': (0.5, 0.7), 
            }
            # now: maximize L = a1 * cos(x, query) - a2 * cos(x, exclusions)
            alpha_1, alpha_2 = deprioritizations[deprioritization_toggles.value]
            similarities = alpha_1 * similarities_query - alpha_2 * similarities_exclu
            
            print(f'{time.time()-t0:.2f} s.')
            
            print('Ranking documents...     ', end=' ')

            # SUBSTRING EXCLUSIONS
            excluded_words = [x for x in wtext_str_exclusion.value.lower().split('\n') if x] # remove empty strings
            
            indices = np.argsort(-similarities)
            indices_without_excluded = []
            
            ix = 0
            while len(indices_without_excluded) < wselect_count.value:
                sentence_text = db.execute('SELECT Text FROM sections WHERE Id = ?', [uids[indices[ix]]]).fetchall()[0][0].lower()
                is_contained = any([w in sentence_text for w in excluded_words])
                
                if not is_contained:
                    indices_without_excluded.append(indices[ix])

                ix += 1
            

            print(f'{time.time()-t0:.2f} s. Excluded {ix - wselect_count.value} items based on substrings.')
            
            print(Color.RED + f'\nInvestigating: {query_value[0]}\n' + Color.END)
            
            for i, (uid_, sim_) in enumerate(zip(uids[indices_without_excluded], similarities[indices_without_excluded])):
                article_sha, section_name, text = db.execute('SELECT Article, Name, Text FROM sections WHERE Id = ?', [uid_]).fetchall()[0]
                article_auth, article_title, date, ref = db.execute('SELECT Authors, Title, Published, Reference FROM articles WHERE Id = ?', [article_sha]).fetchall()[0]
                article_auth = article_auth.split(';')[0] + ' et al.'
                date = date.split()[0]
                ref = ref if ref else ''
                section_name = section_name if section_name else ''
                
                width = 80
                if w_check_whole_paragraph.value:
                    logger.debug(f"UID={uid_}")
                    try:
                        paragraph = find_paragraph(uid_, text, db)
                        formatted_output = highlight_in_paragraph(paragraph, text, width=width, indent=2)
                    except:
                        formatted_output = "<there was a problem retrieving the paragraph, the original sentence is:>\n"
                        formatted_output += text
                else:
                    formatted_output = textwrap.fill(text, width=width)
                
                
                formatted_output = f'<em> {formatted_output} </em>'
                
                article_metadata = f"""<a href="{ref}">&nbsp;[{i+1:2d}] <br> Source: {article_title} </a>
                                   <br> Author: {article_auth}
                                   <br> Section: {section_name.lower().title()}"""
                
                display(HTML(article_metadata))
                display(HTML(formatted_output))
                print()
                
                html_report += article_metadata + f" <br> <p> {formatted_output} </p> <br>"
    
    wselect_model = widgets.ToggleButtons(
        options=[ 'USE', 'SBERT', 'BSV'],
        description='Model:',
        tooltips=['Universal Sentence Encoder', 'Sentence BERT', 'BioSentVec'],
    )
    
    wselect_count = widgets.IntSlider(value=10, min=0, max=100, description='Top N:',)
    
    wcheck = widgets.Checkbox(value=False, description='merge synonyms')
    w_check_whole_paragraph = widgets.Checkbox(value=False, description='show whole paragraph')
    
    wtext_query = widgets.Textarea(layout=widgets.Layout(width='90%', height='80px'), 
                                   value='Glucose is a risk factor for COVID-19.',
                                   description='Query: ')
    wtext_exclusion = widgets.Textarea(layout=widgets.Layout(width='90%', height='80px'),
                                       value='',
                                       description='Deprioritize: ')
    deprioritization_toggles = widgets.ToggleButtons(
        options=['None', 'Weak', 'Mild', 'Strong', 'Stronger'],
        description='Deprioritization strength',
        disabled=False,
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
#         tooltips=['Description of slow', 'Description of regular', 'Description of fast'],
#         icons=['check'] * 5
        style={'description_width': 'initial', 'button_width': '80px'},
#         layout=widgets.Layout(width='100%', height='80px'),
    )

    wtext_str_exclusion = widgets.Textarea(layout=widgets.Layout(width='90%', height='80px'),
                                       value='',
                                       description='Substring Exclusion (newline separated): ',
                                       style={'description_width': 'initial'})
    investigate_button = widgets.Button(description='Investigate!')
    investigate_button.on_click(investigate_on_click)
    pdf_download_button = widgets.Button(description='Generate PDF Report', layout=widgets.Layout(width='25%'))
    
    pdf_download_button.on_click(pdf_button_on_click)
    
    wout = widgets.Output(layout={'border': '1px solid black'})

    display(widgets.VBox([wselect_model, 
                          wselect_count, 
                          wcheck,
                          w_check_whole_paragraph,
                          wtext_query, 
                          wtext_exclusion,
                          deprioritization_toggles,
                          wtext_str_exclusion,
                          investigate_button,
                          pdf_download_button,
                          wout]))
investigate()

---

#### Example Queries

1. Inhibition of N-glycosylation (using N-glycosylation inhibitors or Lectins) is a potential therapeutic approach for COVID-19 therapy.
1. Is high blood / plasma sugar level or hyperglycemia associated with higher susceptibility to coronavirus infection or higher virus replication?
1. Glucose or sugar is a risk factor for COVID-19.
1. Ketogenic diet is protective against COVID-19.

## Sandbox

In [None]:
synonyms_dict['sugar']

In [None]:
synonyms_dict['risk factor']

In [None]:
HTML('And everyone knows <font style="background-color: #992200"> coronavirus</font> is dangerous.')

In [None]:
# db.close()

In [None]:
# embeddings.close()

In [None]:
# embeddings_syns.close()