In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
sys.path.append('/dfs/scratch0/vschen/metal')
import metal
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
from metal.contrib.slicing.online_dp import SliceDPModel, LinearModule
from metal.contrib.slicing.sqlite_wrapper import SnorkelDataset

In [None]:
print('PyTorch: ', torch.__version__)
print('MeTaL:   ', metal.__version__)
print('Python:  ', sys.version)
print('Python:  ', sys.version_info)

In [None]:

db_conn_str   = os.path.join(os.getcwd(),"spouses.db")
candidate_def = ['Spouse', ['person1', 'person2']]


train, dev, test = SnorkelDataset.splits(db_conn_str, 
                                         candidate_def, 
                                         max_seq_len=125)

print(f'[TRAIN] {len(train)}')
print(f'[DEV]   {len(dev)}')
print(f'[TEST]  {len(test)}')

In [None]:
import numpy as np
snorkel_data = np.load('snorkel_data_spouse.npz')
L_train = snorkel_data['L_train']
L_dev = snorkel_data['L_dev']
L_test = snorkel_data['L_test']
train_marginals = snorkel_data['train_marginals']
dev_marginals = snorkel_data['dev_marginals']
accs = snorkel_data['accs']
m = len(accs)

L_train.shape, L_dev.shape, L_test.shape, len(train_marginals), len(dev_marginals)

In [None]:
from metal.contrib.slicing.CDR.embeddings import EmbeddingLoader, load_embeddings
emb_path  = "../glove.6B/glove.6B.50d.txt"
embs  = EmbeddingLoader(emb_path, fmt='text')

In [None]:
from metal.modules import LSTMModule
from metal.tuners import RandomSearchTuner
def init_model(use_end_model=False, r=None, reweight=None):
    wembs = load_embeddings(train.word_dict, embs)
    lstm = LSTMModule(embed_size=50, 
                      hidden_size=50, 
                      embeddings=wembs,
                      lstm_reduction='attention', 
                      dropout=0.25, 
                      num_layers=1, 
                      freeze=False)
    if use_end_model:
        model = EndModel([100, 2], input_module=lstm, seed=123, use_cuda=use_cuda)
    else:
        input_layer_config = {
            "input_relu": False,
            "input_batchnorm": False,
            "input_dropout": 0.0,
        }
        model = SliceDPModel(lstm, accs, r, reweight, seed=123, use_cuda=True, input_layer_config=input_layer_config)

    model.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01
    model.config['train_config']['validation_metric'] = 'f1'
    model.config['train_config']['batch_size'] = 64
    model.config['train_config']['n_epochs'] = 10
    return model

def search_slice_weights(train_loader, dev_loader, r, reweight, max_search=1, search_space=None):
    wembs = load_embeddings(train.word_dict, embs)
    lstm = LSTMModule(embed_size=50, 
                      hidden_size=50, 
                      embeddings=wembs,
                      lstm_reduction='attention', 
                      dropout=0.0, 
                      num_layers=1, 
                      freeze=False)
    
    searcher = RandomSearchTuner(SliceDPModel, validation_metric='f1', log_dir="./run_logs/alpha-loss")

    if search_space is None:
        search_space = {
            "slice_weight": [0, 0.25, 0.5, 0.75, 1.0]
        }

    input_layer_config = {
        "input_relu": False,
        "input_batchnorm": False,
        "input_dropout": 0.0,
    }
    
    trained_model = searcher.search(
        search_space,
        dev_loader,
        train_args=[train_loader],
        init_args=[lstm, accs, r, reweight],
        init_kwargs={"use_cuda": True, "input_layer_config": input_layer_config},
        train_kwargs={
            "lr": 0.01,
            "batch_size": 32,
            "n_epochs": 10
        },
        max_search=max_search
    )
    return trained_model

In [None]:
# multiclass
snorkel_marginals = np.vstack((train_marginals, 1-train_marginals)).T
snorkel_marginals

In [None]:
from metal.contrib.slicing.sqlite_wrapper \
    import SnorkelDataset as SnorkelSliceDataset

train_snorkel = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    train_marginals=snorkel_marginals
)

train_slice = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    L_train=L_train
)

train_slice_snorkel = SnorkelSliceDataset(
    db_conn_str,
    candidate_def,
    split=0,
    L_train=L_train,
    train_marginals=snorkel_marginals
)

## (a) `BaseWeak`: EndModel trained on weak labels

In [None]:
from metal.end_model import EndModel
use_cuda = torch.cuda.is_available()

base_weak = init_model(use_end_model=True)
%time base_weak.train_model(train_snorkel, dev_data=dev)
base_weak_scores = base_weak.score(test, metric=['precision', 'recall', 'f1'])

## (d) `SliceOursWeak`: Slice Model with $\tilde{Y}$ priors

In [None]:
# slice_ours_weak = init_model(use_end_model=False, r=100, reweight=True)
# %time slice_ours_weak.train_model(train_slice_snorkel, dev_data=dev)

search_space = {
    "slice_weight": {"range": [0,1.0] ,"scale": "linear"}
}
%time slice_ours_weak = search_slice_weights(train_slice_snorkel, dev, \
                                             r=100, reweight=True, max_search=10)
slice_ours_weak_scores = slice_ours_weak.score(test, metric=['precision', 'recall', 'f1'])

Loaded 91.0% (29001/31870) pretrained embeddings
Using pretrained embeddings.
Embeddings shape = (31870, 50)
The embeddings are NOT FROZEN
Using lstm_reduction = 'attention'
Slice Heads:
Reweighting: True
Slice Weight: 9.83236397554566
Input Network: Sequential(
  (0): LSTMModule(
    (embeddings): Embedding(31870, 50)
    (lstm): LSTM(50, 50, batch_first=True, bidirectional=True)
  )
)
L_head: Linear(in_features=100, out_features=10, bias=False)
Y_head: Linear(in_features=200, out_features=2, bias=False)
[0] Testing {'slice_weight': 9.83236397554566}
Could not find kwarg "slice_weight" in destination dict.
Using GPU...


HBox(children=(IntProgress(value=0, max=696), HTML(value='')))

Saving model at iteration 0 with best score 0.381
[E:0]	Train Loss: 1.009	Dev f1: 0.381


HBox(children=(IntProgress(value=0, max=696), HTML(value='')))

Process Process-15:
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/dfs/scratch0/vschen/snorkel-pytorch/venv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/usr/local/lib/python3.6/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/usr/local/lib/python3.6/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/usr/local/lib/python3.6/multiprocessing/connection.py", line 414, in _poll
    r = wait([self], timeout)
  File "/usr/local/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/usr/local/lib/python3.6/selectors.py", line 376, in se

RuntimeError: DataLoader worker (pid 109719) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.

## (e) `SliceUWWeak`: Unweighted Slice model with $\tilde{Y}$ priors

In [None]:
slice_uw_weak = init_model(use_end_model=False, r=100, reweight=False)
%time slice_uw_weak.train_model(train_slice_snorkel, dev_data=dev)
slice_uw_weak_scores = slice_uw_weak.score(test, metric=['precision', 'recall', 'f1'])

## Overall Scores

In [None]:
print ("base_weak_score:", base_weak_score)
print ("slice_ours_weak_score:", slice_ours_weak_score)
print ("slice_uw_weak_score:", slice_uw_weak_score)

## Slice-specific scores

In [None]:
from labeling_functions import LFs
print([lf.__name__ for lf in LFs])

In [None]:
# TODO: don't call private fns
Yp_base_weak, Y = base_weak._get_predictions(test)
Yp_slice_ours_weak, Y = slice_ours_weak._get_predictions(test)
Yp_slice_uw_weak, Y = slice_uw_weak._get_predictions(test)

In [None]:
from metal.contrib.slicing.experiment_utils import compare_LF_slices

### `slice_ours_weak` vs. `base_weak`

In [None]:
print ("slice_ours_weak vs base_weak")
compare_LF_slices(Yp_slice_ours_weak, Yp_base_weak, Y, L_test, LFs, metric='accuracy', delta_threshold=0.02)

### `slice_ours_weak` vs. `slice_uw_weak`

In [None]:
print ("slice_ours_weak vs slice_uw_weak")
compare_LF_slices(Yp_slice_ours_weak, Yp_slice_uw_weak, Y, L_test, LFs, metric='accuracy', delta_threshold=0.02)

### `slice_ours_weak` vs. `base_weak`

In [None]:
print ("slice_uw_weak vs. base_weak")
compare_LF_slices(Yp_slice_ours_weak, Yp_slice_uw_weak, Y, L_test, LFs, metric='accuracy', delta_threshold=0.02)