-
Notifications
You must be signed in to change notification settings - Fork 6
/
pplm_dataset.py
129 lines (110 loc) · 3.78 KB
/
pplm_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Load PPLM dataset
"""
from typing import Dict, List, Optional, Iterator, Callable, Union, Tuple
import os
import json
from collections import defaultdict
import random
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from toxicity.train_dpo.dpo_utils import get_local_dir, TemporarilySeededRandom
from constants import DATA_DIR, GPT2_PAD_IDX
def get_pplm_batch_iterator(
tokenizer,
config,
split: str = "train",
device: str = "cuda",
) -> Iterator[Dict]:
"""
Get an iterator over batches of data.
:params:
:split: Which split to use.
:batch_size: Batch size.
:valid_size: Validation size.
"""
assert split in ["train", "valid"]
data_dir = os.path.join(DATA_DIR, "toxicity_pairwise")
batch_size = config.batch_size
if split == "valid":
batch_size = config.eval_batch_size
max_prompt_length = config.max_prompt_length
max_new_tokens = config.max_new_tokens
valid_size = config.valid_size
filenames = [
os.path.join(data_dir, filename)
for filename in os.listdir(data_dir)
if filename.endswith(".jsonl")
]
data = []
for filename in tqdm(filenames):
with open(filename, "r") as file_p:
file_data = file_p.readlines()
data.extend(file_data)
random.shuffle(file_data)
if split == "train":
data = data[:-valid_size]
else:
data = data[-valid_size:]
data_size = len(data)
for idx in range(0, data_size, batch_size):
batch = data[idx : idx + batch_size]
batch = [json.loads(x.strip()) for x in batch]
prompt_text = [x["prompt_text"] for x in batch]
gold_text = [x["unpert_gen_text"] for x in batch]
prompt_tokenized = tokenizer(
prompt_text,
max_length=max_prompt_length,
padding=True,
truncation=True,
return_tensors="pt",
).to(device)
prompt_input_ids = prompt_tokenized["input_ids"]
prompt_attention_mask = prompt_tokenized["attention_mask"]
tokenizer.padding_side = "right"
gold_tokenized = tokenizer(
gold_text,
max_length=max_new_tokens,
padding=True,
truncation=True,
return_tensors="pt",
).to(device)
pos_input_id = gold_tokenized["input_ids"].long()
pplm_text = [x["pert_gen_text"] for x in batch]
pplm_tokenized = tokenizer(
pplm_text,
max_length=max_new_tokens,
padding=True,
truncation=True,
return_tensors="pt",
).to(device)
tokenizer.padding_side = "left"
pos_input_ids = torch.concat(
[prompt_input_ids, gold_tokenized["input_ids"]], dim=1
)
neg_input_ids = torch.concat(
[prompt_input_ids, pplm_tokenized["input_ids"]], dim=1
)
prompt_shape = prompt_input_ids.shape[1]
pos_labels = pos_input_ids.detach().clone()
pos_labels[:, :prompt_shape] = -100
neg_labels = neg_input_ids.detach().clone()
neg_labels[:, :prompt_shape] = -100
yield {
"prompt_input_ids": prompt_input_ids,
"prompt_attention_mask": prompt_attention_mask,
"gold_text": gold_text,
"gold_input_ids": pos_input_id,
"pos_text": gold_text,
"pos_input_ids": pos_input_ids,
"pos_attention_mask": pos_input_ids != tokenizer.pad_token_id,
"pos_labels": pos_labels,
"neg_text": pplm_text,
"neg_input_ids": neg_input_ids,
"neg_attention_mask": neg_input_ids != tokenizer.pad_token_id,
"neg_labels": neg_labels,
}