In [3]:
from pathlib import Path
from collections import defaultdict
import sqlite3
from time import time

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub

import torch

import sent2vec
from nltk import word_tokenize
from nltk.corpus import stopwords
from string import punctuation
from sentence_transformers import SentenceTransformer

In [4]:
data_path = Path("../data/2020-04-08")

cord_path = data_path / "CORD-19-research-challenge"
databases_path = data_path / "databases"
embeddings_path = data_path / "embeddings"
assets_path = Path("../assets")

assert data_path.exists()
assert cord_path.exists()
assert databases_path.exists()
assert embeddings_path.exists()
assert assets_path.exists()

In [None]:
# import nltk

# nltk.download('stopwords')
# nltk.download('punkt')

In [5]:
# Load universal sentence encoder
univsentenc_version = 5
univsentenc_embedder = hub.load(f"https://tfhub.dev/google/universal-sentence-encoder-large/{univsentenc_version}")

INFO:absl:Using /tmp/tfhub_modules to cache modules.


In [6]:
# Load SBERT
sbert_embedder = SentenceTransformer('bert-base-nli-mean-tokens')

In [7]:
%%time

# Load BioSentVec
bsv_embedder = sent2vec.Sent2vecModel()
bsv_embedder.load_model(str(assets_path / 'BioSentVec_PubMed_MIMICIII-bigram_d700.bin'))

bsv_stopwords = set(stopwords.words('english'))

def bsv_preprocess(text):
    text = text.replace('/', ' / ')
    text = text.replace('.-', ' .- ')
    text = text.replace('.', ' . ')
    text = text.replace('\'', ' \' ')
    text = text.lower()
    tokens = [token for token in word_tokenize(text)
              if token not in punctuation and token not in bsv_stopwords]
    return ' '.join(tokens)

CPU times: user 2.33 s, sys: 12.1 s, total: 14.4 s
Wall time: 14.3 s


In [8]:
synonyms_dict = dict()
with open(assets_path / 'synonyms_list.txt', 'r', encoding='utf-8-sig') as f:
    for l in [l_.strip().lower() for l_ in f]:
        if l:
            w = [l_.strip() for l_ in l.split('=')]
            synonyms_dict[w[0]] = w[1:]

In [9]:
del synonyms_dict['sars']

In [10]:
synonyms_index = {x.lower(): k.lower() for k,v in synonyms_dict.items() for x in v}

In [11]:
def sent_preprocessing(sentences, 
                      synonyms_index):
    """Preprocessing of the sentences. (Lower + Split + Replace Synonym)
    
    Parameters
    ----------
    sentences : List[str]
        List of N strings.
    synonyms_index: dict
        Dictionary containing as key the synonym term and as values the reference of this term.
    """
    
    return [" ".join(synonyms_index.get(y, y) for y in word_tokenize(x.lower()))
            for x in sentences]

In [12]:
def embed_sentences(sentences,
                    embedding_name,
                    embedding_model):
    '''Sentence embedding.
    
    Parameters
    ----------
    sentences : List[str]
        List of N strings.
    embedding_name : str
        Name of the embedding type. One of ('USE', 'SBERT', 'BSV').
    embedding_model : tf.Model or torch.Module
        Neural net model to create sentence embeddings.
        
    Return
    ------
    encodded_sentences : np.ndarray
        Numpy array of shape (N, n_dims).
    '''
    if embedding_name == 'USE':
        return embedding_model(sentences).numpy()
    
    elif embedding_name == 'SBERT':
        return np.stack(embedding_model.encode(sentences), axis=0)
    
    elif embedding_name == 'BSV':
        preprocessed = [bsv_preprocess(x) for x in sentences]
        return embedding_model.embed_sentences(preprocessed)
        
    else:
        raise NotImplementedError(f'Embedding {repr(embedding_name)} not '
                                  f'available!')

In [13]:
db_filename = str(databases_path / 'articles.sqlite')
db = sqlite3.connect(db_filename)
curs = db.cursor()

In [14]:
def create_sentence_embeddings(preprocessing=False):

    embedding_names = ['USE', 'SBERT', 'BSV']
    embedding_models = [univsentenc_embedder, sbert_embedder, bsv_embedder]


    batch_size = 1_000

    x = defaultdict(list)
    arr = defaultdict(list)
    all_ids = []

    curs.execute('SELECT Id, Text FROM sections WHERE Tags IS NOT NULL')
    i = 0
    t0 = time()
    while True:
        i += 1
        batch = curs.fetchmany(batch_size)
        if not batch:
            break
        ids, sentences = zip(*batch)  

        all_ids.extend(ids)

        if preprocessing:
            sentences = sent_preprocessing(sentences, synonyms_index)

        for embedding_name, embedding_model in zip(embedding_names,
                                                       embedding_models):
            x_ = embed_sentences(sentences, 
                                 embedding_name=embedding_name, 
                                 embedding_model=embedding_model)
            x[embedding_name].append(x_)

        print(f'Done processing {batch_size * i} in {time()-t0:.1f} s.')


    print('Concatenate...')

    for embedding_name in embedding_names:
        print(f'processing: {embedding_name}')
        # Concatenate
        xx = np.concatenate(x[embedding_name], axis=0)
        all_ids = np.array(all_ids).reshape((-1, 1))
        arr[embedding_name] = np.concatenate((all_ids, xx), axis=1)
    
    print('Save...')

    if preprocessing:
        file_name = "sentence_embeddings_merged_synonyms.npz"
    else:
        file_name = "sentence_embeddings.npz"
    
    np.savez_compressed(file=str(embeddings_path / file_name), **arr)

In [15]:
%%time
create_sentence_embeddings(preprocessing=False)

Done processing 1000 in 9.8 s.
Done processing 2000 in 16.0 s.
Done processing 3000 in 21.0 s.
Done processing 4000 in 26.5 s.
Done processing 5000 in 34.3 s.
Done processing 6000 in 41.4 s.
Done processing 7000 in 47.3 s.
Done processing 8000 in 52.8 s.
Done processing 9000 in 58.8 s.
Done processing 10000 in 64.4 s.
Done processing 11000 in 70.2 s.
Done processing 12000 in 75.4 s.
Done processing 13000 in 80.5 s.
Done processing 14000 in 85.5 s.
Done processing 15000 in 92.0 s.
Done processing 16000 in 98.9 s.
Done processing 17000 in 105.0 s.
Done processing 18000 in 110.8 s.
Done processing 19000 in 116.7 s.
Done processing 20000 in 122.3 s.
Done processing 21000 in 128.4 s.
Done processing 22000 in 133.7 s.
Done processing 23000 in 139.9 s.
Done processing 24000 in 145.5 s.
Done processing 25000 in 151.2 s.
Done processing 26000 in 157.7 s.
Done processing 27000 in 164.0 s.
Done processing 28000 in 169.3 s.
Done processing 29000 in 176.6 s.
Done processing 30000 in 182.2 s.
Done p

In [16]:
%%time
create_sentence_embeddings(preprocessing=True)

Done processing 1000 in 6.5 s.
Done processing 2000 in 13.1 s.
Done processing 3000 in 18.5 s.
Done processing 4000 in 24.3 s.
Done processing 5000 in 32.3 s.
Done processing 6000 in 39.6 s.
Done processing 7000 in 45.8 s.
Done processing 8000 in 51.6 s.
Done processing 9000 in 57.8 s.
Done processing 10000 in 63.8 s.
Done processing 11000 in 70.0 s.
Done processing 12000 in 75.7 s.
Done processing 13000 in 81.2 s.
Done processing 14000 in 86.8 s.
Done processing 15000 in 93.8 s.
Done processing 16000 in 101.1 s.
Done processing 17000 in 107.6 s.
Done processing 18000 in 113.9 s.
Done processing 19000 in 120.3 s.
Done processing 20000 in 126.4 s.
Done processing 21000 in 133.1 s.
Done processing 22000 in 138.7 s.
Done processing 23000 in 145.4 s.
Done processing 24000 in 151.5 s.
Done processing 25000 in 157.6 s.
Done processing 26000 in 164.6 s.
Done processing 27000 in 171.4 s.
Done processing 28000 in 177.2 s.
Done processing 29000 in 185.3 s.
Done processing 30000 in 191.3 s.
Done 