In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import os
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from weasel.utils import utils

# Bias in Bios $-$ ProfTeacher

***This tutorial is a minimal example of what needs to be defined, and how to intertwine it, so that you can train Weasel.***
<br>$\rightarrow$ *For an extended, more feature-rich version of this notebook, check the [1_bias_bios_full script](1_bias_bios_full.py) out (including logging to wandb, callbacks, and how to retrieve your stand-alone end-model after training it with Weasel).*
<br>$\rightarrow$ *If you don't want to use Hydra for now or just want to see the way Weasel interfaces with PyTorch Lightning explicitly, check [this no-Hydra version of the present notebook](./1_bias_bios_no_hydra.ipynb) out - it will make your project easier if you do use it though!*
<br>$\rightarrow$ *In the end of this notebook you will find a (performance) comparison to [Snorkel](https://github.com/snorkel-team/snorkel)*

First let's load some pre-defined dataset & model parameters and hyperparameters.
We'll use [Hydra](https://hydra.cc/) in this notebook as config manager.
It is very convenient, flexible and speeds the ML-pipeline process up. Totally recommended for your own problem!

You will see below in the notebook output several (hyper-)parameters for the various submodules needed to train Weasel:
- Data, of course :) - more on it in the following **Data** markdown cell below.
- An end-model - this is your part: choose any neural net that suits your problem best!
- Weasel's configuration itself, including the encoder network that will learn to predict based on the labeling functions (LF).
- A standard pl.Trainer.

In [3]:
import hydra
from hydra.utils import instantiate as hydra_instantiate

with hydra.initialize(config_path="examples/configs/"):
    config = hydra.compose(
        config_name="profTeacher_simple.yaml",
        overrides=["end_model.adjust_thresh=False",  # Hydra makes overriding easy!
                   "trainer.gpus=1"]  # train on a GPU, set to 0 for CPU only

    ) 

utils.print_config(config) # prints only when rich library is installed
_ = seed_everything(config.seed, workers=True)  # seed for reproducibility

Global seed set to 3


## Data in WeaSEL

In this example we have $C=2$ classes (teacher or professor biography within a text document).

We can make use of data features, *X_train* and *X_test*, for all training and test data points.
Eventually we hope to make predictions with our end-model based solely on these features.

However, what if we don't have labels, Y, for our training examples, i.e. can't train our end-model
in the traditional supervised way on (X, Y) examples?
This is where multi-source weak supervision & Weasel come to the rescue! :D

#### Multiple Labeling Heuristics replace ground truth training labels
For this concrete problem, we created $m=99$ regex-based LFs, that we then applied on the $n=12.5k$ training examples, which gives us
a label matrix $L$ of shape $n \times m$ with values in $\{-1, 0, .., C-1\} = \{-1, 0, 1\}$.

Here, $-1$ means that a LF abstained from labeling, while $0, 1$ indicate that the LF believes that the particular
example is a teacher, professor biography, respectively.

#### Optional, but recommended evaluation of the end-model
Lastly, we will also want to evaluate the skill of our end-model on a small test set that contains
ground truth labels *Y_test* (with the corresponding features *X_test*).

### All in one place

All this is conveniently encapsulated in an [abstract PyTorch Lightning DataModule](../weasel/datamodules/base_datamodule.py),
 that makes it simple for you to set up your own DataModule to easily train your end-model, provided that you already have all of the above at hand:

- Training and Test features, *X_train, X_test*
- Label matrix, *L*
- Ground truth test labels, *Y_test*

See the [code that defines ProfTeacher_DataModule](datamodules/ProfTeacher_datamodule.py)
to see how simple it then is to get started.
Alternatively to the approach of the ``ProfTeacher_DataModule`` that encapsulates the data loading within the DataModule itself,
 you may also create a ``base_datamodule.BasicWeaselDataModule`` by passing it in the constructor the four data components above,
 see [this notebook](0_full_pipeline.ipynb) for a simple synthetic example of it.


In [4]:
# Hydra instantiations make instantiating the DataModule, Model, Trainer, etc. a one-liner :)
profTeacher_data_module = hydra_instantiate(config.datamodule)

## End-model

Having set up the data part, you'll first have to choose your favorite neural net as the end-model <br>
 (the one that you want to use as ``predictions = end-model(X)`` eventually).

Here, we'll use a simple 2-layer feed-forward net/MLP, but you
can easily replace it with *any* neural net model, see [these instructions](../weasel/models/downstream_models/README.md).

In [5]:
MLP_end_model = hydra_instantiate(config.end_model)

## Weasel: Marrying your end-model with the LFs
We now pass this end-model to the wrapping [Weasel model](../weasel/models/weasel.py), which will take care of learning the end-model based
on *X_train* and the LFs, i.e. the label matrix *L*.
To do so, there is an encoder net (another MLP in this case) in Weasel's core, that will be predicting labels based on the
LFs.

In [6]:
weasel_model = hydra_instantiate(config.Weasel, end_model=MLP_end_model, _recursive_=False)


## Training Weasel and end-model

Before fitting Weasel and the end-model, we now just need to instantiate a pl.Trainer instance
(we will checkpoint the best model w.r.t. AUC performance on a small validation set that is split off the test set).

In [7]:
checkpoint_callback = ModelCheckpoint(monitor="Val/auc", mode="max")

trainer = hydra_instantiate(
        config.trainer, callbacks=checkpoint_callback, deterministic=True, max_epochs=75
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


Then, with all the ease of PyTorch Lightning, we can train our model on the DataModule from above.

*If any strange errors are thrown, try rerunning the cell! (e.g. "ValueError: dictionary update sequence element #0 has length 1; 2 is required")*

In [9]:
trainer.fit(
    model=weasel_model,
    datamodule=profTeacher_data_module
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type       | Params
---------------------------------------------
0 | end_model     | MLPNet     | 18.9 K
1 | encoder       | MLPEncoder | 47.3 K
2 | accuracy_func | Softmax    | 0     
---------------------------------------------
66.2 K    Trainable params
0         Non-trainable params
66.2 K    Total params
0.265     Total estimated model params size (MB)


                                                                      

Global seed set to 3


Epoch 0:  98%|█████████▊| 193/197 [00:03<00:00, 59.55it/s, Val/accuracy=0.508, Val/recall=1.000, Val/precision=0.508, Val/f1=0.674, Val/auc=0.593, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 197/197 [00:03<00:00, 58.83it/s, Val/accuracy=0.524, Val/recall=1.000, Val/precision=0.506, Val/f1=0.672, Val/auc=0.780, decision_thresh=0.500]
Epoch 1:  98%|█████████▊| 193/197 [00:03<00:00, 57.97it/s, Val/accuracy=0.524, Val/recall=1.000, Val/precision=0.506, Val/f1=0.672, Val/auc=0.780, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 197/197 [00:03<00:00, 57.43it/s, Val/accuracy=0.648, Val/recall=1.000, Val/precision=0.581, Val/f1=0.735, Val/auc=0.866, decision_thresh=0.500]
Epoch 2:  98%|█████████▊| 193/197 [00:03<00:00, 56.39it/s, Val/accuracy=0.648, Val/recall=1.000, Val/precision=0.581, Val/f1=0.735, Val/auc=0.866, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 197/197 [00:03<00:00, 

Epoch 21: 100%|██████████| 197/197 [00:03<00:00, 64.29it/s, Val/accuracy=0.892, Val/recall=0.910, Val/precision=0.874, Val/f1=0.892, Val/auc=0.937, decision_thresh=0.500]
Epoch 22:  98%|█████████▊| 193/197 [00:02<00:00, 64.79it/s, Val/accuracy=0.892, Val/recall=0.910, Val/precision=0.874, Val/f1=0.892, Val/auc=0.937, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 22: 100%|██████████| 197/197 [00:03<00:00, 64.39it/s, Val/accuracy=0.884, Val/recall=0.877, Val/precision=0.884, Val/f1=0.881, Val/auc=0.937, decision_thresh=0.500]
Epoch 23:  98%|█████████▊| 193/197 [00:02<00:00, 64.99it/s, Val/accuracy=0.884, Val/recall=0.877, Val/precision=0.884, Val/f1=0.881, Val/auc=0.937, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 23: 100%|██████████| 197/197 [00:03<00:00, 64.58it/s, Val/accuracy=0.872, Val/recall=0.836, Val/precision=0.895, Val/f1=0.864, Val/auc=0.937, decision_thresh=0.500]
Epoch 24:  98%|█████████▊| 193/197 [00:02<00:00, 64.79it/s, Val/accuracy=0.87

Epoch 43:  98%|█████████▊| 193/197 [00:03<00:00, 62.35it/s, Val/accuracy=0.888, Val/recall=0.869, Val/precision=0.898, Val/f1=0.883, Val/auc=0.938, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 43: 100%|██████████| 197/197 [00:03<00:00, 62.05it/s, Val/accuracy=0.888, Val/recall=0.869, Val/precision=0.898, Val/f1=0.883, Val/auc=0.938, decision_thresh=0.500]
Epoch 44:  98%|█████████▊| 193/197 [00:02<00:00, 65.89it/s, Val/accuracy=0.888, Val/recall=0.869, Val/precision=0.898, Val/f1=0.883, Val/auc=0.938, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 44: 100%|██████████| 197/197 [00:03<00:00, 65.48it/s, Val/accuracy=0.888, Val/recall=0.885, Val/precision=0.885, Val/f1=0.885, Val/auc=0.938, decision_thresh=0.500]
Epoch 45:  98%|█████████▊| 193/197 [00:02<00:00, 65.74it/s, Val/accuracy=0.888, Val/recall=0.885, Val/precision=0.885, Val/f1=0.885, Val/auc=0.938, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 45: 100%|██████████| 197/197 [00:03<0

Epoch 64: 100%|██████████| 197/197 [00:03<00:00, 58.58it/s, Val/accuracy=0.892, Val/recall=0.926, Val/precision=0.863, Val/f1=0.893, Val/auc=0.941, decision_thresh=0.500]
Epoch 65:  98%|█████████▊| 193/197 [00:03<00:00, 58.39it/s, Val/accuracy=0.892, Val/recall=0.926, Val/precision=0.863, Val/f1=0.893, Val/auc=0.941, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 65: 100%|██████████| 197/197 [00:03<00:00, 57.39it/s, Val/accuracy=0.876, Val/recall=0.959, Val/precision=0.818, Val/f1=0.883, Val/auc=0.941, decision_thresh=0.500]
Epoch 66:  98%|█████████▊| 193/197 [00:03<00:00, 56.72it/s, Val/accuracy=0.876, Val/recall=0.959, Val/precision=0.818, Val/f1=0.883, Val/auc=0.941, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 66: 100%|██████████| 197/197 [00:03<00:00, 56.28it/s, Val/accuracy=0.880, Val/recall=0.959, Val/precision=0.824, Val/f1=0.886, Val/auc=0.941, decision_thresh=0.500]
Epoch 67:  98%|█████████▊| 193/197 [00:03<00:00, 57.89it/s, Val/accuracy=0.88

## Evaluation

Now that Weasel has finished training, we can evaluate on the held-out test set to see how well Weasel did,
i.e. how well the MLP end-model from above generalizes beyond the weak training signal given by the 99 LFs.

That is, evaluate how good the ``predictions = end-model(X_test)`` are with respect to our gold test labels
*Y_test*.  <br>
Note that the LFs, *L*, and Weasel are not needed anymore after training in order to use your end-model for prediction.

In [10]:
test_stats = trainer.test(
    datamodule=profTeacher_data_module,
    ckpt_path='best'
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: 100%|██████████| 189/189 [00:00<00:00, 544.59it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Test/accuracy': 0.8652441049485221,
 'Test/auc': 0.9239741370257348,
 'Test/f1': 0.867239263803681,
 'Test/precision': 0.8472111235416334,
 'Test/recall': 0.8882372654155496,
 'decision_thresh': 0.5}
--------------------------------------------------------------------------------


# * Optional baseline *
We can e.g. compare against [Snorkel](https://github.com/snorkel-team/snorkel)
as a baseline. For rerunning the below, Snorkel needs to be installed (``pip install snorkel``):

### Snorkel generative modeling: Predict probabilistic labels, *Y_train*, based on $L$

In [11]:
import numpy as np
from weasel.datamodules.dataset_classes import BasicDownstreamDataset
from snorkel.labeling.model.label_model import LabelModel

snorkel_label_model = LabelModel(cardinality=config.datamodule.n_classes)
label_matrix = np.array(profTeacher_data_module.ws_train_set.L)
snorkel_label_model.fit(L_train=label_matrix)  # Snorkel only sees the label matrix while learning

profTeacher_snorkel_dm = hydra_instantiate(config.datamodule)  # same as above
# replace the (L, X)-Weasel training set in our DataModule with (X, Y_snorkel),
#   where Y_snorkel are fixed, soft labels that Snorkel learned to predict based on L
Y_probs_snorkel = snorkel_label_model.predict_proba(label_matrix)
profTeacher_snorkel_dm.ws_train_set = BasicDownstreamDataset(
    X=profTeacher_snorkel_dm.ws_train_set.X, Y=Y_probs_snorkel,
    filter_uncertains=True  # filter_uncertains will remove those training samples where all LFs abstained,
)                           #  which would just introduce noise if used for training

Eliminated noisy samples from DownstreamDataset, 2232 removed.


### Re-training the end-model on Snorkel's *Y_train*

In [12]:
# re-initiate the same MLP that Weasel trained + the same Trainer:
MLP_end_model = hydra_instantiate(config.end_model)
checkpoint_callback = ModelCheckpoint(monitor="Val/auc", mode="max")

trainer = hydra_instantiate(
        config.trainer, callbacks=checkpoint_callback, deterministic=True, max_epochs=75
)
# Besides the training set, the only difference is the model:
trainer.fit(model=MLP_end_model, datamodule=profTeacher_snorkel_dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name    | Type       | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | network | Sequential | 18.9 K | [1, 300] | [1, 2]   
--------------------------------------------------------------
18.9 K    Trainable params
0         Non-trainable params
18.9 K    Total params
0.076     Total estimated model params size (MB)


Data split sizes for training, validation, testing: 10062 250 12044
                                                              

Global seed set to 3


Epoch 0:  98%|█████████▊| 158/162 [00:01<00:00, 141.02it/s, loss=0.688, Val/accuracy=0.508, Val/recall=1.000, Val/precision=0.508, Val/f1=0.674, Val/auc=0.537, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 162/162 [00:01<00:00, 142.43it/s, loss=0.688, Val/accuracy=0.512, Val/recall=0.000, Val/precision=0.000, Val/f1=0.000, Val/auc=0.770, decision_thresh=0.500]
Epoch 1:  98%|█████████▊| 158/162 [00:01<00:00, 156.76it/s, loss=0.681, Val/accuracy=0.512, Val/recall=0.000, Val/precision=0.000, Val/f1=0.000, Val/auc=0.770, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 162/162 [00:01<00:00, 158.05it/s, loss=0.681, Val/accuracy=0.512, Val/recall=0.000, Val/precision=0.000, Val/f1=0.000, Val/auc=0.862, decision_thresh=0.500]
Epoch 2:  98%|█████████▊| 158/162 [00:01<00:00, 152.31it/s, loss=0.676, Val/accuracy=0.512, Val/recall=0.000, Val/precision=0.000, Val/f1=0.000, Val/auc=0.862, decision_thresh=0.500]
Validating: 0it [

Epoch 20:  98%|█████████▊| 158/162 [00:01<00:00, 135.82it/s, loss=0.352, Val/accuracy=0.864, Val/recall=0.820, Val/precision=0.893, Val/f1=0.855, Val/auc=0.927, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 20: 100%|██████████| 162/162 [00:01<00:00, 135.98it/s, loss=0.352, Val/accuracy=0.868, Val/recall=0.820, Val/precision=0.901, Val/f1=0.858, Val/auc=0.926, decision_thresh=0.500]
Epoch 21:  98%|█████████▊| 158/162 [00:01<00:00, 132.25it/s, loss=0.336, Val/accuracy=0.868, Val/recall=0.820, Val/precision=0.901, Val/f1=0.858, Val/auc=0.926, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 21: 100%|██████████| 162/162 [00:01<00:00, 133.27it/s, loss=0.336, Val/accuracy=0.864, Val/recall=0.820, Val/precision=0.893, Val/f1=0.855, Val/auc=0.926, decision_thresh=0.500]
Epoch 22:  98%|█████████▊| 158/162 [00:01<00:00, 136.63it/s, loss=0.345, Val/accuracy=0.864, Val/recall=0.820, Val/precision=0.893, Val/f1=0.855, Val/auc=0.926, decision_thresh=0.500]
Validating: 

Epoch 40:  98%|█████████▊| 158/162 [00:01<00:00, 152.96it/s, loss=0.32, Val/accuracy=0.860, Val/recall=0.820, Val/precision=0.885, Val/f1=0.851, Val/auc=0.920, decision_thresh=0.500] 
Validating: 0it [00:00, ?it/s][A
Epoch 40: 100%|██████████| 162/162 [00:01<00:00, 154.33it/s, loss=0.32, Val/accuracy=0.860, Val/recall=0.820, Val/precision=0.885, Val/f1=0.851, Val/auc=0.919, decision_thresh=0.500]
Epoch 41:  98%|█████████▊| 158/162 [00:01<00:00, 131.09it/s, loss=0.331, Val/accuracy=0.860, Val/recall=0.820, Val/precision=0.885, Val/f1=0.851, Val/auc=0.919, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 41: 100%|██████████| 162/162 [00:01<00:00, 131.95it/s, loss=0.331, Val/accuracy=0.860, Val/recall=0.820, Val/precision=0.885, Val/f1=0.851, Val/auc=0.920, decision_thresh=0.500]
Epoch 42:  98%|█████████▊| 158/162 [00:01<00:00, 133.42it/s, loss=0.322, Val/accuracy=0.860, Val/recall=0.820, Val/precision=0.885, Val/f1=0.851, Val/auc=0.920, decision_thresh=0.500]
Validating: 0

Epoch 60:  98%|█████████▊| 158/162 [00:01<00:00, 157.73it/s, loss=0.308, Val/accuracy=0.852, Val/recall=0.795, Val/precision=0.890, Val/f1=0.840, Val/auc=0.917, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 60: 100%|██████████| 162/162 [00:01<00:00, 159.25it/s, loss=0.308, Val/accuracy=0.848, Val/recall=0.787, Val/precision=0.889, Val/f1=0.835, Val/auc=0.917, decision_thresh=0.500]
Epoch 61:  98%|█████████▊| 158/162 [00:01<00:00, 138.53it/s, loss=0.324, Val/accuracy=0.848, Val/recall=0.787, Val/precision=0.889, Val/f1=0.835, Val/auc=0.917, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 61: 100%|██████████| 162/162 [00:01<00:00, 139.37it/s, loss=0.324, Val/accuracy=0.848, Val/recall=0.795, Val/precision=0.882, Val/f1=0.836, Val/auc=0.917, decision_thresh=0.500]
Epoch 62:  98%|█████████▊| 158/162 [00:01<00:00, 151.56it/s, loss=0.315, Val/accuracy=0.848, Val/recall=0.795, Val/precision=0.882, Val/f1=0.836, Val/auc=0.917, decision_thresh=0.500]
Validating: 

### Evaluating our end-model trained with Snorkel instead of Weasel
You can see below that Weasel improved performance by about 3 F1 points over Snorkel!
<br>We observed such strong performances of Weasel to happen pretty consistently across dataset and LF configurations (see paper).


In [13]:
snorkel_test_stats = trainer.test(datamodule=profTeacher_snorkel_dm, ckpt_path='best')

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: 100%|██████████| 189/189 [00:00<00:00, 560.49it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'Test/accuracy': 0.8522915974759216,
 'Test/auc': 0.9291996970456435,
 'Test/f1': 0.8405485345522989,
 'Test/precision': 0.9036423202929273,
 'Test/recall': 0.7856903485254692,
 'decision_thresh': 0.5}
--------------------------------------------------------------------------------
