# Chemical-Disease Relation (CDR) Tutorial

In this example, we'll be writing an application to extract *mentions of* **chemical-induced-disease relationships** from Pubmed abstracts, as per the [BioCreative CDR Challenge](http://www.biocreative.org/resources/corpora/biocreative-v-cdr-corpus/).  This tutorial will show off some of the more advanced features of Snorkel, so we'll assume you've followed the Intro tutorial.

Let's start by reloading from the last notebook.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
from snorkel import SnorkelSession

session = SnorkelSession()

In [2]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

train = session.query(ChemicalDisease).filter(ChemicalDisease.split == 0).all()
dev = session.query(ChemicalDisease).filter(ChemicalDisease.split == 1).all()
test = session.query(ChemicalDisease).filter(ChemicalDisease.split == 2).all()

print('Training set:\t{0} candidates'.format(len(train)))
print('Dev set:\t{0} candidates'.format(len(dev)))
print('Test set:\t{0} candidates'.format(len(test)))

Training set:	21305 candidates
Dev set:	2486 candidates
Test set:	11970 candidates


# Part V: Training an LSTM extraction model

In the intro tutorial, we automatically featurized the candidates and trained a linear model over these features. Here, we'll train a more complicated model for relation extraction: an LSTM network. You can read more about LSTMs [here](https://en.wikipedia.org/wiki/Long_short-term_memory) or [here](http://colah.github.io/posts/2015-08-Understanding-LSTMs/). An LSTM is a type of recurrent neural network and automatically generates a numerical representation for the candidate based on the sentence text, so no need for featurizing explicitly as in the intro tutorial. LSTMs take longer to train, and Snorkel doesn't currently support hyperparameter searches for them. We'll train a single model here, but feel free to try out other parameter sets. Just make sure to use the development set - and not the test set - for model selection.

**Note: Again, training for more epochs than below will greatly improve performance- try it out!**

In [3]:
from snorkel.annotations import load_marginals
train_marginals = load_marginals(session, split=0)

In [4]:
train_marginals.shape

(21305,)

In [4]:
from snorkel.annotations import load_gold_labels
L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

In [6]:
from snorkel.learning.pytorch import LSTM
# from tensorboardX import SummaryWriter
# TBlogger = SummaryWriter('cdr_logs')
train_kwargs = {
    'lr':              0.001,
    'embedding_dim':   500,
    'hidden_dim':      500,
    'n_epochs':        40,
    'dropout':         0.2,
    'rebalance':       False,
    'print_freq':      40,
    'seed':            1701,
    'batch_size':      500
}

lstm = LSTM(n_threads=2)
lstm.train(train, train_marginals, X_dev=dev, Y_dev=L_gold_dev, **train_kwargs)



[LSTM] Training model
[LSTM] n_train=19611  #epochs=40  batch size=500
[LSTM] Epoch 1 (558.89s)	Average loss=0.699204	Dev F1=29.34
[LSTM] Epoch 40 (19322.77s)	Average loss=0.661443	Dev F1=47.50
[LSTM] Model saved as <LSTM>
[LSTM] Training done (19331.95s)
[LSTM] Loaded model <LSTM>


In [5]:
from snorkel.learning.pytorch import LSTM
# from tensorboardX import SummaryWriter
# TBlogger = SummaryWriter('cdr_logs')
train_kwargs = {
    'lr':              0.001,
    'embedding_dim':   500,
    'hidden_dim':      500,
    'n_epochs':        40,
    'dropout':         0.2,
    'rebalance':       False,
    'print_freq':      40,
    'seed':            1701,
    'batch_size':      500
}

lstm = LSTM(n_threads=1)
lstm.train(train, train_marginals, X_dev=dev, Y_dev=L_gold_dev, **train_kwargs)



[LSTM] Training model
[LSTM] n_train=19626  #epochs=40  batch size=500
[LSTM] Epoch 1 (476.48s)	Average loss=0.685440	Dev F1=47.01
[LSTM] Epoch 40 (18945.12s)	Average loss=0.630835	Dev F1=48.55
[LSTM] Model saved as <LSTM>
[LSTM] Training done (18954.29s)
[LSTM] Loaded model <LSTM>


### Scoring on the test set

Finally, we'll evaluate our performance on the blind test set of 500 documents. We'll load labels similar to how we did for the development set, and use the `score` function of our extraction model to see how we did.

In [6]:
from load_external_annotations import load_external_labels
load_external_labels(session, ChemicalDisease, split=2, annotator='gold')
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)
lstm.score(test, L_gold_test)
tp, fp, tn, fn = lstm.error_analysis(session, test, L_gold_test, set_unlabeled_as_neg=False)
metrics = {}
sent_lengths = set()
metrics[0] = [0 for _ in range(4)]

for i, metric in enumerate([tp, fp, tn, fn]):
    for sample in metric:
        metrics[0][i] += 1 
        cand_length = int(abs(sample.chemical.sentence.position - sample.disease.sentence.position)+1)
        if cand_length in sent_lengths:
            metrics[cand_length][i] += 1 
        else:
            metrics[cand_length] = [0 for _ in range(4)]
            metrics[cand_length][i] += 1 
            sent_lengths.add(cand_length)
print(metrics)

AnnotatorLabels created: 0
Scores (Un-adjusted)
Pos. class accuracy: 0.748
Neg. class accuracy: 0.471
Precision            0.375
Recall               0.748
F1                   0.499
----------------------------------------
TP: 2666 | FP: 4451 | TN: 3957 | FN: 896

{0: [2666, 4451, 3957, 896], 1: [1273, 1945, 1227, 237], 2: [1393, 2506, 2730, 659]}


In [7]:
from load_external_annotations import load_external_labels
load_external_labels(session, ChemicalDisease, split=2, annotator='gold')
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)
L_gold_test

AnnotatorLabels created: 0


<11970x1 sparse matrix of type '<class 'numpy.int64'>'
	with 11970 stored elements in Compressed Sparse Row format>

In [8]:
lstm.score(test, L_gold_test)

(0.38400265516096915, 0.6496350364963503, 0.48268669169795575)

In [9]:
tp, fp, tn, fn = lstm.error_analysis(session, test, L_gold_test, set_unlabeled_as_neg=False)

Scores (Un-adjusted)
Pos. class accuracy: 0.65
Neg. class accuracy: 0.559
Precision            0.384
Recall               0.65
F1                   0.483
----------------------------------------
TP: 2314 | FP: 3712 | TN: 4696 | FN: 1248



In [10]:
metrics = {}
sent_lengths = set()
metrics[0] = [0 for _ in range(4)]

for i, metric in enumerate([tp, fp, tn, fn]):
    for sample in metric:
        metrics[0][i] += 1 
        cand_length = int(abs(sample.chemical.sentence.position - sample.disease.sentence.position)+1)
        if cand_length in sent_lengths:
            metrics[cand_length][i] += 1 
        else:
            metrics[cand_length] = [0 for _ in range(4)]
            metrics[cand_length][i] += 1 
            sent_lengths.add(cand_length)

In [11]:
metrics

{0: [2314, 3712, 4696, 1248],
 2: [1142, 1987, 3249, 910],
 1: [1172, 1725, 1447, 338]}

In [24]:
from snorkel.learning.utils import GridSearch
from snorkel.learning import RandomSearch


param_ranges = {'lr': [1e-2, 1e-3, 1e-4, 1e-5], 
                'dropout': [0.0, 0.25, 0.5, 0.75],
               }

model_class_params = {}
model_hyperparams = {
    'embedding_dim':   100,
    'hidden_dim':      100,
    'n_epochs':        30,
    'rebalance':       0.25,
    'print_freq':      30,
    'seed' : 1701
}

train_kwargs = {
    'embedding_dim':   100,
    'hidden_dim':      100,
    'n_epochs':        30,
    'rebalance':       0.25,
    'print_freq':      30,
}

searcher = RandomSearch(lstm, param_ranges, train, Y_train=train_marginals,
    model_class_params=model_class_params, model_hyperparams=model_hyperparams)

In [44]:
from snorkel.learning.pytorch import LSTM

train_kwargs = {
    'lr':              0.001,
    'embedding_dim':   150,
    'hidden_dim':      150,
    'n_epochs':        120,
    'dropout':         0.2,
    'rebalance':       0.25,
    'print_freq':      30,
    'seed':            1701,
    'batch_size':      200
}

lstm = LSTM(n_threads=2)
lstm.train(train, train_marginals, X_dev=dev, Y_dev=L_gold_dev, **train_kwargs)

from load_external_annotations import load_external_labels
load_external_labels(session, ChemicalDisease, split=2, annotator='gold')
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

lstm.score(test, L_gold_test)

tp, fp, tn, fn = lstm.error_analysis(session, test, L_gold_test, set_unlabeled_as_neg=False)

metrics = {}
sent_lengths = set()
metrics[0] = [0 for _ in range(4)]

for i, metric in enumerate([tp, fp, tn, fn]):
    for sample in metric:
        metrics[0][i] += 1 
        cand_length = int(abs(sample.chemical.sentence.position - sample.disease.sentence.position)+1)
        if cand_length in sent_lengths:
            metrics[cand_length][i] += 1 
        else:
            metrics[cand_length] = [0 for _ in range(4)]
            metrics[cand_length][i] += 1 
            sent_lengths.add(cand_length)

print(metrics)

[LSTM] Training model
[LSTM] n_train=13777  #epochs=120  batch size=200
[LSTM] Epoch 1 (50.57s)	Average loss=0.687261	Dev F1=0.00
[LSTM] Epoch 31 (1611.57s)	Average loss=0.615859	Dev F1=0.00
[LSTM] Epoch 61 (3244.24s)	Average loss=0.613536	Dev F1=0.00
[LSTM] Epoch 91 (5044.41s)	Average loss=0.613115	Dev F1=0.00
[LSTM] Epoch 120 (6563.43s)	Average loss=0.612983	Dev F1=0.00
[LSTM] Training done (6566.93s)
AnnotatorLabels created: 0
Scores (Un-adjusted)
Pos. class accuracy: 0.639
Neg. class accuracy: 0.549
Precision            0.375
Recall               0.639
F1                   0.473
----------------------------------------
TP: 2277 | FP: 3789 | TN: 4619 | FN: 1285

{0: [2277, 3789, 4619, 1285], 2: [1122, 2069, 3167, 930], 1: [1155, 1720, 1452, 355]}
