## Imports

In [20]:
import sys
import random
from pathlib import Path
from collections import Counter

import datasets

from utils import combine_strings_with_whitespace

In [3]:
ELK_PATH = Path("../../../elk/")
ELK_PATH.resolve()

PosixPath('/rds/user/am3052/hpc-work/elk')

In [4]:
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

sys.path[:2]

['/rds/user/am3052/hpc-work/elk/elk/promptsource',
 '/rds/user/am3052/hpc-work/elk']

In [5]:
from templates import DatasetTemplates

In [6]:
# Disable the logging of the datasets library
import datasets

datasets.logging.set_verbosity_error()

# Load the dataset

In [7]:
# Set dataset_dict to test_dataset
dataset_dict = datasets.DatasetDict.load_from_disk("datasets/ppo_dataset_raw")
dataset_dict

DatasetDict({
    ag_news: Dataset({
        features: ['text', 'label'],
        num_rows: 1404
    })
    amazon_polarity: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 1404
    })
    dbpedia_14: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 1404
    })
    glue/qnli: Dataset({
        features: ['question', 'sentence', 'label', 'idx'],
        num_rows: 1404
    })
    imdb: Dataset({
        features: ['text', 'label'],
        num_rows: 1404
    })
    piqa: Dataset({
        features: ['goal', 'sol1', 'sol2', 'label'],
        num_rows: 1404
    })
    super_glue/boolq: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 1404
    })
    super_glue/copa: Dataset({
        features: ['premise', 'choice1', 'choice2', 'question', 'idx', 'label'],
        num_rows: 78
    })
    super_glue/rte: Dataset({
        features: ['premise', 'hypothesis', 'idx', 'label'],
        num_rows: 216
  

In [8]:
sum(dataset.num_rows for dataset in dataset_dict.values())

10122

In [9]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

ag_news: 4
amazon_polarity: 2
dbpedia_14: 14
glue/qnli: 2
imdb: 2
piqa: 2
super_glue/boolq: 2
super_glue/copa: 2
super_glue/rte: 2


## Load the templates

In [10]:
dataset_template_dict = {}

for dataset_path in dataset_dict.keys():
    dataset_templates = DatasetTemplates(dataset_path)

    dataset_templates.templates = {
        x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset_dict[dataset_path][0]) is not None
    }

    dataset_template_dict[dataset_path] = dataset_templates

In [11]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

ag_news: 15
amazon_polarity: 11
dbpedia_14: 11
glue/qnli: 5
imdb: 13
piqa: 7
super_glue/boolq: 10
super_glue/copa: 9
super_glue/rte: 11


In [32]:
%%time

# Reproducibility
random.seed(2023)

ALLOWED_KEYS = ["prompt", "best_response", "original_dataset", "template_name"]

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    # if dataset_name != "ag_news": continue
    # if dataset_name != "imdb": continue
    # if dataset_name != "dbpedia_14": continue
    # if dataset_name != "piqa": continue

    for idx, entry in enumerate(dataset):
        new_entry = entry.copy()
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Sample a random template
        template_name = random.choice(
            list(dataset_template_dict[dataset_name].templates.keys())
        )
        template = dataset_template_dict[dataset_name].templates[template_name]
        new_entry["template_name"] = template_name

        q, a = template.apply(new_entry)
        new_entry["prompt"] = combine_strings_with_whitespace(a, q)
        if not new_entry["prompt"][-1].isspace():
            new_entry["prompt"] += " "

        # We can now change the label to whether the sample is truthful or not
        new_entry["best_response"] = a.strip()

        # Remove all other keys
        new_entry = { k: v for k, v in new_entry.items() if k in ALLOWED_KEYS }

        # Append to the new dataset
        new_dataset.append(new_entry)

        # print(new_entry["prompt"])
        # print(new_entry["prompt"] + new_entry["best_response"])
        # print("-----------------------------------")
        # if idx >= 10 - 1: break

ag_news
amazon_polarity
dbpedia_14
glue/qnli
imdb
piqa
super_glue/boolq
super_glue/copa
super_glue/rte
CPU times: user 20 s, sys: 61.3 ms, total: 20.1 s
Wall time: 20.6 s


In [33]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['original_dataset', 'template_name', 'prompt', 'best_response'],
    num_rows: 10122
})

In [34]:
my_dataset[0]

{'original_dataset': 'ag_news',
 'template_name': 'burns_2',
 'prompt': "Consider the following example: ''' 49ers #39; home to be renamed Monster Park after stereo cable company Some fans think 44-year-old Candlestick Park is already a dinosaur. Now the San Francisco 49ers #39; home stadium has the name to match. ''' Choice 1: Sports. Choice 2: World politics.Between choice 1 and choice 2, what is the topic of this example? ",
 'best_response': 'Sports'}

In [36]:
Counter(my_dataset["original_dataset"])

Counter({'ag_news': 1404,
         'amazon_polarity': 1404,
         'dbpedia_14': 1404,
         'glue/qnli': 1404,
         'imdb': 1404,
         'piqa': 1404,
         'super_glue/boolq': 1404,
         'super_glue/copa': 78,
         'super_glue/rte': 216})

In [37]:
my_dataset.to_parquet(f"datasets/ppo_training_dataset.parquet")

Creating parquet from Arrow format:   0%|          | 0/11 [00:00<?, ?ba/s]

5883098

In [38]:
!ls -lah datasets

total 2.9M
drwxrwsr-x 4 am3052 am3052 4.0K Jun 14 23:53 .
drwxrwsr-x 3 am3052 am3052 4.0K Jun 14 22:26 ..
-rw-rw-r-- 1 am3052 am3052 2.8M Jun 14 16:08 burns_datasets_VINC_train.parquet
drwxrwsr-x 9 am3052 am3052 4.0K Jun 14 22:18 ppo_dataset_raw
-rw-rw-r-- 1 am3052 am3052 3.0M Jun 14 23:53 ppo_training_dataset.parquet
drwxrwsr-x 9 am3052 am3052 4.0K Jun 14 22:18 test_dataset_raw
