## Imports

In [1]:
import sys
import random
from pathlib import Path
from collections import Counter

import datasets
from datasets import load_dataset_builder, load_dataset

from utils import replace_text_with_whitespace

In [2]:
ELK_PATH = Path("../../../elk/")
ELK_PATH.resolve()

PosixPath('/fsx/home-augustas/elk')

In [3]:
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

sys.path[:2]

['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk']

In [4]:
from templates import DatasetTemplates

In [5]:
# Disable the logging of the datasets library
import datasets

datasets.logging.set_verbosity_error()

## Config

In [6]:
BURNS_DATASETS = ["imdb"]

VERSION = f"v3"

# SPLIT = "validation"
SPLIT = "train"

N_PER_DATASET = 25000 if SPLIT == "train" else 1550

SEED = 42 if SPLIT == "train" else 2023
SEED

42

## Load and inspect the datasets

In [7]:
dataset_dict = {}
for dataset_path in BURNS_DATASETS:
    print(dataset_path)

    # Parse dataset name
    dataset_name = None    
    if ":" in dataset_path:
        dataset_path, dataset_name = dataset_path.split(":")
    
    
    # Get the most validation-like split
    available_splits = load_dataset_builder(
        dataset_path, name=dataset_name
    ).info.splits.keys()
    split = "validation" if "validation" in available_splits else "test"
    split = split if SPLIT != "train" else "train"
    print(f"{split=}")

    # Load the dataset
    dataset = load_dataset(
        dataset_path, name=dataset_name, split=split,
    )

    # Get a desired subset of the data
    n = N_PER_DATASET if dataset.num_rows > N_PER_DATASET else dataset.num_rows
    dataset = dataset.shuffle(seed=SEED).select(range(n))

    print(dataset.num_rows)

    key = f"{dataset_path}/{dataset_name}" if dataset_name else dataset_path
    dataset_dict[key] = dataset

    print("-----------------------------------")

imdb


split='train'
25000
-----------------------------------


In [8]:
# Number of examples in total in all datasets
sum([dataset.num_rows for dataset in dataset_dict.values()])


25000

In [9]:
assert all([len(Counter(dataset["label"])) > 1 for dataset in dataset_dict.values()])

In [10]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {dataset.num_rows}")

imdb: 25000


In [11]:
for dataset_name, dataset in dataset_dict.items():
    print(f"{dataset_name}: {len(Counter(dataset['label']))}")

imdb: 2


## Split into train and ppo

In [12]:
train_dataset_dict = {}
ppo_dataset_dict = {}

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)
    # test_size = 225 if dataset.num_rows > 225 else 10 # specialized for COPA
    splits = dataset.train_test_split(test_size=0.6, seed=SEED)
    # print(splits)
    print(splits["train"].num_rows, splits["test"].num_rows)

    train_dataset_dict[dataset_name] = splits["train"]
    ppo_dataset_dict[dataset_name] = splits["test"]
    print("-----------------------------------")

imdb
10000 15000
-----------------------------------


In [13]:
print(sum(dataset.num_rows for dataset in train_dataset_dict.values()))
print(sum(dataset.num_rows for dataset in ppo_dataset_dict.values()))

10000
15000


In [15]:
# datasets.DatasetDict(train_dataset_dict).save_to_disk(
#     f"datasets/burns_datasets_VINC_imdb_train_raw_{VERSION}"
# )

Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]

In [16]:
# datasets.DatasetDict(ppo_dataset_dict).save_to_disk(
#     f"datasets/burns_datasets_VINC_imdb_ppo_training_raw_{VERSION}"
# )

Saving the dataset (0/1 shards):   0%|          | 0/15000 [00:00<?, ? examples/s]

In [17]:
# !ls -lah datasets | grep {VERSION}
!ls -lah datasets | grep {VERSION} | grep imdb

drwxr-xr-x  3 augustas Domain Users 1.0K Jul 20 14:23 burns_datasets_VINC_imdb_ppo_training_raw_v3
drwxr-xr-x  3 augustas Domain Users 1.0K Jul 20 14:23 burns_datasets_VINC_imdb_train_raw_v3


## Get the templates

In [18]:
# Set dataset_dict to test_dataset
dataset_dict = datasets.DatasetDict.load_from_disk(
    f"datasets/burns_datasets_VINC_imdb_train_raw_{VERSION}"
)
dataset_dict

DatasetDict({
    imdb: Dataset({
        features: ['text', 'label'],
        num_rows: 10000
    })
})

In [19]:
sum(dataset.num_rows for dataset in dataset_dict.values())

10000

In [21]:
dataset_template_dict = {}

for dataset_path in dataset_dict.keys():
    dataset_templates = DatasetTemplates(dataset_path + "/custom")

    dataset_templates.templates = {
        x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset_dict[dataset_path][0]) is not None
    }

    dataset_template_dict[dataset_path] = dataset_templates

In [22]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

imdb: 5


## Inspect the prompts

In [24]:
# dataset_name = "imdb"

# for template_name, template in dataset_template_dict[dataset_name].templates.items():
#     # print(template_name)
#     q, a = template.apply(
#         dataset_dict[dataset_name][0]
#     )
#     print(" ".join([q, a]))
#     print("---------------------------------")

## Form the dataset for the chosen split

In [25]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_fixed_answer_choices_list()}")

#     print("---------------------------------")


In [26]:
# for dataset_name, dataset_templates in dataset_template_dict.items():
#     print(dataset_name)
#     for template_name, template in dataset_templates.templates.items():
#         print(f"{template_name}: {template.get_answer_choices_list(dataset_dict[dataset_name][0])}")

#     print("---------------------------------")

In [25]:
for dataset_name, dataset_templates in dataset_template_dict.items():
    print(f"{dataset_name}: {len(dataset_templates.templates)}")

imdb: 5


In [27]:
%%time

# Reproducibility
random.seed(SEED)

ALLOWED_KEYS = ["text", "label", "original_dataset", "template_name"]

new_dataset = []

for dataset_name, dataset in dataset_dict.items():
    print(dataset_name)

    for idx, entry in enumerate(dataset):
        new_entry = entry.copy()
        
        # In case we need to know which dataset the entry came from
        new_entry["original_dataset"] = dataset_name

        # Sample a random template
        template_name = random.choice(
            list(dataset_template_dict[dataset_name].templates.keys())
        )
        template = dataset_template_dict[dataset_name].templates[template_name]
        new_entry["template_name"] = template_name

        # Whether the sample will be truthful or not
        is_truthful = random.choice([True, False])            
        
        # Apply the template
        if not is_truthful:
            # Untruthful binary case
            new_entry["label"] = 1 - new_entry["label"]
        new_text = " ".join(template.apply(new_entry))

        new_entry["text"] = new_text

        # We can now change the label to whether the sample is truthful or not
        new_entry["label"] = int(is_truthful)

        # Remove all other keys
        new_entry = { k: v for k, v in new_entry.items() if k in ALLOWED_KEYS }

        # Append to the new dataset
        new_dataset.append(new_entry)

imdb


CPU times: user 15.9 s, sys: 70.3 ms, total: 15.9 s
Wall time: 16 s


In [28]:
my_dataset = datasets.Dataset.from_list(new_dataset)
my_dataset

Dataset({
    features: ['text', 'label', 'original_dataset', 'template_name'],
    num_rows: 10000
})

In [29]:
current_idx = 0
my_dataset[current_idx]

{'text': "This movie i've loved since i was young! Its excellent. Although, it may be a bit much for the average movie watcher if one can't interpret certain subtleties in the film (for example, our hero's name is Achilles, and in the final battle between him and Alexander he's shot in the heel with a rocket, just as Achilles in mythology was shot in his heel). That's a just a little fact that is kind of amusing! Anyway, great movie, good story, it'd be neat to see it redone with today's special effects! Oddly enough, Gary Graham had average success, starring in the T.V. show Alien Nation. This movie is a fun watch and should be more appreciated!\n\nDid the reviewer find this movie good or bad? good",
 'label': 1,
 'original_dataset': 'imdb',
 'template_name': 'Reviewer Opinion bad good choices'}

In [33]:
for current_index in range(10, 20):
    print(f"label={my_dataset[current_index]['label']}")
    print(my_dataset[current_index]["text"])
    print("---------------------------------")

label=0
This story was never among my favourites in Christie's works so I was pleasantly surprised to quite enjoy this adaptation. The mouse motif was effective if a little overdone, the bones of the story are there although more emphasis is placed on the 'crime in the past' subplot. The students were all pretty much as I imagined them although its a pity they weren't a more cosmopolitan bunch - perhaps the revised thirties setting didn't allow for that! I thought some very daring risks were taken with the filming; perhaps its because I've not long re-read the book but it seemed pretty obvious to me who the murderer was from their appearance in some reveal shots quite early on.<br /><br />Humour was much more prevalent in these early Poirots. Sometimes it works but I found a lot of it rather heavy handed in this episode (though I did smile at the 'Lemon sole' throwaway line). Altogether though, a solid entry in the series though not one of the best. 
Is this review positive or negative

In [34]:
Counter(my_dataset["label"]), Counter(my_dataset["original_dataset"])

(Counter({0: 5025, 1: 4975}), Counter({'imdb': 10000}))

In [37]:
# my_dataset.to_parquet(f"datasets/burns_datasets_VINC_imdb_{SPLIT}_{VERSION}.parquet")

In [3]:
!ls -lah datasets | grep {VERSION} | grep imdb

drwxr-xr-x  3 augustas Domain Users  25K Jul 20 14:23 burns_datasets_VINC_imdb_ppo_training_raw_v3
drwxr-xr-x  3 augustas Domain Users  25K Jul 20 14:23 burns_datasets_VINC_imdb_train_raw_v3
-rw-r--r--  1 augustas Domain Users 7.4M Jul 20 14:36 burns_datasets_VINC_imdb_train_v3.parquet
