# Slicing CDR Relation Extraction 

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import sys
sys.path.append('/dfs/scratch0/vschen/metal')

import metal
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
np.set_printoptions(precision=4, suppress=True)

In [3]:
print('PyTorch: ', torch.__version__)
print('MeTaL:   ', metal.__version__)
print('Python:  ', sys.version)
print('Python:  ', sys.version_info)

PyTorch:  0.4.1
MeTaL:    0.3.3
Python:   3.6.7 (default, Dec  8 2018, 17:35:14) 
[GCC 5.4.0 20160609]
Python:   sys.version_info(major=3, minor=6, micro=7, releaselevel='final', serial=0)


## Initalize CDR Dataset
To uncompress the SQLite db: ```bzip2 -d cdr.db.bz2```

In [4]:
from metal.contrib.backends.wrapper import SnorkelDataset
import os

db_conn_str   = os.path.join(os.getcwd(),"cdr.db")
candidate_def = ['ChemicalDisease', ['chemical', 'disease']]

train, dev, test = SnorkelDataset.splits(db_conn_str, 
                                         candidate_def, 
                                         max_seq_len=125)

print(f'[TRAIN] {len(train)}')
print(f'[DEV]   {len(dev)}')
print(f'[TEST]  {len(test)}')

Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/slicing/CDR/cdr.db
Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/slicing/CDR/cdr.db
Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/slicing/CDR/cdr.db
[TRAIN] 8272
[DEV]   888
[TEST]  4620


## Get Pretrained Embeddings

Download [GloVe embeddings](http://nlp.stanford.edu/data/glove.6B.zip):
`wget http://nlp.stanford.edu/data/glove.6B.zip \
&& mkdir -p glove.6B \
&& unzip glove.6B.zip -d glove.6B \
&& rm glove.6B.zip`

In [None]:
from embeddings import EmbeddingLoader, load_embeddings
emb_path  = "../glove.6B/glove.6B.50d.txt"
embs  = EmbeddingLoader(emb_path, fmt='text')

## Generate `L_*` to target slices

In [None]:
from labeling_functions import LFs
print ([lf.__name__ for lf in LFs])

['LF_c_cause_d', 'LF_c_d', 'LF_c_induced_d', 'LF_c_treat_d', 'LF_c_treat_d_wide', 'LF_closer_chem', 'LF_closer_dis', 'LF_ctd_marker_c_d', 'LF_ctd_marker_induce', 'LF_ctd_therapy_treat', 'LF_ctd_unspecified_treat', 'LF_ctd_unspecified_induce', 'LF_d_following_c', 'LF_d_induced_by_c', 'LF_d_induced_by_c_tight', 'LF_d_treat_c', 'LF_develop_d_following_c', 'LF_far_c_d', 'LF_far_d_c', 'LF_improve_before_disease', 'LF_in_ctd_therapy', 'LF_in_ctd_marker', 'LF_in_patient_with', 'LF_induce', 'LF_induce_name', 'LF_induced_other', 'LF_level', 'LF_measure', 'LF_neg_d', 'LF_risk_d', 'LF_treat_d', 'LF_uncertain', 'LF_weak_assertions']


In [None]:
%%time 
from snorkel import SnorkelSession
session = SnorkelSession()

from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)
L_train = labeler.apply(split=0)
L_dev = labeler.apply(split=1) # used for debugging
L_test = labeler.apply(split=2) # used for evaluation

from snorkel.learning.structure import DependencySelector
ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)
from snorkel.learning import GenerativeModel

# need to extract `accs` from gen_model
gen_model = GenerativeModel(lf_propensity=True)
gen_model.train(
    L_train, deps=deps, decay=0.95, step_size=0.1/L_train.shape[0], reg_param=0.0
)

accs = np.array(gen_model.learned_lf_stats()['Accuracy'])
accs[np.isnan(accs)] = 0
accs = np.minimum(accs, 0.999)

gen_marginals = gen_model.marginals(L_train)

Clearing existing...


  0%|          | 5/8272 [00:00<02:51, 48.14it/s]

Running UDF...


 32%|███▏      | 2665/8272 [00:21<01:01, 90.96it/s] 

In [None]:
L = L_train.copy()
L[L==-1] = 2 # convert to multiclass
Y_dev = np.array([ex[1] for ex in dev])

In [None]:
from metal.label_model import LabelModel
label_model = LabelModel(k=2, seed=123)
label_model.train_model(L, Y_dev=Y_dev)
label_model.score((L_dev, Y_dev))

### Weak Labels in Dataset

In [None]:
metal_marginals = label_model.predict_proba(L)
metal_marginals

In [None]:
snorkel_marginals = np.vstack((gen_marginals, 1-gen_marginals)).T
snorkel_marginals

In [None]:
from metal.contrib.slicing.sqlite_wrapper \
    import SnorkelDataset as SnorkelSliceDataset

train_metal = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    train_marginals=metal_marginals
)

train_snorkel = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    train_marginals=snorkel_marginals
)

### Custom Slicing Dataset

In [None]:
train_slice = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    L_train=L_train.todense()
)

train_slice_metal = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    L_train=L_train.todense(),
    train_marginals=metal_marginals
)

train_slice_snorkel = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    L_train=L_train.todense(),
    train_marginals=snorkel_marginals
)

## (a) `Oracle`: EndModel Trained on Full GT

In [None]:
from metal.end_model import EndModel
from metal.modules import LSTMModule
use_cuda = torch.cuda.is_available()

wembs = load_embeddings(train.word_dict, embs)
lstm = LSTMModule(embed_size=50, 
                  hidden_size=100, 
                  embeddings=wembs,
                  lstm_reduction='attention', 
                  dropout=0, 
                  num_layers=1, 
                  freeze=False)

oracle = EndModel([200, 2], input_module=lstm, seed=123, use_cuda=use_cuda)
oracle.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
oracle.config['train_config']['validation_metric'] = 'f1'
oracle.config['train_config']['batch_size'] = 32
oracle.config['train_config']['n_epochs'] = 10

%time oracle.train_model(train, dev_data=dev)
oracle.score(test, metric=['precision', 'recall', 'f1'])

## (b) `BaseWeak`: EndModel trained on weak labels

In [None]:
from metal.end_model import EndModel
from metal.modules import LSTMModule
use_cuda = torch.cuda.is_available()

wembs = load_embeddings(train.word_dict, embs)
lstm = LSTMModule(embed_size=50, 
                  hidden_size=100,
                  embeddings=wembs,
                  lstm_reduction='attention', 
                  dropout=0,
                  num_layers=1,
                  freeze=False)

base_weak = EndModel([200, 2], input_module=lstm, seed=123, use_cuda=use_cuda)

base_weak.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
base_weak.config['train_config']['validation_metric'] = 'f1'
base_weak.config['train_config']['batch_size'] = 32
base_weak.config['train_config']['n_epochs'] = 10

%time base_weak.train_model(train_snorkel, dev_data=dev)
base_weak_scores = base_weak.score(test, metric=['precision', 'recall', 'f1'])

## (c) `SliceUW`: Unweighted SliceModel with `rw=False`

In [None]:
from metal.contrib.slicing.online_dp import SliceDPModel, LinearModule

In [None]:
wembs = load_embeddings(train.word_dict, embs)
lstm = LSTMModule(embed_size=50, 
                  hidden_size=100, 
                  embeddings=wembs,
                  lstm_reduction='attention', 
                  dropout=0, 
                  num_layers=1, 
                  freeze=False)

r_dim = 200
rw = False
slice_uw = SliceDPModel(lstm, accs, r_dim, rw, seed=123, use_cuda=True)

slice_uw.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
slice_uw.config['train_config']['validation_metric'] = 'f1'
slice_uw.config['train_config']['batch_size'] = 32
slice_uw.config['train_config']['n_epochs'] = 10

%time slice_uw.train_model(train_slice, dev_data=dev)
slice_uw_scores = slice_uw.score(test, metric=['precision', 'recall', 'f1'])

## (d) `SliceOurs`: Attention SliceModel with `rw=True`

In [None]:
wembs = load_embeddings(train.word_dict, embs)
lstm = LSTMModule(embed_size=50, 
                  hidden_size=100, 
                  embeddings=wembs,
                  lstm_reduction='attention', 
                  dropout=0, 
                  num_layers=1, 
                  freeze=False)


r_dim = 200
rw = True
slice_ours = SliceDPModel(lstm, accs, r_dim, rw, seed=123, use_cuda=True)

slice_ours.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
slice_ours.config['train_config']['validation_metric'] = 'f1'
slice_ours.config['train_config']['batch_size'] = 32
slice_ours.config['train_config']['n_epochs'] = 10

%time slice_ours.train_model(train_slice, dev_data=dev)
slice_ours_scores = slice_ours.score(test, metric=['precision', 'recall', 'f1'])

## (e) `SliceOursWeak`: Slice Model with $\tilde{Y}$ priors

In [None]:
from metal.contrib.slicing.online_dp import SliceDPModel, LinearModule
from metal.modules import LSTMModule

wembs = load_embeddings(train.word_dict, embs)
lstm = LSTMModule(embed_size=50, 
                  hidden_size=100, 
                  embeddings=wembs,
                  lstm_reduction='attention', 
                  dropout=0, 
                  num_layers=1, 
                  freeze=False)

r_dim = 200
rw = True
slice_ours_weak = SliceDPModel(lstm, accs, r_dim, rw, seed=123, use_cuda=True)

slice_ours_weak.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
slice_ours_weak.config['train_config']['validation_metric'] = 'f1'
slice_ours_weak.config['train_config']['batch_size'] = 32
slice_ours_weak.config['train_config']['n_epochs'] = 10

%time slice_ours_weak.train_model(train_slice_snorkel, dev_data=dev)
slice_ours_weak_scores = slice_ours_weak.score(test, metric=['precision', 'recall', 'f1'])

## (f) `SliceUWWeak`: Unweighted Slice model with $\tilde{Y}$ priors

In [None]:
from metal.contrib.slicing.online_dp import SliceDPModel, LinearModule
from metal.modules import LSTMModule

wembs = load_embeddings(train.word_dict, embs)
lstm = LSTMModule(embed_size=50, 
                  hidden_size=100, 
                  embeddings=wembs,
                  lstm_reduction='attention', 
                  dropout=0, 
                  num_layers=1, 
                  freeze=False)


r_dim = 200
rw = False
slice_uw_weak = SliceDPModel(lstm, accs, r_dim, rw, seed=123, use_cuda=True)

slice_uw_weak.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
slice_uw_weak.config['train_config']['validation_metric'] = 'f1'
slice_uw_weak.config['train_config']['batch_size'] = 32
slice_uw_weak.config['train_config']['n_epochs'] = 10

%time slice_uw_weak.train_model(train_slice_snorkel, dev_data=dev)
slice_uw_weak_scores = slice_uw_weak.score(test, metric=['precision', 'recall', 'f1'])


## Slice-specific scores

In [None]:
# TODO: don't call private fns
Yp_oracle, Y = oracle._get_predictions(test)
Yp_base_weak, Y = base_weak._get_predictions(test)
Yp_slice_uw, Y = slice_uw._get_predictions(test)
Yp_slice_ours, Y = slice_ours._get_predictions(test)
Yp_slice_ours_weak, Y = slice_ours_weak._get_predictions(test)
Yp_slice_uw_weak, Y = slice_uw_weak._get_predictions(test)

#### `slice_ours` (re-weighting, accuracy priors) vs. `base_weak` (end_model trained on weak labels)

In [None]:
from metal.contrib.slicing.experiment_utils import compare_LF_slices
compare_LF_slices(Yp_slice_ours, Yp_base_weak, 
                  Y, L_test, LFs, metric='accuracy', delta_threshold=0.05)

#### `slice_ours_weak` (slice model with weak priors + reweighting) vs. `base_weak` (end_model trained on weak labels)

In [None]:
compare_LF_slices(Yp_slice_ours_weak, Yp_base_weak,
                  Y, L_test, LFs, metric='accuracy', delta_threshold=0.05)

#### `slice_ours_weak` vs. `oracle` (trained on full GT)

In [None]:
compare_LF_slices(Yp_slice_ours, Yp_oracle,
                  Y, L_test, LFs, metric='accuracy', delta_threshold=0.05)

#### `slice_ours` vs. `Yp_slice_uw` (unweighted slice model)

In [None]:
compare_LF_slices(Yp_slice_ours, Yp_slice_uw,
                  Y, L_test, LFs, metric='accuracy', delta_threshold=0.05)