## Imports

In [105]:
import sys
import random
from pathlib import Path
from collections import Counter

import datasets
from datasets import load_dataset_builder, load_dataset

from utils import replace_text_with_whitespace

In [106]:
ELK_PATH = Path("../../../elk/")
ELK_PATH.resolve()

PosixPath('/fsx/home-augustas/elk')

In [107]:
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

sys.path[:2]

['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk']

In [108]:
from templates import DatasetTemplates

In [109]:
# Disable the logging of the datasets library
import datasets

datasets.logging.set_verbosity_error()

## Config

In [110]:
# all datasets used in Burns et al. (2022) (apart from story_cloze)
# BURNS_DATASETS = [
#     "ag_news",
#     "amazon_polarity",
#     "dbpedia_14",
#     "glue:qnli",
#     "imdb",
#     "piqa",
#     "super_glue:boolq",
#     "super_glue:copa",
#     "super_glue:rte",
# ]
# BURNS_DATASETS = [
#     "ag_news",
#     "amazon_polarity",
#     "dbpedia_14",
#     "glue:qnli",
#     "imdb",
#     "super_glue:boolq",
#     "super_glue:copa",
#     "super_glue:rte",
# ]
# BURNS_DATASETS = [
#     "ag_news",
#     "amazon_polarity",
#     "dbpedia_14",
#     "imdb",
#     "super_glue:boolq",
#     "super_glue:copa",
# ]
BURNS_DATASETS = ["imdb"]

version_number = 5
VERSION = f"v1{version_number}"

# SPLIT = "train"
SPLIT = "validation"

# These numbers are chosen so that both datasets have
# approximately 10k examples in total (probably a bit less for train split)
# N_PER_DATASET = 1000 if SPLIT == "train" else 1000
# N_PER_DATASET = 1200 if SPLIT == "train" else 1375
# N_PER_DATASET = 1200 if SPLIT == "train" else 1550
N_PER_DATASET = 25000 if SPLIT == "train" else 1500
if SPLIT == "train":
    # assert N_PER_DATASET <= 2490, "N_PER_DATASET must be <= 2490"
    assert N_PER_DATASET <= 25000, "N_PER_DATASET must be <= 2490"
else:
    assert N_PER_DATASET <= 1838, "N_PER_DATASET must be <= 1838"

# SEED = 42 if SPLIT == "train" else 2023
SEED = 42 if SPLIT == "train" else version_number
SEED

5

## Load and inspect the datasets

In [111]:
dataset_dict = {}
for dataset_path in BURNS_DATASETS:
    print(dataset_path)

    # Parse dataset name
    dataset_name = None    
    if ":" in dataset_path:
        dataset_path, dataset_name = dataset_path.split(":")
    
    
    # Get the most validation-like split
    available_splits = load_dataset_builder(
        dataset_path, name=dataset_name
    ).info.splits.keys()
    split = "validation" if "validation" in available_splits else "test"
    split = split if SPLIT != "train" else "train"
    print(split)

    # Load the dataset
    dataset = load_dataset(
        dataset_path, name=dataset_name, split=split,
    )

    # Get a desired subset of the data
    n = N_PER_DATASET if dataset.num_rows > N_PER_DATASET else dataset.num_rows
    dataset = dataset.shuffle(seed=SEED).select(range(n))

    key = f"{dataset_path}/{dataset_name}" if dataset_name else dataset_path
    dataset_dict[key] = dataset

    print("-----------------------------------")

imdb


test
-----------------------------------


In [112]:
sum(dataset.num_rows for dataset in dataset_dict.values())

1500

In [113]:
assert all([len(Counter(dataset["label"])) > 1 for dataset in dataset_dict.values()])

In [114]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {dataset.num_rows}")

imdb: 1500


In [115]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

imdb: 2


## Get the templates

In [116]:
dataset_template_dict = {}

for dataset_path in dataset_dict.keys():
    dataset_templates = DatasetTemplates(dataset_path)

    dataset_templates.templates = {
        x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset_dict[dataset_path][0]) is not None
    }

    dataset_template_dict[dataset_path] = dataset_templates

In [117]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

imdb: 13


In [118]:
# dataset_name = "ag_news"
# for idx, (template_name, template) in enumerate(dataset_template_dict[dataset_name].templates.items()):
#     print(f"{idx+1}) {dataset_name}/{template_name}: {template.metadata.choices_in_prompt}")


In [119]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     good_templates = [x for x in dataset_templates.templates.values() if x.metadata.choices_in_prompt]
#     print(f"{dataset_name}: {len(good_templates)}")

## Inspect the prompts

In [120]:
# dataset_name = "super_glue/rte"
# dataset_name = "glue/qnli"

# for template_name, template in dataset_template_dict[dataset_name].templates.items():
#     # print(template_name)
#     # print(dataset_dict[dataset_name][0])
#     q, a = template.apply(
#         dataset_dict[dataset_name][1]
#     )
#     # print(q == q.strip())
#     # print(a == a.strip())
#     # print(" ".join([q, a.strip()]))
#     print(" ".join([q, a]))
#     print(len(a))
#     print("---------------------------------")


## Form the dataset for the chosen split

In [121]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_fixed_answer_choices_list()}")

#     print("---------------------------------")


In [122]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_answer_choices_list(dataset_dict[dataset_name][0])}")

#     print("---------------------------------")

In [123]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

imdb: 13


In [124]:
%%time

# Reproducibility
random.seed(SEED)

ALLOWED_KEYS = ["text", "label", "original_dataset", "template_name"]

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    for idx, entry in enumerate(dataset):
        new_entry = entry.copy()
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Sample a random template
        template_name = random.choice(
            list(dataset_template_dict[dataset_name].templates.keys())
        )
        template = dataset_template_dict[dataset_name].templates[template_name]
        new_entry["template_name"] = template_name

        # Whether the sample will be truthful or not
        is_truthful = random.choice([True, False])

        # If not truthful, sample a random incorrect label
        if not is_truthful and len(dataset.features["label"].names) > 2:
            label_mappping = dataset.features["label"] # [class 1, class 2, ...]
            all_label_ids = [label_mappping.str2int(x) for x in label_mappping.names] # [0, 1, ...]
            incorrect_label_id = all_label_ids[(1 - new_entry["label"]) % len(all_label_ids)]
            incorrect_label = template.get_fixed_answer_choices_list()[incorrect_label_id]
            
        
        # Apply the template
        if is_truthful:
            new_text = " ".join(template.apply(new_entry))
        elif not is_truthful and len(dataset.features["label"].names) > 2:
            q, a = template.apply(new_entry)
            incorrect_label = replace_text_with_whitespace(a, incorrect_label)
            new_text = " ".join([q, incorrect_label])
        else:
            # Untruthful binary case
            new_entry["label"] = 1 - new_entry["label"]
            new_text = " ".join(template.apply(new_entry))

        new_entry["text"] = new_text

        # We can now change the label to whether the sample is truthful or not
        new_entry["label"] = int(is_truthful)

        # Remove all other keys
        new_entry = { k: v for k, v in new_entry.items() if k in ALLOWED_KEYS }

        # Append to the new dataset
        new_dataset.append(new_entry)

imdb
CPU times: user 2.3 s, sys: 16.6 ms, total: 2.32 s
Wall time: 2.55 s


In [125]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['text', 'label', 'original_dataset', 'template_name'],
    num_rows: 1500
})

In [126]:
current_idx = 0
my_dataset[current_idx]

{'text': 'At the very beginning, the look at a control panel that reads "8 miles of the cost of California", and no, I didn\'t misspell that, they really did not realize the put of the cost instead of off the coast. These people must have been morons.<br /><br />It\'s good if you\'re into terrible movies, but the sheer fact they couldn\'t catch a simple spelling issue make me believe they really didn\'t put any effort into creating the movie whatsoever. The Navy uniforms are not correct at all in any manner whatsoever.<br /><br />Wow, completely ridiculous, but good if you are looking for something insanely stupid to watch. How these folks made any money off this is beyond me.\nWhat is the sentiment expressed by the reviewer for the movie?\n positive',
 'label': 0,
 'original_dataset': 'imdb',
 'template_name': 'Reviewer Expressed Sentiment'}

In [127]:
# It's unbelievable but the fourth is better
# The following movie review expresses what
# I was going to use 'The German Scream' as a summary
# This movie is by far one of the worst B-movies
# At the very beginning, the look at a control
print(my_dataset[current_idx]["text"])

At the very beginning, the look at a control panel that reads "8 miles of the cost of California", and no, I didn't misspell that, they really did not realize the put of the cost instead of off the coast. These people must have been morons.<br /><br />It's good if you're into terrible movies, but the sheer fact they couldn't catch a simple spelling issue make me believe they really didn't put any effort into creating the movie whatsoever. The Navy uniforms are not correct at all in any manner whatsoever.<br /><br />Wow, completely ridiculous, but good if you are looking for something insanely stupid to watch. How these folks made any money off this is beyond me.
What is the sentiment expressed by the reviewer for the movie?
 positive


In [128]:
# Counter({0: 755, 1: 745})
# Counter({1: 750, 0: 750})
# Counter({0: 774, 1: 726}
# Counter({0: 752, 1: 748})
# Counter({0: 765, 1: 735})
Counter(my_dataset["label"]), Counter(my_dataset["original_dataset"])

(Counter({0: 765, 1: 735}), Counter({'imdb': 1500}))

In [129]:
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_{SPLIT}_{VERSION}.parquet")
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_imdb_{SPLIT}_{VERSION}.parquet")

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

1881013

In [130]:
!ls -lah datasets | grep {VERSION}

drwxr-xr-x  3 augustas Domain Users  25K Jul  4 22:35 burns_datasets_VINC_imdb_train_raw_v15
-rw-r--r--  1 augustas Domain Users 7.4M Jul  4 22:35 burns_datasets_VINC_imdb_train_v15.parquet
-rw-r--r--  1 augustas Domain Users 1.2M Jul  4 22:44 burns_datasets_VINC_imdb_validation_v15.parquet


In [63]:
# parquet_data_files = {
#     "train": "datasets/my_dataset_train.parquet",
#     "validation": "datasets/my_dataset_validation.parquet",
# }

# my_dataset = load_dataset("parquet", data_files=parquet_data_files)
# my_dataset