## Imports

In [2]:
import sys
import random
from pathlib import Path
from collections import Counter

import datasets
from datasets import load_dataset_builder, load_dataset

from utils import replace_text_with_whitespace

In [3]:
ELK_PATH = Path("../../../elk/")
ELK_PATH.resolve()

PosixPath('/fsx/home-augustas/elk')

In [4]:
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

sys.path[:2]

['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk']

In [5]:
from templates import DatasetTemplates

In [6]:
# Disable the logging of the datasets library
import datasets

datasets.logging.set_verbosity_error()

## Config

In [10]:
# all datasets used in Burns et al. (2022) (apart from story_cloze)
BURNS_DATASETS = [
    "ag_news",
    "amazon_polarity",
    "dbpedia_14",
    "glue:qnli",
    "imdb",
    "piqa",
    "super_glue:boolq",
    "super_glue:copa",
    "super_glue:rte",
]
# BURNS_DATASETS = ["ag_news"]
# BURNS_DATASETS = [
#     "ag_news",
#     "amazon_polarity",
#     "dbpedia_14",
#     "glue:qnli",
#     "imdb",
#     "super_glue:boolq",
#     "super_glue:copa",
#     "super_glue:rte",
# ]
# BURNS_DATASETS = [
#     "ag_news",
#     "amazon_polarity",
#     "dbpedia_14",
#     "imdb",
#     "super_glue:boolq",
#     "super_glue:copa",
# ]

VERSION = "v4"

# SPLIT = "train"
SPLIT = "validation"

# These numbers are chosen so that both datasets have
# approximately 10k examples in total (probably a bit less for train split)
# N_PER_DATASET = 1000 if SPLIT == "train" else 1000
N_PER_DATASET = 1200 if SPLIT == "train" else 1550
if SPLIT == "train":
    assert N_PER_DATASET <= 2490, "N_PER_DATASET must be <= 2490"
else:
    assert N_PER_DATASET <= 1838, "N_PER_DATASET must be <= 1838"

SEED = 42 if SPLIT == "train" else 2023
SEED

2023

## Load and inspect the datasets

In [11]:
dataset_dict = {}
for dataset_path in BURNS_DATASETS:
    print(dataset_path)

    # Parse dataset name
    dataset_name = None    
    if ":" in dataset_path:
        dataset_path, dataset_name = dataset_path.split(":")
    
    
    # Get the most validation-like split
    available_splits = load_dataset_builder(
        dataset_path, name=dataset_name
    ).info.splits.keys()
    split = "validation" if "validation" in available_splits else "test"
    split = split if SPLIT != "train" else "train"
    print(f"{split=}")

    # Load the dataset
    dataset = load_dataset(
        dataset_path, name=dataset_name, split=split,
    )

    # Get a desired subset of the data
    n = N_PER_DATASET if dataset.num_rows > N_PER_DATASET else dataset.num_rows
    dataset = dataset.shuffle(seed=42).select(range(n))

    print(dataset.num_rows)

    key = f"{dataset_path}/{dataset_name}" if dataset_name else dataset_path
    dataset_dict[key] = dataset

    print("-----------------------------------")

ag_news


split='test'
1550
-----------------------------------
amazon_polarity
split='test'
1550
-----------------------------------
dbpedia_14
split='test'
1550
-----------------------------------
glue:qnli
split='validation'
1550
-----------------------------------
imdb
split='test'
1550
-----------------------------------
piqa
split='validation'
1550
-----------------------------------
super_glue:boolq
split='validation'
1550
-----------------------------------
super_glue:copa
split='validation'
100
-----------------------------------
super_glue:rte
split='validation'
277
-----------------------------------


In [12]:
# Number of examples in total in all datasets
sum([dataset.num_rows for dataset in dataset_dict.values()])


11227

In [13]:
assert all([len(Counter(dataset["label"])) > 1 for dataset in dataset_dict.values()])

In [14]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {dataset.num_rows}")

ag_news: 1550
amazon_polarity: 1550
dbpedia_14: 1550
glue/qnli: 1550
imdb: 1550
piqa: 1550
super_glue/boolq: 1550
super_glue/copa: 100
super_glue/rte: 277


In [15]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

ag_news: 4
amazon_polarity: 2
dbpedia_14: 14
glue/qnli: 2
imdb: 2
piqa: 2
super_glue/boolq: 2
super_glue/copa: 2
super_glue/rte: 2


## Split into validation and test

In [16]:
ppo_dataset_dict = {}
test_dataset_dict = {}

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)
    # test_size = 225 if dataset.num_rows > 225 else 10 # specialized for COPA
    splits = dataset.train_test_split(test_size=0.1093, seed=42)
    # print(splits)
    print(splits["train"].num_rows, splits["test"].num_rows)

    ppo_dataset_dict[dataset_name] = splits["train"]
    test_dataset_dict[dataset_name] = splits["test"]
    print("-----------------------------------")

ag_news


1380 170
-----------------------------------
amazon_polarity
1380 170
-----------------------------------
dbpedia_14
1380 170
-----------------------------------
glue/qnli
1380 170
-----------------------------------
imdb
1380 170
-----------------------------------
piqa
1380 170
-----------------------------------
super_glue/boolq
1380 170
-----------------------------------
super_glue/copa
89 11
-----------------------------------
super_glue/rte
246 31
-----------------------------------


In [17]:
print(sum(dataset.num_rows for dataset in ppo_dataset_dict.values()))
print(sum(dataset.num_rows for dataset in test_dataset_dict.values()))

9995
1232


In [21]:
# datasets.DatasetDict(ppo_dataset_dict).save_to_disk(
#     f"datasets/burns_datasets_VINC_ppo_training_raw_{VERSION}"
# )

In [22]:
# datasets.DatasetDict(test_dataset_dict).save_to_disk(
#     f"datasets/burns_datasets_VINC_raw_{VERSION}"
# )

In [23]:
!ls -lah datasets | grep {VERSION}

drwxr-xr-x  9 augustas Domain Users  25K Jun 29 11:14 burns_datasets_VINC_ppo_training_raw_v4
drwxr-xr-x  9 augustas Domain Users  33K Jun 29 11:14 burns_datasets_VINC_raw_v4
-rw-r--r--  1 augustas Domain Users 2.8M Jun 29 10:48 burns_datasets_VINC_train_v4.parquet


## Get the templates

In [24]:
# Set dataset_dict to test_dataset
dataset_dict = datasets.DatasetDict.load_from_disk(
    f"datasets/burns_datasets_VINC_raw_{VERSION}"
)
dataset_dict

DatasetDict({
    ag_news: Dataset({
        features: ['text', 'label'],
        num_rows: 170
    })
    amazon_polarity: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 170
    })
    dbpedia_14: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 170
    })
    glue/qnli: Dataset({
        features: ['question', 'sentence', 'label', 'idx'],
        num_rows: 170
    })
    imdb: Dataset({
        features: ['text', 'label'],
        num_rows: 170
    })
    piqa: Dataset({
        features: ['goal', 'sol1', 'sol2', 'label'],
        num_rows: 170
    })
    super_glue/boolq: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 170
    })
    super_glue/copa: Dataset({
        features: ['premise', 'choice1', 'choice2', 'question', 'idx', 'label'],
        num_rows: 11
    })
    super_glue/rte: Dataset({
        features: ['premise', 'hypothesis', 'idx', 'label'],
        num_rows: 31
    })
})

In [25]:
sum(dataset.num_rows for dataset in dataset_dict.values())

1232

In [26]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

ag_news: 4
amazon_polarity: 2
dbpedia_14: 14
glue/qnli: 2
imdb: 2
piqa: 2
super_glue/boolq: 2
super_glue/copa: 2
super_glue/rte: 2


In [27]:
dataset_template_dict = {}

for dataset_path in dataset_dict.keys():
    dataset_templates = DatasetTemplates(dataset_path)

    dataset_templates.templates = {
        x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset_dict[dataset_path][0]) is not None
    }

    dataset_template_dict[dataset_path] = dataset_templates

In [28]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

ag_news: 15
amazon_polarity: 11
dbpedia_14: 11
glue/qnli: 5
imdb: 13
piqa: 7
super_glue/boolq: 10
super_glue/copa: 9
super_glue/rte: 11


## Inspect the prompts

In [29]:
# dataset_name = "ag_news"

# for template_name, template in dataset_template_dict[dataset_name].templates.items():
#     # print(template_name)
#     # print(dataset_dict[dataset_name][0])
#     q, a = template.apply(
#         dataset_dict[dataset_name][0]
#     )
#     # print(q == q.strip())
#     # print(a == a.strip())
#     # print(" ".join([q, a.strip()]))
#     print(" ".join([q, a]))
#     print("---------------------------------")

## Form the dataset for the chosen split

In [33]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_fixed_answer_choices_list()}")

#     print("---------------------------------")


In [34]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_answer_choices_list(dataset_dict[dataset_name][0])}")

#     print("---------------------------------")

In [35]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

ag_news: 15
amazon_polarity: 11
dbpedia_14: 11
glue/qnli: 5
imdb: 13
piqa: 7
super_glue/boolq: 10
super_glue/copa: 9
super_glue/rte: 11


In [36]:
%%time

# Reproducibility
random.seed(SEED)

ALLOWED_KEYS = ["text", "label", "original_dataset", "template_name"]

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    for idx, entry in enumerate(dataset):
        new_entry = entry.copy()
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Sample a random template
        template_name = random.choice(
            list(dataset_template_dict[dataset_name].templates.keys())
        )
        template = dataset_template_dict[dataset_name].templates[template_name]
        new_entry["template_name"] = template_name

        # Whether the sample will be truthful or not
        is_truthful = random.choice([True, False])

        # If not truthful, sample a random incorrect label
        if not is_truthful and len(dataset.features["label"].names) > 2:
            label_mappping = dataset.features["label"] # [class 1, class 2, ...]
            all_label_ids = [label_mappping.str2int(x) for x in label_mappping.names] # [0, 1, ...]
            incorrect_label_id = all_label_ids[(1 - new_entry["label"]) % len(all_label_ids)]
            incorrect_label = template.get_fixed_answer_choices_list()[incorrect_label_id]
            
        
        # Apply the template
        if is_truthful:
            new_text = " ".join(template.apply(new_entry))
        elif not is_truthful and len(dataset.features["label"].names) > 2:
            q, a = template.apply(new_entry)
            incorrect_label = replace_text_with_whitespace(a, incorrect_label)
            new_text = " ".join([q, incorrect_label])
        else:
            # Untruthful binary case
            new_entry["label"] = 1 - new_entry["label"]
            new_text = " ".join(template.apply(new_entry))

        new_entry["text"] = new_text

        # We can now change the label to whether the sample is truthful or not
        new_entry["label"] = int(is_truthful)

        # Remove all other keys
        new_entry = { k: v for k, v in new_entry.items() if k in ALLOWED_KEYS }

        # Append to the new dataset
        new_dataset.append(new_entry)

ag_news


amazon_polarity
dbpedia_14
glue/qnli
imdb
piqa
super_glue/boolq
super_glue/copa
super_glue/rte
CPU times: user 2.55 s, sys: 185 ms, total: 2.74 s
Wall time: 2.75 s


In [37]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['text', 'label', 'original_dataset', 'template_name'],
    num_rows: 1232
})

In [40]:
current_idx = 0
my_dataset[current_idx]

{'text': "Consider the following example: ''' Bush Backs US Tariffs on Shrimp Foreign shrimp producers have denied they are selling shrimp at artificially low prices as a way to win a larger share of the US market. ''' Choice 1: Business. Choice 2: Science and technology.Between choice 1 and choice 2, what is the topic of this example? Science and technology",
 'label': 0,
 'original_dataset': 'ag_news',
 'template_name': 'burns_2'}

In [41]:
print(my_dataset[current_idx]["text"])

Consider the following example: ''' Bush Backs US Tariffs on Shrimp Foreign shrimp producers have denied they are selling shrimp at artificially low prices as a way to win a larger share of the US market. ''' Choice 1: Business. Choice 2: Science and technology.Between choice 1 and choice 2, what is the topic of this example? Science and technology


In [42]:
Counter(my_dataset["label"]), Counter(my_dataset["original_dataset"])

(Counter({0: 623, 1: 609}),
 Counter({'ag_news': 170,
          'amazon_polarity': 170,
          'dbpedia_14': 170,
          'glue/qnli': 170,
          'imdb': 170,
          'piqa': 170,
          'super_glue/boolq': 170,
          'super_glue/rte': 31,
          'super_glue/copa': 11}))

In [45]:
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_{SPLIT}_{VERSION}.parquet")

In [44]:
!ls -lah datasets | grep {VERSION}

drwxr-xr-x  9 augustas Domain Users  25K Jun 29 11:14 burns_datasets_VINC_ppo_training_raw_v4
drwxr-xr-x  9 augustas Domain Users  33K Jun 29 11:14 burns_datasets_VINC_raw_v4
-rw-r--r--  1 augustas Domain Users 2.8M Jun 29 10:48 burns_datasets_VINC_train_v4.parquet
-rw-r--r--  1 augustas Domain Users 380K Jun 29 11:18 burns_datasets_VINC_validation_v4.parquet
