# Virus-Host Species Relation Extraction
## Notebook 2
### UC Davis Epicenter for Disease Dynamics

In [62]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os

# TO USE A DATABASE OTHER THAN SQLITE, USE THIS LINE
# Note that this is necessary for parallel execution amongst other things...
# os.environ['SNORKELDB'] = 'postgres:///snorkel-intro'

import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [63]:
from snorkel.models import candidate_subclass

VirusHost = candidate_subclass('VirusHost', ['virus', 'host'])

### Load Gold Labels
Gold labels are a _small_ set of examples (here, a subset of our training set) which we label by hand and use to help us develop and refine labeling functions. Unlike the _test set_, which we do not look at and use for final evaluation, we can inspect the development set while writing labeling functions.

In [64]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name = 'gold')
L_gold_dev

<1063x1 sparse matrix of type '<class 'numpy.int32'>'
	with 5 stored elements in Compressed Sparse Row format>

## Part I: Writing Labeling Functions

Labeling functions encode our heuristics and weak supervision signals to generate (noisy) labels for our training candidates.

In Snorkel, our primary interface through which we provide training signal to the end extraction model we are training is by writing **labeling functions (LFs)** (as opposed to hand-labeling massive training sets). 

A labeling function is just a Python function that accepts a `Candidate` and returns `1` to mark the `Candidate` as true, `-1` to mark the `Candidate` as false, and `0` to abstain from labeling the `Candidate`.

In [65]:
import re
from snorkel.lf_helpers import (
    get_left_tokens, get_right_tokens, get_between_tokens,
    get_text_between, get_tagged_text,
)

In [66]:
# Pattern based / rule based labeling functions, which look for certain keywords

detect = {'detect','detected', 'detecting', 'was detected', 'was detectable'}
infect = {'infect', 'infected', 'infecting', 'infection'}
misc = {'isolated', 'positive', 'found', 'host', 'hosts'}

def LF_detect(c):
    if len(detect.intersection(get_between_tokens(c))) > 0: 
        return 1
    elif len(detect.intersection(get_left_tokens(c[0], window=2))) > 0:
        return 1
    elif len(detect.intersection(get_left_tokens(c[1], window=2))) > 0:
        return 1
    elif len(detect.intersection(get_right_tokens(c[0], window=2))) >0:
        return 1
    elif len(detect.intersection(get_right_tokens(c[1], window=2))) >0:
        return 1
    else:
        return 0
    
def LF_infect(c):
    if len(infect.intersection(get_between_tokens(c))) > 0: 
        return 1
    elif len(infect.intersection(get_left_tokens(c[0], window=2))) > 0:
        return 1
    elif len(infect.intersection(get_left_tokens(c[1], window=2))) > 0:
        return 1
    elif len(infect.intersection(get_right_tokens(c[0], window=2))) >0:
        return 1
    elif len(infect.intersection(get_right_tokens(c[1], window=2))) >0:
        return 1
    else:
        return 0
    
def LF_misc(c):
    if len(misc.intersection(get_between_tokens(c))) > 0: 
        return 1
    elif len(misc.intersection(get_left_tokens(c[0], window=2))) > 0:
        return 1
    elif len(misc.intersection(get_left_tokens(c[1], window=2))) > 0:
        return 1
    elif len(misc.intersection(get_right_tokens(c[0], window=2))) >0:
        return 1
    elif len(misc.intersection(get_right_tokens(c[1], window=2))) >0:
        return 1
    else:
        return 0
 

In [67]:
# list of all LFs
LFs = [
    LF_detect, LF_infect, LF_misc
]

In [68]:
labeled = []
for c in session.query(VirusHost).filter(VirusHost.split == 0).all():
    for function in LFs:
        if function(c) != 0:
            if c not in labeled:
                labeled.append(c)
print("Number labeled:", len(labeled))

Number labeled: 167


In [69]:
from snorkel.viewer import SentenceNgramViewer

SentenceNgramViewer(labeled, session)

<IPython.core.display.Javascript object>

SentenceNgramViewer(cids=[[[123], [71], [31]], [[18, 19, 20, 21, 22, 23], [69], [56, 57, 58, 59, 60, 88, 89, 9…

## Part II: Applying Labeling Functions

We run the LFs over all training candidates, producing a set of Labels (Virus and Host) and LabelKeys (the names of the LFs) in the database.

In [70]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)

In [71]:
np.random.seed(1701)
%time L_train = labeler.apply(split=0)
L_train

Clearing existing...
Running UDF...


100%|████████████████████████████████| 1063/1063 [00:07<00:00, 141.53it/s]


Wall time: 7.55 s


<1063x3 sparse matrix of type '<class 'numpy.int32'>'
	with 184 stored elements in Compressed Sparse Row format>

Note that the returned matrix is a special subclass of the `scipy.sparse.csr_matrix` class

In [72]:
# get the candidate names and positions of any candidate in the set
L_train.get_candidate(session, 0) 

VirusHost(Span("b'TBE'", sentence=11915, chars=[343,345], words=[62,62]), Span("b'human'", sentence=11915, chars=[208,212], words=[39,39]))

In [73]:
# get the LabelKey (the name of the LF used to identify the candidate)
L_train.get_key(session, 0)

LabelKey (LF_detect)

Viewing statistics about the resulting label matrix:

* **Coverage** is the fraction of candidates that the labeling function emits a non-zero label for.
* **Overlap** is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a non-zero label for.
* **Conflict** is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a *conflicting* non-zero label for.

In [74]:
L_train.lf_stats(session)

Unnamed: 0,j,Coverage,Overlaps,Conflicts
LF_detect,0,0.027281,0.003763,0.0
LF_infect,1,0.037629,0.01317,0.0
LF_misc,2,0.108184,0.01223,0.0


## Part III: Fitting a Generative Model
Now, we'll train a model of the LFs to estimate their accuracies. Once the model is trained, we can combine the outputs of the LFs into a single, noise-aware training label set for our extractor. Intuitively, we'll model the LFs by observing how they overlap and conflict with each other.

In [76]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel()
gen_model.train(L_train, epochs=100, decay=0.95, step_size=0.1 / L_train.shape[0], reg_param=1e-6)

Inferred cardinality: 2


In [77]:
gen_model.weights.lf_accuracy

array([0.08450605, 0.09574646, 0.117071  ])