<a href="https://colab.research.google.com/github/amoux/corona/blob/master/notebooks/Difference_between_IndexIVFFlat_and_IndexIVFPQ_for_question_answering_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from coronanlp import CORD19, SentenceEncoder
from coronanlp import clean_tokenization, normalize_whitespace
from coronanlp.engine import ScibertQuestionAnswering, BertSummarizerArguments
from coronanlp.retrival import common_tokens, extract_questions
from coronanlp.parser import regex_bracket, regex_figure
from coronanlp.indexing import fit_index_ivf_fpq, fit_index_ivf_hnsw
from coronanlp.utils import render_output as render
from coronanlp.utils import GRADIENTS as G

In [None]:
%%time
cord19 = CORD19.from_store(name='cord19-large')
print(cord19)

CORD19(papers: 29,315, files_sorted: True, source: [
  biorxiv_medrxiv, comm_use_subset, noncomm_use_subset, custom_license,
])
CPU times: user 353 ms, sys: 28 ms, total: 381 ms
Wall time: 380 ms


In [None]:
%%time
if not cord19.sentencizer_enabled:
    cord19.init_sentencizer()
print(cord19.sentencizer_enabled)

True
CPU times: user 5.46 s, sys: 208 ms, total: 5.67 s
Wall time: 5.7 s


In [None]:
%%time
biorxiv = cord19.sample(s=0)
sents = cord19.batch(biorxiv, minlen=16, maxlen=1024, workers=7)

HBox(children=(FloatProgress(value=0.0, description='files/papers', max=885.0, style=ProgressStyle(description…


CPU times: user 6min 23s, sys: 33.5 s, total: 6min 56s
Wall time: 1min 41s


In [None]:
print(sents)

SentenceStore(avg_seqlen: 33.49 | num_papers: 885 | num_sents: 83,819 | num_tokens: 2,807,074)


In [None]:
sents.min(), sents.max(), sents.std()

(16, 889, 18.779493321526648)

In [None]:
largest = sents.largest(10)
largest

[(369, 530),
 (55, 393),
 (229, 386),
 (319, 367),
 (304, 364),
 (783, 364),
 (517, 354),
 (163, 352),
 (489, 351),
 (448, 341)]

In [None]:
pid, count = largest[0]
title = cord19.title(pid)
print(title)

Title: HOPS-dependent endosomal fusion required for efficient cytosolic delivery of therapeutic peptides and small proteins


In [None]:
encoder = SentenceEncoder.from_pretrained("amoux/scibert_nli_squad", device="cuda")
tokenizer, device = encoder.tokenizer, encoder.device

In [None]:
inputs = tokenizer(sents[:5], padding=True, truncation=True,
                   max_length=256, return_tensors="pt")
with torch.no_grad():
    output = encoder(**inputs.to(device))
    queries = output['sentence_embeddings'].to('cpu').numpy()

# Queries to test the both index objects
queries.shape

(5, 768)

In [None]:
%%time
embeddings = encoder.encode(sents, batch_size=12)
assert embeddings.shape[0] == len(sents)

HBox(children=(FloatProgress(value=0.0, description='batch', max=6985.0, style=ProgressStyle(description_width…


CPU times: user 6min 3s, sys: 7.43 s, total: 6min 11s
Wall time: 5min 21s


In [None]:
%%time
def clean_function(x: str) -> str:
    x = regex_bracket.sub(' ', x)
    x = regex_figure.sub(' ', x)
    x = normalize_whitespace(x)
    return clean_tokenization(x)

sents.map(clean_function, inplace=True)

CPU times: user 3.11 s, sys: 0 ns, total: 3.11 s
Wall time: 3.13 s


In [None]:
fpq_index = fit_index_ivf_fpq(embeddings, k=8, nlist=4096, m=32)
fpq_index.is_trained

True

In [None]:
hnsw_index = fit_index_ivf_hnsw(embeddings, metric='l2', m=32)
hnsw_index.is_trained

True

In [None]:
sid = 2
fpq_index.nprobe = 256
D0, I0 = fpq_index.search(queries, k=5)
print(f'(query; {sid}): {sents[sid]}\n')
nearest = sents.get(I0[sid].tolist())
for k, sent in enumerate(nearest):
    print(f'({k+1}) {sent}')

(query; 2): On December 31, 2019, a total of 27 cases were reported; meanwhile, a rapid response team led by the Chinese Centre for Disease Control and Prevention (China CDC) was formed to conduct detailed epidemiologic and aetiologic investigations in Wuhan.

(1) On the evening of 06/02/2020 the UK definition of a suspected case was extended to include people presenting with respiratory illness (defined as cough, shortness of breath or fever with or without other symptoms) returning from or transiting through China including Hong Kong and Macau, Japan, Malaysia, South Korea, Singapore, Taiwan or Thailand within the last 14 days, with the case definition subsequently changing further on 25/02/2020 to include northern Italy, Iran and further countries in SE Asia.
(2) Blood was centrifuged at 4.8K rpm for 10' followed by 12K rpm at 20' and filtration through a We would like to thank Dr. Julia Calzada-Wack, Jacqueline Mueller and Marion Fisch for their 625 kind assistance with tissue proc

In [None]:
sid = 2
hnsw_index.nprobe = 256
D1, I1 = hnsw_index.search(queries, k=5)
print(f'(query; {sid}): {sents[sid]}\n')
nearest = sents.get(I1[sid].tolist())
for k, sent in enumerate(nearest):
    print(f'({k+1}) {sent}')

(query; 2): On December 31, 2019, a total of 27 cases were reported; meanwhile, a rapid response team led by the Chinese Centre for Disease Control and Prevention (China CDC) was formed to conduct detailed epidemiologic and aetiologic investigations in Wuhan.

(1) On December 31, 2019, a total of 27 cases were reported; meanwhile, a rapid response team led by the Chinese Centre for Disease Control and Prevention (China CDC) was formed to conduct detailed epidemiologic and aetiologic investigations in Wuhan.
(2) The From January 21 to February 15, 2020, based on epidemiological evidence, fever and/or respiratory symptoms, chest radiological findings and blood white blood cell (WBC) results, physicians at the Fever Clinic referred 156 cases to panel discussion by multi-discipline experts.
(3) Following Wu et al. 5, we developed a susceptible-exposed-infectious-recovered (SEIR)-based metapopulation model, which has 100 sub-populations representing the 100 cities that have the greatest num

In [None]:
%%time
Q = ScibertQuestionAnswering(
    sents=sents,
    index=fpq_index,
    encoder=encoder,
    cord19=cord19,
    summarizer_kwargs=BertSummarizerArguments(
        ratio=0.3,
        min_length=sents.min(),
        max_length=sents.max(),
        algorithm="kmeans",
    )
)

CPU times: user 13.7 s, sys: 833 ms, total: 14.5 s
Wall time: 13.3 s


In [None]:
%%time
questions = extract_questions(sents, remove_empty=True)

HBox(children=(FloatProgress(value=0.0, description='sentences', max=885.0, style=ProgressStyle(description_wi…


CPU times: user 223 ms, sys: 3.85 ms, total: 226 ms
Wall time: 225 ms


In [None]:
def answer(question, index='fpq', a=0, k=5, p=35, probe=256):
    Q.index = fpq_index if index == 'fpq' else hnsw_index
    theme = 'feels' if index == 'fpq' else 'virgin'
    pred = Q(question, k, p, probe)
    render(pred, pred.a[pred.topk(a)], grad_pair=G[theme])
    
questions._store

{45: ['Why do they need it in high amounts, provided that the Orf1a stoichiometry is thought to be higher than that of Orf1b products?'],
 100: ['What features of the time series might have driven the results in our correlation analysis?'],
 102: ['How does this interaction occur? What parameters would define how effectively pathogens are transmitted into the reservoir?'],
 188: ['What viral genes should be attenuated? How many attenuating mutations should be made to the genome? What synonymous features should be targeted for deoptimization?'],
 229: ['How does crowding among disordered domains overcome the ability of BAR scaffolds to stabilize lipid tubules?'],
 347: ['How do NPs induce cellular stress responses, and why do the PACAs studied here have different effects?'],
 361: ['How does the system transition from a single energy barrier for homogeneous boundaries to two barriers when a gradient in spontaneous curvature at the boundaries is imposed?'],
 487: ['When you suspect that 

In [None]:
answer(questions[0], index='fpq')

In [None]:
answer(questions[0], index='hnsw')

In [None]:
answer(questions[2], index='fpq')

In [None]:
answer(questions[2], index='hnsw')

In [None]:
answer(questions[5], index='fpq')

In [None]:
answer(questions[5], index='hnsw')

In [None]:
answer(questions[6], index='fpq')

In [None]:
answer(questions[6], index='hnsw')

In [None]:
%%time
answer(questions[7], index='fpq')

CPU times: user 5.67 s, sys: 359 ms, total: 6.03 s
Wall time: 1.97 s


In [None]:
%%time
answer(questions[7], index='hnsw')

CPU times: user 5.96 s, sys: 397 ms, total: 6.36 s
Wall time: 2.09 s
