In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
import warnings
warnings.filterwarnings("ignore")

In [9]:
import yaml, json
from types import SimpleNamespace
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

# Bias in Bios $-$ ProfTeacher
***This tutorial is a minimal example of what needs to be defined,
and how to intertwine it, so that you can train Weasel, without Hydra as Config manager.***
<br>$\rightarrow$ *For an extended, more feature-rich version of this notebook, check the [1_bias_bios_full script](1_bias_bios_full.py) out (including logging to wandb, callbacks, and how to retrieve your stand-alone end-model after training it with Weasel).*
<br>$\rightarrow$ *For the recommended config managing with Hydra, check [this notebook](./1_bias_bios.ipynb) out*

First let's load some pre-defined dataset & model parameters and hyperparameters:

In [10]:
with open(f"configs/profTeacher_no_hydra.yaml") as f:
    config_dict = yaml.load(f)
    config = json.loads(json.dumps(config_dict), object_hook=lambda d: SimpleNamespace(**d))
    seed_everything(config.seed, workers=True)  # seed for reproducibility
config

Global seed set to 3


namespace(seed=3,
          datamodule=namespace(batch_size=64,
                               val_test_split=[250, -1],
                               seed=3),
          trainer=namespace(gpus=0, max_epochs=100),
          end_model=namespace(dropout=0.3,
                              net_norm='none',
                              activation_func='ReLU',
                              input_dim=300,
                              hidden_dims=[50, 50, 25],
                              output_dim=2,
                              adjust_thresh=True),
          Weasel=namespace(num_LFs=99,
                           n_classes=2,
                           class_balance=[0.5, 0.5],
                           loss_function='cross_entropy',
                           temperature=2.0,
                           accuracy_scaler='sqrt',
                           use_aux_input_for_encoder=True,
                           class_conditional_accuracies=True,
                           encoder=names

## Data in WeaSEL

In this example we have $C=2$ classes (teacher or professor biography within a text document).

We can make use of data features, *X_train* and *X_test*, for all training and test data points.
Eventually we hope to make predictions with our end-model based solely on these features.

However, what if we don't have labels, Y, for our training examples, i.e. can't train our end-model
in the traditional supervised way on (X, Y) examples?
This is where multi-source weak supervision & Weasel come to the rescue! :D

#### Multiple Labeling Heuristics replace ground truth training labels
For this concrete problem, we created $m=99$ regex-based LFs, that we then applied on the $n=12.5k$ training examples, which gives us
a label matrix $L$ of shape $n \times m$ with values in $\{-1, 0, .., C-1\} = \{-1, 0, 1\}$.

Here, $-1$ means that a LF abstained from labeling, while $0, 1$ indicate that the LF believes that the particular
example is a teacher, professor biography, respectively.

#### Optional, but recommended evaluation of the end-model
Lastly, we will also want to evaluate the skill of our end-model on a small test set that contains
ground truth labels *Y_test* (with the corresponding features *X_test*).

### All in one place

All this is conveniently encapsulated in an [abstract PyTorch Lightning DataModule](../weasel/datamodules/base_datamodule.py),
 that makes it simple for you to set up your own DataModule to easily train your end-model, provided that you already have all of the above at hand:

- Training and Test features, *X_train, X_test*
- Label matrix, *L*
- Ground truth test labels, *Y_test*

See the [code that defines ProfTeacher_DataModule](datamodules/ProfTeacher_datamodule.py)
to see how simple it then is to get started.
Alternatively to the approach of the ``ProfTeacher_DataModule`` that encapsulates the data loading within the DataModule itself,
 you may also create a ``base_datamodule.BasicWeaselDataModule`` by passing it in the constructor the four data components above,
 see [this notebook](0_full_pipeline.ipynb) for a simple synthetic example of it.

In [4]:
from examples.datamodules.ProfTeacher_datamodule import ProfTeacher_DataModule
data_module = ProfTeacher_DataModule(**vars(config.datamodule))

## End-model

Having set up the data part, you'll first have to choose your favorite neural net as the end-model <br>
 (the one that you want to use as ``predictions = end-model(X)`` eventually).

Here, we'll use a simple 2-layer feed-forward net/MLP, but you
can easily replace it with *any* neural net model, see the instructions in the Readme.

In [5]:
from weasel.models.downstream_models.MLP import MLPNet
endmodel = MLPNet(**vars(config.end_model))

## Weasel: Marrying your end-model with the LFs
We now pass this end-model to the wrapping Weasel model, which will take care of learning the end-model based
on *X_train* and the LFs, i.e. the label matrix *L*.
To do so, there is an encoder net (another MLP in this case) in Weasel's core, that will be predicting labels based on the
LFs.

In [14]:
from weasel.models import Weasel
weasel = Weasel(**vars(config.Weasel), end_model=endmodel)


## Training Weasel and end-model

Before fitting Weasel and the end-model, we now just need to instantiate a pl.Trainer instance
(we will checkpoint the best model w.r.t. AUC performance on a small validation set that is split off the test set).

In [18]:
checkpoint_callback = ModelCheckpoint(monitor="Val/auc", mode="max")

trainer = pl.Trainer(
    **vars(config.trainer),
    logger=None,
    deterministic=True,
    callbacks=[checkpoint_callback]
)

trainer.fit(weasel, datamodule=data_module)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type       | Params
---------------------------------------------
0 | end_model     | MLPNet     | 18.9 K
1 | encoder       | MLPEncoder | 47.3 K
2 | accuracy_func | Softmax    | 0     
---------------------------------------------
66.2 K    Trainable params
0         Non-trainable params
66.2 K    Total params
0.265     Total estimated model params size (MB)


                                                              

Global seed set to 3


Epoch 0:  98%|█████████▊| 193/197 [00:03<00:00, 48.32it/s, Val/accuracy=0.641, Val/recall=0.846, Val/precision=0.604, Val/f1=0.705, Val/auc=0.593, decision_thresh=0.501]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 197/197 [00:04<00:00, 48.36it/s, Val/accuracy=0.680, Val/recall=0.951, Val/precision=0.611, Val/f1=0.744, Val/auc=0.784, decision_thresh=0.499]
Epoch 1:  98%|█████████▊| 193/197 [00:04<00:00, 47.76it/s, Val/accuracy=0.680, Val/recall=0.951, Val/precision=0.611, Val/f1=0.744, Val/auc=0.784, decision_thresh=0.499]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 197/197 [00:04<00:00, 47.82it/s, Val/accuracy=0.780, Val/recall=0.918, Val/precision=0.713, Val/f1=0.803, Val/auc=0.864, decision_thresh=0.501]
Epoch 2:  98%|█████████▊| 193/197 [00:03<00:00, 49.66it/s, Val/accuracy=0.780, Val/recall=0.918, Val/precision=0.713, Val/f1=0.803, Val/auc=0.864, decision_thresh=0.501]
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 197/197 [00:03<00:00, 

Epoch 21: 100%|██████████| 197/197 [00:04<00:00, 49.07it/s, Val/accuracy=0.908, Val/recall=0.910, Val/precision=0.902, Val/f1=0.906, Val/auc=0.932, decision_thresh=0.500]
Epoch 22:  98%|█████████▊| 193/197 [00:03<00:00, 48.46it/s, Val/accuracy=0.908, Val/recall=0.910, Val/precision=0.902, Val/f1=0.906, Val/auc=0.932, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 22: 100%|██████████| 197/197 [00:04<00:00, 48.46it/s, Val/accuracy=0.904, Val/recall=0.910, Val/precision=0.895, Val/f1=0.902, Val/auc=0.932, decision_thresh=0.500]
Epoch 23:  98%|█████████▊| 193/197 [00:04<00:00, 47.99it/s, Val/accuracy=0.904, Val/recall=0.910, Val/precision=0.895, Val/f1=0.902, Val/auc=0.932, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 23: 100%|██████████| 197/197 [00:04<00:00, 48.02it/s, Val/accuracy=0.904, Val/recall=0.910, Val/precision=0.895, Val/f1=0.902, Val/auc=0.932, decision_thresh=0.500]
Epoch 24:  98%|█████████▊| 193/197 [00:03<00:00, 48.36it/s, Val/accuracy=0.90

Epoch 43:  98%|█████████▊| 193/197 [00:03<00:00, 49.64it/s, Val/accuracy=0.896, Val/recall=0.902, Val/precision=0.887, Val/f1=0.894, Val/auc=0.931, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 43: 100%|██████████| 197/197 [00:03<00:00, 49.66it/s, Val/accuracy=0.900, Val/recall=0.893, Val/precision=0.901, Val/f1=0.897, Val/auc=0.931, decision_thresh=0.500]
Epoch 44:  98%|█████████▊| 193/197 [00:03<00:00, 49.58it/s, Val/accuracy=0.900, Val/recall=0.893, Val/precision=0.901, Val/f1=0.897, Val/auc=0.931, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 44: 100%|██████████| 197/197 [00:03<00:00, 49.63it/s, Val/accuracy=0.900, Val/recall=0.893, Val/precision=0.901, Val/f1=0.897, Val/auc=0.932, decision_thresh=0.500]
Epoch 45:  98%|█████████▊| 193/197 [00:03<00:00, 50.32it/s, Val/accuracy=0.900, Val/recall=0.893, Val/precision=0.901, Val/f1=0.897, Val/auc=0.932, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 45: 100%|██████████| 197/197 [00:03<0

Epoch 64: 100%|██████████| 197/197 [00:04<00:00, 48.11it/s, Val/accuracy=0.900, Val/recall=0.934, Val/precision=0.870, Val/f1=0.901, Val/auc=0.937, decision_thresh=0.500]
Epoch 65:  98%|█████████▊| 193/197 [00:03<00:00, 50.04it/s, Val/accuracy=0.900, Val/recall=0.934, Val/precision=0.870, Val/f1=0.901, Val/auc=0.937, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 65: 100%|██████████| 197/197 [00:03<00:00, 50.03it/s, Val/accuracy=0.900, Val/recall=0.934, Val/precision=0.870, Val/f1=0.901, Val/auc=0.938, decision_thresh=0.500]
Epoch 66:  98%|█████████▊| 193/197 [00:03<00:00, 48.29it/s, Val/accuracy=0.900, Val/recall=0.934, Val/precision=0.870, Val/f1=0.901, Val/auc=0.938, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 66: 100%|██████████| 197/197 [00:04<00:00, 48.36it/s, Val/accuracy=0.900, Val/recall=0.934, Val/precision=0.870, Val/f1=0.901, Val/auc=0.937, decision_thresh=0.500]
Epoch 67:  98%|█████████▊| 193/197 [00:03<00:00, 48.52it/s, Val/accuracy=0.90

Epoch 86:  98%|█████████▊| 193/197 [00:04<00:00, 47.77it/s, Val/accuracy=0.908, Val/recall=0.918, Val/precision=0.896, Val/f1=0.907, Val/auc=0.950, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 86: 100%|██████████| 197/197 [00:04<00:00, 46.81it/s, Val/accuracy=0.900, Val/recall=0.918, Val/precision=0.882, Val/f1=0.900, Val/auc=0.951, decision_thresh=0.500]
Epoch 87:  98%|█████████▊| 193/197 [00:04<00:00, 47.22it/s, Val/accuracy=0.900, Val/recall=0.918, Val/precision=0.882, Val/f1=0.900, Val/auc=0.951, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 87: 100%|██████████| 197/197 [00:04<00:00, 47.31it/s, Val/accuracy=0.896, Val/recall=0.951, Val/precision=0.853, Val/f1=0.899, Val/auc=0.951, decision_thresh=0.500]
Epoch 88:  98%|█████████▊| 193/197 [00:04<00:00, 47.99it/s, Val/accuracy=0.896, Val/recall=0.951, Val/precision=0.853, Val/f1=0.899, Val/auc=0.951, decision_thresh=0.500]
Validating: 0it [00:00, ?it/s][A
Epoch 88: 100%|██████████| 197/197 [00:04<0

## Evaluation

Now that Weasel has finished training, we can evaluate on the held-out test set to see how well Weasel did,
i.e. how well the MLP end-model from above generalizes beyond the weak training signal given by the 99 LFs.

That is, evaluate how good the ``predictions = end-model(X_test)`` are with respect to our gold test labels
*Y_test*.  <br>
$\rightarrow$Note that the LFs, *L*, and Weasel are not needed anymore after training/for prediction.

In [None]:
# See the sister notebook, 1_bias_bios.ipynb for a Snorkel baseline
test_stats = trainer.test(weasel, datamodule=data_module, ckpt_path='best')