## Imports

In [1]:
import sys
import random
from pathlib import Path
from collections import Counter

import datasets
from datasets import load_dataset_builder, load_dataset

from utils import replace_text_with_whitespace

In [2]:
ELK_PATH = Path("../../../elk/")
ELK_PATH.resolve()

PosixPath('/fsx/home-augustas/elk')

In [3]:
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

sys.path[:2]

['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk']

In [4]:
from templates import DatasetTemplates

In [5]:
# Disable the logging of the datasets library
import datasets

datasets.logging.set_verbosity_error()

## Config

In [6]:
# all datasets used in Burns et al. (2022) (apart from story_cloze)
# BURNS_DATASETS = [
#     "ag_news",
#     "amazon_polarity",
#     "dbpedia_14",
#     "glue:qnli",
#     "imdb",
#     "piqa",
#     "super_glue:boolq",
#     "super_glue:copa",
#     "super_glue:rte",
# ]
BURNS_DATASETS = ["piqa"]

VERSION = f"v2"

# SPLIT = "validation"
SPLIT = "train"

# These numbers are chosen so that both datasets have
# approximately 10k examples in total (probably a bit less for train split)
N_PER_DATASET = 17000 if SPLIT == "train" else 1550
if SPLIT == "train":
    # assert N_PER_DATASET <= 2490, "N_PER_DATASET must be <= 2490"
    assert N_PER_DATASET <= 25000, "N_PER_DATASET must be <= 2490"
else:
    assert N_PER_DATASET <= 1838, "N_PER_DATASET must be <= 1838"

SEED = 42 if SPLIT == "train" else 2023
SEED

42

## Load and inspect the datasets

In [7]:
dataset_dict = {}
for dataset_path in BURNS_DATASETS:
    print(dataset_path)

    # Parse dataset name
    dataset_name = None    
    if ":" in dataset_path:
        dataset_path, dataset_name = dataset_path.split(":")
    
    
    # Get the most validation-like split
    available_splits = load_dataset_builder(
        dataset_path, name=dataset_name
    ).info.splits.keys()
    split = "validation" if "validation" in available_splits else "test"
    split = split if SPLIT != "train" else "train"
    print(f"{split=}")

    # Load the dataset
    dataset = load_dataset(
        dataset_path, name=dataset_name, split=split,
    )

    # Get a desired subset of the data
    n = N_PER_DATASET if dataset.num_rows > N_PER_DATASET else dataset.num_rows
    dataset = dataset.shuffle(seed=SEED).select(range(n))

    print(dataset.num_rows)

    key = f"{dataset_path}/{dataset_name}" if dataset_name else dataset_path
    dataset_dict[key] = dataset

    print("-----------------------------------")

piqa
split='train'
16113
-----------------------------------


In [8]:
# Number of examples in total in all datasets
sum([dataset.num_rows for dataset in dataset_dict.values()])


16113

In [9]:
assert all([len(Counter(dataset["label"])) > 1 for dataset in dataset_dict.values()])

In [10]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {dataset.num_rows}")

piqa: 16113


In [11]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

piqa: 2


## Split into train and ppo

In [13]:
train_dataset_dict = {}
ppo_dataset_dict = {}

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)
    splits = dataset.train_test_split(train_size=10_000, seed=SEED)
    # print(splits)
    print(splits["train"].num_rows, splits["test"].num_rows)

    train_dataset_dict[dataset_name] = splits["train"]
    ppo_dataset_dict[dataset_name] = splits["test"]
    print("-----------------------------------")

piqa
10000 6113
-----------------------------------


In [14]:
print(sum(dataset.num_rows for dataset in train_dataset_dict.values()))
print(sum(dataset.num_rows for dataset in ppo_dataset_dict.values()))

10000
6113


In [16]:
# datasets.DatasetDict(train_dataset_dict).save_to_disk(
#     # f"datasets/burns_datasets_VINC_train_raw_{VERSION}"
#     # f"datasets/burns_datasets_VINC_imdb_train_raw_{VERSION}"
#     f"datasets/wrapped_piqa_train_raw_{VERSION}"
# )

Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]

In [17]:
# datasets.DatasetDict(ppo_dataset_dict).save_to_disk(
#     # f"datasets/burns_datasets_VINC_ppo_training_raw_{VERSION}"
#     # f"datasets/burns_datasets_VINC_imdb_ppo_training_raw_{VERSION}"
#     f"datasets/wrapped_piqa_ppo_training_raw_{VERSION}"
# )

Saving the dataset (0/1 shards):   0%|          | 0/6113 [00:00<?, ? examples/s]

In [18]:
# !ls -lah datasets | grep {VERSION}
# !ls -lah datasets | grep {VERSION} | grep imdb
!ls -lah datasets | grep {VERSION} | grep piqa

drwxr-xr-x  3 augustas Domain Users 1.0K Jul 24 20:31 wrapped_piqa_ppo_training_raw_v2
drwxr-xr-x  3 augustas Domain Users  33K Jul 24 20:31 wrapped_piqa_train_raw_v2


## Get the templates

In [22]:
# Set dataset_dict to test_dataset
dataset_dict = datasets.DatasetDict.load_from_disk(
#     # f"datasets/burns_datasets_VINC_train_raw_{VERSION}"
    # f"datasets/burns_datasets_VINC_imdb_train_raw_{VERSION}"
    f"datasets/wrapped_piqa_train_raw_{VERSION}"
)
dataset_dict

DatasetDict({
    piqa: Dataset({
        features: ['goal', 'sol1', 'sol2', 'label'],
        num_rows: 10000
    })
})

In [23]:
sum(dataset.num_rows for dataset in dataset_dict.values())

10000

In [24]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

piqa: 2


In [25]:
dataset_template_dict = {}

for dataset_path in dataset_dict.keys():
    dataset_templates = DatasetTemplates(dataset_path)

    dataset_templates.templates = {
        x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset_dict[dataset_path][0]) is not None
    }

    dataset_template_dict[dataset_path] = dataset_templates

In [29]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

piqa: 7


In [30]:
for dataset_name in dataset_template_dict:
    good_templates = {
        name: x for name, x in dataset_template_dict[dataset_name].templates.items() if x.metadata.choices_in_prompt
    }
    dataset_template_dict[dataset_name].templates = good_templates
    print(f"{dataset_name}: {len(good_templates)}")

piqa: 6


In [31]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

piqa: 6


## Inspect the prompts

In [66]:
dataset_name = "piqa"

for template_name, template in dataset_template_dict[dataset_name].templates.items():
    print(template_name)
    # print(template.jinja)
    # print(template.metadata.choices_in_prompt)
    # print(dataset_dict[dataset_name][0])
    q, a = template.apply(
        dataset_dict[dataset_name][0]
    )
    # print(q == q.strip())
    # print(a == a.strip())
    # print(" ".join([q, a.strip()]))
    # print(" ".join([q, a]))
    print("---------------------------------")

what_is_the_correct_ending
---------------------------------
pick_correct_choice_with_choice_given_before_goal
---------------------------------
pick_correct_choice_index
---------------------------------
finish_sentence_with_correct_choice
---------------------------------
choose the most appropriate solution
---------------------------------
Does this solution make sense? sol1
---------------------------------


## Form the dataset for the chosen split

In [25]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_fixed_answer_choices_list()}")

#     print("---------------------------------")


In [26]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_answer_choices_list(dataset_dict[dataset_name][0])}")

#     print("---------------------------------")

In [44]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

piqa: 6


In [60]:
%%time

# Reproducibility
random.seed(SEED)

ALLOWED_KEYS = ["text", "label", "original_dataset", "template_name"]

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    for idx, entry in enumerate(dataset):
        new_entry = entry.copy()
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Sample a random template
        template_name = random.choice(
            list(dataset_template_dict[dataset_name].templates.keys())
        )
        template = dataset_template_dict[dataset_name].templates[template_name]
        new_entry["template_name"] = template_name

        # Whether the sample will be truthful or not
        is_truthful = random.choice([True, False])

        # If not truthful, sample a random incorrect label
        if not is_truthful and len(dataset.features["label"].names) > 2:
            label_mappping = dataset.features["label"] # [class 1, class 2, ...]
            all_label_ids = [label_mappping.str2int(x) for x in label_mappping.names] # [0, 1, ...]
            incorrect_label_id = all_label_ids[(1 - new_entry["label"]) % len(all_label_ids)]
            incorrect_label = template.get_fixed_answer_choices_list()[incorrect_label_id]
            
        
        # Apply the template
        if is_truthful:
            # new_text = " ".join(template.apply(new_entry))
            q, a = template.apply(new_entry)
            new_text = " ".join([q.rstrip(), a.strip()])
        elif not is_truthful and len(dataset.features["label"].names) > 2:
            q, a = template.apply(new_entry)
            incorrect_label = replace_text_with_whitespace(a, incorrect_label)
            new_text = " ".join([q, incorrect_label])
        else:
            # Untruthful binary case
            new_entry["label"] = 1 - new_entry["label"]
            # new_text = " ".join(template.apply(new_entry))
            q, a = template.apply(new_entry)
            new_text = " ".join([q.rstrip(), a.strip()])

        new_entry["text"] = new_text

        # We can now change the label to whether the sample is truthful or not
        new_entry["label"] = int(is_truthful)

        # Remove all other keys
        new_entry = { k: v for k, v in new_entry.items() if k in ALLOWED_KEYS }

        # Append to the new dataset
        new_dataset.append(new_entry)


piqa


CPU times: user 23.1 s, sys: 0 ns, total: 23.1 s
Wall time: 23.1 s


In [61]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['label', 'original_dataset', 'template_name', 'text'],
    num_rows: 10000
})

In [62]:
current_idx = 0
my_dataset[current_idx]

{'label': 1,
 'original_dataset': 'piqa',
 'template_name': 'Does this solution make sense? sol1',
 'text': 'Does this phrase make sense?\nWhat service to turn dough on when making biscuits? turn the dough onto a lightly floured work surface and continue kneading until everything comes together.\nAnswer with Yes or No Yes'}

In [63]:
for current_index in range(10, 20):
    print(f"label={my_dataset[current_index]['label']}")
    print(my_dataset[current_index]["text"])
    print("---------------------------------")

label=0
Does this phrase make sense?
To grill hamburger safely and tastily leave meat under the grill until you can no longer see any blood. Remove.
Answer with Yes or No Yes
---------------------------------
label=0
Solution 1: Sprinkle your chosen cheese on top of the tortilla chips and place in the oven and set to broil.  Leave them in the oven until cheese has melted and the chips are warm.
Solution 2: Sprinkle your chosen cheese on top of the tortilla chips and place on the stove and set high.  Leave them on the stove until cheese has melted and the chips are warm.

Goal: How do you melt cheese on nachos?

Given the goal, what is the correct solution?

Answer by copying the correct solution Sprinkle your chosen cheese on top of the tortilla chips and place on the stove and set high.  Leave them on the stove until cheese has melted and the chips are warm.
---------------------------------
label=0
Given a goal and 2 solutions, choose the most appropriate solution.
Goal: How can I ma

In [67]:
# Counter({1: 5055, 0: 4945}
# Counter({1: 5021, 0: 4979}
# Counter({0: 5031, 1: 4969}
# Counter({0: 5067, 1: 4933}
# Counter({0: 5108, 1: 4892}
Counter(my_dataset["label"]), Counter(my_dataset["original_dataset"])

(Counter({1: 5039, 0: 4961}), Counter({'piqa': 10000}))

In [68]:
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_{SPLIT}_{VERSION}.parquet")
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_imdb_{SPLIT}_{VERSION}.parquet")
# my_dataset.to_parquet(f"datasets/wrapped_piqa_{SPLIT}_{VERSION}.parquet")

Creating parquet from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

4256068

In [70]:
# !ls -lah datasets | grep {VERSION}
# !ls -lah datasets | grep {VERSION} | grep imdb
!ls -lah datasets | grep {VERSION} | grep piqa

drwxr-xr-x  3 augustas Domain Users  25K Jul 24 20:31 wrapped_piqa_ppo_training_raw_v2
-rw-r--r--  1 augustas Domain Users 972K Jul 24 21:26 wrapped_piqa_ppo_training_v2.parquet
drwxr-xr-x  3 augustas Domain Users  33K Jul 24 20:31 wrapped_piqa_train_raw_v2
-rw-r--r--  1 augustas Domain Users 1.4M Jul 24 21:36 wrapped_piqa_train_v2.parquet
