In [1]:
import warnings
warnings.filterwarnings('ignore')

import coronanlp
from coronanlp.ukplab import SentenceTransformer
from coronanlp.engine import ScibertQuestionAnswering, QuestionAnsweringArguments
from coronanlp.summarization import BertSummarizerArguments
from coronanlp.utils import get_store_dir as store_home

cordberta = '/home/ego/huggingface-models/bundles/CordBERTa/nli_stsb/0_Transformer/'

In [2]:
logs = {
    p.lstat().st_atime: p.name for p in store_home().iterdir()
    if p.is_dir()
}
store = logs[max(logs)]  # last used/accessed
store

'cord19_large'

In [3]:
qa = ScibertQuestionAnswering(
    papers=corona_nlp.Papers.from_disk(store),
    index=corona_nlp.load_store(type_store='index', store_name=store),
    encoder=SentenceTransformer(cordberta, device='cpu'),
    model_device='cuda',
    summarizer_hidden=-4,
    summarizer_reduce='max',
    summarizer_kwargs=BertSummarizerArguments(
        ratio=0.2,
        min_length=40,
        max_length=600,
        use_first=True,
        algorithm='gmm'
    ),
)
qa.all_model_devices

{'summarizer_model_device': device(type='cuda'),
 'sentence_transformer_model_device': 'cpu',
 'question_answering_model_device': device(type='cuda', index=0)}

In [4]:
tasks = coronanlp.TaskList()
tasks

[Task(id: 1, question: What do we know details diagnostics and surveillance?),
 Task(id: 2, question: What has been published details information sharing and inter-sectoral collaboration?),
 Task(id: 3, question: What has been published details ethical and social science considerations?),
 Task(id: 4, question: What do we know details the effectiveness of non-pharmaceutical interventions?),
 Task(id: 5, question: What has been published details medical care?),
 Task(id: 6, question: What do we know details virus genetics, origin, and evolution?),
 Task(id: 7, question: What do we know details vaccines and therapeutics?),
 Task(id: 8, question: What do we know details COVID-19 risk factors?),
 Task(id: 9, question: What is known details transmission, incubation, and environmental stability?)]

In [5]:
t1 = tasks[0].all()
t1[0]

'How widespread current exposure is to be able to make immediate policy recommendations on mitigation measures. Denominators for testing and a mechanism for rapidly sharing that information, including demographics, to the extent possible. Sampling methods to determine asymptomatic disease (e.g., use of serosurveys (such as convalescent samples) and early detection of disease (e.g., use of screening of neutralizing antibodies such as ELISAs).'

In [18]:
preds = qa.answer(t1[0], topk=5, top_p=25, nprobe=256, mode='bert')
print(preds) # output shape (3, 25) since the question was split into 3 sentences.

QuestionAnsweringOutput(size: 15, shape: (3, 25))


In [73]:
from coronanlp.utils import GRADIENTS as G

index = 0  # lets go over the predictions to the first split of 1/3
split = int(len(preds) / preds.shape()[0])
output = list(filter(lambda topk: topk.start > 0, preds[:split]))[0]
question, context, answer = preds.q[index], preds.c, output.answer

corona_nlp.render_output(
    answer=answer,
    question=question,
    context=context,
    grad_pair=G['mauve'],
)

In [75]:
import pandas as pd

data = {'sid': [], 'pid': [], 'dist': [], 'in_ctx': [],
        'query': [], 'title': [], 'sent': []}

# Abbreviations: sid = sentence-id, pid = paper-id, ctx: context
for q in range(preds.ids.shape[0]):
    query = preds.q[q]
    for p in range(preds.ids.shape[1]):
        sid, dist = preds.ids[q][p], preds.dist[q][p]
        pid = qa.papers.decode(sid.item())
        title = qa.cord19.title(pid)
        sent = qa.papers[sid.item()]
        in_ctx = True if sent in preds.c else False
        rows = [sid, pid, dist, in_ctx, query, title, sent]
        for col, row in zip(data.keys(), rows):
            data[col].append(row)
df1 = pd.DataFrame(data=data)
df1

Unnamed: 0,sid,pid,dist,in_ctx,query,title,sent
0,1713442,11640,135.647797,True,How widespread current exposure is to be able ...,Middle East Respiratory Syndrome Coronavirus T...,Until additional evidence is available to furt...
1,1344096,8886,137.929214,False,How widespread current exposure is to be able ...,,"Alternatively, should we focus limited resourc..."
2,1218420,8033,146.825317,True,How widespread current exposure is to be able ...,"Modeling the impact of air, sea, and land trav...","In mitigating viral pandemics, the benefit to ..."
3,1676860,11342,147.331787,False,How widespread current exposure is to be able ...,Potential Impact of Antiviral Drug Use during ...,"However, the likely rapid global spread of a p..."
4,693563,4573,152.514359,True,How widespread current exposure is to be able ...,The Use of Recombinant Feline Interferon Omega...,"Concerning these results, rFeIFNω seems to be ..."
...,...,...,...,...,...,...,...
70,1649954,11035,150.736465,False,Sampling methods to determine asymptomatic dis...,,The ABCs infrastructure was used to conduct ca...
71,1272224,8400,150.793945,False,Sampling methods to determine asymptomatic dis...,Magnetic Nanotrap Particles Preserve the Stabi...,The current diagnostic approaches to confirm V...
72,1254810,8281,150.911377,False,Sampling methods to determine asymptomatic dis...,Selection of key recommendations for quality i...,We selected key recommendations for the broad ...
73,1256569,8316,150.933273,False,Sampling methods to determine asymptomatic dis...,Evaluation of Targeted Next-Generation Sequenc...,Current tests for infectious disease diagnosis...


In [74]:
model_data = {'question': [], 'answer': [], 'score': []}
num_inputs = 5
max_length = len(preds)
for q, x in enumerate(range(0, max_length, num_inputs)):
    split = preds[x: min(x+num_inputs, max_length)]
    query = preds.q[q]
    for pred in split:
        model_data['question'].append(query)
        model_data['answer'].append(pred.answer)
        model_data['score'].append(pred.score)
df2 = pd.DataFrame(data=model_data)
df2

Unnamed: 0,question,answer,score
0,How widespread current exposure is to be able ...,,0.3316216
1,How widespread current exposure is to be able ...,continued use of existing precautionary recomm...,0.01710174
2,How widespread current exposure is to be able ...,Until additional evidence,0.0062957
3,How widespread current exposure is to be able ...,Until additional evidence is available,0.005963233
4,How widespread current exposure is to be able ...,continued use of existing precautionary recomm...,0.005673415
5,Denominators for testing and a mechanism for r...,travel restrictions,0.2583641
6,Denominators for testing and a mechanism for r...,imposing travel restrictions,0.06701491
7,Denominators for testing and a mechanism for r...,travel,0.0349847
8,Denominators for testing and a mechanism for r...,"viral pandemics, the benefit to be gained from...",0.03082118
9,Denominators for testing and a mechanism for r...,travel restrictions as an adjunct to other eff...,0.03055965


In [77]:
topk = 2
tags = corona_nlp.common_tokens(map(str.lower, preds.q), nlp=qa.nlp)
tags, _ = zip(*list(filter(lambda k: len(k[0]) > 3, tags))[:topk])
file_1 = '_'.join(tags) + '_predictions.csv'
file_2 = '_'.join(tags) + '_q&a_results.csv'
df1.to_csv(file_1)
df2.to_csv(file_2)

In [79]:
!ls *.csv

 disease_widespread_predictions.csv  'disease_widespread_q&a_results.csv'
