## Imports

In [1]:
import sys
import random
from pathlib import Path
from collections import Counter

import datasets

from utils import combine_strings_with_whitespace

In [2]:
ELK_PATH = Path("../../../elk/")
ELK_PATH.resolve()

PosixPath('/fsx/home-augustas/elk')

In [3]:
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

sys.path[:2]

['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk']

In [4]:
from templates import DatasetTemplates

In [5]:
# Disable the logging of the datasets library
import datasets

datasets.logging.set_verbosity_error()

# Load the dataset

In [6]:
VERSION = "v2"

In [7]:
# !ls -lah datasets | grep {VERSION}
# !ls -lah datasets | grep {VERSION} | grep imdb
!ls -lah datasets | grep {VERSION} | grep piqa

drwxr-xr-x  3 augustas Domain Users  25K Jul 24 20:31 wrapped_piqa_ppo_training_raw_v2
drwxr-xr-x  3 augustas Domain Users  33K Jul 24 20:31 wrapped_piqa_train_raw_v2
-rw-r--r--  1 augustas Domain Users 1.4M Jul 24 21:15 wrapped_piqa_train_v2.parquet


In [8]:
# Set dataset_dict to test_dataset
dataset_dict = datasets.DatasetDict.load_from_disk(
    # f"datasets/burns_datasets_VINC_ppo_training_raw_{VERSION}"
    f"datasets/wrapped_piqa_ppo_training_raw_{VERSION}"
)
dataset_dict

DatasetDict({
    piqa: Dataset({
        features: ['goal', 'sol1', 'sol2', 'label'],
        num_rows: 6113
    })
})

In [9]:
sum(dataset.num_rows for dataset in dataset_dict.values())

6113

In [10]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

piqa: 2


## Load the templates

In [11]:
dataset_template_dict = {}

for dataset_path in dataset_dict.keys():
    dataset_templates = DatasetTemplates(dataset_path)

    dataset_templates.templates = {
        x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset_dict[dataset_path][0]) is not None
    }

    dataset_template_dict[dataset_path] = dataset_templates

In [12]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

piqa: 7


In [13]:
for dataset_name in dataset_template_dict:
    good_templates = {
        name: x for name, x in dataset_template_dict[dataset_name].templates.items() if x.metadata.choices_in_prompt
    }
    dataset_template_dict[dataset_name].templates = good_templates
    print(f"{dataset_name}: {len(good_templates)}")

piqa: 6


In [14]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

piqa: 6


In [41]:
%%time

# Reproducibility
random.seed(2023)

ALLOWED_KEYS = ["prompt", "best_response", "original_dataset", "template_name"]

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    for idx, entry in enumerate(dataset):
        new_entry = entry.copy()
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Sample a random template
        template_name = random.choice(
            list(dataset_template_dict[dataset_name].templates.keys())
        )
        template = dataset_template_dict[dataset_name].templates[template_name]
        new_entry["template_name"] = template_name

        q, a = template.apply(new_entry)
        new_entry["prompt"] = q.rstrip()
        # new_entry["prompt"] = combine_strings_with_whitespace(a, q)
        # if not new_entry["prompt"][-1].isspace():
        #     new_entry["prompt"] += " "

        # We can now change the label to whether the sample is truthful or not
        # new_entry["best_response"] = a.strip()
        new_entry["best_response"] = a.strip()

        # Remove all other keys
        new_entry = { k: v for k, v in new_entry.items() if k in ALLOWED_KEYS }

        # Append to the new dataset
        new_dataset.append(new_entry)

piqa
CPU times: user 13.8 s, sys: 7.27 ms, total: 13.8 s
Wall time: 13.8 s


In [42]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['original_dataset', 'template_name', 'prompt', 'best_response'],
    num_rows: 6113
})

In [43]:
current_idx = 0
my_dataset[current_idx]

{'original_dataset': 'piqa',
 'template_name': 'finish_sentence_with_correct_choice',
 'prompt': 'Finish the following sentence with the best choice: To make a cat like medicine more, you can\n\nChoices:\n- Mix it with tuna juice\n- Mix it with lemon juice\n\nAnswer:',
 'best_response': 'Mix it with tuna juice'}

In [44]:
print(f'\'{my_dataset[current_idx]["prompt"]}\'')

'Finish the following sentence with the best choice: To make a cat like medicine more, you can

Choices:
- Mix it with tuna juice
- Mix it with lemon juice

Answer:'


In [45]:
print(f'\'{my_dataset[current_idx]["best_response"]}\'')

'Mix it with tuna juice'


In [47]:
for current_index in range(10, 20):
    print(my_dataset[current_index]["prompt"] + " " + my_dataset[current_index]["best_response"])
    print("---------------------------------")

Goal: how do you turn off a light?

Which is the correct ending?
- flip the switch upward.
- flip the switch downward.

Answer: flip the switch downward.
---------------------------------
Does this phrase make sense?
Make dusting windows faster. use leaf blower to blow dust off.
Answer with Yes or No No
---------------------------------
Sentence: how to keep your drink holders from getting gunk in them

Choice 1: use rainex in the holder, followed by a little gum.

Choice 2: line with a silicone cupcake liner

What is the index of the correct choice for ending for the sentence?

Answer: 2
---------------------------------
Does this phrase make sense?
brush can be gripped by  fingernails 
Answer with Yes or No No
---------------------------------
Given a goal and 2 solutions, choose the most appropriate solution.
Goal: To make bread softer when baking it in the oven,
- Solution 1: add a small pan of sugar under the bread pan as it cooks.
- Solution 2: add a small pan of ice under the br

In [33]:
Counter(my_dataset["original_dataset"])

Counter({'piqa': 6113})

In [49]:
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_imdb_ppo_training_{VERSION}.parquet")
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_ppo_training_{VERSION}.parquet")
# my_dataset.to_parquet(f"datasets/wrapped_piqa_ppo_training_{VERSION}.parquet")

Creating parquet from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

2532048

In [50]:
# !ls -lah datasets | grep {VERSION}
# !ls -lah datasets | grep {VERSION} | grep imdb
!ls -lah datasets | grep {VERSION} | grep piqa

drwxr-xr-x  3 augustas Domain Users  25K Jul 24 20:31 wrapped_piqa_ppo_training_raw_v2
-rw-r--r--  1 augustas Domain Users 972K Jul 24 21:26 wrapped_piqa_ppo_training_v2.parquet
drwxr-xr-x  3 augustas Domain Users  33K Jul 24 20:31 wrapped_piqa_train_raw_v2
-rw-r--r--  1 augustas Domain Users 1.4M Jul 24 21:15 wrapped_piqa_train_v2.parquet
