Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Reward Model training #1246

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
27 changes: 27 additions & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,33 @@ Additional DeepSpeed settings besides those mentioned above should be wrapped in
"eval_iters": 10,
```

However, if you want to use DPO style training you'll need to set pos/neg data paths instead of a single one, e.g.

```yaml
"dataset_impl": "pairwise",
"train_impl": "dpo",
"pack_impl": "unpacked",
"dpo_beta": 0.1,
"dpo_fp32": true,
"pos_train_data_path": "data/enwik8/enwik8_text_pos_document",
"pos_valid_data_path": "data/enwik8/enwik8_text_pos_document",
"pos_test_data_path": "data/enwik8/enwik8_text_pos_document",
"neg_train_data_path": "data/enwik8/enwik8_text_neg_document",
"neg_valid_data_path": "data/enwik8/enwik8_text_neg_document",
"neg_test_data_path": "data/enwik8/enwik8_text_neg_document",
## If you have labels... (likely to mask out user turns)
"pos_train_label_data_path": "data/enwik8/enwik8_text_pos_label_document",
"pos_valid_label_data_path": "data/enwik8/enwik8_text_pos_label_document",
"pos_test_label_data_path": "data/enwik8/enwik8_text_pos_label_document",
"neg_train_label_data_path": "data/enwik8/enwik8_text_neg_label_document",
"neg_valid_label_data_path": "data/enwik8/enwik8_text_neg_label_document",
"neg_test_label_data_path": "data/enwik8/enwik8_text_neg_label_document",
## If you want to precompute the logits over your dataset...
"precompute_model_name": "gpt2",
## Needed for the generation.py step, if precomputing
"text_gen_type": "precompute"
```

### LR Scheduler settings

```yaml
Expand Down
3 changes: 3 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
generate_samples_from_prompt,
generate_samples_unconditional,
generate_samples_interactive,
precompute_logits,
)


Expand Down Expand Up @@ -83,6 +84,8 @@ def main(input_args=None, overwrite_values=None):
top_p=neox_args.top_p,
)

elif neox_args.text_gen_type == "precompute":
precompute_logits(neox_args=neox_args, model=model)
else:
raise ValueError(
f"`text_gen_type` either not specified or not recognised: {neox_args.text_gen_type}"
Expand Down
1 change: 1 addition & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def load_checkpoint(
load_lr_scheduler_states=load_optim_and_scheduler,
load_module_only=not load_optim_and_scheduler,
tag=tag,
load_module_strict=neox_args.train_impl != "rm",
)

if checkpoint_name is None:
Expand Down
209 changes: 181 additions & 28 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data.pairwise_dataset import PairwiseDataset
from megatron.data.samplers import DistributedBatchSampler


Expand Down Expand Up @@ -53,46 +54,120 @@ def make_data_loader(dataset, neox_args):

def build_the_dataset(
data_prefix,
pos_data_prefix,
neg_data_prefix,
name,
data_impl,
pack_impl,
dataset_impl,
allow_chopped,
num_samples,
seq_length,
seed,
skip_warmup,
build_index_mappings=True,
label_prefix=None,
pos_label_prefix=None,
neg_label_prefix=None,
precompute_model_name=None,
):
"""Build train/valid/test datasets."""

indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
if label_prefix is None:
label_dataset = None
if dataset_impl == "gpt2":
indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup)
if label_prefix is None:
label_dataset = None
else:
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup)
if precompute_model_name is not None:
# If we have the name, assume it exists. If it doesn't, it will just be None which is fine.
precompute_indexed_dataset = make_indexed_dataset(
data_prefix + "_" + precompute_model_name, data_impl, skip_warmup
)
precompute_indexed_dataset = precompute_indexed_dataset
elif dataset_impl == "pairwise":
pos_indexed_dataset = make_indexed_dataset(
pos_data_prefix, data_impl, skip_warmup
)
neg_indexed_dataset = make_indexed_dataset(
neg_data_prefix, data_impl, skip_warmup
)
if pos_label_prefix is None:
pos_label_dataset = None
# Also do neg here since they both must be the same
assert neg_label_prefix is None
neg_label_dataset = None
else:
pos_label_dataset = make_indexed_dataset(
pos_label_prefix, data_impl, skip_warmup
)
# Also do neg here since they both must be the same
assert neg_label_prefix is not None
neg_label_dataset = make_indexed_dataset(
neg_label_prefix, data_impl, skip_warmup
)
if precompute_model_name is None:
pos_ref_dataset = None
neg_ref_dataset = None
else:
pos_ref_dataset = make_indexed_dataset(
pos_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup
)
neg_ref_dataset = make_indexed_dataset(
neg_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup
)
else:
label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup)
raise NotImplementedError(f"dataset_impl={dataset_impl} not implemented")

total_num_of_documents = indexed_dataset.sizes.shape[0]
total_num_of_documents = (
indexed_dataset.sizes.shape[0]
if dataset_impl == "gpt2"
else pos_indexed_dataset.sizes.shape[0]
)
print_rank_0(" {}:".format(name))
print_rank_0(" no. of documents:{}".format(total_num_of_documents))
dataset = None
documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32)
dataset = GPT2Dataset(
name,
data_prefix,
documents,
indexed_dataset,
num_samples,
seq_length,
seed,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
)
if dataset_impl == "gpt2":
dataset = GPT2Dataset(
name,
data_prefix,
documents,
indexed_dataset,
num_samples,
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
)
elif dataset_impl == "pairwise":
dataset = PairwiseDataset(
name,
pos_data_prefix,
documents,
pos_indexed_dataset,
neg_indexed_dataset,
num_samples,
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
pos_label_dataset=pos_label_dataset,
neg_label_dataset=neg_label_dataset,
pos_ref_dataset=pos_ref_dataset,
neg_ref_dataset=neg_ref_dataset,
)
return dataset


def build_train_valid_test_datasets(
data_prefix,
use_shared_fs,
data_impl,
pack_impl,
allow_chopped,
splits_string,
train_valid_test_num_samples,
seq_length,
Expand Down Expand Up @@ -129,7 +204,6 @@ def build_dataset(index, name):
documents = np.arange(
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32
)

dataset = GPT2Dataset(
name,
data_prefix,
Expand All @@ -138,6 +212,8 @@ def build_dataset(index, name):
train_valid_test_num_samples[index],
seq_length,
seed,
pack_impl=pack_impl,
allow_chopped=allow_chopped,
use_shared_fs=use_shared_fs,
)
return dataset
Expand Down Expand Up @@ -204,54 +280,129 @@ def build_weighted_datasets(
):
# build individual datasets
train_datasets, valid_datasets, test_datasets = [], [], []
for i, (train_path, label_path, valid_path, test_path) in enumerate(
for i, (
train_path,
train_label_path,
valid_path,
valid_label_path,
test_path,
test_label_path,
pos_train_path,
neg_train_path,
pos_train_label_path,
neg_train_label_path,
pos_valid_path,
neg_valid_path,
pos_valid_label_path,
neg_valid_label_path,
pos_test_path,
neg_test_path,
pos_test_label_path,
neg_test_label_path,
) in enumerate(
zip_longest(
neox_args.train_data_paths,
neox_args.label_data_paths if neox_args.label_data_paths else [],
neox_args.valid_data_paths,
neox_args.test_data_paths,
neox_args.train_data_paths if neox_args.train_data_paths else [],
neox_args.train_label_data_paths
if neox_args.train_label_data_paths
else [],
neox_args.valid_data_paths if neox_args.valid_data_paths else [],
neox_args.valid_label_data_paths
if neox_args.valid_label_data_paths
else [],
neox_args.test_data_paths if neox_args.test_data_paths else [],
neox_args.test_label_data_paths if neox_args.test_label_data_paths else [],
neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [],
neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [],
neox_args.pos_train_label_data_paths
if neox_args.pos_train_label_data_paths
else [],
neox_args.neg_train_label_data_paths
if neox_args.neg_train_label_data_paths
else [],
neox_args.pos_valid_data_paths if neox_args.pos_valid_data_paths else [],
neox_args.neg_valid_data_paths if neox_args.neg_valid_data_paths else [],
neox_args.pos_valid_label_data_paths
if neox_args.pos_valid_label_data_paths
else [],
neox_args.neg_valid_label_data_paths
if neox_args.neg_valid_label_data_paths
else [],
neox_args.pos_test_data_paths if neox_args.pos_test_data_paths else [],
neox_args.neg_test_data_paths if neox_args.neg_test_data_paths else [],
neox_args.pos_test_label_data_paths
if neox_args.pos_test_label_data_paths
else [],
neox_args.neg_test_label_data_paths
if neox_args.neg_test_label_data_paths
else [],
)
):
if train_path:
if train_path or pos_train_path:
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
name=f"train_{i}",
data_impl=neox_args.data_impl,
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
label_prefix=train_label_path,
dataset_impl=neox_args.dataset_impl,
pos_data_prefix=pos_train_path,
neg_data_prefix=neg_train_path,
pos_label_prefix=pos_train_label_path,
neg_label_prefix=neg_train_label_path,
precompute_model_name=neox_args.precompute_model_name,
)
)

if valid_path:
if valid_path or pos_valid_path:
valid_datasets.append(
build_the_dataset(
data_prefix=valid_path,
name=f"valid_{i}",
data_impl=neox_args.data_impl,
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
num_samples=valid_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=valid_label_path,
dataset_impl=neox_args.dataset_impl,
pos_data_prefix=pos_valid_path,
neg_data_prefix=neg_valid_path,
pos_label_prefix=pos_valid_label_path,
neg_label_prefix=neg_valid_label_path,
precompute_model_name=neox_args.precompute_model_name,
)
)

if test_path:
if test_path or pos_test_path:
test_datasets.append(
build_the_dataset(
data_prefix=test_path,
name=f"test_{i}",
data_impl=neox_args.data_impl,
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
num_samples=test_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=test_label_path,
dataset_impl=neox_args.dataset_impl,
pos_data_prefix=pos_test_path,
neg_data_prefix=neg_test_path,
pos_label_prefix=pos_test_label_path,
neg_label_prefix=neg_test_label_path,
precompute_model_name=neox_args.precompute_model_name,
)
)
return train_datasets, valid_datasets, test_datasets
Expand Down Expand Up @@ -323,7 +474,7 @@ def build_train_valid_test_data_iterators(neox_args):
test_iters * neox_args.train_batch_size,
]

if neox_args.train_data_paths:
if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths):
# when individual train / valid / test data paths are provided
# normalize weight values and get num samples for each dataset
train_weights, train_num_samples = get_normalized_weights_and_num_samples(
Expand Down Expand Up @@ -414,6 +565,8 @@ def build_train_valid_test_data_iterators(neox_args):
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
)

# Build dataloders.
Expand Down
Loading