# An Introduction to WISER, Part 2: Generative Models

In this part of the tutorial, we will take the results of the labeling functions from part 1 and learn a generative model that combines them.

We will start by reloading the data with the labeling function outputs from part 1.

## Reloading Data

In [5]:
import pickle

with open('output/tmp/train_data.p', 'rb') as f:
    train_data = pickle.load(f)

with open('output/tmp/dev_data.p', 'rb') as f:
    dev_data = pickle.load(f)
    
with open('output/tmp/test_data.p', 'rb') as f:
    test_data = pickle.load(f)

## Reinspecting Data

We can view the data again with all of the tagging rule annotations.

In [6]:
from wiser.viewer import Viewer
Viewer(dev_data, height=120)

<IPython.core.display.Javascript object>

Viewer(html='<head>\n<style>\nspan.active {\n    background-color: skyblue;\n    box-shadow: 1px 1px 1px grey;…

We can inspect at the precision, recall, and F1 scores using an unweighted combination of tagging rules with ``score_labels_majority_vote``.

In [7]:
from wiser.eval import score_labels_majority_vote
score_labels_majority_vote(dev_data)

Unnamed: 0,TP,FP,FN,P,R,F1
Majority Vote,224,993,1385,0.1841,0.1392,0.1585


# Generative Model

To aggregate the tagging and linking rules, we will need to train a generative model.

## Defining a Generative Model

We now need to declare a generative model. In this tutorial, we will be using the *linked HMM*, a model that makes use of linking rules to model dependencies between adjacent tokens. You can use other existing generative models available at `labelmodels`. 

Generative moedls have the following hyperparameters:
* Initial Accuracy (init_acc) is the initial estimated tagging and link-ing rule accuracy, also used as the mean of the prior distribution of the model parameters.

* Strength of Regularization (acc_prior) is the weight of the regularizer  pulling  tagging  and  linking  rule  accuracies  toward their initial values.

* Balance Prior (balance_prior) is used to regularize the class prior in Naive Bayes or the initial class distribution for HMM and Linked HMM, as well as the transition matrix in those methods, towards a more uniform distribution.

For more details on generative models and the *linked HMM*, please refer to our paper.

In [12]:
from labelmodels import NaiveBayes, HMM, LinkedHMM
from wiser.generative import Model

# model = Model(LinkedHMM, init_acc=0.9, acc_prior=50, balance_prior=500)
model = Model(LinkedHMM, init_acc=0.95, acc_prior=50, balance_prior=500)

## Training a Generative Model

Once we're done creating our generative model, we're ready to begin training! We first need to create a ``LearningConfig`` to specify the training configuration for the model.

In [13]:
from labelmodels import LearningConfig

config = LearningConfig()
config.num_epochs = 5

Then, we must pass the config object to the ``train`` , alongside the training and development data.

In [14]:
model.train(config, train_data=train_data[:750], dev_data=dev_data)

(0.7487, 0.5556, 0.6379)

## Evaluating a Generative Model

We can easily evaluate the performance of any generative model using the function ``evaluate`` function. Here, we'll evaluate our *linked HMM* on the test set.

In [15]:
model.evaluate(test_data)

Unnamed: 0,TP,FP,FN,P,R,F1
Predictions,657,190,718,0.7757,0.4778,0.5914


## Saving the Output of the Generative Model

After implementing your generative model, you need to save its probabilistic training labels. The ``save_probabilistic_output`` wrapper function will save the training, development, and testing outputs with the probabilistic labels to the specified directory. We will use these labels in the next part of the tutorial to train a recurrent neural network.

In [17]:
model.save_output(data=train_data, path='output/generative/link_hmm/train_data.p', save_distribution=True)
model.save_output(data=train_data[:750], path='output/generative/link_hmm/train_data_750.p', save_distribution=True)
model.save_output(data=dev_data, path='output/generative/link_hmm/dev_data.p', save_distribution=True, save_tags=True)
model.save_output(data=test_data, path='output/generative/link_hmm/test_data.p', save_distribution=True, save_tags=True)