In [1]:
from oleace.utils.concept import *
import numpy as np
from datasets import load_dataset, concatenate_datasets, ClassLabel
from importlib import reload

from oleace.datasets import hans
HANSDataset = hans.HANSDataset

from collections import Counter

from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
!pip install -e .

In [2]:
# Get the MNLI dataset and tokenize it
mnli_dataset = load_dataset("multi_nli", split="train")
mnli_dataset = tokenize_mnli(mnli_dataset)

concept_detectors = {
    "constituent": is_constituent,
    "lexical_overlap": is_lexical_overlap,
    "subsequence": is_subsequence,
}

# Make a list of indices for each concept
concept_indices: dict[str, list[int]] = defaultdict(list)
for idx, example_data in enumerate(
    tqdm(
        mnli_dataset,
        desc="Assigning concepts to MNLI examples",
        unit=" examples",
        leave=False,
    )
):
    no_concept_assigned = True
    for concept, detector in concept_detectors.items():
        if detector(example_data):
            concept_indices[concept].append(idx)
            no_concept_assigned = False
    if no_concept_assigned:
        concept_indices["negative"].append(idx)

Downloading readme: 100%|██████████| 8.89k/8.89k [00:00<00:00, 40.8MB/s]
Downloading data: 100%|██████████| 214M/214M [00:12<00:00, 16.6MB/s] 
Downloading data: 100%|██████████| 4.94M/4.94M [00:00<00:00, 13.4MB/s]
Downloading data: 100%|██████████| 5.10M/5.10M [00:00<00:00, 16.5MB/s]
Generating train split: 100%|██████████| 392702/392702 [00:00<00:00, 531673.96 examples/s]
Generating validation_matched split: 100%|██████████| 9815/9815 [00:00<00:00, 539196.24 examples/s]
Generating validation_mismatched split: 100%|██████████| 9832/9832 [00:00<00:00, 635933.77 examples/s]
tokenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<00:00, 193kB/s]
config.json: 100%|██████████| 570/570 [00:00<00:00, 4.52MB/s]
vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 3.94MB/s]
tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 12.9MB/s]
Map: 100%|██████████| 392702/392702 [00:33<00:00, 11860.88 examples/s]
                                                                                      

In [3]:
for key in concept_indices:
    idc = concept_indices[key]
    print(key)
    labs = [mnli_dataset[i]['label'] for i in tqdm(idc)]
    print(np.bincount(labs))

negative


100%|██████████| 389745/389745 [01:08<00:00, 5659.81it/s]


[128241 130795 130709]
lexical_overlap


100%|██████████| 2419/2419 [00:00<00:00, 5509.20it/s]


[2158   82  179]
subsequence


100%|██████████| 1346/1346 [00:00<00:00, 5563.00it/s]


[1274   39   33]
constituent


100%|██████████| 1062/1062 [00:00<00:00, 5644.91it/s]

[1004   36   22]





In [13]:
hans_train = HANSDataset(split="train", data_dir=Path.cwd() / "data/hans")

Processing HANS dataset:   0%|          | 0/30000 [00:00<?, ?examples/s]

                                                                                      

In [21]:
label_counts = Counter()
for example in hans_train:
    label = (example['labels'], example['heuristic'])
    label_counts[label] += 1

for label, count in label_counts.items():
    print(f"{label}: {count}")


(1, 'lexical_overlap'): 5000
(0, 'lexical_overlap'): 5000
(1, 'subsequence'): 5000
(0, 'subsequence'): 5000
(1, 'constituent'): 5000
(0, 'constituent'): 5000


In [76]:
hans_train_dataset[0]

{'premise': 'The doctors supported the scientist .',
 'hypothesis': 'The scientist supported the doctors .',
 'label': 1,
 'parse_premise': '(ROOT (S (NP (DT The) (NNS doctors)) (VP (VBD supported) (NP (DT the) (NN scientist))) (. .)))',
 'parse_hypothesis': '(ROOT (S (NP (DT The) (NN scientist)) (VP (VBD supported) (NP (DT the) (NNS doctors))) (. .)))',
 'binary_parse_premise': '( ( The doctors ) ( ( supported ( the scientist ) ) . ) )',
 'binary_parse_hypothesis': '( ( The scientist ) ( ( supported ( the doctors ) ) . ) )',
 'heuristic': 'lexical_overlap',
 'subcase': 'ln_subject/object_swap',
 'template': 'temp1'}

In [None]:
# Prepare HANS dataset
hans_train_dataset = load_dataset("hans", split="train")
hans_val_dataset = load_dataset("hans", split="validation")

# Prepare MNLI dataset
mnli_train_dataset = load_dataset("multi_nli", split="train")
mnli_val_dataset = load_dataset("multi_nli", split="validation_matched")

# Tokenize HANS dataset
hans_train_dataset = tokenize_mnli(hans_train_dataset)
hans_val_dataset = tokenize_mnli(hans_val_dataset)

# Tokenize MNLI dataset
mnli_train_dataset = tokenize_mnli(mnli_train_dataset)
mnli_val_dataset = tokenize_mnli(mnli_val_dataset)

dataset = "hansmnli"

match dataset:
    case "mnli":
        train_dataset = mnli_train_dataset
        val_dataset = mnli_val_dataset
    case "hansmnli":
        # shuffle HANS
        hans_train_dataset = hans_train_dataset.shuffle(seed=42)
        hans_val_dataset = hans_val_dataset.shuffle(seed=42)
        # get first third of HANS examples with label 1
        hans_train_noentail = hans_train_dataset.filter(lambda example: example["label"] == 1)
        hans_val_noentail = hans_val_dataset.filter(lambda example: example["label"] == 1)
        hans_train_noentail = hans_train_noentail.select(range(len(hans_train_noentail) // 3))
        hans_val_noentail = hans_val_noentail.select(range(len(hans_val_noentail) // 3))
        # get all HANS examples with label 0
        hans_train_entail = hans_train_dataset.filter(lambda example: example["label"] == 0)
        hans_val_entail = hans_val_dataset.filter(lambda example: example["label"] == 0)
        # concatenate HANS examples
        hans_train_dataset = concatenate_datasets([hans_train_noentail, hans_train_entail])
        hans_val_dataset = concatenate_datasets([hans_val_noentail, hans_val_entail])

        # train: cut down MNLI to match HANS
        mnli_train_dataset = mnli_train_dataset.shuffle(seed=42)
        mnli_train_dataset = mnli_train_dataset.select(range(len(hans_train_dataset)))
        # take MNLI examples with no applicable heuristic
        mnli_train_dataset = mnli_train_dataset.filter(no_heuristic)
        mnli_val_dataset = mnli_val_dataset.filter(no_heuristic)
        # val: cut down HANS to match MNLI
        hans_val_dataset = hans_val_dataset.shuffle(seed=42)
        hans_val_dataset = hans_val_dataset.select(range(len(mnli_val_dataset)))


        # rename HANS 'binary_parse_premise' to 'premise_binary_parse'
        # and 'binary_parse_hypothesis' to 'hypothesis_binary_parse'
        hans_train_dataset = hans_train_dataset.rename_column("binary_parse_premise", "premise_binary_parse")
        hans_train_dataset = hans_train_dataset.rename_column("binary_parse_hypothesis", "hypothesis_binary_parse")
        hans_val_dataset = hans_val_dataset.rename_column("binary_parse_premise", "premise_binary_parse")
        hans_val_dataset = hans_val_dataset.rename_column("binary_parse_hypothesis", "hypothesis_binary_parse")


        # relabel MNLI 'label' 1 and 2 both to 1
        mnli_train_dataset = mnli_train_dataset.map(lambda example: {"label": 1 if example["label"] > 0 else 0})
        mnli_val_dataset = mnli_val_dataset.map(lambda example: {"label": 1 if example["label"] > 0 else 0})
        # change MNLI 'label' from ClassLabel(names=['entailment', 'neutral', 'contradiction']) 
        # to ClassLabel(names=['entailment', 'non-entailment'])
        mnli_features = mnli_train_dataset.features.copy()
        mnli_features["label"] = ClassLabel(names=["entailment", "non-entailment"])
        mnli_train_dataset = mnli_train_dataset.cast(mnli_features)
        mnli_val_dataset = mnli_val_dataset.cast(mnli_features)

        # concatenate HANS and MNLI examples
        hans_columns = set(hans_train_dataset.column_names)
        mnli_columns = set(mnli_train_dataset.column_names)
        hans_only_columns = hans_columns - mnli_columns
        mnli_only_columns = mnli_columns - hans_columns
        hans_train_dataset = hans_train_dataset.remove_columns(list(hans_only_columns))
        mnli_train_dataset = mnli_train_dataset.remove_columns(list(mnli_only_columns))
        hans_val_dataset = hans_val_dataset.remove_columns(list(hans_only_columns))
        mnli_val_dataset = mnli_val_dataset.remove_columns(list(mnli_only_columns))
        train_dataset = concatenate_datasets([hans_train_dataset, mnli_train_dataset])
        val_dataset = concatenate_datasets([hans_val_dataset, mnli_val_dataset])

In [10]:
# get MNLI column types
load_dataset("hans", split="train").features

{'premise': Value(dtype='string', id=None),
 'hypothesis': Value(dtype='string', id=None),
 'label': ClassLabel(names=['entailment', 'non-entailment'], id=None),
 'parse_premise': Value(dtype='string', id=None),
 'parse_hypothesis': Value(dtype='string', id=None),
 'binary_parse_premise': Value(dtype='string', id=None),
 'binary_parse_hypothesis': Value(dtype='string', id=None),
 'heuristic': Value(dtype='string', id=None),
 'subcase': Value(dtype='string', id=None),
 'template': Value(dtype='string', id=None)}

In [11]:

load_dataset("multi_nli", split="train").features

{'promptID': Value(dtype='int32', id=None),
 'pairID': Value(dtype='string', id=None),
 'premise': Value(dtype='string', id=None),
 'premise_binary_parse': Value(dtype='string', id=None),
 'premise_parse': Value(dtype='string', id=None),
 'hypothesis': Value(dtype='string', id=None),
 'hypothesis_binary_parse': Value(dtype='string', id=None),
 'hypothesis_parse': Value(dtype='string', id=None),
 'genre': Value(dtype='string', id=None),
 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)}

In [6]:
mnli_features = mnli_train_dataset.features.copy()
mnli_features["label"] = ClassLabel(names=["entailment", "non-entailment"])
mnli_train_dataset = mnli_train_dataset.cast(mnli_features)
mnli_val_dataset = mnli_val_dataset.cast(mnli_features)

{'premise': Value(dtype='string', id=None),
 'hypothesis': Value(dtype='string', id=None),
 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}