In [1]:
import faiss
import torch
import torch.nn.functional as F
import numpy as np

import coronanlp as corona
from coronanlp.engine import ScibertQuestionAnswering
from coronanlp.ukplab.sentence import SentenceTransformer

In [2]:
SENTENCES_DB = 'dev-data/543820_3858_ml15_select.pkl'
INDEX_IVF_DB = 'dev-data/gold_index_ivg.index'
NLI_ENCODER = '/home/ego/huggingface-models/bundles/CordBERTa/nli_stsb/0_Transformer'

In [3]:
qa = ScibertQuestionAnswering(
    papers=corona.SentenceStore.from_disk(SENTENCES_DB),
    encoder=SentenceTransformer(NLI_ENCODER),
    index=faiss.read_index(INDEX_IVF_DB),
)
qa.all_model_devices

{'summarizer_model_device': device(type='cuda'),
 'sentence_transformer_model_device': device(type='cuda'),
 'question_answering_model_device': device(type='cpu')}

In [4]:
tasklist = corona.TaskList()
tasklist

[Task(id: 1, question: What do we know details diagnostics and surveillance?),
 Task(id: 2, question: What has been published details information sharing and inter-sectoral collaboration?),
 Task(id: 3, question: What has been published details ethical and social science considerations?),
 Task(id: 4, question: What do we know details the effectiveness of non-pharmaceutical interventions?),
 Task(id: 5, question: What has been published details medical care?),
 Task(id: 6, question: What do we know details virus genetics, origin, and evolution?),
 Task(id: 7, question: What do we know details vaccines and therapeutics?),
 Task(id: 8, question: What do we know details COVID-19 risk factors?),
 Task(id: 9, question: What is known details transmission, incubation, and environmental stability?)]

In [5]:
t1 = tasklist[0]
allt1 = t1.all()
print(t1.info)

What has been published concerning systematic, holistic approach to diagnostics (from the public health surveillance perspective to being able to predict clinical outcomes)?


In [6]:
preds = qa.answer(t1.info, topk=5, top_p=25, nprobe=64, mode='bert')
preds.popempty()
preds.ids, preds.dist



(array([[179943,  38779,  48340, 171641,  11026,  16090, 132451,  10551,
         231547,  32627, 203359, 123822, 231157,  59945, 333167, 203328,
          37302,  74584,   1534, 425932, 261597, 268659, 397260,  27072,
         117127]]),
 array([[114.0309 , 119.05531, 127.31623, 128.10754, 134.01633, 135.51642,
         137.91711, 138.64822, 139.26003, 141.16678, 142.14966, 144.36464,
         147.06775, 147.06958, 147.22888, 148.9502 , 149.55885, 149.9512 ,
         150.12485, 150.1834 , 150.33147, 151.04956, 151.89029, 152.3189 ,
         152.47305]], dtype=float32))

In [7]:
list(preds)

[ModelOutput(score=0.02010919339954853, start=11, end=35, answer='laboratory confirmation,'),
 ModelOutput(score=0.00906957034021616, start=133, end=143, answer='sequencing.'),
 ModelOutput(score=0.00279330019839108, start=40, end=143, answer='national reference laboratory aims to obtain material from regional laboratories for further sequencing.'),
 ModelOutput(score=0.002625082153826952, start=11, end=21, answer='laboratory')]

In [8]:
preds.spans()

[(11, 35), (133, 143), (40, 143), (11, 21)]

In [9]:
out = preds[2]
print(out.answer)

national reference laboratory aims to obtain material from regional laboratories for further sequencing.


In [14]:
question, context, answer = preds.q, ' '.join(preds.c), out.answer
corona.render_output(answer=answer, context=context, question=question)

Question: What has been published concerning systematic, holistic approach to diagnostics (from the public health surveillance perspective to being able to predict clinical outcomes)?

Answer: national reference laboratory aims to obtain material from regional laboratories for further sequencing.

Context:

In case of laboratory confirmation, the << national reference laboratory aims to obtain material from regional laboratories for further sequencing. [ANSWER] >> An alternate measure of program success is the extent to which screening delays the first importation of cases into the community, possibly providing additional time to train medical staff, deploy public health responders or refine travel policies (Cowling et al., Unlike the UK national strategy documents and plans, the US National Health Information Infrastructure Strategy document (also known as "Information for Health") refers explicitly to GIS and real-time health and disease monitoring and states that "public health will

In [18]:
sids = preds.ids.tolist()[0]
pids = list(qa.papers.lookup(sids, mode='table').keys())
titles = [x.lower() for x in qa.cord19.titles(pids)]
titles[::2]

['visual tools to assess the plausibility of algorithm- identified infectious disease clusters: an application to mumps data from the netherlands',
 'in silico approach to accelerate the development of mass spectrometry-based proteomics methods for detection of viral proteins: application to covid-19',
 'towards evidence-based, gis-driven national spatial health information infrastructure and surveillance services in the united kingdom',
 'hajj, umrah, and the neglected tropical diseases',
 'the past, present, and future of public health surveillance',
 'improved global capacity for influenza surveillance sign up for twitter and find the latest information about emerging infectious diseases from the eid journal. @cdc_eidjournal',
 'epidemiologic data and pathogen genome sequences: a powerful synergy for public health',
 "strengthening field-based training in low and middle-income countries to build public health capacity: lessons from australia's master of applied epidemiology program"

In [19]:
abstract = qa.cord19.load_paper(pids[0])['abstract']
print(abstract[0]['text'])



In [20]:
pred_labels = corona.common_tokens(titles, nlp=qa.nlp)
labels, freqs = zip(*pred_labels)
labelmap = dict(zip(labels, [k/max(freqs) for k in freqs]))
list(labelmap.items())[:10]

[('health', 1.0),
 ('public', 0.5625),
 ('disease', 0.375),
 ('surveillance', 0.375),
 ('infectious', 0.3125),
 ('detection', 0.25),
 ('diseases', 0.1875),
 ('review', 0.1875),
 ('epidemiology', 0.1875),
 ('global', 0.1875)]

## zero-shot-classification

Let's test the predicted answers from the question answering model finetuned on SQUAD 2.0 with `CordBERTa`.

- **CordBERTa**

    - base-model: roberta-base-cased

    - fine-tuned: language-modeling

    - data/tokens: CORD19 kaggle dataset.
    
    - downstream-task: MNLI (sequence classification)

In [21]:
from transformers import pipeline
from transformers import RobertaTokenizer, RobertaForSequenceClassification
CORDBERTA = '/home/ego/huggingface-models/finetuned/roberta_mnli_cord19/'

In [22]:
classifier = pipeline(
    task='zero-shot-classification',
    model=RobertaForSequenceClassification.from_pretrained(CORDBERTA),
    tokenizer=RobertaTokenizer.from_pretrained(CORDBERTA),
)

In [23]:
hypothesis_template = 'This text is about {}.'
sequences = [o.answer for o in preds]
candidate_labels = list(labels[:5])

zeroshot = classifier(sequences, candidate_labels,
                      hypothesis_template, multi_class=True)
zeroshot

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[{'sequence': 'laboratory confirmation,',
  'labels': ['surveillance', 'disease', 'health', 'public', 'infectious'],
  'scores': [0.8520071506500244,
   0.745579183101654,
   0.6930623054504395,
   0.6136701703071594,
   0.5498717427253723]},
 {'sequence': 'sequencing.',
  'labels': ['surveillance', 'health', 'disease', 'infectious', 'public'],
  'scores': [0.8411460518836975,
   0.6569487452507019,
   0.6352137923240662,
   0.5501412749290466,
   0.5465444922447205]},
 {'sequence': 'national reference laboratory aims to obtain material from regional laboratories for further sequencing.',
  'labels': ['surveillance', 'health', 'disease', 'public', 'infectious'],
  'scores': [0.8664122819900513,
   0.5979554653167725,
   0.40629252791404724,
   0.37213629484176636,
   0.22388437390327454]},
 {'sequence': 'laboratory',
  'labels': ['surveillance', 'disease', 'health', 'public', 'infectious'],
  'scores': [0.7985999584197998,
   0.7546665668487549,
   0.752849280834198,
   0.7494865655899

> Above, we can see that the classifier scores the label `<surveillance>` as the closest match for all predicted answers/sequences (from the question-answering model). Considering we did not directly use the titles as inputs, `<surveillance>` is indeed related to the context of the question asked:

**"What has been published concerning the systematic, holistic approach to diagnostics (from the public health surveillance perspective to being able to predict clinical outcomes)?".**

> Keep in mind; the models did not have access to the titles/labels. Instead, we "forced" the models to put together a "report/summary" using `3,858` research papers or, to be more exact, `543,820` sentences. Another fact to consider; the papers/paper-ids were picked or filtered from a total of `13,202` possible papers to choose from by using the method `Tune IDs to Tasks.` The way it works is simple; Given some `tasks` and a `title_map` (a mapping between `{paper-ids: titles}` e.g., **{4514: "visual tools to assess the plausibility of algorithm ..."}**) where the `tasks` are expressed as queries and `titles` as the database - both represented as dense vectors. We compute each query's association against the database and then sort by the weighted distribution of scored ids. In other words, from the most commonly accessed by each query to the least accessed by each query.

```python
from coronanlp import tune_ids, extract_titles_fast
from coronanlp import TaskList, CORD19
from coronanlp.ukplab import SentenceTransformer

cord19 = CORD19(...)
print(cord19)
```

* The results in this notebook are from the following sources:

```
CORD19(papers: 13202, files_sorted: True, source: [
  biorxiv_medrxiv, noncomm_use_subset, comm_use_subset, pmc_custom_license,
])
```

* Here are the steps to obtain the results metioned above:

```python
...
title_map = extract_titles_fast(cord19, minlen=10, maxids=-1)
gold_ids = tune_ids(
    encoder=SentenceTransformer(...),
    title_map=title_map,
    task_list=TaskList(), 
    target_size=1800,
)
# If you already have an instance of SentenceStore e.g if is the full dataset.
gold_papers = papers.index_select(gold_ids)

# If not you can simply do:
gold_sample = list(gold_ids.sample())
gold_papers = cord19.batch(gold_sample, minlen=15)

print(gold_papers)
# SentenceStore(avg_seqlen=180.96, num_papers=3858, num_sents=543820)
```


In [24]:
encoder_model = qa.encoder._modules['0'].model
encoder_tokenizer = qa.encoder.tokenizer
encoder_maxlength = qa.encoder.max_length
encoder_device = encoder_model.device
print('Encoder: max_length: {}, device: {}'.format(
    encoder_maxlength, encoder_device))

Encoder: max_length: 128, device: cuda:0


In [25]:
text_pairs = [zeroshot[2]['sequence']] + zeroshot[2]['labels']

inputs = encoder_tokenizer.batch_encode_plus(
    text_pairs,
    padding=True,
    truncation=True,
    max_length=encoder_maxlength,
    return_tensors='pt',
)
inputs = inputs.to(encoder_device)
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
output = encoder_model(input_ids, attention_mask=attention_mask)[0]

sentence_rep = output[:1].mean(dim=1)  # == text_pairs[:1]
label_reps = output[1:].mean(dim=1)   # == text_pairs[1:]
sentence_rep.shape, label_reps.shape

(torch.Size([1, 768]), torch.Size([5, 768]))

In [26]:
text_pairs[:1], text_pairs[1:]

(['national reference laboratory aims to obtain material from regional laboratories for further sequencing.'],
 ['surveillance', 'health', 'disease', 'public', 'infectious'])

In [27]:
similarities = F.cosine_similarity(sentence_rep, label_reps, dim=1)
closest = similarities.argsort(descending=True)

for zero_idx, cosine_idx in enumerate(closest):
    print('<k: {1}>\tcosine | zeroshot:\t{0} | {2}'.format(
        text_pairs[1:][cosine_idx], similarities[cosine_idx],
        text_pairs[1:][zero_idx]))

<k: -0.18992996215820312>	cosine | zeroshot:	public | surveillance
<k: -0.2037743777036667>	cosine | zeroshot:	surveillance | health
<k: -0.2134021520614624>	cosine | zeroshot:	health | disease
<k: -0.21375295519828796>	cosine | zeroshot:	disease | public
<k: -0.24101680517196655>	cosine | zeroshot:	infectious | infectious
