## Imports

In [1]:
import sys
import random
from pathlib import Path
from collections import Counter

import datasets
from datasets import load_dataset_builder, load_dataset

In [2]:
ELK_PATH = Path("../../../elk/")
ELK_PATH.resolve()

PosixPath('/rds/user/am3052/hpc-work/elk')

In [3]:
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

sys.path[:2]

['/rds/user/am3052/hpc-work/elk/elk/promptsource',
 '/rds/user/am3052/hpc-work/elk']

In [4]:
from templates import DatasetTemplates

In [5]:
# Disable the logging of the datasets library
import datasets

datasets.logging.set_verbosity_error()
# logging.getLogger("transformers").setLevel(logging.ERROR)

## Config

In [6]:
SPLIT = "train"

## Inspect the datasets

In [7]:
# all datasets used in Burns et al. (2022)
BURNS_DATASETS = [
    "ag_news",
    "amazon_polarity",
    "dbpedia_14",
    "glue:qnli",
    "imdb",
    "piqa",
    "super_glue:boolq",
    "super_glue:copa",
    "super_glue:rte",
]

Load in the datasets:

In [8]:
N_PER_DATASET = 100
assert N_PER_DATASET <= 1800, "N_PER_DATASET must be <= 1800"

dataset_dict = {}
for dataset_path in BURNS_DATASETS:
    print(dataset_path)

    # Parse dataset name
    dataset_name = None    
    if ":" in dataset_path:
        dataset_path, dataset_name = dataset_path.split(":")
    
    
    # Get the most validation-like split
    available_splits = load_dataset_builder(
        dataset_path, name=dataset_name
    ).info.splits.keys()
    split = "validation" if "validation" in available_splits else "test"
    split = split if SPLIT != "train" else "train"
    print(f"{split=}")

    # Load the dataset
    dataset = load_dataset(
        dataset_path, name=dataset_name, split=split,
    )

    # Get a desired subset of the data
    n = N_PER_DATASET if dataset.num_rows > N_PER_DATASET else dataset.num_rows
    dataset = dataset.shuffle(seed=42).select(range(n))

    key = f"{dataset_path}/{dataset_name}" if dataset_name else dataset_path
    dataset_dict[key] = dataset

    print("-----------------------------------")

ag_news
split='train'
-----------------------------------
amazon_polarity
split='train'
-----------------------------------
dbpedia_14
split='train'
-----------------------------------
glue:qnli
split='train'
-----------------------------------
imdb
split='train'
-----------------------------------
piqa
split='train'
-----------------------------------
super_glue:boolq
split='train'
-----------------------------------
super_glue:copa
split='train'
-----------------------------------
super_glue:rte
split='train'
-----------------------------------


In [9]:
assert all([len(Counter(dataset["label"])) > 1 for dataset in dataset_dict.values()])

In [22]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

ag_news: 4
amazon_polarity: 2
dbpedia_14: 14
glue/qnli: 2
imdb: 2
piqa: 2
super_glue/boolq: 2
super_glue/copa: 2
super_glue/rte: 2


In [34]:
tmp = dataset_dict["ag_news"].features["label"]

[tmp.str2int(x) for x in tmp.names]

[0, 1, 2, 3]

In [10]:
dataset_dict.keys()

dict_keys(['ag_news', 'amazon_polarity', 'dbpedia_14', 'glue/qnli', 'imdb', 'piqa', 'super_glue/boolq', 'super_glue/copa', 'super_glue/rte'])

In [11]:
dataset_template_dict = {}

for dataset_path in dataset_dict.keys():
    if ":" in dataset_path:
        dataset_path = dataset_path.replace(":", "/")

    dataset_templates = DatasetTemplates(dataset_path)

    dataset_templates.templates = {
        x.name: x for x in dataset_templates.templates.values()
    }

    dataset_template_dict[dataset_path] = dataset_templates

In [12]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

ag_news: 15
amazon_polarity: 11
dbpedia_14: 11
glue/qnli: 5
imdb: 13
piqa: 9
super_glue/boolq: 10
super_glue/copa: 9
super_glue/rte: 11


## Inspect the prompts

In [None]:
dataset_name = "ag_news"

for template_name, template in dataset_template_dict[dataset_name].templates.items():
    # print(template_name)
    # print(dataset_dict[dataset_name][0])
    q, a = template.apply(
        dataset_dict[dataset_name][0]
    )
    # print(q == q.strip())
    # print(a == a.strip())
    # print(" ".join([q, a.strip()]))
    print(" ".join([q, a]))
    print("---------------------------------")

## Form the dataset for the chosen split

In [54]:
dataset_dict["ag_news"].features["label"]

ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None)

In [60]:
tmp = list(dataset_template_dict["ag_news"].templates.values())[-3]
tmp.jinja

'{{text}} Which section of a newspaper would this article likely appear in, choice 1: {{answer_choices[label]}}, or choice 2: {{answer_choices[1 - label]}}? |||  {{answer_choices[label]}}'

In [64]:
dataset_dict["ag_news"][0]

{'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.',
 'label': 0}

In [68]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_fixed_answer_choices_list()}")

#     print("---------------------------------")


In [66]:
tmp.get_answer_choices_list(dataset_dict["ag_news"][0])

['World politics', 'Sports', 'Business', 'Science and technology']

In [55]:
# Reproducibility
random.seed(42)

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    for idx, entry in enumerate(dataset):
        # print(idx)
        # print(entry)
        new_entry = entry.copy()
        print(new_entry)
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Sample a random template
        template_name = random.choice(
            list(dataset_template_dict[dataset_name].templates.keys())
        )
        template = dataset_template_dict[dataset_name].templates[template_name]

        # Binarize the labels, format is [false_label, true_label]
        # label_mappping = dataset.features["label"] # [class 1, class 2, ...]
        # all_label_ids = [label_mappping.str2int(x) for x in label_mappping.names] # [0, 1, ...]
        # incorrect_labels = [x for x in all_label_ids if x != new_entry["label"]]
        # random_incorrect_label = random.choice(incorrect_labels)
        # labels = [random_incorrect_label, new_entry["label"]]
        # print(labels)

        # Whether the sample will be truthful or not
        is_truthful = random.choice([True, False])
        print(f"{is_truthful=}")

        # If not truthful, sample a random incorrect label
        if not is_truthful:
            # Approach 1
            # print(new_entry["label"], dataset.features["label"].names[new_entry["label"]])
            # label_mappping = dataset.features["label"] # [class 1, class 2, ...]
            # all_label_ids = [label_mappping.str2int(x) for x in label_mappping.names] # [0, 1, ...]
            # incorrect_labels = [x for x in all_label_ids if x != new_entry["label"]]
            # print(f"{new_entry['label']}, {incorrect_labels=}")
            # new_entry["label"] = random.choice(incorrect_labels)
            # print(new_entry["label"], dataset.features["label"].names[new_entry["label"]])
            
            # Approach 2
            # labels = labels[::-1]
            # print(labels)

            # Approach 3
            print(new_entry["label"], dataset.features["label"].names[new_entry["label"]])
            label_mappping = dataset.features["label"] # [class 1, class 2, ...]
            all_label_ids = [label_mappping.str2int(x) for x in label_mappping.names] # [0, 1, ...]
            incorrect_label_id = all_label_ids[(1 - new_entry["label"]) % len(all_label_ids)]
            incorrect_label = template.get_fixed_answer_choices_list()[incorrect_label_id]
            print(incorrect_label, dataset.features["label"].names[incorrect_label_id])
            
        
        # Apply the template
        new_text = " ".join(template.apply(new_entry))
        new_entry["text"] = new_text

        # We can now change the label to whether the sample is truthful or not
        new_entry["label"] = int(is_truthful)

        print(new_entry)
        new_dataset.append(new_entry)
        print("---------------------------------")

        if idx >= 10 - 1: break

    # for template_name, template in dataset_template_dict[dataset_name].templates.items():
    #     print(template_name)

    break

ag_news
{'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.', 'label': 0}
is_truthful=True
{'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally. Which is the topic of this example, choice 1: World politics, or choice 2: Sports? World politics', 'label': 1, 'original_dataset': 'ag_news'}
---------------------------------
{'text': 'Desiring Stability Redskins coach Joe Gibbs expects few major personnel changes in the offseason and wants to instill a culture of stability in Washington.', 'label': 1}
is_truthful=False
1 Sports
1, incorrect_labels=[0, 2, 3]
0 World
{'text': 'What label best describes this news article?\nDesiring Stability Redskins coach Joe Gibbs expects few major personnel changes in the offseason and wants 

In [56]:
-2 % 3

1

In [50]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['text', 'label', 'original_dataset'],
    num_rows: 10
})

In [51]:
my_dataset[0]

{'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally. Which is the topic of this example, choice 1: World politics, or choice 2: Sports? World politics',
 'label': 1,
 'original_dataset': 'ag_news'}

In [52]:
my_dataset["label"]

[1, 0, 1, 1, 1, 0, 1, 1, 1, 1]