# Virus-Host Species Relation Extraction
## Notebook 1
### UC Davis Epicenter for Disease Dynamics


- Input files required to work:
    - documents: 'pdfs.tsv'
    - host/species names: 'domestic_names.csv', 'ictv_animals.csv', 'ictv_viruses.csv', 'virus_abbrev.csv'

## Part I: Preprocessing the Text Corpus

In [114]:
import numpy as np
import pandas as pd

In [115]:
import os
from pathlib import Path

In [116]:
# Load Snorkel
%load_ext autoreload
%autoreload 2
%matplotlib inline

from snorkel import SnorkelSession
session = SnorkelSession()

n_docs = 500

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [117]:
from snorkel.parser.spacy_parser import Spacy
from snorkel.parser import CorpusParser
from snorkel.models import Document, Sentence
from snorkel.parser import TSVDocPreprocessor

### Reading in the documents using a document preprocessor

The PDF documents have been converted to a .tsv file, with a format of document name tab-separated by document content. The doc preprocessor reads in the documents. 

In [118]:
#doc_preprocessor = TSVDocPreprocessor('pdfs.tsv', max_docs=n_docs) # old file (39 papers)
doc_preprocessor = TSVDocPreprocessor('pdfs_big.tsv', max_docs=n_docs) # new files (88 papers)

### Running a `CorpusParser`

We use Spacy, an NLP preprocessing tool, which splits the documents into sentences and tokens. 

In [119]:
corpus_parser = CorpusParser(parser=Spacy())
%time corpus_parser.apply(doc_preprocessor, count=n_docs)

Clearing existing...
Running UDF...





  0%|                                                  | 0/500 [00:00<?, ?it/s]


  0%|▏                                         | 2/500 [00:00<01:37,  5.09it/s]


  1%|▎                                         | 3/500 [00:00<01:48,  4.57it/s]


  1%|▎                                         | 4/500 [00:00<01:41,  4.87it/s]


  1%|▍                                         | 5/500 [00:01<01:36,  5.14it/s]


  1%|▌                                         | 6/500 [00:01<01:42,  4.82it/s]


  1%|▌                                         | 7/500 [00:01<01:27,  5.60it/s]


  2%|▋                                         | 8/500 [00:01<01:51,  4.42it/s]


  2%|▊                                         | 9/500 [00:02<02:03,  3.99it/s]


  2%|▊                                        | 10/500 [00:02<02:00,  4.07it/s]


  2%|▉                                        | 11/500 [00:02<01:56,  4.19it/s]


  2%|▉                                        | 12/500 [00:02<01:48,  4.51it/s]


  3%|█       

Wall time: 38.5 s


In [120]:
print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Documents: 88
Sentences: 19011


### Import dictionaries for entity matching

We create matcher functions from pre-defined dictionaries to match virus/host names in the text data. The functions match full names, abbreviations, and acronyms. The data is given by ICTV virus classification and IUCN list of animal species.

In [121]:
# Create a list of animal host names 
domestic_names = pd.read_csv('domestic_names.csv')
names1 = domestic_names.iloc[:,0]
names2 = domestic_names.iloc[:,1]
names3 = domestic_names.iloc[:,2]
names_list = names1.append([names2,names3])
names_list = names_list.tolist()
names_list.append("dromedary")
names_list.append("Peking")

In [122]:
ictv_animals = pd.read_csv('ictv_animals.csv')
#print('Total animal names:', ictv_animals.count().sum()) # total number of animal names in the ddict
ictv_series = ictv_animals.stack().reset_index().iloc[:,2]
ictv_list = ictv_series.tolist()

In [123]:
# Function that gets first letter of genus + species name 
def name(s): 
    # split the string into a list  
    l = s.split() 
    new_word = ""  # begins as empty string
    if len(l) == 2:
        for i in range(len(l)-1): 
            s = l[i] 
            # adds the capital first character  
            new_word += (s[0].upper()+'. ') 
        new_word += l[-1].title() # add the last word
        return new_word 
    else:
        return s

In [124]:
ictv_list2 = [name(s) for s in ictv_list] # shortened species names list
animals_list = list(set(names_list + ictv_list + ictv_list2))
# remove terms we don't want to match
dont_want = ['once','ounce','ou','mal']
animals_list = [a for a in animals_list if a.lower() not in dont_want]
#animals_list.remove('Mal')

In [125]:
# Create a list of virus names
ictv_viruses = pd.read_csv('ictv_viruses.csv')
# create copies of certain virus names without the digit at the end
ictv_viruses['Species2'] = ictv_viruses['Species'].str.replace('\d+', '', regex=True)

In [126]:
ictv_v_series = ictv_viruses.stack().reset_index().iloc[:,2].drop_duplicates()
virus_list = ictv_v_series.tolist()

In [127]:
virus_abbrev = pd.read_csv('virus_abbrev.csv', header = None)
virus_list = virus_list + virus_abbrev.iloc[:,0].tolist() 
# remove terms we don't want to match
dont_want2 = ['bat', 'bat virus', 'den', 'langur', 'mcp', 'con', 'spf', '(SPF)', 'his'] 
virus_list = [a for a in virus_list if a.lower() not in dont_want2]


In [128]:
# Clean up white space and remove any empty strings
animals_list = [animal.strip() for animal in animals_list]
animals_list = list(filter(None, animals_list))
virus_list = [virus.strip() for virus in virus_list]
virus_list = list(filter(None, virus_list))

In [129]:
# search the lists for any unwanted terms:
import re
#r = re.compile("mal", flags=re.IGNORECASE)
#animals_list2 = []
#for a in animals_list:
#    if len(a) < 10:
#        animals_list2.append(a)
#new_list = list(filter(r.match, animals_list2))
#print(new_list)

In [130]:
animals_list.remove('Mal')

In [131]:
print('Virus terms to match:', len(virus_list))
print('Host terms to match:', len(animals_list))

Virus terms to match: 8476
Host terms to match: 69873


## Part II: Candidate Extraction

The next step is to extract candidates from the text. A `candidate` in Snorkel is the object we want to make a prediction on. In our case, the candidate are pairs of virus-host species mentions. Our task will be to predict which pairs are correctly described as linked in the text.

In [132]:
from snorkel.matchers import DictionaryMatch
from snorkel.candidates import Ngrams, CandidateExtractor
from snorkel.models import candidate_subclass

In [133]:
# Define the candidate schema to extract (virus-host pair). This is a subclass of candidate and is defined using a helper function. The VirusHost mention connects two Spans of text and creates the table in the database backend.
VirusHost = candidate_subclass('VirusHost', ['virus', 'host'])

### Writing a basic `CandidateExtractor`

* `CandidateExtractor` is a basic function to extract **candidate Virus-Host relation mentions** from the corpus.

* We will extract `Candidates` by identifying, for each `Sentence`, all pairs of n-grams (up to 7-grams) that were tagged. (An n-gram is a span of text made up of n tokens; A token is a string of contiguous characters between two spaces). 

<br>

We do this with three objects:

* A `ContextSpace` defines the "space" of all candidates we even potentially consider; in this case we use the `Ngrams` subclass, and look for all n-grams up to 7 words long

* A `Matcher` heuristically filters the candidates we use. 

* A `CandidateExtractor` combines this all together

In [134]:
# Define the dictionary matchers, define the candidate extractor
ngrams = Ngrams(n_max=10)
virus_matcher = DictionaryMatch(d = virus_list)
animals_matcher = DictionaryMatch(d = animals_list)
cand_extractor = CandidateExtractor(VirusHost, [ngrams, ngrams], [virus_matcher, animals_matcher], nested_relations = True)

### Split the docs into 3 sets: training, development, and testing sets

In [135]:
from snorkel.models import Document

docs = session.query(Document).order_by(Document.name).all()

train_sents = set()
dev_sents   = set()
test_sents  = set()

for i, doc in enumerate(docs):
    for s in doc.sentences:
        if i % 10 == 8:
            dev_sents.add(s)
        elif i % 10 == 9:
            test_sents.add(s)
        else:
            train_sents.add(s)

In [136]:
# Number of candidates per set
print(len(train_sents))
print(len(dev_sents))
print(len(test_sents))

15388
1915
1708


In [137]:
%%time
for i, sents in enumerate([train_sents, dev_sents, test_sents]):
    cand_extractor.apply(sents, split=i)
    print("Number of candidates:", session.query(VirusHost).filter(VirusHost.split == i).count())

Clearing existing...
Running UDF...





  0%|                                                | 0/15388 [00:00<?, ?it/s]


  0%|                                      | 38/15388 [00:00<00:41, 372.16it/s]


  1%|▏                                    | 101/15388 [00:00<00:36, 423.31it/s]


  1%|▎                                    | 152/15388 [00:00<00:34, 443.62it/s]


  1%|▍                                    | 200/15388 [00:00<00:33, 449.97it/s]


  2%|▋                                    | 276/15388 [00:00<00:29, 511.58it/s]


  2%|▊                                    | 335/15388 [00:00<00:28, 529.82it/s]


  3%|▉                                    | 388/15388 [00:00<00:28, 528.13it/s]


  3%|█                                    | 455/15388 [00:00<00:26, 563.82it/s]


  3%|█▏                                   | 512/15388 [00:00<00:30, 493.46it/s]


  4%|█▎                                   | 563/15388 [00:01<00:30, 481.23it/s]


  4%|█▌                                   | 627/15388 [00:01<00:28, 519.80it/s]


  4%|█▋      

 38%|█████████████▌                      | 5787/15388 [00:11<00:19, 503.30it/s]


 38%|█████████████▋                      | 5844/15388 [00:11<00:18, 521.47it/s]


 38%|█████████████▊                      | 5904/15388 [00:11<00:17, 541.17it/s]


 39%|█████████████▉                      | 5960/15388 [00:11<00:17, 543.35it/s]


 39%|██████████████                      | 6016/15388 [00:11<00:17, 524.94it/s]


 39%|██████████████▏                     | 6070/15388 [00:12<00:35, 265.31it/s]


 40%|██████████████▎                     | 6114/15388 [00:12<00:30, 301.12it/s]


 40%|██████████████▍                     | 6171/15388 [00:12<00:26, 350.05it/s]


 40%|██████████████▌                     | 6231/15388 [00:12<00:22, 398.38it/s]


 41%|██████████████▋                     | 6295/15388 [00:12<00:20, 449.17it/s]


 41%|██████████████▊                     | 6352/15388 [00:12<00:18, 477.15it/s]


 42%|██████████████▉                     | 6407/15388 [00:12<00:18, 491.43it/s]


 42%|███████████

 75%|██████████████████████████▍        | 11617/15388 [00:22<00:06, 597.71it/s]


 76%|██████████████████████████▌        | 11678/15388 [00:22<00:06, 572.38it/s]


 76%|██████████████████████████▋        | 11739/15388 [00:23<00:06, 583.01it/s]


 77%|██████████████████████████▊        | 11800/15388 [00:23<00:06, 590.69it/s]


 77%|██████████████████████████▉        | 11862/15388 [00:23<00:05, 599.02it/s]


 78%|███████████████████████████▏       | 11931/15388 [00:23<00:05, 618.49it/s]


 78%|███████████████████████████▎       | 11994/15388 [00:23<00:05, 566.34it/s]


 78%|███████████████████████████▍       | 12052/15388 [00:23<00:05, 568.53it/s]


 79%|███████████████████████████▌       | 12122/15388 [00:23<00:05, 602.33it/s]


 79%|███████████████████████████▋       | 12190/15388 [00:23<00:05, 623.55it/s]


 80%|███████████████████████████▊       | 12254/15388 [00:23<00:05, 572.02it/s]


 80%|████████████████████████████       | 12313/15388 [00:24<00:05, 540.61it/s]


 80%|███████████

Number of candidates: 3872
Clearing existing...
Running UDF...





  0%|                                                 | 0/1915 [00:00<?, ?it/s]


  3%|█                                      | 53/1915 [00:00<00:03, 529.45it/s]


  5%|██                                    | 103/1915 [00:00<00:03, 513.70it/s]


  9%|███▌                                  | 177/1915 [00:00<00:03, 565.45it/s]


 12%|████▋                                 | 239/1915 [00:00<00:02, 578.99it/s]


 15%|█████▋                                | 286/1915 [00:00<00:03, 479.78it/s]


 19%|███████                               | 355/1915 [00:00<00:02, 525.51it/s]


 22%|████████▌                             | 430/1915 [00:00<00:02, 575.89it/s]


 26%|█████████▊                            | 493/1915 [00:00<00:02, 590.96it/s]


 29%|██████████▉                           | 553/1915 [00:00<00:02, 543.44it/s]


 33%|████████████▍                         | 624/1915 [00:01<00:02, 580.14it/s]


 36%|█████████████▋                        | 689/1915 [00:01<00:02, 599.33it/s]


 39%|████████

Number of candidates: 347
Clearing existing...
Running UDF...





  0%|                                                 | 0/1708 [00:00<?, ?it/s]


  2%|▋                                      | 32/1708 [00:00<00:05, 319.70it/s]


  5%|█▉                                     | 84/1708 [00:00<00:04, 356.94it/s]


  8%|██▉                                   | 131/1708 [00:00<00:04, 383.67it/s]


 11%|████                                  | 182/1708 [00:00<00:03, 411.35it/s]


 14%|█████▍                                | 243/1708 [00:00<00:03, 455.79it/s]


 18%|██████▊                               | 305/1708 [00:00<00:02, 495.02it/s]


 21%|███████▉                              | 354/1708 [00:00<00:02, 488.92it/s]


 24%|█████████                             | 410/1708 [00:00<00:02, 504.02it/s]


 28%|██████████▌                           | 474/1708 [00:00<00:02, 538.21it/s]


 31%|███████████▊                          | 529/1708 [00:01<00:02, 504.26it/s]


 34%|█████████████                         | 589/1708 [00:01<00:02, 528.08it/s]


 38%|████████

Number of candidates: 506
Wall time: 40.9 s


In [138]:
print("Number of training candidates:", session.query(VirusHost).filter(VirusHost.split == 0).count())
print("Number of development candidates:", session.query(VirusHost).filter(VirusHost.split == 1).count())
print("Number of test candidates:", session.query(VirusHost).filter(VirusHost.split == 2).count())
print("Total candidates extracted:", session.query(VirusHost).count())

Number of training candidates: 3872
Number of development candidates: 347
Number of test candidates: 506
Total candidates extracted: 4725


In [139]:
cand_extracted = []
for c in session.query(VirusHost).filter(VirusHost.split == 1).all():
    cand_extracted.append(c)
print("Development set candidates extracted:", len(cand_extracted))

Development set candidates extracted: 347


In [140]:
# viewing and hand lableing the first 100 candidates of the development set

from snorkel.viewer import SentenceNgramViewer

SentenceNgramViewer(cand_extracted[0:100], session, height = 350)

<IPython.core.display.Javascript object>

SentenceNgramViewer(cids=[[[45, 48, 50], [98, 99], [58]], [[22], [2], [11, 42]], [[10], [5], [7]], [[44, 51], …

In [141]:
# sentenceviewer can be used to hand-label data for gold label set: just export the sqlite table to csv format and make sure the util_virushost.py file points to the file location. Then the gold labels will be saved and can be exported such as below:

### Part III: Import gold labels (hand labels) to check performance

The hand labeled set is used to evaluate the quality of the model.

In [147]:
from util_virushost import load_external_labels

%time missed = load_external_labels(session, VirusHost, annotator_name = 'gold')

AnnotatorLabels created: 71
Wall time: 572 ms


In [148]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name = 'gold')
L_gold_dev

<3872x1 sparse matrix of type '<class 'numpy.int32'>'
	with 5 stored elements in Compressed Sparse Row format>

### Next steps: Developing Labeling Functions in Notebook 2