## Imports

In [1]:
import sys
import random
from pathlib import Path
from collections import Counter

import datasets
from datasets import load_dataset_builder, load_dataset

from utils import replace_text_with_whitespace

In [2]:
ELK_PATH = Path("../../../elk/")
ELK_PATH.resolve()

PosixPath('/fsx/home-augustas/elk')

In [3]:
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

sys.path[:2]

['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk']

In [4]:
from templates import DatasetTemplates

In [5]:
# Disable the logging of the datasets library
import datasets

datasets.logging.set_verbosity_error()

## Config

In [6]:
# all datasets used in Burns et al. (2022) (apart from story_cloze)
BURNS_DATASETS = [
    "ag_news",
    "amazon_polarity",
    "dbpedia_14",
    "glue:qnli",
    "imdb",
    "piqa",
    "super_glue:boolq",
    "super_glue:copa",
    "super_glue:rte",
]
# BURNS_DATASETS = [
#     "ag_news",
#     "amazon_polarity",
#     "dbpedia_14",
#     "glue:qnli",
#     "imdb",
#     "super_glue:boolq",
#     "super_glue:copa",
#     "super_glue:rte",
# ]
# BURNS_DATASETS = [
#     "ag_news",
#     "amazon_polarity",
#     "dbpedia_14",
#     "imdb",
#     "super_glue:boolq",
#     "super_glue:copa",
# ]
# BURNS_DATASETS = ["imdb"]

VERSION = f"none"

# SPLIT = "validation"
SPLIT = "train"

# These numbers are chosen so that both datasets have
# approximately 10k examples in total (probably a bit less for train split)
# N_PER_DATASET = 1000 if SPLIT == "train" else 1000
# N_PER_DATASET = 1200 if SPLIT == "train" else 1375
# N_PER_DATASET = 1200 if SPLIT == "train" else 1550
N_PER_DATASET = 25000 if SPLIT == "train" else 1550
if SPLIT == "train":
    # assert N_PER_DATASET <= 2490, "N_PER_DATASET must be <= 2490"
    assert N_PER_DATASET <= 25000, "N_PER_DATASET must be <= 2490"
else:
    assert N_PER_DATASET <= 1838, "N_PER_DATASET must be <= 1838"

SEED = 42 if SPLIT == "train" else 2023
SEED

42

## Load and inspect the datasets

In [7]:
dataset_dict = {}
for dataset_path in BURNS_DATASETS:
    print(dataset_path)

    # Parse dataset name
    dataset_name = None    
    if ":" in dataset_path:
        dataset_path, dataset_name = dataset_path.split(":")
    
    
    # Get the most validation-like split
    available_splits = load_dataset_builder(
        dataset_path, name=dataset_name
    ).info.splits.keys()
    split = "validation" if "validation" in available_splits else "test"
    split = split if SPLIT != "train" else "train"
    print(f"{split=}")

    # Load the dataset
    dataset = load_dataset(
        dataset_path, name=dataset_name, split=split,
    )

    # Get a desired subset of the data
    n = N_PER_DATASET if dataset.num_rows > N_PER_DATASET else dataset.num_rows
    dataset = dataset.shuffle(seed=SEED).select(range(n))

    print(dataset.num_rows)

    key = f"{dataset_path}/{dataset_name}" if dataset_name else dataset_path
    dataset_dict[key] = dataset

    print("-----------------------------------")

imdb


split='train'
25000
-----------------------------------


In [8]:
# Number of examples in total in all datasets
sum([dataset.num_rows for dataset in dataset_dict.values()])


25000

In [9]:
assert all([len(Counter(dataset["label"])) > 1 for dataset in dataset_dict.values()])

In [10]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {dataset.num_rows}")

imdb: 25000


In [11]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

imdb: 2


## Split into train and ppo

In [12]:
train_dataset_dict = {}
ppo_dataset_dict = {}

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)
    # test_size = 225 if dataset.num_rows > 225 else 10 # specialized for COPA
    splits = dataset.train_test_split(test_size=0.6, seed=SEED)
    # print(splits)
    print(splits["train"].num_rows, splits["test"].num_rows)

    train_dataset_dict[dataset_name] = splits["train"]
    ppo_dataset_dict[dataset_name] = splits["test"]
    print("-----------------------------------")

imdb
10000 15000
-----------------------------------


In [13]:
print(sum(dataset.num_rows for dataset in train_dataset_dict.values()))
print(sum(dataset.num_rows for dataset in ppo_dataset_dict.values()))

10000
15000


In [16]:
# datasets.DatasetDict(train_dataset_dict).save_to_disk(
#     # f"datasets/burns_datasets_VINC_train_raw_{VERSION}"
#     f"datasets/burns_datasets_VINC_imdb_train_raw_{VERSION}"
# )

Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]

In [17]:
# datasets.DatasetDict(ppo_dataset_dict).save_to_disk(
#     # f"datasets/burns_datasets_VINC_ppo_training_raw_{VERSION}"
#     f"datasets/burns_datasets_VINC_imdb_ppo_training_raw_{VERSION}"
# )

Saving the dataset (0/1 shards):   0%|          | 0/15000 [00:00<?, ? examples/s]

In [18]:
# !ls -lah datasets | grep {VERSION}
!ls -lah datasets | grep {VERSION} | grep imdb

drwxr-xr-x  3 augustas Domain Users 1.0K Jul  4 23:02 burns_datasets_VINC_imdb_ppo_training_raw_v2
drwxr-xr-x  3 augustas Domain Users 1.0K Jul  4 23:02 burns_datasets_VINC_imdb_train_raw_v2


## Get the templates

In [19]:
# Set dataset_dict to test_dataset
dataset_dict = datasets.DatasetDict.load_from_disk(
#     # f"datasets/burns_datasets_VINC_train_raw_{VERSION}"
    f"datasets/burns_datasets_VINC_imdb_train_raw_{VERSION}"
)
dataset_dict

DatasetDict({
    imdb: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
})

In [20]:
sum(dataset.num_rows for dataset in dataset_dict.values())

10000

In [21]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

imdb: 2


In [22]:
dataset_template_dict = {}

for dataset_path in dataset_dict.keys():
    dataset_templates = DatasetTemplates(dataset_path)

    dataset_templates.templates = {
        x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset_dict[dataset_path][0]) is not None
    }

    dataset_template_dict[dataset_path] = dataset_templates

In [23]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

imdb: 13


## Inspect the prompts

In [24]:
# dataset_name = "ag_news"

# for template_name, template in dataset_template_dict[dataset_name].templates.items():
#     # print(template_name)
#     # print(dataset_dict[dataset_name][0])
#     q, a = template.apply(
#         dataset_dict[dataset_name][0]
#     )
#     # print(q == q.strip())
#     # print(a == a.strip())
#     # print(" ".join([q, a.strip()]))
#     print(" ".join([q, a]))
#     print("---------------------------------")

## Form the dataset for the chosen split

In [25]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_fixed_answer_choices_list()}")

#     print("---------------------------------")


In [26]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_answer_choices_list(dataset_dict[dataset_name][0])}")

#     print("---------------------------------")

In [27]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

imdb: 13


In [28]:
%%time

# Reproducibility
random.seed(SEED)

ALLOWED_KEYS = ["text", "label", "original_dataset", "template_name"]

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    for idx, entry in enumerate(dataset):
        new_entry = entry.copy()
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Sample a random template
        template_name = random.choice(
            list(dataset_template_dict[dataset_name].templates.keys())
        )
        template = dataset_template_dict[dataset_name].templates[template_name]
        new_entry["template_name"] = template_name

        # Whether the sample will be truthful or not
        is_truthful = random.choice([True, False])

        # If not truthful, sample a random incorrect label
        if not is_truthful and len(dataset.features["label"].names) > 2:
            label_mappping = dataset.features["label"] # [class 1, class 2, ...]
            all_label_ids = [label_mappping.str2int(x) for x in label_mappping.names] # [0, 1, ...]
            incorrect_label_id = all_label_ids[(1 - new_entry["label"]) % len(all_label_ids)]
            incorrect_label = template.get_fixed_answer_choices_list()[incorrect_label_id]
            
        
        # Apply the template
        if is_truthful:
            new_text = " ".join(template.apply(new_entry))
        elif not is_truthful and len(dataset.features["label"].names) > 2:
            q, a = template.apply(new_entry)
            incorrect_label = replace_text_with_whitespace(a, incorrect_label)
            new_text = " ".join([q, incorrect_label])
        else:
            # Untruthful binary case
            new_entry["label"] = 1 - new_entry["label"]
            new_text = " ".join(template.apply(new_entry))

        new_entry["text"] = new_text

        # We can now change the label to whether the sample is truthful or not
        new_entry["label"] = int(is_truthful)

        # Remove all other keys
        new_entry = { k: v for k, v in new_entry.items() if k in ALLOWED_KEYS }

        # Append to the new dataset
        new_dataset.append(new_entry)

imdb


CPU times: user 14.7 s, sys: 90.9 ms, total: 14.8 s
Wall time: 14.8 s


In [29]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['text', 'label', 'original_dataset', 'template_name'],
    num_rows: 10000
})

In [30]:
current_idx = 0
my_dataset[current_idx]

{'text': "Consider the following example:  ''' This movie i've loved since i was young! Its excellent. Although, it may be a bit much for the average movie watcher if one can't interpret certain subtleties in the film (for example, our hero's name is Achilles, and in the final battle between him and Alexander he's shot in the heel with a rocket, just as Achilles in mythology was shot in his heel). That's a just a little fact that is kind of amusing! Anyway, great movie, good story, it'd be neat to see it redone with today's special effects! Oddly enough, Gary Graham had average success, starring in the T.V. show Alien Nation. This movie is a fun watch and should be more appreciated! '''\n\nBetween 0 and 1, which is the sentiment of this example?\n 1",
 'label': 1,
 'original_dataset': 'imdb',
 'template_name': 'burns_2'}

In [31]:
# Lynch. The man has some really great
# I first learned of the Wendigo
# The following movie review expresses
# Okay this movie fine like I said but
# i liked this film a lot. it's dark
print(my_dataset[current_idx]["text"])

Consider the following example:  ''' This movie i've loved since i was young! Its excellent. Although, it may be a bit much for the average movie watcher if one can't interpret certain subtleties in the film (for example, our hero's name is Achilles, and in the final battle between him and Alexander he's shot in the heel with a rocket, just as Achilles in mythology was shot in his heel). That's a just a little fact that is kind of amusing! Anyway, great movie, good story, it'd be neat to see it redone with today's special effects! Oddly enough, Gary Graham had average success, starring in the T.V. show Alien Nation. This movie is a fun watch and should be more appreciated! '''

Between 0 and 1, which is the sentiment of this example?
 1


In [32]:
# Counter({1: 5055, 0: 4945}
# Counter({1: 5021, 0: 4979}
# Counter({0: 5031, 1: 4969}
# Counter({0: 5067, 1: 4933}
# Counter({0: 5108, 1: 4892}
Counter(my_dataset["label"]), Counter(my_dataset["original_dataset"])

(Counter({1: 5031, 0: 4969}), Counter({'imdb': 10000}))

In [33]:
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_{SPLIT}_{VERSION}.parquet")
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_imdb_{SPLIT}_{VERSION}.parquet")

Creating parquet from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

12412247

In [35]:
!ls -lah datasets | grep {VERSION} | grep imdb
# !ls -lah datasets | grep {VERSION}

drwxr-xr-x  3 augustas Domain Users  25K Jul  4 23:02 burns_datasets_VINC_imdb_ppo_training_raw_v2
drwxr-xr-x  3 augustas Domain Users  25K Jul  4 23:02 burns_datasets_VINC_imdb_train_raw_v2
-rw-r--r--  1 augustas Domain Users 7.4M Jul  4 23:03 burns_datasets_VINC_imdb_train_v2.parquet
