<img align="left" src="imgs/logo.jpg" width="50px" style="margin-right:10px">

# Snorkel Workshop: Extracting Spouse Relations <br> from the News
## Part 4: Training our End Extraction Model

In this final section of the tutorial, we'll use the noisy training labels we generated in the last tutorial part to train our end machine learning model.

For this tutorial, we will be training a fairly effective deep learning model. More generally, however, Snorkel plugs in with many ML libraries, making it easy to use almost any state-of-the-art model as the end model!

In [1]:
%load_ext autoreload
%autoreload 
%matplotlib inline

import os
import numpy as np

In [2]:
import torch
from torch.utils.data import DataLoader
from snorkel.model.utils import MetalDataset

## I. Loading Candidates and Gold Labels


In [3]:
import pickle

with open('new_dev_data.pkl', 'rb') as f:
    dev_data = pickle.load(f)
    dev_labels = pickle.load(f)
    
with open('new_train_data.pkl', 'rb') as f:
    train_data = pickle.load(f)
    train_labels = pickle.load(f)
    
with open('new_test_data.pkl', 'rb') as f:
    test_data = pickle.load(f)
    test_labels = pickle.load(f)

with open('train_proba.pkl', 'rb') as f:
    train_marginals = pickle.load(f)

## II. Training a _Long Short-term Memory_ (LSTM) Neural Network

[LSTMs](https://en.wikipedia.org/wiki/Long_short-term_memory) can acheive state-of-the-art performance on many text classification tasks. We'll train a simple LSTM model below. 

In deep learning, hyperparameter tuning is very important and computationally expensive step in training models. For purposes of this tutorial, we have a pre-trained model using the training labels generated from the previous notebook.

### Data Processing for LSTM

First, we prepare the input to our LSTM by adding *markers* to the beginning and end of the person mentions so the LSTM knows which two persons in the sentence we want to learn a relation for. We then featurize the tokens it using a standard vocabulary.

In [4]:
from utils import EmbeddingFeaturizer
from utils import mark_entities

markers = ['[[BEGIN0]]','[[END0]]','[[BEGIN1]]','[[END1]]']
featurizer = EmbeddingFeaturizer(markers=markers)

def convert_to_lstm_input(data):
    X = []
    #mark candidates with markers
    for i in range(len(data)):
        cand = data.loc[i]
        marked_tokens = mark_entities(
                    cand.tokens,
                    positions=[cand.person1_word_idx, cand.person2_word_idx],
                    markers=markers)
        X.append(marked_tokens)
        
    #featurize string tokens tokens
    featurize_X = featurizer.fit_transform(X)
    return featurize_X

train_X_tensor = convert_to_lstm_input(train_data)
dev_X_tensor = convert_to_lstm_input(dev_data)
test_X_tensor = convert_to_lstm_input(test_data)

### Creating DataLoaders

In [5]:
from utils import upgrade_dataloaders

datasets = []
datasets.append(MetalDataset(train_X_tensor, torch.LongTensor(train_marginals[:,0]))) #TODO: check 
datasets.append(MetalDataset(dev_X_tensor, torch.LongTensor(dev_labels+1.)))
datasets.append(MetalDataset(test_X_tensor, torch.LongTensor(test_labels+1.)))

dataloaders = []
for dataset, split in zip(datasets, ["train", "valid", "test"]):
    dataloader = DataLoader(dataset)
    dataloader.split = split
    dataloaders.append(dataloader)
    
dataloaders = upgrade_dataloaders(dataloaders)

### Training LSTM Model
For purposes of this tutorial, we have saved a pre-trained model that was trained using probabilistic labels generated in the previous notebook. 

We define our model here and load the pretrained weights before evaluation. 

In [6]:
import torch.nn as nn
from snorkel.mtl.simple_model import SimpleModel
from utils import LSTMModule, EmbeddingsEncoder

MAX_INT = train_X_tensor.max()
embed_size = 4
hidden_size = 5

lstm_module = LSTMModule(
    embed_size,
    hidden_size,
    bidirectional=False,
    verbose=False,
    lstm_reduction="attention",
    encoder_class=EmbeddingsEncoder,
    encoder_kwargs={"vocab_size": MAX_INT + 1},
)

model = SimpleModel(
    modules=[
    lstm_module,
    nn.Linear(lstm_module.output_dim,1)],
    metrics = ['accuracy', 'f1', 'precision','recall'])
print(model)

SimpleModel(name=SimpleModel)


**Load and Score Pre-Trained Model**

In [7]:
model.load('./trained_spouse_model')

print("Dev Set Scores")
scores = model.score(dataloaders[1:2])
print(scores)

print("Test Set Scores")
scores = model.score(dataloaders[2])
print(scores)

Dev Set Scores
{'task/data_valid/valid/accuracy': 0.3624161073825503, 'task/data_valid/valid/f1': 0.5320197044334976, 'task/data_valid/valid/precision': 0.3624161073825503, 'task/data_valid/valid/recall': 1.0}
Test Set Scores
{'task/data_test/test/accuracy': 0.28289473684210525, 'task/data_test/test/f1': 0.441025641025641, 'task/data_test/test/precision': 0.28289473684210525, 'task/data_test/test/recall': 1.0}


### Note: This takes > 30 mins to Run on a CPU!

In [8]:
# # Train SimpleModel
# from snorkel.mtl.trainer import Trainer
# trainer = Trainer(progress_bar=True, n_epochs=5)
# trainer.train_model(model, dataloaders)