-
Notifications
You must be signed in to change notification settings - Fork 48
/
builders.py
358 lines (319 loc) · 14.1 KB
/
builders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic file to build datasets
modified from: https://github.com/NVIDIA/NeMo/blob/2baef811f21372c3340dd2d82635d2377e78a660/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py
to allow us to build SFT, RewardModel and RLHF datasets
"""
import json
from functools import partial
import numpy as np
import torch
from megatron.core import parallel_state
from omegaconf.dictconfig import DictConfig
from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import (
get_datasets_weights_and_num_samples,
get_train_valid_test_split_,
)
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingSampler
from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import get_indexed_dataset_
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingBatchSampler,
MegatronPretrainingRandomBatchSampler,
)
from nemo.utils import logging
from nemo_aligner.data.nlp.datasets import (
DPOModelDataset,
RegressionRewardModelDataset,
RewardModelDataset,
RLHFDataset,
)
from nemo_aligner.data.nlp.samplers import MegatronPretrainingRandomSampler
from nemo_aligner.utils.utils import collate_with_batch_max_sequence_length
def build_dataset_generic(cls, cfg, data_prefix, data_impl, num_samples, seq_length, seed, tokenizer, name):
def _build_dataset(current_data_prefix, current_num_samples):
if data_impl == "mmap":
data_payload = get_indexed_dataset_(current_data_prefix, data_impl, cfg.data.get("skip_warmup", True))
elif data_impl.startswith("json"):
with open(current_data_prefix, "r", encoding="utf_8") as fr:
data_payload = [json.loads(line.strip()) for line in fr]
else:
raise RuntimeError(f"data.data_impl must be either mmap or json or jsonl, but got {data_impl}")
total_num_of_documents = len(data_payload)
# Print stats about the splits.
logging.info(" > dataset split:")
logging.info(" Total {} documents is : {} ".format(name, total_num_of_documents))
drop_last = True
if name == "valid":
drop_last = cfg.data.get("validation_drop_last", True)
dataset = cls(
cfg=cfg,
tokenizer=tokenizer,
name=name,
data_prefix=current_data_prefix,
documents=np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32),
data=data_payload,
seq_length=seq_length,
seed=seed,
drop_last=drop_last,
)
return dataset
if len(data_prefix) == 1:
return _build_dataset(data_prefix[0], num_samples)
else:
output = get_datasets_weights_and_num_samples(data_prefix, num_samples)
data_prefixes, weights, datasets_num_samples = output
datasets = []
for i in range(len(data_prefixes)):
dataset = _build_dataset(data_prefixes[i], datasets_num_samples[i])
datasets.append(dataset)
return BlendableDataset(datasets, weights, num_samples)
def build_train_valid_test_datasets(
cls, cfg, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, tokenizer,
):
if isinstance(data_prefix, DictConfig):
assert (
data_prefix.get("train") is not None
and data_prefix.get("test") is not None
and data_prefix.get("validation") is not None
), f"Data prefix dictionary should have train, test and validation keys. data_prefix currently has only {data_prefix.keys()}"
if cfg.data.splits_string is not None:
logging.warning(cfg.data.splits_string + " ignored since data path is of type dictionary.")
train_ds = build_dataset_generic(
cls=cls,
cfg=cfg,
data_prefix=data_prefix["train"],
data_impl=data_impl,
num_samples=int(train_valid_test_num_samples[0]),
seq_length=seq_length,
seed=seed,
tokenizer=tokenizer,
name="train",
)
validation_ds = build_dataset_generic(
cls=cls,
cfg=cfg,
data_prefix=data_prefix["validation"],
data_impl=data_impl,
num_samples=int(train_valid_test_num_samples[0]),
seq_length=seq_length,
seed=seed,
tokenizer=tokenizer,
name="validation",
)
test_ds = build_dataset_generic(
cls=cls,
cfg=cfg,
data_prefix=data_prefix["test"],
data_impl=data_impl,
num_samples=int(train_valid_test_num_samples[0]),
seq_length=seq_length,
seed=seed,
tokenizer=tokenizer,
name="test",
)
return train_ds, validation_ds, test_ds
else:
# Single dataset.
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(
cls=cls,
cfg=cfg,
data_prefix=data_prefix[0],
data_impl=data_impl,
splits_string=splits_string,
train_valid_test_num_samples=train_valid_test_num_samples,
seq_length=seq_length,
seed=seed,
tokenizer=tokenizer,
)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
data_prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(data_prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
cls=cls,
cfg=cfg,
data_prefix=data_prefixes[i],
data_impl=data_impl,
splits_string=splits_string,
train_valid_test_num_samples=datasets_train_valid_test_num_samples[i],
seq_length=seq_length,
seed=seed,
tokenizer=tokenizer,
)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
train_n, valid_n, test_n = map(sum, zip(*datasets_train_valid_test_num_samples))
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights, train_n)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_n)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights, test_n)
return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)
def _build_train_valid_test_datasets(
cls, cfg, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, tokenizer,
):
"""Build train, valid, and test datasets."""
# Indexed dataset or jsonl
if data_impl == "mmap":
data_payload = get_indexed_dataset_(data_prefix, data_impl, cfg.data.get("skip_warmup", True))
elif data_impl.startswith("json"):
with open(data_prefix, "r", encoding="utf_8") as fr:
data_payload = [json.loads(line.strip()) for line in fr]
else:
raise RuntimeError(f"data.data_impl must be either mmap or json or jsonl, but got {data_impl}")
total_num_of_documents = len(data_payload)
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
logging.info(" > dataset split:")
def print_split_stats(name, index):
logging.info(" {}:".format(name))
logging.info(
" document indices in [{}, {}) total of {} "
"documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index])
)
print_split_stats("train", 0)
print_split_stats("validation", 1)
print_split_stats("test", 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32)
drop_last = True
if name == "validation":
drop_last = cfg.data.get("validation_drop_last", True)
dataset = cls(
cfg=cfg,
tokenizer=tokenizer,
name=name,
data_prefix=data_prefix,
documents=documents,
data=data_payload,
num_samples=train_valid_test_num_samples[index],
seq_length=seq_length,
seed=seed,
drop_last=drop_last,
)
return dataset
train_dataset = build_dataset(0, "train")
valid_dataset = build_dataset(1, "validation")
test_dataset = build_dataset(2, "test")
return (train_dataset, valid_dataset, test_dataset)
build_train_valid_test_rlhf_datasets = partial(build_train_valid_test_datasets, RLHFDataset)
build_train_valid_test_rm_datasets = partial(build_train_valid_test_datasets, RewardModelDataset)
build_train_valid_test_dpo_datasets = partial(build_train_valid_test_datasets, DPOModelDataset)
build_train_valid_test_regression_rm_datasets = partial(build_train_valid_test_datasets, RegressionRewardModelDataset)
def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None):
dataset_cls = GPTSFTChatDataset if is_chat else GPTSFTDataset
dataset = dataset_cls(
file_path=data_cfg.file_path,
tokenizer=tokenizer,
max_seq_length=data_cfg.max_seq_length,
min_seq_length=data_cfg.min_seq_length,
add_bos=data_cfg.get("add_bos", False),
add_eos=data_cfg.get("add_eos", True),
add_sep=data_cfg.get("add_sep", False),
sep_id=0,
max_num_samples=num_samples,
seed=data_cfg.get("seed", 1234),
label_key=data_cfg.get("label_key", "answer"),
answer_only_loss=answer_only_loss,
truncation_field=data_cfg.get("truncation_field", "text"),
pad_to_max_length=data_cfg.get("pad_to_max_length", False),
index_mapping_dir=data_cfg.get("index_mapping_dir", None),
prompt_template=data_cfg.get("prompt_template", None),
virtual_tokens=0,
memmap_workers=data_cfg.get(
"memmap_workers", None
), # used to set num. of workers to create the memmap index files
hf_dataset=data_cfg.get(
"hf_dataset", False
), # Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset.
truncation_method=data_cfg.get(
"truncation_method", "right"
), # used to choose truncation method. Options: ['random', 'left', 'right']
special_tokens=special_tokens,
output_original_text=data_cfg.get("output_original_text", False),
)
return dataset
def collate_with_pad_to_max_batch(max_seqlen, tokenizer_eos_id, cfg):
"""collate function that pads each sequence to the max in the batch
"""
return partial(
collate_with_batch_max_sequence_length,
response_token_length=max_seqlen,
eos_id=tokenizer_eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
)
def build_dataloader(
cfg,
dataset,
consumed_samples,
mbs,
gbs,
drop_last=True,
pad_samples_to_global_batch_size=False,
collate_fn=None,
load_gbs=True,
use_random_sampler=True,
):
"""Buld dataloader given an input dataset."""
logging.info(f"Building dataloader with consumed samples: {consumed_samples}")
# Common parameters for batch sampler creation
common_params = {
"total_samples": len(dataset),
"consumed_samples": consumed_samples,
"micro_batch_size": mbs,
"data_parallel_rank": parallel_state.get_data_parallel_rank(),
"data_parallel_size": parallel_state.get_data_parallel_world_size(),
"drop_last": drop_last,
"global_batch_size": gbs,
"pad_samples_to_global_batch_size": pad_samples_to_global_batch_size,
}
# Megatron sampler
if hasattr(cfg.model.data, "dataloader_type") and cfg.model.data.dataloader_type == "single":
if use_random_sampler:
cls = MegatronPretrainingRandomBatchSampler if load_gbs else MegatronPretrainingRandomSampler
common_params["seed"] = cfg.model.seed
else:
cls = MegatronPretrainingBatchSampler if load_gbs else MegatronPretrainingSampler
batch_sampler = cls(**common_params)
else:
raise ValueError('`cfg.model.data.dataloader_type` must be set to "single"')
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=cfg.model.data.num_workers,
pin_memory=True,
collate_fn=collate_fn,
)