In [None]:
VERSION = 'cord19_v47'

In [None]:
FIRST_RUN = True

In [None]:
DEBUG_RUN = False

In [None]:
SEED = 9173

# Imports

In [None]:
# ! pip install pyarrow  # Needed by Pandas for Parquet operations.

In [None]:
import re
from pathlib import Path
from datetime import datetime

In [None]:
import torch
import spacy
import numpy as np
import pandas as pd
from tqdm import tqdm
from bbsearch.utils import H5

# Load sentences

In [None]:
filename = f'sentences_{VERSION}.parquet'

if FIRST_RUN:
    import sqlalchemy
    engine = sqlalchemy.create_engine(f'mysql+pymysql://guest:guest@dgx1.bbp.epfl.ch:8853/{VERSION}')
    sentences = pd.read_sql(f"SELECT sentence_id, text FROM sentences", engine, 'sentence_id')
    sentences.to_parquet(filename, index=True)
else:
    sentences = pd.read_parquet(filename)

if DEBUG_RUN:
    sentences = sentences.sample(10000)

scount = sentences.size
print(f'{scount:,} sentences')

# Explore sentences

In [None]:
%matplotlib inline

## Duplicates

In [None]:
sentences = sentences.drop_duplicates()

dcount = sentences.size
print(f'{dcount:,} sentences (- {scount-dcount:,} duplicates)')

## Length

In [None]:
sentences.text.str.len().hist(figsize=(8, 6), bins=75, log=True)

# Select sentences

Strategies:
- [x] keywords
- [ ] annotations
- [ ] k-means
- [ ] LDA

## Keywords

In [None]:
# All keywords in bold from BBS Ontology v0.3 on 17.09.2020.

keywords = {'pathogens', 'cardiac injury', 'cardiovascular disease', 'sars',
            'acute respiratory distress syndrome', 'gas exchange', 'inflammation',
            'sars-cov-2 infection', 'viral entry', 'glucose metabolism', 'golgi', 'human',
            'dry cough', 'mammals', 'cardiovascular injury', 'glycation', 'endoplasmic reticulum',
            'carbohydrates', 'innate immunity', 'igt', 'polysaccharide', 'hypertension',
            'thrombotic events', 'neutrophils', 'dc cells', 'obesity', 'congested cough',
            'influenzavirus', 'viral replication', 'septic shock', 'macrophages', 'cvd', 'lactate',
            'myalgia', 'chest pain', 'oxygen', 'mucociliary clearance', 'high blood sugar level',
            'respiratory failure', 'fever', 'systemic disorder', 'flu', 'influenzae',
            'hyperglycemia', 'impaired glucose tolerance', 'iron',
            'severe acute respiratory syndrome', 'immunity', 'host defense',
            'respiratory viral infection', 'multi-organs failure', 'blood clot',
            'viral infection', 'hypoxia', 'glucose homeostasis', 'vasoconstriction', 'covid-19',
            'sars-cov-2', 'fatigue', 'multiple organ failure', 'productive cough',
            'adaptive immunity', 'atp', 'bacteria', 'nk cells', 'coagulation', 'ards', 'diarrhea',
            'cytokine storm', 'dendritic cells', 'pneumonia', 'thrombosis', 'phagocytosis',
            'alveolar macrophages', 'glucose', 'clearance', 'epithelial cells', 'glucose uptake',
            'coronavirus', 'plasma membrane', 'lymphocytes', 'oxidative stress', 'glycans',
            'glycolysis', 'pulmonary embolism', 'glycosylation', 'viruses',
            'viral respiratory tract infection', 'diabetes', 'life-cycle', 'mammalia',
            'antimicrobials activity', 'ketones', 'immune system', 'pathogen'}

In [None]:
def ok(text: str) -> pd.Series:
    conditions = (
        100 >= len(text) <= 300,
        re.match('^[A-Z][a-z]+ .*', text),
        # TODO Improve matching.
        not {x.lower() for x in text.split()}.isdisjoint(keywords),
        # TODO Keep only English.
    )
    return all(conditions)

filtered = sentences[sentences.text.map(lambda x: ok(x))].copy()

fcount = filtered.size
print(f'{fcount:,} sentences ({scount-fcount:,} not selected)')

In [None]:
filtered['mapping'] = np.arange(fcount)

# Sample sentences

In [None]:
n = 20

sampled = filtered.sample(n, random_state=SEED)

# Load embeddings

In [None]:
def load_embeddings(model: str, version: str) -> torch.Tensor:
    # 'model' in ['Sent2Vec', 'BSV']
    path = Path(f'/raid/sync/proj115/bbs_data/{version}/embeddings/embeddings.h5')
    # TODO Load only filtered indices.
    # NB H5.load(...) changes the ordering if given indices.
    embeddings = H5.load(path, model, 10000)
    tensor = torch.from_numpy(embeddings)
    norm = torch.norm(tensor, dim=1, keepdim=True)
    norm[norm == 0] = 1
    tensor /= norm
    return tensor

mapping = filtered.index.values - 1
embeddings = load_embeddings('Sent2Vec', VERSION)[mapping]

ecount = embeddings.size()[0]
print(f'{ecount == fcount} (- {fcount-ecount:,})')

# Pair sentences

Strategies:
- [ ] random
- [x] most similar
- [ ] quartiles
- [ ] power law

In [None]:
nlp = spacy.load('en_core_sci_lg')

In [None]:
def compute_similarities(index: int, embeddings: torch.Tensor) -> torch.Tensor:
    embedding = embeddings[index]
    norm = torch.norm(embedding).item()
    norm = 1 if norm == 0 else norm
    embedding /= norm
    return torch.nn.functional.linear(embedding, embeddings)

rows = []

for x in tqdm(sampled.itertuples(), total=n):
    similarities = compute_similarities(x.mapping, embeddings)
    sims, idxs = similarities.sort(descending=True)
    
    # TODO Add other strategies.
    sim, idx = sims[1:][0].item(), idxs[1:][0].item()
    
    i0, s0 = x.Index, x.text
    row = filtered.loc[filtered.mapping == idx]
    i1, s1 = row.index.item(), row.text.item()
    
    doc0, doc1 = nlp(s0), nlp(s1)
    set0, set1 = {x.lemma_ for x in doc0 if x.is_alpha}, {x.lemma_ for x in doc1 if x.is_alpha}
    dissimilarity = min(len(set0 - set1) / len(set0), len(set1 - set0) / len(set1))
    
    rows.append((i0, s0, i1, s1, sim, 1 - dissimilarity))
    
cols = ['sentence_id_1', 'sentence_text_1', 'sentence_id_2', 'sentence_text_2',
        'vectors_similarity', 'words_similarity']

pairs = pd.DataFrame(rows, columns=cols).sort_values('vectors_similarity', ascending=False)

In [None]:
def format_results(pairs: pd.DataFrame) -> str:
    def _(i, x):
        return (
            f'pair: {i}  id_1: {x.sentence_id_1}  id_2: {x.sentence_id_2}  '
            f'vectors_sim: {x.vectors_similarity:.3f}  words_sim: {x.words_similarity:.3f}\n'
            f'-\n'
            f'{x.sentence_text_1.strip()}\n'
            f'-\n'
            f'{x.sentence_text_2.strip()}\n'
        )
    formatted = (_(i, x) for i, x in enumerate(pairs.itertuples()))
    return '\n\n'.join(formatted)

print(format_results(pairs[:10]))

In [None]:
def write_results_txt(pairs: pd.DataFrame, n: int, directory: str) -> None:
    time =  datetime.now().strftime("%Y-%m-%d_%Hh%M")
    filename = f'pairs_n{n}_{time}.txt'
    path = Path(directory, filename)
    content = format_results(pairs)
    path.write_text(content, encoding='utf-8')
    print(f'<file> {filename}')

write_results_txt(pairs, n, '.')