# Metal CDR Relation Extraction 

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import sys
sys.path.append('/dfs/scratch0/vschen/metal')
import metal
import torch
from torch.utils.data import Dataset, DataLoader

In [3]:
import numpy as np

In [4]:
print('PyTorch: ', torch.__version__)
print('MeTaL:   ', metal.__version__)
print('Python:  ', sys.version)
print('Python:  ', sys.version_info)

PyTorch:  1.0.0
MeTaL:    0.3.3
Python:   3.6.7 (default, Dec  8 2018, 17:35:14) 
[GCC 5.4.0 20160609]
Python:   sys.version_info(major=3, minor=6, micro=7, releaselevel='final', serial=0)


## Initalize CDR Dataset
To uncompress the SQLite db: ```bzip2 -d cdr.db.bz2```

In [5]:
from metal.contrib.backends.wrapper import SnorkelDataset
import os

db_conn_str   = os.path.join(os.getcwd(),"cdr.db")
candidate_def = ['ChemicalDisease', ['chemical', 'disease']]

train, dev, test = SnorkelDataset.splits(db_conn_str, 
                                         candidate_def, 
                                         max_seq_len=125)

print(f'[TRAIN] {len(train)}')
print(f'[DEV]   {len(dev)}')
print(f'[TEST]  {len(test)}')

Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/backends/cdr.db
Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/backends/cdr.db
Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/backends/cdr.db
[TRAIN] 8272
[DEV]   888
[TEST]  4620


In [6]:
from snorkel import SnorkelSession
session = SnorkelSession()

from snorkel.models import Document, Sentence
print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Documents: 1500
Sentences: 14001


In [7]:
from snorkel.models import candidate_subclass

ChemicalDisease = candidate_subclass('ChemicalDisease', ['chemical', 'disease'])

train_cands = session.query(ChemicalDisease).filter(ChemicalDisease.split == 0).all()
dev_cands = session.query(ChemicalDisease).filter(ChemicalDisease.split == 1).all()

In [8]:
import bz2
from six.moves.cPickle import load

with bz2.BZ2File('data/ctd.pkl.bz2', 'rb') as ctd_f:
    ctd_unspecified, ctd_therapy, ctd_marker = load(ctd_f)

In [9]:
def cand_in_ctd_unspecified(c):
    return 1 if c.get_cids() in ctd_unspecified else 0

def cand_in_ctd_therapy(c):
    return 1 if c.get_cids() in ctd_therapy else 0

def cand_in_ctd_marker(c):
    return 1 if c.get_cids() in ctd_marker else 0

def LF_in_ctd_unspecified(c):
    return -1 * cand_in_ctd_unspecified(c)

def LF_in_ctd_therapy(c):
    return -1 * cand_in_ctd_therapy(c)

def LF_in_ctd_marker(c):
    return cand_in_ctd_marker(c)

In [10]:
import re
from snorkel.lf_helpers import (
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B,
)

# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'

def LF_induce(c):
    return 1 if re.search(r'{{A}}.{0,20}induc.{0,20}{{B}}', get_tagged_text(c), flags=re.I) else 0

causal_past = ['induced', 'caused', 'due']
def LF_d_induced_by_c(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + '.{0,9}(by|to).{0,50}', 1)
def LF_d_induced_by_c_tight(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(causal_past) + ' (by|to) ', 1)

def LF_induce_name(c):
    return 1 if 'induc' in c.chemical.get_span().lower() else 0     

causal = ['cause[sd]?', 'induce[sd]?', 'associated with']
def LF_c_cause_d(c):
    return 1 if (
        re.search(r'{{A}}.{0,50} ' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
        and not re.search('{{A}}.{0,50}(not|no).{0,20}' + ltp(causal) + '.{0,50}{{B}}', get_tagged_text(c), re.I)
    ) else 0

treat = ['treat', 'effective', 'prevent', 'resistant', 'slow', 'promise', 'therap']
def LF_d_treat_c(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1)
def LF_c_treat_d(c):
    return rule_regex_search_btw_AB(c, '.{0,50}' + ltp(treat) + '.{0,50}', -1)
def LF_treat_d(c):
    return rule_regex_search_before_B(c, ltp(treat) + '.{0,50}', -1)
def LF_c_treat_d_wide(c):
    return rule_regex_search_btw_AB(c, '.{0,200}' + ltp(treat) + '.{0,200}', -1)

def LF_c_d(c):
    return 1 if ('{{A}} {{B}}' in get_tagged_text(c)) else 0

def LF_c_induced_d(c):
    return 1 if (
        ('{{A}} {{B}}' in get_tagged_text(c)) and 
        (('-induc' in c[0].get_span().lower()) or ('-assoc' in c[0].get_span().lower()))
        ) else 0

def LF_improve_before_disease(c):
    return rule_regex_search_before_B(c, 'improv.*', -1)

pat_terms = ['in a patient with ', 'in patients with']
def LF_in_patient_with(c):
    return -1 if re.search(ltp(pat_terms) + '{{B}}', get_tagged_text(c), flags=re.I) else 0

uncertain = ['combin', 'possible', 'unlikely']
def LF_uncertain(c):
    return rule_regex_search_before_A(c, ltp(uncertain) + '.*', -1)

def LF_induced_other(c):
    return rule_regex_search_tagged_text(c, '{{A}}.{20,1000}-induced {{B}}', -1)

def LF_far_c_d(c):
    return rule_regex_search_btw_AB(c, '.{100,5000}', -1)

def LF_far_d_c(c):
    return rule_regex_search_btw_BA(c, '.{100,5000}', -1)

def LF_risk_d(c):
    return rule_regex_search_before_B(c, 'risk of ', 1)

def LF_develop_d_following_c(c):
    return 1 if re.search(r'develop.{0,25}{{B}}.{0,25}following.{0,25}{{A}}', get_tagged_text(c), flags=re.I) else 0

procedure, following = ['inject', 'administrat'], ['following']
def LF_d_following_c(c):
    return 1 if re.search('{{B}}.{0,50}' + ltp(following) + '.{0,20}{{A}}.{0,50}' + ltp(procedure), get_tagged_text(c), flags=re.I) else 0

def LF_measure(c):
    return -1 if re.search('measur.{0,75}{{A}}', get_tagged_text(c), flags=re.I) else 0

def LF_level(c):
    return -1 if re.search('{{A}}.{0,25} level', get_tagged_text(c), flags=re.I) else 0

def LF_neg_d(c):
    return -1 if re.search('(none|not|no) .{0,25}{{B}}', get_tagged_text(c), flags=re.I) else 0

WEAK_PHRASES = ['none', 'although', 'was carried out', 'was conducted',
                'seems', 'suggests', 'risk', 'implicated',
               'the aim', 'to (investigate|assess|study)']

WEAK_RGX = r'|'.join(WEAK_PHRASES)

def LF_weak_assertions(c):
    return -1 if re.search(WEAK_RGX, get_tagged_text(c), flags=re.I) else 0

In [11]:
def LF_ctd_marker_c_d(c):
    return LF_c_d(c) * cand_in_ctd_marker(c)

def LF_ctd_marker_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_marker(c)

def LF_ctd_therapy_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_therapy(c)

def LF_ctd_unspecified_treat(c):
    return LF_c_treat_d_wide(c) * cand_in_ctd_unspecified(c)

def LF_ctd_unspecified_induce(c):
    return (LF_c_induced_d(c) or LF_d_induced_by_c_tight(c)) * cand_in_ctd_unspecified(c)


In [12]:
def LF_closer_chem(c):
    # Get distance between chemical and disease
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
    if dis_start < chem_start:
        dist = chem_start - dis_end
    else:
        dist = dis_start - chem_end
    # Try to find chemical closer than @dist/2 in either direction
    sent = c.get_parent()
    closest_other_chem = float('inf')
    for i in range(dis_end, min(len(sent.words), dis_end + dist // 2)):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
            return -1
    for i in range(max(0, dis_start - dist // 2), dis_start):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
            return -1
    return 0

def LF_closer_dis(c):
    # Get distance between chemical and disease
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    dis_start, dis_end = c.disease.get_word_start(), c.disease.get_word_end()
    if dis_start < chem_start:
        dist = chem_start - dis_end
    else:
        dist = dis_start - chem_end
    # Try to find chemical disease than @dist/8 in either direction
    sent = c.get_parent()
    for i in range(chem_end, min(len(sent.words), chem_end + dist // 8)):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Disease' and cid != sent.entity_cids[dis_start]:
            return -1
    for i in range(max(0, chem_start - dist // 8), chem_start):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Disease' and cid != sent.entity_cids[dis_start]:
            return -1
    return 0

In [13]:
LFs = [
    LF_c_cause_d,
    LF_c_d,
    LF_c_induced_d,
    LF_c_treat_d,
    LF_c_treat_d_wide,
    LF_closer_chem,
    LF_closer_dis,
    LF_ctd_marker_c_d,
    LF_ctd_marker_induce,
    LF_ctd_therapy_treat,
    LF_ctd_unspecified_treat,
    LF_ctd_unspecified_induce,
    LF_d_following_c,
    LF_d_induced_by_c,
    LF_d_induced_by_c_tight,
    LF_d_treat_c,
    LF_develop_d_following_c,
    LF_far_c_d,
    LF_far_d_c,
    LF_improve_before_disease,
    LF_in_ctd_therapy,
    LF_in_ctd_marker,
    LF_in_patient_with,
    LF_induce,
    LF_induce_name,
    LF_induced_other,
    LF_level,
    LF_measure,
    LF_neg_d,
    LF_risk_d,
    LF_treat_d,
    LF_uncertain,
    LF_weak_assertions,
]

In [14]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)

In [15]:
%time L_train = labeler.apply(split=0)
L_train

Clearing existing...


  0%|          | 6/8272 [00:00<02:20, 58.78it/s]

Running UDF...


100%|██████████| 8272/8272 [00:38<00:00, 212.26it/s]

CPU times: user 38.4 s, sys: 304 ms, total: 38.7 s
Wall time: 39.6 s





<8272x33 sparse matrix of type '<class 'numpy.int64'>'
	with 20079 stored elements in Compressed Sparse Row format>

In [16]:
from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)

In [17]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel(lf_propensity=True)
gen_model.train(
    L_train, deps=deps, decay=0.95, step_size=0.1/L_train.shape[0], reg_param=0.0
)

Inferred cardinality: 2


In [18]:
train_marginals = gen_model.marginals(L_train)

In [19]:
from metal.contrib.backends.wrapper import SnorkelDataset
train_slice = SnorkelDataset(
    db_conn_str,
    candidate_def,
    split=0,
    L_train=L_train,
)

Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/backends/cdr.db


In [20]:
from metal.end_model import EndModel
from metal.modules import LSTMModule
use_cuda = torch.cuda.is_available()

lstm = LSTMModule(embed_size=50, 
                  hidden_size=100, 
                  vocab_size=train.word_dict.len(),
                  lstm_reduction='attention', 
                  dropout=0,
                  num_layers=1, 
                  freeze=False)

Using randomly initialized embeddings.
Embeddings shape = (9946, 50)
The embeddings are NOT FROZEN
Using lstm_reduction = 'attention'


## Train Slice Model

In [21]:
from metal.contrib.slicing.online_dp import SliceDPModel, LinearModule
from metal.modules import LSTMModule

r_dim = 200
rw = True
accs = np.array(gen_model.learned_lf_stats()['Accuracy'])
accs[np.isnan(accs)] = 0
accs = np.minimum(accs, 0.999)
slice_model = SliceDPModel(lstm, accs, r_dim, rw, seed=123, use_cuda=True)

slice_model.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
slice_model.config['train_config']['validation_metric'] = 'f1'
slice_model.config['train_config']['batch_size'] = 32
slice_model.config['train_config']['n_epochs'] = 5

  "Precision": tp / (tp + fp),
  "Accuracy": (tp + tn) / coverage,
  self.w = torch.from_numpy(np.log(accs / (1-accs))).float()


Slice Heads:
Input Network: Sequential(
  (0): Sequential(
    (0): LSTMModule(
      (embeddings): Embedding(9946, 50)
      (lstm): LSTM(50, 100, batch_first=True, bidirectional=True)
    )
    (1): ReLU()
  )
)
L_head: Linear(in_features=200, out_features=33, bias=False)
Y_head: Linear(in_features=400, out_features=2, bias=False)


In [22]:
%%time
slice_model.train_model(train_slice, dev_data=dev)

Using GPU...


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))

  A = F.softmax(self.forward_L(x)).unsqueeze(1)



Saving model at iteration 0 with best score 0.524
[E:0]	Train Loss: 0.444	Dev f1: 0.524


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


Saving model at iteration 1 with best score 0.539
[E:1]	Train Loss: 0.433	Dev f1: 0.539


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


Saving model at iteration 2 with best score 0.544
[E:2]	Train Loss: 0.426	Dev f1: 0.544


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


Saving model at iteration 3 with best score 0.577
[E:3]	Train Loss: 0.422	Dev f1: 0.577


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:4]	Train Loss: 0.419	Dev f1: 0.560
Restoring best model from iteration 3 with score 0.577
Finished Training
F1: 0.577
        y=1    y=2   
 l=1    271    373   
 l=2    25     219   
CPU times: user 17min 38s, sys: 21.3 s, total: 17min 59s
Wall time: 17min 57s


In [23]:
score = slice_model.score(test, metric=['precision', 'recall', 'f1'])

Precision: 0.400
Recall: 0.908
F1: 0.555
        y=1    y=2   
 l=1   1367   2054   
 l=2    138   1061   


## Train End Model (Random Initalized Embeddings)

In [24]:
end_model = EndModel([200, 2], input_module=lstm, seed=123, use_cuda=use_cuda)

end_model.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
end_model.config['train_config']['validation_metric'] = 'f1'
end_model.config['train_config']['batch_size'] = 32
end_model.config['train_config']['n_epochs'] = 5


Network architecture:
Sequential(
  (0): Sequential(
    (0): LSTMModule(
      (embeddings): Embedding(9946, 50)
      (lstm): LSTM(50, 100, batch_first=True, bidirectional=True)
    )
    (1): ReLU()
  )
  (1): Linear(in_features=200, out_features=2, bias=True)
)



In [25]:
%%time
end_model.train_model(train, dev_data=dev)

Using GPU...


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


Saving model at iteration 0 with best score 0.604
[E:0]	Train Loss: 0.399	Dev f1: 0.604


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:1]	Train Loss: 0.228	Dev f1: 0.584


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


Saving model at iteration 2 with best score 0.619
[E:2]	Train Loss: 0.144	Dev f1: 0.619


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:3]	Train Loss: 0.104	Dev f1: 0.480


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:4]	Train Loss: 0.086	Dev f1: 0.616
Restoring best model from iteration 2 with score 0.619
Finished Training
F1: 0.619
        y=1    y=2   
 l=1    186    119   
 l=2    110    473   
CPU times: user 6min 39s, sys: 8.69 s, total: 6min 48s
Wall time: 6min 46s


In [26]:
score = end_model.score(test, metric=['precision', 'recall', 'f1'])

Precision: 0.544
Recall: 0.619
F1: 0.579
        y=1    y=2   
 l=1    931    780   
 l=2    574   2335   


## Train End Model (Pretrained Embeddings)

Download [GloVe embeddings](http://nlp.stanford.edu/data/glove.6B.zip):
`wget http://nlp.stanford.edu/data/glove.6B.zip \
&& mkdir -p glove.6B \
&& unzip glove.6B.zip -d glove.6B \
&& rm glove.6B.zip`

In [27]:
import string
import numpy as np

import torch.nn.init as init

class EmbeddingLoader(object):
    """
    Simple text file embedding loader. Words with GloVe and FastText.
    """
    def __init__(self, fpath, fmt='text', dim=None, normalize=True):
        assert os.path.exists(fpath)
        self.fpath = fpath
        self.dim = dim
        self.fmt = fmt
        # infer dimension
        if not self.dim:
            header = open(self.fpath, "rU").readline().strip().split(' ')
            self.dim = len(header) - 1 if len(header) != 2 else int(header[-1])

        self.vocab, self.vectors = zip(*[(w,vec) for w,vec in self._read()])
        self.vocab = {w:i for i,w in enumerate(self.vocab)}
        self.vectors = np.vstack(self.vectors)
        if normalize:
            self.vectors = (self.vectors.T / np.linalg.norm(self.vectors, axis=1, ord=2)).T

    def _read(self):
        start = 0 if self.fmt == "text" else 1
        for i, line in enumerate(open(self.fpath, "rU")):
            if i < start:
                continue
            line = line.rstrip().split(' ')
            vec = np.array([float(x) for x in line[1:]])
            if len(vec) != self.dim:
                errors += [line[0]]
                continue
            yield (line[0], vec)
            

def load_embeddings(vocab, embeddings):
    """
    Load pretrained embeddings
    """
    def get_word_match(w, word_dict):
        if w in word_dict:
            return word_dict[w]
        elif w.lower() in word_dict:
            return word_dict[w.lower()]
        elif w.strip(string.punctuation) in word_dict:
            return word_dict[w.strip(string.punctuation)]
        elif w.strip(string.punctuation).lower() in word_dict:
            return word_dict[w.strip(string.punctuation).lower()]
        else:
            return -1

    num_words = vocab.len()
    emb_dim   = embeddings.vectors.shape[1]
    vecs      = init.xavier_normal_(torch.empty(num_words, emb_dim))
    vecs[0]   = torch.zeros(emb_dim)

    n = 0
    for w in vocab.d:
        idx = get_word_match(w, embeddings.vocab)
        if idx == -1:
            continue
        i = vocab.lookup(w)
        vecs[i] = torch.FloatTensor(embeddings.vectors[idx])
        n += 1

    print("Loaded {:2.1f}% ({}/{}) pretrained embeddings".format(float(n) / vocab.len() * 100.0, n, vocab.len() ))
    return vecs         

In [28]:
emb_path  = "glove.6B/glove.6B.50d.txt"
embs  = EmbeddingLoader(emb_path, fmt='text')



In [29]:
from metal.contrib.backends.wrapper import SnorkelDataset

db_conn_str   = "cdr.db"
candidate_def = ['ChemicalDisease', ['chemical', 'disease']]

train, dev, test = SnorkelDataset.splits(db_conn_str, 
                                         candidate_def, 
                                         pretrained_word_dict=embs.vocab, 
                                         max_seq_len=125)

print(f'[TRAIN] {len(train)}')
print(f'[DEV]   {len(dev)}')
print(f'[TEST]  {len(test)}')

Connected to sqlite:///cdr.db
Connected to sqlite:///cdr.db
Connected to sqlite:///cdr.db
[TRAIN] 8272
[DEV]   888
[TEST]  4620


### Initalize pretrained embedding matrix

In [30]:
wembs = load_embeddings(train.word_dict, embs)

Loaded 79.9% (9116/11406) pretrained embeddings


In [31]:
from metal.end_model import EndModel
from metal.modules import LSTMModule
use_cuda = torch.cuda.is_available()

lstm = LSTMModule(embed_size=50, 
                  hidden_size=100, 
                  embeddings=wembs,
                  lstm_reduction='attention', 
                  dropout=0, 
                  num_layers=1, 
                  freeze=False)

end_model = EndModel([200, 2], input_module=lstm, seed=123, use_cuda=use_cuda)

end_model.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
end_model.config['train_config']['validation_metric'] = 'f1'
end_model.config['train_config']['batch_size'] = 32
end_model.config['train_config']['n_epochs'] = 5

Using pretrained embeddings.
Embeddings shape = (11406, 50)
The embeddings are NOT FROZEN
Using lstm_reduction = 'attention'

Network architecture:
Sequential(
  (0): Sequential(
    (0): LSTMModule(
      (embeddings): Embedding(11406, 50)
      (lstm): LSTM(50, 100, batch_first=True, bidirectional=True)
    )
    (1): ReLU()
  )
  (1): Linear(in_features=200, out_features=2, bias=True)
)



In [32]:
%%time
end_model.train_model(train, dev_data=dev)

Using GPU...


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


Saving model at iteration 0 with best score 0.599
[E:0]	Train Loss: 0.533	Dev f1: 0.599


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:1]	Train Loss: 0.293	Dev f1: 0.568


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:2]	Train Loss: 0.174	Dev f1: 0.560


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:3]	Train Loss: 0.119	Dev f1: 0.585


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:4]	Train Loss: 0.078	Dev f1: 0.575
Restoring best model from iteration 0 with score 0.599
Finished Training
F1: 0.599
        y=1    y=2   
 l=1    191    151   
 l=2    105    441   
CPU times: user 6min 19s, sys: 9.48 s, total: 6min 28s
Wall time: 6min 27s


In [33]:
score = end_model.score(test, metric=['precision', 'recall', 'f1'])

Precision: 0.526
Recall: 0.658
F1: 0.585
        y=1    y=2   
 l=1    990    892   
 l=2    515   2223   


### Slicing

In [34]:
from metal.contrib.slicing.online_dp import SliceDPModel, LinearModule
from metal.modules import LSTMModule

r_dim = 200
rw = True
accs = np.array(gen_model.learned_lf_stats()['Accuracy'])
accs[np.isnan(accs)] = 0
accs = np.minimum(accs, 0.999)
slice_model = SliceDPModel(lstm, accs, r_dim, rw, seed=123, use_cuda=True)

slice_model.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
slice_model.config['train_config']['validation_metric'] = 'f1'
slice_model.config['train_config']['batch_size'] = 32
slice_model.config['train_config']['n_epochs'] = 5

Slice Heads:
Input Network: Sequential(
  (0): Sequential(
    (0): LSTMModule(
      (embeddings): Embedding(11406, 50)
      (lstm): LSTM(50, 100, batch_first=True, bidirectional=True)
    )
    (1): ReLU()
  )
)
L_head: Linear(in_features=200, out_features=33, bias=False)
Y_head: Linear(in_features=400, out_features=2, bias=False)


In [35]:
%%time
slice_model.train_model(train_slice, dev_data=dev)

Using GPU...


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


Saving model at iteration 0 with best score 0.526
[E:0]	Train Loss: 0.439	Dev f1: 0.526


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


Saving model at iteration 1 with best score 0.543
[E:1]	Train Loss: 0.428	Dev f1: 0.543


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:2]	Train Loss: 0.422	Dev f1: 0.532


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


Saving model at iteration 3 with best score 0.550
[E:3]	Train Loss: 0.418	Dev f1: 0.550


HBox(children=(IntProgress(value=0, max=259), HTML(value='')))


[E:4]	Train Loss: 0.416	Dev f1: 0.541
Restoring best model from iteration 3 with score 0.550
Finished Training
F1: 0.550
        y=1    y=2   
 l=1    261    392   
 l=2    35     200   
CPU times: user 17min 38s, sys: 22.3 s, total: 18min
Wall time: 17min 58s


In [36]:
score = slice_model.score(test, metric=['precision', 'recall', 'f1'])

Precision: 0.396
Recall: 0.896
F1: 0.549
        y=1    y=2   
 l=1   1349   2061   
 l=2    156   1054   
