# Virus-Host Species Relation Extraction
## Notebook 4
### UC Davis Epicenter for Disease Dynamics

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import re
import numpy as np
from snorkel import SnorkelSession
session = SnorkelSession()

# Connect to the database backend and initalize a Snorkel session
#from lib.init import *
from snorkel.models import candidate_subclass
from snorkel.annotations import load_gold_labels

from snorkel.lf_helpers import (
    get_left_tokens, get_right_tokens, get_between_tokens,
    get_text_between, get_tagged_text,
)

VirusHost = candidate_subclass('VirusHost', ['virus', 'host'])

ModuleNotFoundError: No module named 'snorkel'

In [3]:
from snorkel.annotations import load_marginals

train_cands = session.query(VirusHost).filter(VirusHost.split == 0).order_by(VirusHost.id).all()
dev_cands   = session.query(VirusHost).filter(VirusHost.split == 1).order_by(VirusHost.id).all()
test_cands  = session.query(VirusHost).filter(VirusHost.split == 2).order_by(VirusHost.id).all()

L_gold_dev  = load_gold_labels(session, annotator_name='gold', split=1, load_as_array=True, zero_one=True)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2, zero_one=True)

train_marginals = load_marginals(session, split=0)

## Training a Long Short-Term Memory Neural Network

In [4]:
from snorkel.learning.pytorch.rnn import LSTM

train_kwargs = {
    'lr':         0.001,
    'dim':        100,
    'n_epochs':   10,
    'dropout':    0.25,
    'print_freq': 1,
    'batch_size': 128,
    'max_sentence_length': 100
}

lstm = LSTM(n_threads=1)
lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)

[LSTM] Training model
[LSTM] n_train=890  #epochs=10  batch size=128




[LSTM] Epoch 1 (9.38s)	Average loss=0.693488	Dev F1=12.82
[LSTM] Epoch 2 (19.81s)	Average loss=0.693342	Dev F1=65.08
[LSTM] Epoch 3 (29.67s)	Average loss=0.693277	Dev F1=42.86
[LSTM] Epoch 4 (39.75s)	Average loss=0.693245	Dev F1=60.94
[LSTM] Epoch 5 (52.15s)	Average loss=0.693223	Dev F1=44.44
[LSTM] Epoch 6 (62.68s)	Average loss=0.693210	Dev F1=67.13
[LSTM] Epoch 7 (73.86s)	Average loss=0.693197	Dev F1=46.00
[LSTM] Epoch 8 (86.15s)	Average loss=0.693193	Dev F1=59.26
[LSTM] Epoch 9 (96.88s)	Average loss=0.693191	Dev F1=53.33
[LSTM] Model saved as <LSTM>
[LSTM] Epoch 10 (106.28s)	Average loss=0.693181	Dev F1=50.45
[LSTM] Training done (106.64s)
[LSTM] Loaded model <LSTM>


In [9]:
p, r, f1 = lstm.score(dev_cands, L_gold_dev)
#p, r, f1 = lstm.score(test_cands, L_gold_test)    we dont have test gold labels yet
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

Prec: 0.640, Recall: 0.457, F1 Score: 0.533


In [12]:
tp, fp, tn, fn = lstm.error_analysis(session, dev_cands, L_gold_dev)
# tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test) 

Scores (Un-adjusted)
Pos. class accuracy: 0.457
Neg. class accuracy: 0.438
Precision            0.64
Recall               0.457
F1                   0.533
----------------------------------------
TP: 32 | FP: 18 | TN: 14 | FN: 38



In [13]:
# save model for later use
lstm.save("virushost.lstm")

[LSTM] Model saved as <virushost.lstm>
