Skip to content

Commit

Permalink
- Add metrics to forward step to add DPO specific metrics that are us…
Browse files Browse the repository at this point in the history
…eful (accuracy, etc)

- Add reference model setup for DPO
- Add pairwise dataset for positive/negative pairs
- Add DPO loss
  • Loading branch information
dmahan93 committed Jun 25, 2024
1 parent 15e3059 commit 2d20d86
Show file tree
Hide file tree
Showing 5 changed files with 994 additions and 75 deletions.
159 changes: 136 additions & 23 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data.pairwise_dataset import PairwiseDataset
from megatron.data.samplers import DistributedBatchSampler


Expand Down Expand Up @@ -53,43 +54,105 @@ def make_data_loader(dataset, neox_args):

def build_the_dataset(
data_prefix,
pos_data_prefix,
neg_data_prefix,
name,
data_impl,
pack_impl,
dataset_impl,
allow_chopped,
num_samples,
seq_length,
seed,
skip_warmup,
build_index_mappings=True,
label_prefix=None,
pos_label_prefix=None,
neg_label_prefix=None,
pos_ref_prefix=None,
neg_ref_prefix=None,
):
"""Build train/valid/test datasets."""

indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
if label_prefix is None:
label_dataset = None
if dataset_impl == "gpt2":
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
if label_prefix is None:
label_dataset = None
else:
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup)
elif dataset_impl == "pairwise":
pos_indexed_dataset = make_indexed_dataset(
pos_data_prefix, data_impl, skip_warmup
)
neg_indexed_dataset = make_indexed_dataset(
neg_data_prefix, data_impl, skip_warmup
)
if pos_label_prefix is None:
pos_label_dataset = None
# Also do neg here since they both must be the same
assert neg_label_prefix is None
neg_label_dataset = None
else:
pos_label_dataset = make_indexed_dataset(
pos_label_prefix, data_impl, skip_warmup
)
# Also do neg here since they both must be the same
assert neg_label_prefix is not None
neg_label_dataset = make_indexed_dataset(
neg_label_prefix, data_impl, skip_warmup
)
if pos_ref_prefix is not None:
pos_ref_dataset = make_indexed_dataset(
pos_ref_prefix, data_impl, skip_warmup
)
# Also do neg here since they both must be the same
assert neg_ref_prefix is not None
neg_ref_dataset = make_indexed_dataset(
neg_ref_prefix, data_impl, skip_warmup
)
else:
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup)
raise NotImplementedError(f"dataset_impl={dataset_impl} not implemented")

total_num_of_documents = indexed_dataset.sizes.shape[0]
total_num_of_documents = (
indexed_dataset.sizes.shape[0]
if dataset_impl == "gpt2"
else pos_indexed_dataset.sizes.shape[0]
)
print_rank_0(" {}:".format(name))
print_rank_0(" no. of documents:{}".format(total_num_of_documents))
dataset = None
documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32)
dataset = GPT2Dataset(
name,
data_prefix,
documents,
indexed_dataset,
num_samples,
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
)
if dataset_impl == "gpt2":
dataset = GPT2Dataset(
name,
data_prefix,
documents,
indexed_dataset,
num_samples,
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
)
elif dataset_impl == "pairwise":
dataset = PairwiseDataset(
name,
pos_data_prefix,
documents,
pos_indexed_dataset,
neg_indexed_dataset,
num_samples,
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
pos_label_dataset=pos_label_dataset,
neg_label_dataset=neg_label_dataset,
pos_ref_dataset=pos_ref_dataset,
neg_ref_dataset=neg_ref_dataset,
)
return dataset


Expand Down Expand Up @@ -135,7 +198,6 @@ def build_dataset(index, name):
documents = np.arange(
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32
)

dataset = GPT2Dataset(
name,
data_prefix,
Expand Down Expand Up @@ -219,18 +281,54 @@ def build_weighted_datasets(
valid_label_path,
test_path,
test_label_path,
pos_train_path,
neg_train_path,
pos_train_label_path,
neg_train_label_path,
pos_valid_path,
neg_valid_path,
pos_valid_label_path,
neg_valid_label_path,
pos_test_path,
neg_test_path,
pos_test_label_path,
neg_test_label_path,
) in enumerate(
zip_longest(
neox_args.train_data_paths,
neox_args.train_data_paths if neox_args.train_data_paths else [],
neox_args.train_label_data_paths
if neox_args.train_label_data_paths
else [],
neox_args.valid_data_paths,
neox_args.valid_data_paths if neox_args.valid_data_paths else [],
neox_args.valid_label_data_paths
if neox_args.valid_label_data_paths
else [],
neox_args.test_data_paths,
neox_args.test_data_paths if neox_args.pos_train_data_paths else [],
neox_args.test_label_data_paths if neox_args.test_label_data_paths else [],
neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [],
neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [],
neox_args.pos_train_label_data_paths
if neox_args.pos_train_label_data_paths
else [],
neox_args.neg_train_label_data_paths
if neox_args.neg_train_label_data_paths
else [],
neox_args.pos_valid_data_paths if neox_args.pos_valid_data_paths else [],
neox_args.neg_valid_data_paths if neox_args.neg_valid_data_paths else [],
neox_args.pos_valid_label_data_paths
if neox_args.pos_valid_label_data_paths
else [],
neox_args.neg_valid_label_data_paths
if neox_args.neg_valid_label_data_paths
else [],
neox_args.pos_test_data_paths if neox_args.pos_test_data_paths else [],
neox_args.neg_test_data_paths if neox_args.neg_test_data_paths else [],
neox_args.pos_test_label_data_paths
if neox_args.pos_test_label_data_paths
else [],
neox_args.neg_test_label_data_paths
if neox_args.neg_test_label_data_paths
else [],
)
):
if train_path:
Expand All @@ -247,6 +345,11 @@ def build_weighted_datasets(
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=train_label_path,
dataset_impl=neox_args.dataset_impl,
pos_data_prefix=pos_train_path,
neg_data_prefix=neg_train_path,
pos_label_prefix=pos_train_label_path,
neg_label_prefix=neg_train_label_path,
)
)

Expand All @@ -264,6 +367,11 @@ def build_weighted_datasets(
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=valid_label_path,
dataset_impl=neox_args.dataset_impl,
pos_data_prefix=pos_valid_path,
neg_data_prefix=neg_valid_path,
pos_label_prefix=pos_valid_label_path,
neg_label_prefix=neg_valid_label_path,
)
)

Expand All @@ -281,6 +389,11 @@ def build_weighted_datasets(
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=test_label_path,
dataset_impl=neox_args.dataset_impl,
pos_data_prefix=pos_test_path,
neg_data_prefix=neg_test_path,
pos_label_prefix=pos_test_label_path,
neg_label_prefix=neg_test_label_path,
)
)
return train_datasets, valid_datasets, test_datasets
Expand Down
Loading

0 comments on commit 2d20d86

Please sign in to comment.