In [1]:
from oleace.utils.concept import *
import numpy as np
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Get the MNLI dataset and tokenize it
mnli_dataset = load_dataset("multi_nli", split="train")
mnli_dataset = tokenize_mnli(mnli_dataset)

concept_detectors = {
    "constituent": is_constituent,
    "lexical_overlap": is_lexical_overlap,
    "subsequence": is_subsequence,
}

# Make a list of indices for each concept
concept_indices: dict[str, list[int]] = defaultdict(list)
for idx, example_data in enumerate(
    tqdm(
        mnli_dataset,
        desc="Assigning concepts to MNLI examples",
        unit=" examples",
        leave=False,
    )
):
    no_concept_assigned = True
    for concept, detector in concept_detectors.items():
        if detector(example_data):
            concept_indices[concept].append(idx)
            no_concept_assigned = False
    if no_concept_assigned:
        concept_indices["negative"].append(idx)

Downloading readme: 100%|██████████| 8.89k/8.89k [00:00<00:00, 40.8MB/s]
Downloading data: 100%|██████████| 214M/214M [00:12<00:00, 16.6MB/s] 
Downloading data: 100%|██████████| 4.94M/4.94M [00:00<00:00, 13.4MB/s]
Downloading data: 100%|██████████| 5.10M/5.10M [00:00<00:00, 16.5MB/s]
Generating train split: 100%|██████████| 392702/392702 [00:00<00:00, 531673.96 examples/s]
Generating validation_matched split: 100%|██████████| 9815/9815 [00:00<00:00, 539196.24 examples/s]
Generating validation_mismatched split: 100%|██████████| 9832/9832 [00:00<00:00, 635933.77 examples/s]
tokenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<00:00, 193kB/s]
config.json: 100%|██████████| 570/570 [00:00<00:00, 4.52MB/s]
vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 3.94MB/s]
tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 12.9MB/s]
Map: 100%|██████████| 392702/392702 [00:33<00:00, 11860.88 examples/s]
                                                                                      

In [3]:
for key in concept_indices:
    idc = concept_indices[key]
    print(key)
    labs = [mnli_dataset[i]['label'] for i in tqdm(idc)]
    print(np.bincount(labs))

negative


100%|██████████| 389745/389745 [01:08<00:00, 5659.81it/s]


[128241 130795 130709]
lexical_overlap


100%|██████████| 2419/2419 [00:00<00:00, 5509.20it/s]


[2158   82  179]
subsequence


100%|██████████| 1346/1346 [00:00<00:00, 5563.00it/s]


[1274   39   33]
constituent


100%|██████████| 1062/1062 [00:00<00:00, 5644.91it/s]

[1004   36   22]





In [18]:

from oleace.datasets.hans import HANSDataset
from collections import Counter

from pathlib import Path

In [13]:
hans_train = HANSDataset(split="train", data_dir=Path.cwd() / "data/hans")

Processing HANS dataset:   0%|          | 0/30000 [00:00<?, ?examples/s]

                                                                                      

In [21]:
label_counts = Counter()
for example in hans_train:
    label = (example['labels'], example['heuristic'])
    label_counts[label] += 1

for label, count in label_counts.items():
    print(f"{label}: {count}")


(1, 'lexical_overlap'): 5000
(0, 'lexical_overlap'): 5000
(1, 'subsequence'): 5000
(0, 'subsequence'): 5000
(1, 'constituent'): 5000
(0, 'constituent'): 5000


In [23]:
2158/(179+82)

8.268199233716475