In [1]:
import datasets
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_name = "Anthropic/hh-rlhf"
dataset = load_dataset(dataset_name)

Found cached dataset json (/home/codespace/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-1cfce6f8a62eee84/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████| 2/2 [00:05<00:00,  2.61s/it]


In [6]:
from transformers import AutoTokenizer

# currently we choose the gpt2 tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

tokenizer.pad_token = tokenizer.eos_token

In [7]:
# now we want to transform the dataset pytorch dataloader
# we have to create a custom collate function
max_length = 512

def collate_fn(list_of_samples):
    """
    In this function we define how we want to collate (combine) samples from the dataset
    The dataset return a dict with the following keys:
        - "chosen" : the chosen text
        - "rejected" : the rejected text
    """
    # we tokenize the chosen and rejected text for every sample
    chosen_tokenized = tokenizer(
        [sample["chosen"] for sample in list_of_samples],
        max_length=max_length,
        padding="max_length",
        truncation=True
    )

    rejected_tokenized = tokenizer(
        [sample["rejected"] for sample in list_of_samples],
        max_length=max_length,
        padding="max_length",
        truncation=True
    )
    
    # we create a new dict with the tokenized text
    tokenized = {
        "chosen": chosen_tokenized,
        "rejected": rejected_tokenized
    }

    # we return the tokenized text
    return tokenized

# we want to use the pytorch dataloader
from torch.utils.data import DataLoader

# we create the dataloader
dataloader = DataLoader(dataset["train"], batch_size=8, collate_fn=collate_fn)

# we can now iterate over the dataloader
for batch in dataloader:
    print(batch)
    break


{'chosen': {'input_ids': [[198, 198, 20490, 25, 1867, 389, 617, 269, 1046, 2456, 287, 46932, 30, 198, 198, 48902, 25, 3423, 447, 247, 82, 281, 17503, 1351, 13, 198, 198, 8021, 11, 19317, 11, 809, 26679, 11, 18824, 11, 5089, 11, 7510, 11, 21551, 11, 256, 2799, 11, 7510, 2256, 11, 7510, 21454, 11, 629, 10599, 388, 11, 40267, 11, 40107, 11, 5089, 263, 11, 7510, 12, 30041, 11, 10973, 11, 269, 2178, 38811, 11, 5089, 77, 1018, 1136, 11, 475, 400, 2305, 11, 40125, 11, 14509, 562, 11, 269, 3320, 12603, 11, 29836, 11, 43546, 11, 18314, 11, 19311, 11, 6611, 11, 266, 962, 11, 474, 1042, 11, 10973, 12, 82, 19296, 11, 22938, 378, 11, 277, 9460, 313, 11, 24506, 11, 474, 6457, 11, 474, 6457, 12, 75, 7958, 11, 37833, 11, 33526, 11, 1125, 729, 11, 329, 6988, 1352, 11, 781, 2238, 7357, 11, 9583, 1891, 11, 10816, 11, 16949, 11, 32581, 296, 578, 11, 3095, 1136, 11, 285, 1689, 447, 247, 82, 2933, 11, 277, 9460, 313, 11, 583, 1851, 11, 24506, 11, 629, 2178, 363, 11, 21551, 11, 198, 198, 20490, 25, 1867, 338