In [5]:
# Autoreload
%load_ext autoreload
%autoreload 2

## Imports

In [25]:
import random
from collections import Counter

import datasets
datasets.logging.set_verbosity_error() # Disable the logging of the datasets library
from datasets import load_dataset_builder, load_dataset

from fastchat.model import get_conversation_template

## Config

In [8]:
BURNS_DATASETS = ["glue:qnli"]

VERSION = f"v1"

# SPLIT = "validation"
SPLIT = "train"

N_PER_DATASET = 105_000 if SPLIT == "train" else 1550

SEED = 42 if SPLIT == "train" else 2023
SEED

42

## Load and inspect the datasets

In [9]:
dataset_dict = {}
for dataset_path in BURNS_DATASETS:
    print(dataset_path)

    # Parse dataset name
    dataset_name = None    
    if ":" in dataset_path:
        dataset_path, dataset_name = dataset_path.split(":")
    
    
    # Get the most validation-like split
    available_splits = load_dataset_builder(
        dataset_path, name=dataset_name
    ).info.splits.keys()
    split = "validation" if "validation" in available_splits else "test"
    split = split if SPLIT != "train" else "train"
    print(f"{split=}")

    # Load the dataset
    dataset = load_dataset(
        dataset_path, name=dataset_name, split=split,
    )

    # Get a desired subset of the data
    n = N_PER_DATASET if dataset.num_rows > N_PER_DATASET else dataset.num_rows
    dataset = dataset.shuffle(seed=SEED).select(range(n))

    print(dataset.num_rows)

    key = f"{dataset_path}/{dataset_name}" if dataset_name else dataset_path
    dataset_dict[key] = dataset

    print("-----------------------------------")

glue:qnli


split='train'
104743
-----------------------------------


In [10]:
# Number of examples in total in all datasets
sum([dataset.num_rows for dataset in dataset_dict.values()])


104743

In [11]:
assert all([len(Counter(dataset["label"])) > 1 for dataset in dataset_dict.values()])

In [12]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {dataset.num_rows}")

glue/qnli: 104743


In [13]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

glue/qnli: 2


## Split into train and ppo

In [14]:
train_dataset_dict = {}
ppo_dataset_dict = {}

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)
    # test_size = 225 if dataset.num_rows > 225 else 10 # specialized for COPA
    splits = dataset.train_test_split(train_size=10_000, seed=SEED)
    # print(splits)
    print(splits["train"].num_rows, splits["test"].num_rows)

    train_dataset_dict[dataset_name] = splits["train"]
    ppo_dataset_dict[dataset_name] = splits["test"]
    print("-----------------------------------")

glue/qnli
10000 94743
-----------------------------------


In [15]:
print(sum(dataset.num_rows for dataset in train_dataset_dict.values()))
print(sum(dataset.num_rows for dataset in ppo_dataset_dict.values()))

10000
94743


In [21]:
# datasets.DatasetDict(train_dataset_dict).save_to_disk(
#     f"datasets/qnli_vicuna_train_raw_{VERSION}"
# )

In [20]:
# datasets.DatasetDict(ppo_dataset_dict).save_to_disk(
#     f"datasets/qnli_vicuna_ppo_training_raw_{VERSION}"
# )

In [22]:
!ls -lah datasets | grep {VERSION} | grep qnli

drwxr-xr-x  3 augustas Domain Users 1.0K Jul 21 22:46 qnli_vicuna_ppo_training_raw_v1
drwxr-xr-x  3 augustas Domain Users  33K Jul 21 22:45 qnli_vicuna_train_raw_v1


## The template

In [23]:
# Set dataset_dict to test_dataset
dataset_dict = datasets.DatasetDict.load_from_disk(
    f"datasets/qnli_vicuna_train_raw_{VERSION}"
)
dataset_dict

DatasetDict({
    glue/qnli: Dataset({
        features: ['question', 'sentence', 'label', 'idx'],
        num_rows: 10000
    })
})

In [24]:
sum(dataset.num_rows for dataset in dataset_dict.values())

10000

In [27]:
def format_label(label):
    return "No" if label else "Yes"

# Dataset template
def template(example, answer_prefix="Answer:"):
    conv = get_conversation_template("lmsys/vicuna-7b-v1.3")

    message = (
        "Consider the sentence below in triple backticks "
        "and corresponding question. Does the sentence contain enough information "
        "to answer the question? Your answer should be either yes or no.\n\n"
        "Desired format:\n"
        "Answer: <your_answer>\n"
        f"Do not print \"{answer_prefix}\" again, just what you think the answer is.\n\n"
        f"Sentence:\n```\n{example['sentence']}\n```\n"
        f"Question: {example['question']}?\n"
        f"{answer_prefix}"
    )

    conv.append_message(conv.roles[0], message)
    conv.append_message(conv.roles[1], None)

    return conv.get_prompt(), format_label(example["label"])


q, a = template(dataset[2])
print(" ".join([q, a]))

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Consider the sentence below in triple backticks and corresponding question. Does the sentence contain enough information to answer the question? Your answer should be either yes or no.

Desired format:
Answer: <your_answer>
Do not print "Answer:" again, just what you think the answer is.

Sentence:
```
The Quran in its present form is generally considered by academic scholars to record the words spoken by Muhammad because the search for variants has not yielded any differences of great significance.[page needed]
```
Question: Which which university is Fred Donner affiliated??
Answer: ASSISTANT: No


## Form the dataset for the chosen split

In [28]:
%%time

# Reproducibility
random.seed(SEED)

ALLOWED_KEYS = ["text", "label", "original_dataset"]

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    for idx, entry in enumerate(dataset):
        new_entry = entry.copy()
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Whether the sample will be truthful or not
        is_truthful = random.choice([True, False])            
        
        # Apply the template
        if not is_truthful:
            # Untruthful binary case
            new_entry["label"] = 1 - new_entry["label"]
        new_text = " ".join(template(new_entry))

        new_entry["text"] = new_text

        # We can now change the label to whether the sample is truthful or not
        new_entry["label"] = int(is_truthful)

        # Remove all other keys
        new_entry = { k: v for k, v in new_entry.items() if k in ALLOWED_KEYS }

        # Append to the new dataset
        new_dataset.append(new_entry)

glue/qnli
CPU times: user 421 ms, sys: 15.5 ms, total: 436 ms
Wall time: 469 ms


In [29]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['label', 'original_dataset', 'text'],
    num_rows: 10000
})

In [30]:
current_idx = 0
my_dataset[current_idx]

{'label': 1,
 'original_dataset': 'glue/qnli',
 'text': 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions. USER: Consider the sentence below in triple backticks and corresponding question. Does the sentence contain enough information to answer the question? Your answer should be either yes or no.\n\nDesired format:\nAnswer: <your_answer>\nDo not print "Answer:" again, just what you think the answer is.\n\nSentence:\n```\nModern web browsers support a combination of standards-based and de facto HTML and XHTML, which should be rendered in the same way by all browsers.\n```\nQuestion: HTML and XHTML should be what by all browsers??\nAnswer: ASSISTANT: Yes'}

In [31]:
for current_index in range(10, 20):
    print(f"label={my_dataset[current_index]['label']}")
    print(my_dataset[current_index]["text"])
    print("---------------------------------")

label=1
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Consider the sentence below in triple backticks and corresponding question. Does the sentence contain enough information to answer the question? Your answer should be either yes or no.

Desired format:
Answer: <your_answer>
Do not print "Answer:" again, just what you think the answer is.

Sentence:
```
Over time, Roman architecture was modified as their urban requirements changed, and the civil engineering and building construction technology became developed and refined.
```
Question: Do any Roman structures still exist in our time??
Answer: ASSISTANT: No
---------------------------------
label=1
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Consider the sentence below in triple backticks and corresponding

In [32]:
Counter(my_dataset["label"]), Counter(my_dataset["original_dataset"])

(Counter({0: 5016, 1: 4984}), Counter({'glue/qnli': 10000}))

In [35]:
# my_dataset.to_parquet(f"datasets/qnli_vicuna_{SPLIT}_{VERSION}.parquet")

In [36]:
!ls -lah datasets | grep {VERSION} | grep qnli

drwxr-xr-x  3 augustas Domain Users  25K Jul 21 22:46 qnli_vicuna_ppo_training_raw_v1
drwxr-xr-x  3 augustas Domain Users  33K Jul 21 22:45 qnli_vicuna_train_raw_v1
-rw-r--r--  1 augustas Domain Users 2.0M Jul 21 23:02 qnli_vicuna_train_v1.parquet
