# Snorkel Workshop: Slicing Tutorial

## Setup
To start, let's make sure that we have the right paths/environment variables set by following the instructions in `snorkel-superglue/README.md`.

Specifically, ensure that (1) `snorkel` is installed and (2) `SUPERGLUEDATA` is set where [download_superglue_data.py](https://github.com/HazyResearch/snorkel-superglue/blob/staging/download_superglue_data.py) was called.

In [1]:
import sys, os
from pathlib import Path

if not "cwd" in globals():
    cwd = Path(os.getcwd())
sys.path.insert(0, str(cwd.parents[0]))

In [2]:
import pandas as pd
# Don't truncate the sentence when viewing examples
pd.set_option('display.max_colwidth', -1)

Note, we rely heavily on the `snorkel.mtl` module, which is a great abstraction for implementing these slicing tasks. 
Intuitively, we want an API to additional capacity corresponding to each slice—exactly what the task flows in these packages allow!

In [3]:
from snorkel.mtl.data import MultitaskDataLoader
from snorkel.mtl.model import MultitaskModel
from snorkel.mtl.snorkel_config import default_config as config
from snorkel.mtl.trainer import Trainer

In [4]:
import models
from tokenizer import get_tokenizer
from utils import task_dataset_to_dataframe

## Explore the WiC dataset
We'll be working with the [Words in Context (WiC) task](https://pilehvar.github.io/wic/). To start, let's look at a few examples. To do so, we'll convert them to dataframes.

In [5]:
from dataloaders import get_jsonl_path
from parsers.wic import get_rows

task_name = "WiC"
data_dir = os.environ["SUPERGLUEDATA"]
split = "valid"
max_data_samples = None # max examples to include in dataset

jsonl_path = get_jsonl_path(data_dir, task_name, split)
wic_df = pd.DataFrame.from_records(get_rows(jsonl_path, max_data_samples=max_data_samples))

Recall, the WiC task is used to identify the intended meaning of specified words across multiple contexts—the `label` indicates whether the word is used in the same sense in both `sentence1` and `sentence2`!

In [6]:
wic_df[["sentence1", "sentence2", "word", "label"]].head()

Unnamed: 0,sentence1,sentence2,word,label
0,Room and board .,He nailed boards across the windows .,board,False
1,Circulate a rumor .,This letter is being circulated among the faculty .,circulate,False
2,Hook a fish .,"He hooked a snake accidentally , and was so scared he dropped his rod into the water .",hook,True
3,For recreation he wrote poetry and solved crossword puzzles .,Drug abuse is often regarded as a form of recreation .,recreation,True
4,Making a hobby of domesticity .,A royal family living in unpretentious domesticity .,domesticity,False


## Train a model using BERT
Now, let's train a model using the Snorkel API, with the [BERT](https://arxiv.org/abs/1810.04805) model, a powerful pre-training mechanism for general language understanding.
Thanks to folks at [huggingface](https://github.com/huggingface/pytorch-pretrained-BERT), we can use this model with with a simple import statement!

In [7]:
bert_model = "bert-large-cased"
tokenizer_name = "bert-large-cased"
batch_size = 4
max_sequence_length = 256

In [8]:
# load the word-piece tokenizer for the 'bert-large-cased' vocabulary
tokenizer = get_tokenizer(tokenizer_name)

In the style of our MTL tutorial, we'll use a few helpers to load them into Pytorch datasets/our `MultitaskDataLoader`.

In [9]:
from dataloaders import get_dataset

datasets = []
dataloaders = []
for split in ["train", "valid"]:
    # parse raw data and format it as a Pytorch dataset
    dataset = get_dataset(
        data_dir, task_name, split, tokenizer, max_data_samples, max_sequence_length
    )
    dataloader = MultitaskDataLoader(
        task_to_label_dict={task_name: "labels"},
        dataset=dataset,
        split=split,
        batch_size=batch_size,
        shuffle=(split == "train"),
    )
    datasets.append(dataset)
    dataloaders.append(dataloader)

In [10]:
# Construct dataloaders and tasks and load slices
base_task = models.model[task_name](bert_model)
tasks = [base_task]
tasks

[Task(name=WiC)]

In [11]:
model = MultitaskModel(
    name=f"SuperGLUE",
    tasks=tasks, 
    dataparallel=False,
    device=-1 # use CPU
)

We've pretrained a model for you, but feel free to uncomment this line to experiment with it yourself!

In [12]:
# trainer = Trainer(**config)
# trainer.train_model(slice_model, dataloaders)

In [13]:
# If you're missing the model, uncomment this line:
# ! wget https://www.dropbox.com/s/vix9bhzy18o3wjl/vanilla_model.pth?dl=0 && mv vanilla_model.pth?dl=0 vanilla_model.pth

In [14]:
wic_path = "vanilla_model.pth"
model.load(wic_path)

How well do we do on the valid set?

In [15]:
%%time
model.score(dataloaders[1])

CPU times: user 13min 50s, sys: 1.83 s, total: 13min 51s
Wall time: 30.7 s


{'WiC/SuperGLUE/valid/accuracy': 0.7460815047021944}

## Error analysis (specific to the SFs we plan to write)

The key to debugging machine learning models---error analysis! let's look at a few examples that we get wrong.

In [16]:
%%time
results = model.predict(dataloaders[1], return_preds=True)
golds, preds = results["golds"][task_name], results["preds"][task_name]

CPU times: user 13min 24s, sys: 5.99 s, total: 13min 30s
Wall time: 29.6 s


In [17]:
incorrect_preds = golds != preds
wic_df[incorrect_preds][["sentence1", "sentence2", "word", "label"]].head()

Unnamed: 0,sentence1,sentence2,word,label
1,Circulate a rumor .,This letter is being circulated among the faculty .,circulate,False
4,Making a hobby of domesticity .,A royal family living in unpretentious domesticity .,domesticity,False
5,The child 's acquisition of language .,That graphite tennis racquet is quite an acquisition .,acquisition,False
7,They swam in the nude .,The marketing rule ' nude sells ' spread from verbal to visual mainstream media in the 20th century .,nude,False
16,He took the manuscript in both hands and gave it a mighty tear .,There were big tears rolling down Lisa 's cheeks .,tear,False


We notice that one particular error mode occurs when the target **word** is a _verb_. Let's investigate further...

We view examples where we make the wrong prediction _and_ the target word is a verb.

In [18]:
target_is_verb = wic_df["pos"] == "V"
df_wrong_and_target_is_verb = wic_df[incorrect_preds & target_is_verb]
df_wrong_and_target_is_verb[["sentence1", "sentence2", "word", "pos", "label"]].head()

Unnamed: 0,sentence1,sentence2,word,pos,label
1,Circulate a rumor .,This letter is being circulated among the faculty .,circulate,V,False
45,To clutch power .,She clutched her purse .,clutch,V,True
62,She used to wait down at the Dew Drop Inn .,Wait here until your car arrives .,wait,V,False
78,Wear gloves so your hands stay warm .,"Stay with me , please .",stay,V,True
83,You need to push quite hard to get this door open .,Nora pushed through the crowd .,push,V,True


In [19]:
len(df_wrong_and_target_is_verb) / len(wic_df[incorrect_preds])

0.3765432098765432

This error mode accounts for over **37%** of our incorrect predictions! Let's address with _slicing_.

## Write slicing functions
We write slicing functions to target specific subsets of the data that we care about—this could correspond to the examples we find underperforming in an error analysis, or specific subsets that are application critical (e.g. night-time images in a self-driving dataset). Then, we'd like to add slice-specific capacity to our model so that it pays more attention to these examples!

By applying the decorator, `@slicing_function()`, we wrap each Snorkel to have access to previously defined preprocessors, resources, etc.— just like with labeling fucntions!

In [20]:
from snorkel.slicing.sf import slicing_function
from snorkel.types import DataPoint

@slicing_function()
def SF_verb(x: DataPoint) -> int:
    return x.pos == 'V'

slicing_functions = [SF_verb]
slice_names = [sf.name for sf in slicing_functions]

## Train _slice-aware_ model

Now, let's update our tasks to to add _additional capacity_ corresponding to each slice we've specified.
At a high level, `convert_to_slicing_tasks` will take an existing task and add additional Pytorch modules in the task flow corresponding to each of the `slice_names` you've provided.

Note that this plays nicely into our MTL abstraction—additional operations in the task flow are specific for each slice!

In [21]:
from snorkel.slicing.utils import convert_to_slice_tasks

slice_tasks = convert_to_slice_tasks(base_task, slice_names)
slice_tasks

[Task(name=WiC_slice:SF_verb_ind),
 Task(name=WiC_slice:base_ind),
 Task(name=WiC_slice:SF_verb_pred),
 Task(name=WiC_slice:base_pred),
 Task(name=WiC)]

We then update our dataloaders so that our labels are set up to appropriately train on these slices.

In [22]:
from snorkel.slicing.apply import PandasSFApplier
from snorkel.slicing.utils import add_slice_labels

slice_dataloaders = []
applier = PandasSFApplier(slicing_functions)

for dl in dataloaders:
    df = task_dataset_to_dataframe(dl.dataset)
    S_matrix = applier.apply(df)
    # updates dataloaders in place
    add_slice_labels(dl, base_task, S_matrix, slice_names)

100%|██████████| 5428/5428 [00:00<00:00, 24523.58it/s]
100%|██████████| 638/638 [00:00<00:00, 41594.89it/s]


We initialize a new _slice-aware model_, and train!

In [23]:
slice_model = MultitaskModel(
    name=f"SuperGLUE", 
    tasks=slice_tasks, 
    dataparallel=False,
    device=-1
)

Again, we've loaded a pretrained model for you to explore on your own, but you can explore training if you'd like.

In [24]:
# trainer = Trainer(**config)
# trainer.train_model(slice_model, dataloaders)

In [25]:
# If you're missing the model, uncomment this line:
# ! wget https://www.dropbox.com/s/h6620vfeompgu9o/slice_model.pth?dl=0 && mv slice_model.pth?dl=0 slice_model.pth

In [26]:
slice_wic_path = "slice_model.pth"
slice_model.load(slice_wic_path)

## Evaluate _slice-aware_ model

In [27]:
%%time 
slice_model.score(dataloaders[1])

  for slice_ind_name in slice_ind_op_names
  for slice_pred_name in slice_pred_op_names


CPU times: user 13min 38s, sys: 13.3 s, total: 13min 52s
Wall time: 30.7 s


{'WiC/SuperGLUE/valid/accuracy': 0.7554858934169278,
 'WiC_slice:SF_verb_ind/SuperGLUE/valid/f1': 0.5493230174081237,
 'WiC_slice:SF_verb_pred/SuperGLUE/valid/accuracy': 0.51440329218107,
 'WiC_slice:base_ind/SuperGLUE/valid/f1': 1.0,
 'WiC_slice:base_pred/SuperGLUE/valid/accuracy': 0.7570532915360502}

With some simple error analysis and an interface to specifying which _slice_ of the data we care about, we've improved our model **0.94 accuracy points** over a previous state-of-the-art model!