# Chemical-Disease Relation (CDR) Tutorial

In this example, we'll be writing an application to extract *mentions of* **chemical-induced-disease relationships** from Pubmed abstracts, as per the [BioCreative CDR Challenge](http://www.biocreative.org/resources/corpora/biocreative-v-cdr-corpus/).  This tutorial will show off some of the more advanced features of Snorkel, so we'll assume you've followed the Intro tutorial.

Let's start by reloading from the last notebook.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from snorkel import SnorkelSession

session = SnorkelSession()

In [3]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

# Part III: Writing LFs

This tutorial features some more advanced LFs than the intro tutorial, with more focus on distant supervision and dependencies between LFs.

### Distant supervision approaches

We'll use the [Comparative Toxicogenomics Database](http://ctdbase.org/) (CTD) for distant supervision. The CTD lists chemical-condition entity pairs under three categories: therapy, marker, and unspecified. Therapy means the chemical treats the condition, marker means the chemical is typically present with the condition, and unspecified is...unspecified. We can write LFs based on these categories.

In [4]:
import bz2
from six.moves.cPickle import load

with bz2.BZ2File('data/ctd.pkl.bz2', 'rb') as ctd_f:
    ctd_unspecified, ctd_therapy, ctd_marker = load(ctd_f)

In [5]:
def cand_in_ctd_unspecified(c):
    return 1 if c.get_cids() in ctd_unspecified else 0

def cand_in_ctd_therapy(c):
    return 1 if c.get_cids() in ctd_therapy else 0

def cand_in_ctd_marker(c):
    return 1 if c.get_cids() in ctd_marker else 0

In [6]:
def LF_in_ctd_unspecified(c):
    if -1 * cand_in_ctd_unspecified(c) == -1:
        return 2
    else:
        return -1 * cand_in_ctd_unspecified(c)

def LF_in_ctd_therapy(c):
    if -1 * cand_in_ctd_therapy(c) == -1:
        return 2
    else:
        return -1 * cand_in_ctd_therapy(c)

def LF_in_ctd_marker(c):
    return cand_in_ctd_marker(c)

### Text pattern approaches

Now we'll use some LF helpers to create LFs based on indicative text patterns. We came up with these rules by using the viewer to examine training candidates and noting frequent patterns.

In [7]:
import re
from snorkel.lf_helpers import (
    cross_sentence_get_tagged_text,
    cross_sentence_rule_regex_search_tagged_text,
    cross_sentence_rule_regex_search_btw_AB,
    cross_sentence_rule_regex_search_btw_BA,
    cross_sentence_rule_regex_search_before_A,
    cross_sentence_rule_regex_search_before_B,
    get_sentences,
    get_dep_path
)

# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'

def LF_induce(c):
    return 1 if re.search(r'{{A}}.{0,20}induc.{0,20}{{B}}', cross_sentence_get_tagged_text(c, session), flags=re.I) else 0

causal_past = ['induced', 'caused', 'due']
def LF_d_induced_by_c(c):
    return cross_sentence_rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + '.{0,9}(by|to).{0,50}', 1, session)

def LF_d_induced_by_c_tight(c):
    return cross_sentence_rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + ' (by|to) ', 1, session)

def LF_induce_name(c):
    return 1 if 'induc' in c.chemical.get_span().lower() else 0     

causal = ['cause[sd]?', 'induce[sd]?', 'associated with']
def LF_c_cause_d(c):
    return 1 if (
        re.search(r'{{A}}.{0,50} ' + ltp(causal) + '.{0,50}{{B}}', cross_sentence_get_tagged_text(c, session), re.I)
        and not re.search('{{A}}.{0,50}(not|no).{0,20}' + ltp(causal) + '.{0,50}{{B}}', cross_sentence_get_tagged_text(c, session), re.I)
    ) else 0

treat = ['treat', 'effective', 'prevent', 'resistant', 'slow', 'promise', 'therap']

def LF_d_treat_c(c):
    return cross_sentence_rule_regex_search_btw_BA(c, '.{0,50}' + ltp(treat) + '.{0,50}', 2, session)

def LF_c_treat_d(c):
    return cross_sentence_rule_regex_search_btw_AB(c, '.{0,50}' + ltp(treat) + '.{0,50}', 2, session)

def LF_treat_d(c):
    return cross_sentence_rule_regex_search_before_B(c, ltp(treat) + '.{0,50}', 2, session)

def LF_c_treat_d_wide(c):
    return cross_sentence_rule_regex_search_btw_AB(c, '.{0,200}' + ltp(treat) + '.{0,200}', 2, session)

def LF_c_d(c):
    return 1 if ('{{A}} {{B}}' in cross_sentence_get_tagged_text(c, session)) else 0

def LF_c_induced_d(c):
    return 1 if (
        ('{{A}} {{B}}' in cross_sentence_get_tagged_text(c, session)) and 
        (('-induc' in c[0].get_span().lower()) or ('-assoc' in c[0].get_span().lower()))
        ) else 0

def LF_improve_before_disease(c):
    return cross_sentence_rule_regex_search_before_B(c, 'improv.*', 2, session)

pat_terms = ['in a patient with ', 'in patients with']
def LF_in_patient_with(c):
    return 2 if re.search(ltp(pat_terms) + '{{B}}', cross_sentence_get_tagged_text(c, session), flags=re.I) else 0

uncertain = ['combin', 'possible', 'unlikely']
def LF_uncertain(c):
    return cross_sentence_rule_regex_search_before_A(c, ltp(uncertain) + '.*', 2, session)

def LF_induced_other(c):
    return cross_sentence_rule_regex_search_tagged_text(c, '{{A}}.{20,1000}-induced {{B}}', 2, session)

def LF_far_c_d(c):
    return cross_sentence_rule_regex_search_btw_AB(c, '.{100,5000}', 2, session)

def LF_far_d_c(c):
    return cross_sentence_rule_regex_search_btw_BA(c, '.{100,5000}', 2, session)

def LF_risk_d(c):
    return cross_sentence_rule_regex_search_before_B(c, 'risk of ', 1, session)

def LF_develop_d_following_c(c):
    return 1 if re.search(r'develop.{0,25}{{B}}.{0,25}following.{0,25}{{A}}', cross_sentence_get_tagged_text(c, session), flags=re.I) else 0

procedure, following = ['inject', 'administrat'], ['following']
def LF_d_following_c(c):
    return 1 if re.search('{{B}}.{0,50}' + ltp(following) + '.{0,20}{{A}}.{0,50}' + ltp(procedure), cross_sentence_get_tagged_text(c, session), flags=re.I) else 0

def LF_measure(c):
    return 2 if re.search('measur.{0,75}{{A}}', cross_sentence_get_tagged_text(c, session), flags=re.I) else 0

def LF_level(c):
    return 2 if re.search('{{A}}.{0,25} level', cross_sentence_get_tagged_text(c, session), flags=re.I) else 0

def LF_neg_d(c):
    return 2 if re.search('(none|not|no) .{0,25}{{B}}', cross_sentence_get_tagged_text(c, session), flags=re.I) else 0

WEAK_PHRASES = ['none', 'although', 'was carried out', 'was conducted',
                'seems', 'suggests', 'risk', 'implicated',
               'the aim', 'to (investigate|assess|study)']

WEAK_RGX = r'|'.join(WEAK_PHRASES)

def LF_weak_assertions(c):
    return 2 if re.search(WEAK_RGX, cross_sentence_get_tagged_text(c, session), flags=re.I) else 0

### Composite LFs

The following LFs take some of the strongest distant supervision and text pattern LFs, and combine them to form more specific LFs. These LFs introduce some obvious dependencies within the LF set, which we will model later.

In [8]:
def LF_ctd_marker_c_d(c):
    return LF_c_d(c) * cand_in_ctd_marker(c)

def LF_ctd_marker_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_marker(c)

def LF_ctd_therapy_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_therapy(c)

def LF_ctd_unspecified_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_unspecified(c)

def LF_ctd_unspecified_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_unspecified(c)

### Rules based on context hierarchy

These last two rules will make use of the context hierarchy. The first checks if there is a chemical mention much closer to the candidate's disease mention than the candidate's chemical mention. The second does the analog for diseases.

In [10]:
pos_keywords = set(['treat', 'effective', 'prevent', 'resistant', 'slow', 'promise', 'therap',
                  'cause', 'caused', 'induce', 'due', 'induced'])

def LF_pos_keywords_in_dep_path(c):

    path = get_dep_path(c, session)
    words = set([word for (word, dep_type) in path])

    return 1 if len(words.intersection(pos_keywords)) > 0 else 0

### Running the LFs on the training set

In [12]:
LFs = [
    LF_pos_keywords_in_dep_path,
    LF_c_cause_d,
    LF_c_d,
    LF_c_induced_d,
    LF_c_treat_d,
    LF_c_treat_d_wide,
    LF_ctd_marker_c_d,
    LF_ctd_marker_induce,
    LF_ctd_therapy_treat,
    LF_ctd_unspecified_treat,
    LF_ctd_unspecified_induce,
    LF_d_following_c,
    LF_d_induced_by_c,
    LF_d_induced_by_c_tight,
    LF_d_treat_c,
    LF_develop_d_following_c,
    LF_far_c_d,
    LF_far_d_c,
    LF_improve_before_disease,
    LF_in_ctd_therapy,
    LF_in_ctd_marker,
    LF_in_patient_with,
    LF_induce,
    LF_induce_name,
    LF_induced_other,
    LF_level,
    LF_measure,
    LF_neg_d,
    LF_risk_d,
    LF_treat_d,
    LF_uncertain,
    LF_weak_assertions,
]

In [13]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)

In [14]:
L_train = labeler.apply(split = 0)

  0%|          | 16/21300 [00:00<02:19, 152.50it/s]

Clearing existing...
Running UDF...


100%|██████████| 21300/21300 [02:00<00:00, 176.60it/s]


In [15]:
from metal.label_model import LabelModel
label_model = LabelModel(k=2, seed=123)

This call to matplotlib.use() has no effect because the backend has already
been chosen; matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.

The backend was *originally* set to 'module://ipykernel.pylab.backend_inline' by the following code:
  File "/anaconda3/envs/snorkel/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/anaconda3/envs/snorkel/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/anaconda3/envs/snorkel/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/anaconda3/envs/snorkel/lib/python3.6/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/anaconda3/envs/snorkel/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 486, in start
    self.io_loop.start()
  File "/anaconda3/envs/snorkel/lib/python3.6/site-packages/t

In [19]:
from snorkel.annotations import load_gold_labels_array
from load_external_annotations import load_external_labels
import numpy as np

USE_DEV_BALANCE = True
if USE_DEV_BALANCE:
    load_external_labels(session, ChemicalDisease, split=1, annotator='gold')
    L_gold_dev = load_gold_labels_array(session, annotator_name='gold', split=1)

    l = []
    for i in L_gold_dev:
        if i == -1:
            l.append(2)
        else:
            l.append(i)
    y_dev = np.asarray(l)

    label_model.train(L_train, Y_dev=y_dev, n_epochs=600, print_every=50)
else:
    label_model.train(L_train, n_epochs=600, print_every=50)
  

AnnotatorLabels created: 0
Computing O...
Estimating \mu...
[E:0]	Train Loss: 4.893
[E:50]	Train Loss: 0.029
[E:100]	Train Loss: 0.026
[E:150]	Train Loss: 0.025
[E:200]	Train Loss: 0.025
[E:250]	Train Loss: 0.024
[E:300]	Train Loss: 0.020
[E:350]	Train Loss: 0.013
[E:400]	Train Loss: 0.010
[E:450]	Train Loss: 0.009
[E:500]	Train Loss: 0.008
[E:550]	Train Loss: 0.008
[E:599]	Train Loss: 0.008
Finished Training


In [20]:
from metal.analysis import confusion_matrix

EVAL_SPLIT = 2
load_external_labels(session, ChemicalDisease, split=EVAL_SPLIT, annotator='gold')
L_gold_eval = load_gold_labels_array(session, annotator_name='gold', split=EVAL_SPLIT)

l = []
for i in L_gold_eval:
    if i == -1:
        l.append(2)
    else:
        l.append(i)
y_eval = np.asarray(l)
L_eval = labeler.apply(split=EVAL_SPLIT)

score = label_model.score(L_eval, y_eval)
scores = label_model.score(L_eval, y_eval, metric=['precision', 'recall', 'f1'])

Y_eval_p = label_model.predict(L_eval)
cm = confusion_matrix(y_eval, Y_eval_p)
print(cm)

AnnotatorLabels created: 11973
Clearing existing...


  0%|          | 16/11973 [00:00<01:16, 156.34it/s]

Running UDF...


100%|██████████| 11973/11973 [01:11<00:00, 168.29it/s]


Accuracy: 0.641
Precision: 0.369
Recall: 0.286
F1: 0.322
        y=1    y=2   
 l=1   1022   2548   
 l=2   1746   6657   
[[1022 2548]
 [1746 6657]]


In [23]:
from metal.label_model.baselines import MajorityLabelVoter

mv = MajorityLabelVoter(seed=123)
scores = mv.score(L_eval, y_eval, metric=['precision', 'recall', 'f1'])
Y_eval_p = mv.predict(L_eval)
cm = confusion_matrix(y_eval, Y_eval_p)
print(cm)

Precision: 0.472
Recall: 0.659
F1: 0.550
        y=1    y=2   
 l=1   2343   1227   
 l=2   2655   5748   
[[2343 1227]
 [2655 5748]]
