In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [68]:
MAX_LENGTH = 1024

# Train set indices which contain examples that will be tokenized
# to longer than MAX_LENGTH
INVALID_INDICES = [
    4050, 5584, 11484, 12970, 16268,
    17807, 27560, 27891, 31253, 46463,
    50176, 53368, 58029, 58571, 60007,
    63541, 64250, 72894, 73473, 74163
]
len(INVALID_INDICES)

20

## Load the dataset

In [79]:
dataset = load_dataset("Dahoas/rm-static", cache_dir="../../.hf_cache/datasets")
dataset

Found cached dataset parquet (/rds/user/am3052/hpc-work/elk-rlhf/notebooks/../../.hf_cache/datasets/Dahoas___parquet/default-b9d2c4937d617106/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


Dataset({
    features: ['prompt', 'response', 'chosen', 'rejected'],
    num_rows: 5103
})

In [71]:
dataset["train"] = dataset["train"].filter(lambda item, idx: idx not in INVALID_INDICES, with_indices=True)
dataset

                                                                       

DatasetDict({
    train: Dataset({
        features: ['prompt', 'response', 'chosen', 'rejected'],
        num_rows: 76236
    })
    test: Dataset({
        features: ['prompt', 'response', 'chosen', 'rejected'],
        num_rows: 5103
    })
})

In [72]:
dataset["test"].num_rows / (dataset["train"].num_rows + dataset["test"].num_rows)

0.06273743222808247

## Tokenize the dataset

In [73]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")

EOS_TOKEN = tokenizer.eos_token

In [74]:
def tokenize_row(row, tokenizer, max_length=1024):
    prompt = row["prompt"]
    chosen_response, rejected_response = row["chosen"], row["rejected"]

    chosen_tokenized = tokenizer(
        prompt + chosen_response + EOS_TOKEN, return_tensors="pt"
    )["input_ids"]
    rejected_tokenized = tokenizer(
        prompt + rejected_response + EOS_TOKEN, return_tensors="pt"
    )["input_ids"]

    return { "chosen": chosen_tokenized, "rejected": rejected_tokenized }

tokenized_text = tokenize_row(dataset["train"][0], tokenizer)
list(map(lambda x: x.shape[1], tokenized_text.values())), list(tokenized_text.keys()), tokenized_text["chosen"][:, :10]

([183, 135],
 ['chosen', 'rejected'],
 tensor([[  198,   198, 20490,    25,  1680,   345,  6901,   262,  4831,   284]]))

In [75]:
dataset["train"] = dataset["train"].map(tokenize_row, fn_kwargs={ "tokenizer": tokenizer })
dataset["test"] = dataset["test"].map(tokenize_row, fn_kwargs={ "tokenizer": tokenizer })
# dataset = dataset.remove_columns(["prompt", "response"])
dataset

Loading cached processed dataset at /rds/user/am3052/hpc-work/.hf_cache/datasets/Dahoas___parquet/default-b9d2c4937d617106/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-5e48db60fca6235f.arrow


DatasetDict({
    train: Dataset({
        features: ['prompt', 'response', 'chosen', 'rejected'],
        num_rows: 76236
    })
    test: Dataset({
        features: ['prompt', 'response', 'chosen', 'rejected'],
        num_rows: 5103
    })
})

In [76]:
dataset["train"][0]

{'prompt': '\n\nHuman: Can you describe the steps to clean fingerprints and smudges from a laptop screen\n\nAssistant: Yes, certainly. To clean your screen, you first need to use a microfiber cloth or soft, damp cloth to gently wipe down the surface of the screen. Next, you’ll want to grab a soft, lint-free, microfiber cleaning cloth and gently rub it back and forth across the screen to remove fingerprints and smudges.\n\nHuman: Can I spray isopropyl alcohol onto the cloth and clean it that way?\n\nAssistant:',
 'response': ' Yes, you can do that to help the cloth pick up even more dirt from the screen. Be sure to always use a clean, soft cloth, not a piece of scratchy, roughened, or textured material, and make sure it’s lint-free.',
 'chosen': [[198,
   198,
   20490,
   25,
   1680,
   345,
   6901,
   262,
   4831,
   284,
   3424,
   34290,
   290,
   895,
   463,
   3212,
   422,
   257,
   13224,
   3159,
   198,
   198,
   48902,
   25,
   3363,
   11,
   3729,
   13,
   1675,
 

In [77]:
bad_idxs = []

for i, sample in enumerate(dataset["train"]):
    if len(sample["chosen"][0]) > MAX_LENGTH or len(sample["rejected"][0]) > MAX_LENGTH:
        bad_idxs.append(i)

bad_idxs

[]

In [78]:
for i, sample in enumerate(dataset["test"]):
    if len(sample["chosen"][0]) > MAX_LENGTH or len(sample["rejected"][0]) > MAX_LENGTH:
        print(i)