# Virus-Host Species Relation Extraction
## Notebook 2
### UC Davis Epicenter for Disease Dynamics

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()

In [2]:
from snorkel.models import candidate_subclass

VirusHost = candidate_subclass('VirusHost', ['virus', 'host'])

## Part I: Writing Labeling Functions

Labeling functions encode our heuristics and weak supervision signals to generate (noisy) labels for our training candidates.

In Snorkel, our primary interface through which we provide training signal to the end extraction model we are training is by writing **labeling functions (LFs)** (as opposed to hand-labeling massive training sets). 

A labeling function is just a Python function that accepts a `Candidate` and returns `1` to mark the `Candidate` as true, `-1` to mark the `Candidate` as false, and `0` to abstain from labeling the `Candidate`.

In [3]:
import re
from snorkel.lf_helpers import (
    get_left_tokens, 
    get_right_tokens, 
    get_between_tokens,
    get_text_between, 
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B,
)

In [4]:
# Text Pattern / Text Rule based labeling functions, which look for certain keywords

# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'

# --------------------------------

# Positive LFs:

detect = {'detect', 'detects', 'detected', 'detecting', 'detection', 'detectable'}
infect = {'infect', 'infects', 'infected', 'infecting', 'infection'}
isolate = {'isolate', 'isolates', 'isolated', 'isolating', 'isolation'}
misc = {'isolat(e|es|ed|ing)?', '(sero)?positive', 'found', 'host[s]?', 'remove[d]?', 'prevalen(ce|t)?', 'affect(s|ed|ing)?', 'confirm(s|ed|ing)?', 'case[s]?', 'relat(ed|es|e|ing|ion)?'}

def LF_detect(c):
    if len(detect.intersection(get_between_tokens(c))) > 0: 
        return 1
    elif len(detect.intersection(get_left_tokens(c[0], window=2))) > 0:
        return 1
    elif len(detect.intersection(get_left_tokens(c[1], window=2))) > 0:
        return 1
    elif len(detect.intersection(get_right_tokens(c[0], window=2))) > 0:
        return 1
    elif len(detect.intersection(get_right_tokens(c[1], window=2))) > 0:
        return 1
    else:
        return 0
    
def LF_infect(c):
    if len(infect.intersection(get_between_tokens(c))) > 0: 
        return 1
    elif len(infect.intersection(get_left_tokens(c[0], window=2))) > 0:
        return 1
    elif len(infect.intersection(get_left_tokens(c[1], window=2))) > 0:
        return 1
    elif len(infect.intersection(get_right_tokens(c[0], window=2))) > 0:
        return 1
    elif len(infect.intersection(get_right_tokens(c[1], window=2))) > 0:
        return 1
    else:
        return 0
    
def LF_misc(c):
    if len(misc.intersection(get_between_tokens(c))) > 0: 
        return 1
    elif len(misc.intersection(get_left_tokens(c[0], window=2))) > 0:
        return 1
    elif len(misc.intersection(get_left_tokens(c[1], window=2))) > 0:
        return 1
    elif len(misc.intersection(get_right_tokens(c[0], window=2))) > 0:
        return 1
    elif len(misc.intersection(get_right_tokens(c[1], window=2))) > 0:
        return 1
    else:
        return 0

# Words like 'isolated'
def LF_isolate(c):
    if len(isolate.intersection(get_between_tokens(c))) > 0: 
        return 1
    elif len(isolate.intersection(get_left_tokens(c[0], window=2))) > 0:
        return 1
    elif len(isolate.intersection(get_left_tokens(c[1], window=2))) > 0:
        return 1
    elif len(isolate.intersection(get_right_tokens(c[0], window=2))) > 0:
        return 1
    elif len(isolate.intersection(get_right_tokens(c[1], window=2))) > 0:
        return 1
    else:
        return 0

# Words like 'caused'
causal = ['caus(es|ed|e|ing|ation)?', 'induc(es|ed|e|ing)?', 'associat(ed|ing|es|e|ion)?']

def LF_v_cause_h(c):
    return 1 if (
        re.search(r'{{A}}.{0,50} ' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
        and not re.search('{{A}}.{0,50}(not|no).{0,20}' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
    ) else 0

def LF_v_h(c):
    return 1 if ('{{A}} {{B}}' in get_tagged_text(c)) else 0

# -----------------------------------

# Negative LFs:

# Uncertain pairs
uncertain = ['combin', 'possible', 'unlikely']

def LF_uncertain(c):
    return rule_regex_search_before_A(c, ltp(uncertain) + '.*', -1)

# if candidate pair is too far apart (between 100-5000 characs apart), mark as negative
def LF_far_v_h(c):
    return rule_regex_search_btw_AB(c, '.{100,5000}', -1)

def LF_far_h_v(c):
    return rule_regex_search_btw_BA(c, '.{100,5000}', -1)

def LF_neg_h(c):
    return -1 if re.search('(none|not|no) .{0,25}{{B}}', get_tagged_text(c), flags=re.I) else 0

WEAK_PHRASES = ['none', 'although', 'seems', 'suggests', 
                'risk', 'to (investigate|assess|study)', 
               'against', 'negative', 'negate', 'resist']

WEAK_RGX = r'|'.join(WEAK_PHRASES)

def LF_weak_assertions(c):
    return -1 if re.search(WEAK_RGX, get_tagged_text(c), flags=re.I) else 0


In [5]:
# list of all LFs
LFs = [
    LF_detect, LF_infect, LF_misc, LF_isolate, LF_v_cause_h, LF_v_h, LF_uncertain,
    LF_far_v_h, LF_far_h_v, LF_neg_h, LF_weak_assertions
]

In [6]:
# To label and view LFs for testing
labeled = []
for c in session.query(VirusHost).filter(VirusHost.split == 0).all():
    for function in LFs:
        if function(c) != 0:
            if c not in labeled:
                labeled.append(c)
print("Number labeled:", len(labeled))

Number labeled: 2965


In [7]:
from snorkel.viewer import SentenceNgramViewer

SentenceNgramViewer(labeled, session)

<IPython.core.display.Javascript object>

SentenceNgramViewer(cids=[[[1586], [582, 1110, 1382, 1645, 1789, 1848], [1426, 1427]], [[138], [37, 928], [175…

## Part II: Applying Labeling Functions

We run the LFs over all training candidates, producing a set of Labels (Virus and Host) and LabelKeys (the names of the LFs) in the database.

In [8]:
# set up the label annotator class
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)

In [9]:
np.random.seed(1701)
%time L_train = labeler.apply(split=0)
L_train

Clearing existing...
Running UDF...

Wall time: 26.5 s


<3780x11 sparse matrix of type '<class 'numpy.int32'>'
	with 4534 stored elements in Compressed Sparse Row format>

Note that the returned matrix is a special subclass of the `scipy.sparse.csr_matrix` class

In [10]:
# get the candidate names and positions of any candidate in the set
L_train.get_candidate(session, 0) 

VirusHost(Span("b'West Nile virus'", sentence=2289, chars=[108,122], words=[25,27]), Span("b'red-legged partridge'", sentence=2289, chars=[191,210], words=[40,43]))

In [11]:
# get the LabelKey (the name of the LF used to identify the candidate)
L_train.get_key(session, 0)

LabelKey (LF_detect)

Viewing statistics about the resulting label matrix:

* **Coverage** is the fraction of candidates that the labeling function emits a non-zero label for.
* **Overlap** is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a non-zero label for.
* **Conflict** is the fraction candidates that the labeling function emits a non-zero label for and that another labeling function emits a *conflicting* non-zero label for.

In [12]:
L_train.lf_stats(session)

Unnamed: 0,j,Coverage,Overlaps,Conflicts
LF_detect,0,0.048413,0.030423,0.02963
LF_infect,1,0.038095,0.022487,0.020635
LF_misc,2,0.015608,0.009788,0.009788
LF_isolate,3,0.043915,0.030159,0.029894
LF_v_cause_h,4,0.005026,0.001323,0.000529
LF_v_h,5,0.008466,0.005291,0.005291
LF_uncertain,6,0.00291,0.002116,0.001058
LF_far_v_h,7,0.445238,0.260317,0.053439
LF_far_h_v,8,0.213228,0.121429,0.01746
LF_neg_h,9,0.003968,0.002646,0.000529


## Part III: Checking Against Gold Labels (Hand Labeled Set)
- Run the labeler on the development set
- Load in some external labels:

### Load Gold Labels
Gold labels are a _small_ set of examples (here, a subset of our training set) which we label by hand and use to help us develop and refine labeling functions. Unlike the _test set_, which we do not look at and use for final evaluation, we can inspect the development set while writing labeling functions.

In [13]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name = "gold", split=1)
L_gold_dev

<346x1 sparse matrix of type '<class 'numpy.int32'>'
	with 127 stored elements in Compressed Sparse Row format>

In [14]:
%time L_dev = labeler.apply_existing(split=1)

Clearing existing...
Running UDF...

Wall time: 2.98 s


In [15]:
# Label Matrix Empirical Accuracies

L_dev.lf_stats(session, labels=L_gold_dev.toarray().ravel())

  ac = (tp+tn) / (tp+tn+fp+fn)


Unnamed: 0,j,Coverage,Overlaps,Conflicts,TP,FP,FN,TN,Empirical Acc.
LF_detect,0,0.017341,0.017341,0.017341,3,2,0,0,0.6
LF_infect,1,0.023121,0.00289,0.00289,7,0,0,0,1.0
LF_misc,2,0.017341,0.00578,0.00578,5,1,0,0,0.833333
LF_isolate,3,0.00289,0.00289,0.00289,0,1,0,0,0.0
LF_v_cause_h,4,0.00578,0.0,0.0,2,0,0,0,1.0
LF_v_h,5,0.0,0.0,0.0,0,0,0,0,
LF_uncertain,6,0.0,0.0,0.0,0,0,0,0,
LF_far_v_h,7,0.34104,0.098266,0.008671,0,0,8,27,0.771429
LF_far_h_v,8,0.297688,0.023121,0.008671,0,0,10,13,0.565217
LF_neg_h,9,0.0,0.0,0.0,0,0,0,0,


#### Iterating on Labeling Function Design:
When writing labeling functions, you will want to iterate on the process outlined above several times. You should focus on tuning individual LFs, based on emprical accuracy metrics, and adding new LFs to improve coverage.

In [16]:
### See Notebook Part 3 for Generative Model Training