# Chemical-Disease Relation (CDR) Tutorial

In this example, we'll be writing an application to extract *mentions of* **chemical-induced-disease relationships** from Pubmed abstracts, as per the [BioCreative CDR Challenge](http://www.biocreative.org/resources/corpora/biocreative-v-cdr-corpus/).  This tutorial will show off some of the more advanced features of Snorkel, so we'll assume you've followed the Intro tutorial.

Let's start by reloading from the last notebook.

In [52]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from snorkel import SnorkelSession

session = SnorkelSession()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [53]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

# Part III: Writing LFs

This tutorial features some more advanced LFs than the intro tutorial, with more focus on distant supervision and dependencies between LFs.

### Distant supervision approaches

We'll use the [Comparative Toxicogenomics Database](http://ctdbase.org/) (CTD) for distant supervision. The CTD lists chemical-condition entity pairs under three categories: therapy, marker, and unspecified. Therapy means the chemical treats the condition, marker means the chemical is typically present with the condition, and unspecified is...unspecified. We can write LFs based on these categories.

In [54]:
import bz2
from six.moves.cPickle import load

with bz2.BZ2File('data/ctd.pkl.bz2', 'rb') as ctd_f:
    ctd_unspecified, ctd_therapy, ctd_marker = load(ctd_f)

In [55]:
def cand_in_ctd_unspecified(c):
    return 1 if c.get_cids() in ctd_unspecified else 0

def cand_in_ctd_therapy(c):
    return 1 if c.get_cids() in ctd_therapy else 0

def cand_in_ctd_marker(c):
    return 1 if c.get_cids() in ctd_marker else 0

In [56]:
def LF_in_ctd_unspecified(c):
    if -1 * cand_in_ctd_unspecified(c) == -1:
        return 2
    else:
        return -1 * cand_in_ctd_unspecified(c)

def LF_in_ctd_therapy(c):
    if -1 * cand_in_ctd_therapy(c) == -1:
        return 2
    else:
        return -1 * cand_in_ctd_therapy(c)

def LF_in_ctd_marker(c):
    return cand_in_ctd_marker(c)

### Text pattern approaches

Now we'll use some LF helpers to create LFs based on indicative text patterns. We came up with these rules by using the viewer to examine training candidates and noting frequent patterns.

In [57]:
import re
from snorkel.lf_helpers import (
    cross_context_get_tagged_text,
    cross_context_rule_regex_search_tagged_text,
    cross_context_rule_regex_search_btw_AB,
    cross_context_rule_regex_search_btw_BA,
    cross_context_rule_regex_search_before_A,
    cross_context_rule_regex_search_before_B,
    get_sentences,
    get_dep_path
)

# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'

def LF_induce(c):
    return 1 if re.search(r'{{A}}.{0,20}induc.{0,20}{{B}}', cross_context_get_tagged_text(c, session), flags=re.I) else 0

causal_past = ['induced', 'caused', 'due']
def LF_d_induced_by_c(c):
    r = cross_context_rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + '.{0,9}(by|to).{0,50}', 1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_d_induced_by_c_tight(c):
    r = cross_context_rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + ' (by|to) ', 1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_induce_name(c):
    return 1 if 'induc' in c.chemical.get_span().lower() else 0     

causal = ['cause[sd]?', 'induce[sd]?', 'associated with']
def LF_c_cause_d(c):
    return 1 if (
        re.search(r'{{A}}.{0,50} ' + ltp(causal) + '.{0,50}{{B}}', cross_context_get_tagged_text(c, session), re.I)
        and not re.search('{{A}}.{0,50}(not|no).{0,20}' + ltp(causal) + '.{0,50}{{B}}', cross_context_get_tagged_text(c, session), re.I)
    ) else 0

treat = ['treat', 'effective', 'prevent', 'resistant', 'slow', 'promise', 'therap']

def LF_d_treat_c(c):
    r = cross_context_rule_regex_search_btw_BA(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_c_treat_d(c):
    r = cross_context_rule_regex_search_btw_AB(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_treat_d(c):
    r = cross_context_rule_regex_search_before_B(c, ltp(treat) + '.{0,50}', -1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_c_treat_d_wide(c):
    r = cross_context_rule_regex_search_btw_AB(c, '.{0,200}' + ltp(treat) + '.{0,200}', -1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_c_d(c):
    return 1 if ('{{A}} {{B}}' in cross_context_get_tagged_text(c, session)) else 0

def LF_c_induced_d(c):
    return 1 if (
        ('{{A}} {{B}}' in cross_context_get_tagged_text(c, session)) and 
        (('-induc' in c[0].get_span().lower()) or ('-assoc' in c[0].get_span().lower()))
        ) else 0

def LF_improve_before_disease(c):
    r = cross_context_rule_regex_search_before_B(c, 'improv.*', -1, session)
    if r == -1:
        return 2
    else:
        return r
pat_terms = ['in a patient with ', 'in patients with']
def LF_in_patient_with(c):
    return 2 if re.search(ltp(pat_terms) + '{{B}}', cross_context_get_tagged_text(c, session), flags=re.I) else 0

uncertain = ['combin', 'possible', 'unlikely']
def LF_uncertain(c):
    r = cross_context_rule_regex_search_before_A(c, ltp(uncertain) + '.*', -1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_induced_other(c):
    r = cross_context_rule_regex_search_tagged_text(c, '{{A}}.{20,1000}-induced {{B}}', -1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_far_c_d(c):
    r = cross_context_rule_regex_search_btw_AB(c, '.{100,5000}', -1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_far_d_c(c):
    r = cross_context_rule_regex_search_btw_BA(c, '.{100,5000}', -1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_risk_d(c):
    r = cross_context_rule_regex_search_before_B(c, 'risk of ', 1, session)
    if r == -1:
        return 2
    else:
        return r
def LF_develop_d_following_c(c):
    return 1 if re.search(r'develop.{0,25}{{B}}.{0,25}following.{0,25}{{A}}', cross_context_get_tagged_text(c, session), flags=re.I) else 0

procedure, following = ['inject', 'administrat'], ['following']
def LF_d_following_c(c):
    return 1 if re.search('{{B}}.{0,50}' + ltp(following) + '.{0,20}{{A}}.{0,50}' + ltp(procedure), cross_context_get_tagged_text(c, session), flags=re.I) else 0

def LF_measure(c):
    return 2 if re.search('measur.{0,75}{{A}}', cross_context_get_tagged_text(c, session), flags=re.I) else 0

def LF_level(c):
    return 2 if re.search('{{A}}.{0,25} level', cross_context_get_tagged_text(c, session), flags=re.I) else 0

def LF_neg_d(c):
    return 2 if re.search('(none|not|no) .{0,25}{{B}}', cross_context_get_tagged_text(c, session), flags=re.I) else 0

WEAK_PHRASES = ['none', 'although', 'was carried out', 'was conducted',
                'seems', 'suggests', 'risk', 'implicated',
               'the aim', 'to (investigate|assess|study)']

WEAK_RGX = r'|'.join(WEAK_PHRASES)

def LF_weak_assertions(c):
    return 2 if re.search(WEAK_RGX, cross_context_get_tagged_text(c, session), flags=re.I) else 0

### Composite LFs

The following LFs take some of the strongest distant supervision and text pattern LFs, and combine them to form more specific LFs. These LFs introduce some obvious dependencies within the LF set, which we will model later.

In [58]:
def LF_ctd_marker_c_d(c):
    return LF_c_d(c) * cand_in_ctd_marker(c)

def LF_ctd_marker_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_marker(c)

def LF_ctd_therapy_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_therapy(c)

def LF_ctd_unspecified_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_unspecified(c)

def LF_ctd_unspecified_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_unspecified(c)

### Rules based on context hierarchy

These last two rules will make use of the context hierarchy. The first checks if there is a chemical mention much closer to the candidate's disease mention than the candidate's chemical mention. The second does the analog for diseases.

In [59]:
pos_keywords = set(['treat', 'effective', 'prevent', 'resistant', 'slow', 'promise', 'therap',
                  'cause', 'caused', 'induce', 'due', 'induced'])

def LF_pos_keywords_in_dep_path(c):

    path = get_dep_path(c, session)
    words = set([word for (word, dep_type) in path])

    if len(words.intersection(pos_keywords)) > 0:
        return 1
    else:
        return 0

### Running the LFs on the training set

In [60]:
LFs = [
    LF_pos_keywords_in_dep_path,
    LF_c_cause_d,
    LF_c_d,
    LF_c_induced_d,
    LF_c_treat_d,
    LF_c_treat_d_wide,
    LF_ctd_marker_c_d,
    LF_ctd_marker_induce,
    LF_ctd_therapy_treat,
    LF_ctd_unspecified_treat,
    LF_ctd_unspecified_induce,
    LF_d_following_c,
    LF_d_induced_by_c,
    LF_d_induced_by_c_tight,
    LF_d_treat_c,
    LF_develop_d_following_c,
    LF_far_c_d,
    LF_far_d_c,
    LF_improve_before_disease,
    LF_in_ctd_therapy,
    LF_in_ctd_marker,
    LF_in_patient_with,
    LF_induce,
    LF_induce_name,
    LF_induced_other,
    LF_level,
    LF_measure,
    LF_neg_d,
    LF_risk_d,
    LF_treat_d,
    LF_uncertain,
    LF_weak_assertions,
]

In [61]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)

In [62]:
L_train = labeler.apply(split = 0)

  0%|          | 0/21283 [00:00<?, ?it/s]

Clearing existing...
Running UDF...


100%|██████████| 21283/21283 [02:11<00:00, 161.40it/s]


In [69]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)
dep_tups = [(x,y) for x,y,z in deps]
dep_tups

[(9, 25),
 (1, 19),
 (26, 30),
 (6, 7),
 (14, 31),
 (19, 21),
 (8, 9),
 (29, 31),
 (21, 25),
 (1, 30),
 (14, 24),
 (3, 14),
 (4, 25),
 (9, 23),
 (21, 28),
 (1, 25),
 (14, 21),
 (11, 23),
 (16, 29),
 (4, 30),
 (9, 18),
 (21, 31),
 (1, 4),
 (11, 18),
 (16, 26),
 (3, 8),
 (4, 19),
 (0, 22),
 (25, 29),
 (10, 12),
 (14, 19),
 (11, 17),
 (3, 7),
 (8, 29),
 (4, 16),
 (18, 23),
 (11, 28),
 (28, 31),
 (8, 26),
 (4, 21),
 (25, 27),
 (18, 24),
 (11, 27),
 (8, 23),
 (1, 8),
 (18, 29),
 (23, 27),
 (3, 28),
 (1, 11),
 (18, 30),
 (23, 30),
 (3, 27),
 (8, 17),
 (9, 14),
 (2, 23),
 (23, 29),
 (15, 27),
 (24, 30),
 (17, 23),
 (27, 31),
 (3, 21),
 (15, 30),
 (24, 27),
 (0, 1),
 (17, 18),
 (5, 25),
 (3, 16),
 (15, 29),
 (17, 29),
 (4, 9),
 (5, 28),
 (15, 16),
 (4, 14),
 (2, 3),
 (19, 26),
 (15, 23),
 (26, 29),
 (17, 27),
 (9, 29),
 (5, 18),
 (19, 25),
 (16, 23),
 (5, 21),
 (1, 18),
 (26, 31),
 (14, 28),
 (4, 5),
 (9, 27),
 (21, 24),
 (1, 29),
 (29, 30),
 (2, 10),
 (14, 25),
 (4, 26),
 (21, 27),
 (1, 24),


In [70]:
len(dep_tups)

253

In [13]:
# from metal.analysis import lf_summary
# lf_summary(L_train)

In [14]:
# import os
# from metal.analysis import view_label_matrix, view_overlaps

# # This if statement and others like it are for our continuous integration tests; you can ignore them.
# if 'CI' not in os.environ:
#     view_label_matrix(L_train)

In [66]:
from metal.label_model import LabelModel
label_model = LabelModel(k=2, seed=123)

In [71]:
from snorkel.annotations import load_gold_labels_array
from load_external_annotations import load_external_labels

USE_DEV_BALANCE = True
if USE_DEV_BALANCE:
    load_external_labels(session, ChemicalDisease, split=1, annotator='gold')
    L_gold_dev = load_gold_labels_array(session, annotator_name='gold', split=1)

    l = []
    for i in L_gold_dev:
        if i == -1:
            l.append(2)
        else:
            l.append(i)
    y_dev = np.asarray(l)

    label_model.train(L_train, Y_dev=y_dev, deps=dep_tups, n_epochs=600, print_every=50)
else:
    label_model.train(L_train, n_epochs=600, print_every=50)
  

AnnotatorLabels created: 0


NotImplementedError: Graph triangulation not implemented.

In [50]:
import numpy as np
from metal.analysis import confusion_matrix

EVAL_SPLIT = 2
load_external_labels(session, ChemicalDisease, split=EVAL_SPLIT, annotator='gold')
L_gold_eval = load_gold_labels_array(session, annotator_name='gold', split=EVAL_SPLIT)

l = []
for i in L_gold_eval:
    if i == -1:
        l.append(2)
    else:
        l.append(i)
y_eval = np.asarray(l)
L_eval = labeler.apply(split=EVAL_SPLIT)

score = label_model.score(L_eval, y_eval)
scores = label_model.score(L_eval, y_eval, metric=['precision', 'recall', 'f1'])

Y_eval_p = label_model.predict(L_eval)
cm = confusion_matrix(y_eval, Y_eval_p)
print(cm)

  0%|          | 17/4685 [00:00<00:27, 169.45it/s]

AnnotatorLabels created: 0
Clearing existing...
Running UDF...


100%|██████████| 4685/4685 [00:23<00:00, 200.71it/s]


Accuracy: 0.722
Precision: 0.604
Recall: 0.401
F1: 0.482
        y=1    y=2   
 l=1    607    905   
 l=2    398   2775   
[[ 607  905]
 [ 398 2775]]


In [51]:
from metal.label_model.baselines import MajorityLabelVoter

mv = MajorityLabelVoter(seed=123)
scores = mv.score(L_eval, y_eval, metric=['precision', 'recall', 'f1'])
Y_eval_p = mv.predict(L_eval)
cm = confusion_matrix(y_eval, Y_eval_p)
print(cm)

Precision: 0.494
Recall: 0.779
F1: 0.605
        y=1    y=2   
 l=1   1188    324   
 l=2   1247   1926   
[[1188  324]
 [1247 1926]]


# Part IV: Training the generative model

As mentioned above, we want to include the dependencies between our LFs when training the generative model. Snorkel makes it easy to do this! `DependencySelector` runs a fast structure learning algorithm over the matrix of LF outputs to identify a set of likely dependencies. We can see that these match up with our prior knowledge. For example, it identified a "reinforcing" dependency between `LF_c_induced_d` and `LF_ctd_marker_induce`. Recall that we constructed the latter using the former.

Now we'll train the generative model, using the `deps` argument to account for the learned dependencies. We'll also model LF propensity here, unlike the intro tutorial. In addition to learning the accuracies of the LFs, this also learns their likelihood of labeling an example.

In [19]:
probs = label_model.predict_proba(L_train)

In [20]:
probs

array([[0.42275871, 0.57724129],
       [0.42275871, 0.57724129],
       [0.42275871, 0.57724129],
       ...,
       [0.16736742, 0.83263258],
       [0.16736742, 0.83263258],
       [0.58740038, 0.41259962]])

In [21]:
# from snorkel.annotations import save_marginals
# save_marginals(session, L_train, train_marginals)

### Checking performance against development set labels

Finally, we'll run the labeler on the development set, load in some external labels, then evaluate the LF performance. The external labels are applied via a small script for convenience. It maps the document-level relation annotations found in the CDR file to mention-level labels. Note that these will not be perfect, although they are pretty good. If we wanted to keep iterating, we could use `snorkel.lf_helpers.test_LF` against the dev set, or look at some false positive and false negative candidates.