# RTE Demo

In [8]:
import os, sys
from pathlib import Path

if not "cwd" in globals():
    cwd = Path(os.getcwd())
sys.path.insert(0, str(cwd.parents[0]))

**Recognizing Textual Entailment (RTE)** is a natural language inference task that addresses whether a _premise_ sentence entails a _hypothesis_ sentence, with labels `entailment` (do the facts of the premise imply the facts of the hypothesis?) or `not_entailment`.

In this notebook, we'll...
* Write over a dozen, simple slicing functions based on an [external error analysis](https://arxiv.org/abs/1904.11544) to initialize our model
* Load existing base architecture weights + fine-tune slice heads from pretrained models

## Load data and task

In [11]:
TASK_NAME = "RTE"
BERT_MODEL = "bert-large-cased"

dataloader_config = {
    "batch_size": 8,
    "data_dir": os.environ.get("SUPERGLUEDATA", os.path.join(str(cwd.parents[0]), "data")),
    "splits": ["train", "valid"],
}

trainer_config = {
    "lr": 2e-5,
    "optimizer": "adamax",
    "n_epochs": 15,
    "conter_unit": "epochs",
    "evaluation_freq": 0.25,
}

In [14]:
from dataloaders import get_dataloaders

dataloaders = get_dataloaders(
    task_name=TASK_NAME,
    tokenizer_name=BERT_MODEL,
    **dataloader_config
)

In [13]:
from superglue_tasks import task_funcs

task = task_funcs[TASK_NAME](BERT_MODEL)

100%|██████████| 1242874899/1242874899 [01:18<00:00, 15835137.96B/s]


## Compare to vanilla model

In [15]:
from snorkel.mtl.model import MultitaskModel
from snorkel.mtl.trainer import Trainer

vanilla_model = MultitaskModel(tasks=[task], device=-1, dataparallel=False)

Uncomment the following to train the vanilla BERT model.

In [16]:
# trainer = Trainer(**trainer_config)
# trainer.train_model(model, dataloaders)

Or, use the one we pretrained for you!

In [18]:
# If you're missing the model, uncomment this line:
# ! wget https://www.dropbox.com/s/t2ri9o0iz765hsn/RTE_bert.pth?dl=0 && mv RTE_bert.pth?dl=0 RTE_bert.pth

vanilla_model.load("RTE_bert.pth")

In [19]:
%%time
vanilla_model.score(dataloaders[1])

CPU times: user 3min 49s, sys: 23.8 s, total: 4min 13s
Wall time: 1min 3s


{'RTE/SuperGLUE/valid/accuracy': 0.7364620938628159}

## Apply SFs
We emphasize here that our _data slicing_ abstraction follows intuitive, programmer workflows!

We rely on error buckets defined by [Kim et. al 2019](https://arxiv.org/pdf/1904.11544.pdf) to define our slices.
Then, we apply quick-to-write, heuristics to target each of these buckets. 
Intuitively, these are slices that were important enough to be measured independently by researchers—so we'd like to write slicing functions to try and improve performance!

In [20]:
from snorkel.slicing.sf import slicing_function

### Prepositions

In [21]:
@slicing_function()
def slice_temporal_preposition(example):
    temporal_prepositions = ["after", "before", "past"]
    both_sentences = example.sentence1 + example.sentence2
    return any([p in both_sentences for p in temporal_prepositions])

@slicing_function()
def slice_possessive_preposition(example):
    possessive_prepositions = ["inside of", "with", "within"]
    both_sentences = example.sentence1 + example.sentence2
    return any([p in both_sentences for p in possessive_prepositions])

### Comparatives

In [22]:
@slicing_function()
def slice_is_comparative(example):
    comparative_words = ["more", "less", "better", "worse", "bigger", "smaller"]
    both_sentences = example.sentence1 + example.sentence2
    return any([p in both_sentences for p in comparative_words])

### Quantification

In [23]:
@slicing_function()
def slice_is_quantification(example):
    quantification_words = ["all", "some", "none"]
    both_sentences = example.sentence1 + example.sentence2
    return any([p in both_sentences for p in quantification_words])

### Wh-Words

In [24]:
@slicing_function()
def slice_where(example):
    sentences = example.sentence1 + example.sentence2
    return "where" in sentences

@slicing_function()
def slice_who(example):
    sentences = example.sentence1 + example.sentence2
    return "who" in sentences

@slicing_function()
def slice_what(example):
    sentences = example.sentence1 + example.sentence2
    return "what" in sentences

@slicing_function()
def slice_when(example):
    sentences = example.sentence1 + example.sentence2
    return "when" in sentences

### Coordinating Conjunctions

In [25]:
@slicing_function()
def slice_and(example):
    sentences = example.sentence1 + example.sentence2
    return "and" in sentences

@slicing_function()
def slice_but(example):
    sentences = example.sentence1 + example.sentence2
    return "but" in sentences

@slicing_function()
def slice_or(example):
    sentences = example.sentence1 + example.sentence2
    return "or" in sentences

### Definite-Indefinite Articles
Multiple occurences of articles like `the` or `an` might refer to different entities—we try to heuristically capture this here!

In [26]:
@slicing_function()
def slice_multiple_articles(example):
    sentences = example.sentence1 + example.sentence2
    multiple_indefinite = sum([int(x == "a") for x in sentences.split()]) > 1 \
        or sum([int(x == "an") for x in sentences.split()]) > 1
    multiple_definite = sum([int(x == "the") for x in sentences.split()]) > 1
    return multiple_indefinite or multiple_definite

### (Misc.) Sentence Length

In [27]:
@slicing_function()
def slice_short_hypothesis(example, thresh=5):
    return len(example.sentence2.split()) < thresh

@slicing_function()
def slice_long_hypothesis(example, thresh=15):
    return len(example.sentence2.split()) > thresh

@slicing_function()
def slice_short_premise(example, thresh=10):
    return len(example.sentence1.split()) < thresh

@slicing_function()
def slice_long_premise(example, thresh=100):
    return len(example.sentence1.split()) > thresh

### Add slices to dataloaders and model

In [28]:
slicing_functions = [
    slice_temporal_preposition,
    slice_possessive_preposition,
    slice_is_comparative,
    slice_is_quantification,
    slice_where,
    slice_who,
    slice_what,
    slice_when,
    slice_and,
    slice_or,
    slice_but,
    slice_multiple_articles,
    slice_short_hypothesis,
    slice_long_hypothesis,
    slice_short_premise,
    slice_long_premise
]

In [29]:
slice_names = [sf.name for sf in slicing_functions]

In [30]:
from snorkel.slicing.utils import add_slice_labels, convert_to_slice_tasks

# make slices tasks
slice_tasks = convert_to_slice_tasks(task, slice_names)
slice_tasks

[Task(name=RTE_slice:slice_temporal_preposition_ind),
 Task(name=RTE_slice:slice_possessive_preposition_ind),
 Task(name=RTE_slice:slice_is_comparative_ind),
 Task(name=RTE_slice:slice_is_quantification_ind),
 Task(name=RTE_slice:slice_where_ind),
 Task(name=RTE_slice:slice_who_ind),
 Task(name=RTE_slice:slice_what_ind),
 Task(name=RTE_slice:slice_when_ind),
 Task(name=RTE_slice:slice_and_ind),
 Task(name=RTE_slice:slice_or_ind),
 Task(name=RTE_slice:slice_but_ind),
 Task(name=RTE_slice:slice_multiple_articles_ind),
 Task(name=RTE_slice:slice_short_hypothesis_ind),
 Task(name=RTE_slice:slice_long_hypothesis_ind),
 Task(name=RTE_slice:slice_short_premise_ind),
 Task(name=RTE_slice:slice_long_premise_ind),
 Task(name=RTE_slice:base_ind),
 Task(name=RTE_slice:slice_temporal_preposition_pred),
 Task(name=RTE_slice:slice_possessive_preposition_pred),
 Task(name=RTE_slice:slice_is_comparative_pred),
 Task(name=RTE_slice:slice_is_quantification_pred),
 Task(name=RTE_slice:slice_where_pred),
 

In [31]:
from snorkel.slicing.apply import PandasSFApplier
from snorkel.slicing.utils import add_slice_labels
from utils import task_dataset_to_dataframe

applier = PandasSFApplier(slicing_functions)

# add slice labels
for dl in dataloaders:
    df = task_dataset_to_dataframe(dl.dataset)
    S_matrix = applier.apply(df)
    
    # updates dataloaders in place
    add_slice_labels(dl, task, S_matrix, slice_names)

100%|██████████| 2490/2490 [00:00<00:00, 3484.15it/s]
100%|██████████| 277/277 [00:00<00:00, 3419.15it/s]


In [32]:
slice_model = MultitaskModel(tasks=slice_tasks, device=-1, dataparallel=False)

## Load from pretrained BERT
Given that the `slice_model` shares the same backbone architecture as the `vanilla_model`, we can simply reload these backbone weights (pretrained on RTE), and then fine-tune the slicing heads!

Load previous backbone weights...

In [33]:
slice_model.load_state_dict(vanilla_model.collect_state_dict())

And fine-tune the slice heads!

In [34]:
# trainer = Trainer(**trainer_config)
# trainer.train_model(slice_model, dataloaders)

Or load our pretrained model.

In [35]:
# If you're missing the model, uncomment this line:
# ! wget https://www.dropbox.com/s/18ta5z3tzasba0m/RTE_slice_from_bert.pth?dl=0 && mv RTE_slice_from_bert.pth?dl=0 RTE_slice_from_bert.pth

slice_model.load("RTE_slice_from_bert.pth")

In [36]:
%%time
slice_model.score(dataloaders[1])

  for slice_ind_name in slice_ind_op_names
  for slice_pred_name in slice_pred_op_names


CPU times: user 4min 1s, sys: 21 s, total: 4min 22s
Wall time: 1min 5s


{'RTE/SuperGLUE/valid/accuracy': 0.7581227436823105,
 'RTE_slice:slice_temporal_preposition_ind/SuperGLUE/valid/f1': 0.1951219512195122,
 'RTE_slice:slice_temporal_preposition_pred/SuperGLUE/valid/accuracy': 0.8333333333333334,
 'RTE_slice:slice_possessive_preposition_ind/SuperGLUE/valid/f1': 0.29885057471264365,
 'RTE_slice:slice_possessive_preposition_pred/SuperGLUE/valid/accuracy': 0.696969696969697,
 'RTE_slice:slice_is_comparative_ind/SuperGLUE/valid/f1': 0,
 'RTE_slice:slice_is_comparative_pred/SuperGLUE/valid/accuracy': 0.7096774193548387,
 'RTE_slice:slice_is_quantification_ind/SuperGLUE/valid/f1': 0.29059829059829057,
 'RTE_slice:slice_is_quantification_pred/SuperGLUE/valid/accuracy': 0.6818181818181818,
 'RTE_slice:slice_where_ind/SuperGLUE/valid/f1': 0,
 'RTE_slice:slice_where_pred/SuperGLUE/valid/accuracy': 0.7777777777777778,
 'RTE_slice:slice_who_ind/SuperGLUE/valid/f1': 0.24,
 'RTE_slice:slice_who_pred/SuperGLUE/valid/accuracy': 0.7222222222222222,
 'RTE_slice:slice_what

By specifying all these slicing functions at the model (in a type of shotgun approach...) we see overall improvements of **+2.2 accuracy points**!