## Imports

In [1]:
import sys
import random
from pathlib import Path
from collections import Counter

import datasets
from datasets import load_dataset_builder, load_dataset

from utils import replace_text_with_whitespace

In [2]:
ELK_PATH = Path("../../../elk/")
ELK_PATH.resolve()

PosixPath('/fsx/home-augustas/elk')

In [3]:
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

sys.path[:2]

['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk']

In [4]:
from templates import DatasetTemplates

In [5]:
# Disable the logging of the datasets library
import datasets

datasets.logging.set_verbosity_error()

## Config

In [9]:
BURNS_DATASETS = ["imdb"]

VERSION = f"v3"

# SPLIT = "train"
SPLIT = "validation"

N_PER_DATASET = 25000 if SPLIT == "train" else 1500

SEED = 42 if SPLIT == "train" else 2023
SEED

2023

## Load and inspect the datasets

In [10]:
dataset_dict = {}
for dataset_path in BURNS_DATASETS:
    print(dataset_path)

    # Parse dataset name
    dataset_name = None    
    if ":" in dataset_path:
        dataset_path, dataset_name = dataset_path.split(":")
    
    
    # Get the most validation-like split
    available_splits = load_dataset_builder(
        dataset_path, name=dataset_name
    ).info.splits.keys()
    split = "validation" if "validation" in available_splits else "test"
    split = split if SPLIT != "train" else "train"
    print(split)

    # Load the dataset
    dataset = load_dataset(
        dataset_path, name=dataset_name, split=split,
    )

    # Get a desired subset of the data
    n = N_PER_DATASET if dataset.num_rows > N_PER_DATASET else dataset.num_rows
    dataset = dataset.shuffle(seed=SEED).select(range(n))

    key = f"{dataset_path}/{dataset_name}" if dataset_name else dataset_path
    dataset_dict[key] = dataset

    print("-----------------------------------")

imdb


test
-----------------------------------


In [11]:
sum(dataset.num_rows for dataset in dataset_dict.values())

1500

In [12]:
assert all([len(Counter(dataset["label"])) > 1 for dataset in dataset_dict.values()])

In [13]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {dataset.num_rows}")

imdb: 1500


In [14]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

imdb: 2


## Get the templates

In [16]:
dataset_template_dict = {}

for dataset_path in dataset_dict.keys():
    dataset_templates = DatasetTemplates(dataset_path + "/custom")

    dataset_templates.templates = {
        x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset_dict[dataset_path][0]) is not None
    }

    dataset_template_dict[dataset_path] = dataset_templates

In [19]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

imdb: 5


## Inspect the prompts

In [19]:
# dataset_name = "super_glue/rte"
# dataset_name = "glue/qnli"

# for template_name, template in dataset_template_dict[dataset_name].templates.items():
#     # print(template_name)
#     # print(dataset_dict[dataset_name][0])
#     q, a = template.apply(
#         dataset_dict[dataset_name][1]
#     )
#     # print(q == q.strip())
#     # print(a == a.strip())
#     # print(" ".join([q, a.strip()]))
#     print(" ".join([q, a]))
#     print(len(a))
#     print("---------------------------------")


## Form the dataset for the chosen split

In [20]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_fixed_answer_choices_list()}")

#     print("---------------------------------")


In [21]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_answer_choices_list(dataset_dict[dataset_name][0])}")

#     print("---------------------------------")

In [20]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

imdb: 5


In [21]:
%%time

# Reproducibility
random.seed(SEED)

ALLOWED_KEYS = ["text", "label", "original_dataset", "template_name"]

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    for idx, entry in enumerate(dataset):
        new_entry = entry.copy()
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Sample a random template
        template_name = random.choice(
            list(dataset_template_dict[dataset_name].templates.keys())
        )
        template = dataset_template_dict[dataset_name].templates[template_name]
        new_entry["template_name"] = template_name

        # Whether the sample will be truthful or not
        is_truthful = random.choice([True, False])

        # Apply the template
        if not is_truthful:
            # Untruthful binary case
            new_entry["label"] = 1 - new_entry["label"]
        new_text = " ".join(template.apply(new_entry))

        new_entry["text"] = new_text

        # We can now change the label to whether the sample is truthful or not
        new_entry["label"] = int(is_truthful)

        # Remove all other keys
        new_entry = { k: v for k, v in new_entry.items() if k in ALLOWED_KEYS }

        # Append to the new dataset
        new_dataset.append(new_entry)

imdb
CPU times: user 2.77 s, sys: 105 ms, total: 2.87 s
Wall time: 3.75 s


In [22]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['text', 'label', 'original_dataset', 'template_name'],
    num_rows: 1500
})

In [23]:
current_idx = 0
my_dataset[current_idx]

{'text': 'Consider the following example:  \'\'\' I am shocked and amazed to find reviews short of miserable for this horrible film. I rented this "movie" or feces, whatever you wish to call it, with several friends and after thirty minutes we had to stop watching. Just listening to the dialog left a horrible taste of sour milk in my mouth. This film was about as intelligent as an ass pimple.I hope I never see that bra-less, raggedy Anne look alike (Julianne Nicholson) again.It was like watching the most putrid pilot for a sitcom that will never make it to television, but instead of being a quick but painful 30 minutes( all I could bare)this was an excruciating 90 minutes. \'\'\'\n\nBetween 0 and 1, which is the sentiment of this example? 1',
 'label': 0,
 'original_dataset': 'imdb',
 'template_name': 'burns_2'}

In [28]:
for current_index in range(10, 20):
    print(f"label={my_dataset[current_index]['label']}")
    print(f"template_name={my_dataset[current_index]['template_name']}")
    print(my_dataset[current_index]["text"])
    print("---------------------------------")

label=1
template_name=Sentiment with choices 
Oh boy.. This movie is so mediocre I don't really know what exactly to write about it. <br /><br />I think it's easier to write what it's not: <br /><br />It's not very entertaining. It's not original. And there's not one character in the whole movie I cared about.<br /><br />Kind of reminds me of a certain reality TV show on MTV, but without any interesting people. It just drags on and on and I could hardly wait for it to end. The only thing that kept me from switching it off was Jennifer Lyons (c:<br /><br />I thought a long time about this movie to find one good thing to say about it. What I liked was the reminder not to judge a person by the first impression you get (as Holly did when she accused Nicole) which earns it a score of 2 out of 10 instead of a 1. 
Is this review positive or negative? negative
---------------------------------
label=1
template_name=custom_1
Question: is the movie review given below in triple backticks positive

In [25]:
Counter(my_dataset["label"]), Counter(my_dataset["original_dataset"])

(Counter({0: 764, 1: 736}), Counter({'imdb': 1500}))

In [29]:
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_imdb_{SPLIT}_{VERSION}.parquet")

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

1895355

In [30]:
!ls -lah datasets | grep {VERSION} | grep imdb

drwxr-xr-x  3 augustas Domain Users  25K Jul 20 14:23 burns_datasets_VINC_imdb_ppo_training_raw_v3
-rw-r--r--  1 augustas Domain Users  12M Jul 20 15:06 burns_datasets_VINC_imdb_ppo_training_v3.parquet
drwxr-xr-x  3 augustas Domain Users  25K Jul 20 14:23 burns_datasets_VINC_imdb_train_raw_v3
-rw-r--r--  1 augustas Domain Users 7.4M Jul 20 14:36 burns_datasets_VINC_imdb_train_v3.parquet
-rw-r--r--  1 augustas Domain Users 1.2M Jul 20 15:14 burns_datasets_VINC_imdb_validation_v3.parquet
