In [None]:
import time
import sqlite3

import numpy as np
import pandas as pd
import tensorflow_hub as hub
import ipywidgets as widgets

from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

In [None]:
class Color:
    PURPLE = '\033[95m'
    CYAN = '\033[96m'
    DARKCYAN = '\033[36m'
    BLUE = '\033[94m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    END = '\033[0m'

In [None]:
%%time

# Load universal sentence encoder
use_version = 5
use = hub.load(f"https://tfhub.dev/google/universal-sentence-encoder-large/{use_version}")

# Load SBERT
sbert = SentenceTransformer('bert-base-nli-mean-tokens')

def embed_sentences(sentences,
                    embedding_name,
                    embedding_model):
    '''Sentence embedding.
    
    Parameters
    ----------
    sentences : List[str]
        List of N strings.
    embedding_name : str
        Name of the embedding type. One of ('USE', 'SBERT').
    embedding_model : tf.Model or torch.Module
        Neural net model to create sentence embeddings.
        
    Return
    ------
    encodded_sentences : np.ndarray
        Numpy array of shape (N, n_dims).
    '''
    if embedding_name == 'USE':
        return embedding_model(sentences).numpy()
    elif embedding_name == 'SBERT':
        return np.stack(embedding_model.encode(sentences), axis=0)
    else:
        raise NotImplementedError(f'Embedding {repr(embedding_name)} not '
                                  f'available!')

In [None]:
EMBEDDINGS_NAMES = ['USE', 'SBERT']

In [None]:
embeddings = np.load('sentence_embeddings/sentence_embeddings.npz')

In [None]:
db = sqlite3.connect('cord19q/articles.sqlite')

In [None]:
DEFAULT_QUERY = ("Inhibition of N-glycosylation (using N-glycosylation inhibitors or Lectins)"
                 "is a potential therapeutic approach for COVID-19 therapy.")

In [None]:
def investigate():
    
    def on_clicked(b):
        wout.clear_output()
        with wout:
            print()
            t0 = time.time()
            
            print('Embedding sentence...', end=' ')
            embedding_query = embed_sentences([wtext.value], wselect_model.value, eval(wselect_model.value.lower()))
            print(f'{time.time()-t0:.2f} s.')
            
            print('Computing similarities...', end=' ')
            arr = embeddings[wselect_model.value]
            uids, embedding_docs = arr[:, 0], arr[:, 1:]
            similarities = cosine_similarity(X=embedding_query, Y=embedding_docs).squeeze()
            print(f'{time.time()-t0:.2f} s.')
            
            print('Ranking documents...', end=' ')
            indices = np.argsort(-similarities)[:wselect_count.value]
            print(f'{time.time()-t0:.2f} s.')
            
            print()
            for uid_, sim_ in zip(uids[indices], similarities[indices]):
                article_sha, text = db.execute('SELECT Article, Text FROM sections WHERE Id = ?', [uid_]).fetchall()[0]
                print(f'Section id: {int(uid_):>7,d} --- Similarity: {sim_:.2f}')
                print(Color.BLUE + text + Color.END)
                article_title = db.execute('SELECT Title FROM articles WHERE Id = ?', [article_sha]).fetchone()[0]
                print(Color.GREEN + 'From: ' + article_title + Color.END)
                print()
    
    wselect_model = widgets.ToggleButtons(
        options=[ 'USE', 'SBERT', 'BSV'],
        description='Model:',
        tooltips=['Universal Sentence Encoder', 'Sentence BERT', 'Coming Soon'],
    )
    
    wselect_count = widgets.IntSlider(value=10, min=0, max=100, description='Top N:',)
    
    wtext = widgets.Textarea(value=DEFAULT_QUERY, layout=widgets.Layout(width='90%', height='80px'))

    button = widgets.Button(description='Investigate!')
    button.on_click(on_clicked)
    
    wout = widgets.Output(layout={'border': '1px solid black'})

    display(widgets.VBox([wselect_model, wselect_count, wtext, button, wout]))

In [None]:
# Glucose consumption could also be a risk-factor for COVID-19 severity.

In [None]:
investigate()

In [None]:
# embeddings.close()