From 2d20d86526f0714a475434f16fe9bc9ad7d48e8c Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 24 Jun 2024 20:27:37 -0500 Subject: [PATCH] - Add metrics to forward step to add DPO specific metrics that are useful (accuracy, etc) - Add reference model setup for DPO - Add pairwise dataset for positive/negative pairs - Add DPO loss --- megatron/data/data_utils.py | 159 ++++++-- megatron/data/pairwise_dataset.py | 585 +++++++++++++++++++++++++++ megatron/neox_arguments/neox_args.py | 56 +++ megatron/training.py | 267 +++++++++--- megatron/utils.py | 2 +- 5 files changed, 994 insertions(+), 75 deletions(-) create mode 100644 megatron/data/pairwise_dataset.py diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 7e4dbdb37..2c548077d 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -23,6 +23,7 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.blendable_dataset import BlendableDataset from megatron.data.gpt2_dataset import GPT2Dataset +from megatron.data.pairwise_dataset import PairwiseDataset from megatron.data.samplers import DistributedBatchSampler @@ -53,9 +54,12 @@ def make_data_loader(dataset, neox_args): def build_the_dataset( data_prefix, + pos_data_prefix, + neg_data_prefix, name, data_impl, pack_impl, + dataset_impl, allow_chopped, num_samples, seq_length, @@ -63,33 +67,92 @@ def build_the_dataset( skip_warmup, build_index_mappings=True, label_prefix=None, + pos_label_prefix=None, + neg_label_prefix=None, + pos_ref_prefix=None, + neg_ref_prefix=None, ): """Build train/valid/test datasets.""" - - indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) - if label_prefix is None: - label_dataset = None + if dataset_impl == "gpt2": + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + if label_prefix is None: + label_dataset = None + else: + label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) + elif dataset_impl == "pairwise": + pos_indexed_dataset = make_indexed_dataset( + pos_data_prefix, data_impl, skip_warmup + ) + neg_indexed_dataset = make_indexed_dataset( + neg_data_prefix, data_impl, skip_warmup + ) + if pos_label_prefix is None: + pos_label_dataset = None + # Also do neg here since they both must be the same + assert neg_label_prefix is None + neg_label_dataset = None + else: + pos_label_dataset = make_indexed_dataset( + pos_label_prefix, data_impl, skip_warmup + ) + # Also do neg here since they both must be the same + assert neg_label_prefix is not None + neg_label_dataset = make_indexed_dataset( + neg_label_prefix, data_impl, skip_warmup + ) + if pos_ref_prefix is not None: + pos_ref_dataset = make_indexed_dataset( + pos_ref_prefix, data_impl, skip_warmup + ) + # Also do neg here since they both must be the same + assert neg_ref_prefix is not None + neg_ref_dataset = make_indexed_dataset( + neg_ref_prefix, data_impl, skip_warmup + ) else: - label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) + raise NotImplementedError(f"dataset_impl={dataset_impl} not implemented") - total_num_of_documents = indexed_dataset.sizes.shape[0] + total_num_of_documents = ( + indexed_dataset.sizes.shape[0] + if dataset_impl == "gpt2" + else pos_indexed_dataset.sizes.shape[0] + ) print_rank_0(" {}:".format(name)) print_rank_0(" no. of documents:{}".format(total_num_of_documents)) dataset = None documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) - dataset = GPT2Dataset( - name, - data_prefix, - documents, - indexed_dataset, - num_samples, - seq_length, - seed, - pack_impl=pack_impl, - allow_chopped=allow_chopped, - build_index_mappings=build_index_mappings, - label_dataset=label_dataset, - ) + if dataset_impl == "gpt2": + dataset = GPT2Dataset( + name, + data_prefix, + documents, + indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, + build_index_mappings=build_index_mappings, + label_dataset=label_dataset, + ) + elif dataset_impl == "pairwise": + dataset = PairwiseDataset( + name, + pos_data_prefix, + documents, + pos_indexed_dataset, + neg_indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, + build_index_mappings=build_index_mappings, + pos_label_dataset=pos_label_dataset, + neg_label_dataset=neg_label_dataset, + pos_ref_dataset=pos_ref_dataset, + neg_ref_dataset=neg_ref_dataset, + ) return dataset @@ -135,7 +198,6 @@ def build_dataset(index, name): documents = np.arange( start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 ) - dataset = GPT2Dataset( name, data_prefix, @@ -219,18 +281,54 @@ def build_weighted_datasets( valid_label_path, test_path, test_label_path, + pos_train_path, + neg_train_path, + pos_train_label_path, + neg_train_label_path, + pos_valid_path, + neg_valid_path, + pos_valid_label_path, + neg_valid_label_path, + pos_test_path, + neg_test_path, + pos_test_label_path, + neg_test_label_path, ) in enumerate( zip_longest( - neox_args.train_data_paths, + neox_args.train_data_paths if neox_args.train_data_paths else [], neox_args.train_label_data_paths if neox_args.train_label_data_paths else [], - neox_args.valid_data_paths, + neox_args.valid_data_paths if neox_args.valid_data_paths else [], neox_args.valid_label_data_paths if neox_args.valid_label_data_paths else [], - neox_args.test_data_paths, + neox_args.test_data_paths if neox_args.pos_train_data_paths else [], neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], + neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], + neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], + neox_args.pos_train_label_data_paths + if neox_args.pos_train_label_data_paths + else [], + neox_args.neg_train_label_data_paths + if neox_args.neg_train_label_data_paths + else [], + neox_args.pos_valid_data_paths if neox_args.pos_valid_data_paths else [], + neox_args.neg_valid_data_paths if neox_args.neg_valid_data_paths else [], + neox_args.pos_valid_label_data_paths + if neox_args.pos_valid_label_data_paths + else [], + neox_args.neg_valid_label_data_paths + if neox_args.neg_valid_label_data_paths + else [], + neox_args.pos_test_data_paths if neox_args.pos_test_data_paths else [], + neox_args.neg_test_data_paths if neox_args.neg_test_data_paths else [], + neox_args.pos_test_label_data_paths + if neox_args.pos_test_label_data_paths + else [], + neox_args.neg_test_label_data_paths + if neox_args.neg_test_label_data_paths + else [], ) ): if train_path: @@ -247,6 +345,11 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=train_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_train_path, + neg_data_prefix=neg_train_path, + pos_label_prefix=pos_train_label_path, + neg_label_prefix=neg_train_label_path, ) ) @@ -264,6 +367,11 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=valid_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_valid_path, + neg_data_prefix=neg_valid_path, + pos_label_prefix=pos_valid_label_path, + neg_label_prefix=neg_valid_label_path, ) ) @@ -281,6 +389,11 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=test_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_test_path, + neg_data_prefix=neg_test_path, + pos_label_prefix=pos_test_label_path, + neg_label_prefix=neg_test_label_path, ) ) return train_datasets, valid_datasets, test_datasets diff --git a/megatron/data/pairwise_dataset.py b/megatron/data/pairwise_dataset.py new file mode 100644 index 000000000..b59218f08 --- /dev/null +++ b/megatron/data/pairwise_dataset.py @@ -0,0 +1,585 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pairwise style dataset.""" + +import os +import time + +import numpy as np +import torch + +from megatron import mpu, print_rank_0 + + +class PairwiseDataset(torch.utils.data.Dataset): + def __init__( + self, + name, + pos_data_prefix, # Don't need neg since it's assumed you have paired the data already. + documents, + pos_indexed_dataset, + neg_indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl="unpacked", + build_index_mappings=True, + use_shared_fs=True, + pos_label_dataset=None, + pos_ref_dataset=None, + neg_label_dataset=None, + neg_ref_dataset=None, + allow_chopped=True, + ): + + self.name = name + self.pos_indexed_dataset = pos_indexed_dataset + self.pos_label_dataset = pos_label_dataset + self.pos_ref_dataset = pos_ref_dataset + self.neg_indexed_dataset = neg_indexed_dataset + self.neg_label_dataset = neg_label_dataset + self.neg_ref_dataset = neg_ref_dataset + self.pack_impl = pack_impl + self.seq_length = seq_length + # Checks + assert np.min(documents) >= 0 + assert (neg_label_dataset is not None and pos_label_dataset is not None) or ( + neg_label_dataset is None and pos_label_dataset is None + ), "Label datasets must be both None or both not None" + assert np.max(documents) < pos_indexed_dataset.sizes.shape[0] + assert pos_indexed_dataset.sizes.shape[0] == neg_indexed_dataset.sizes.shape[0] + assert ( + pack_impl != "packed" + ), "Packed implementation not supported for pairwise dataset" + + if build_index_mappings: + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, + pos_data_prefix, + documents, + self.pos_indexed_dataset.sizes, + self.neg_indexed_dataset.sizes, + self.pos_label_dataset, + self.neg_label_dataset, + num_samples, + seq_length, + seed, + pack_impl, + use_shared_fs=use_shared_fs, + allow_chopped=allow_chopped, + ) + self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 + self.sample_idx_len = self.sample_idx.shape[0] - 1 + + if self.shuffle_idx_len != self.sample_idx_len - 1: + print( + f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" + ) + + def __len__(self): + return min(self.shuffle_idx_len, self.sample_idx_len) + + def __getitem__(self, idx): + try: + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # Labels and texts are supposed to be fully in sync. + datasets = ( + [self.pos_indexed_dataset, self.neg_indexed_dataset] + if self.pos_label_dataset is None + else [ + self.pos_indexed_dataset, + self.neg_indexed_dataset, + self.pos_label_dataset, + self.neg_label_dataset, + ] + ) + samples = [] + pos_ref_samples = [] + neg_ref_samples = [] + # If we are within the same document, just extract the chunk. + for n, dataset in enumerate(datasets): + if doc_index_f == doc_index_l: + samples.append( + dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + ) + if n == 0: + if self.pos_ref_dataset is not None: + pos_ref_samples.append( + self.pos_ref_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + ) + neg_ref_samples.append( + self.neg_ref_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + ) + + else: + # Otherwise, get the rest of the initial document. + sample_list = [ + dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] + + if n == 0: + if self.pos_ref_dataset is not None: + pos_ref_sample_list = [ + self.pos_ref_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + ) + ] + neg_ref_sample_list = [ + self.neg_ref_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + ) + ] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(dataset.get(self.doc_idx[i])) + if n == 0: + if self.pos_ref_dataset is not None: + pos_ref_sample_list.append( + self.pos_ref_dataset.get( + self.doc_idx[i], + ) + ) + neg_ref_sample_list.append( + self.neg_ref_dataset.get( + self.doc_idx[i], + ) + ) + # And finally add the relevant portion of last document. + sample_list.append( + dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) + samples.append(np.concatenate(sample_list)) + if n == 0: + if self.pos_ref_dataset is not None: + pos_ref_sample_list.append( + self.pos_ref_dataset.get( + self.doc_idx[doc_index_l], length=offset_l + 1 + ) + ) + pos_ref_samples.append(np.concatenate(pos_ref_sample_list)) + neg_ref_sample_list.append( + self.neg_ref_dataset.get( + self.doc_idx[doc_index_l], length=offset_l + 1 + ) + ) + neg_ref_samples.append(np.concatenate(neg_ref_sample_list)) + if self.pos_ref_dataset is not None: + if len(pos_ref_samples[0]) < (self.seq_length): + # Pad with 0s + pos_ref_samples[0] = np.pad( + pos_ref_samples[0], + (0, (self.seq_length) - len(pos_ref_samples[0])), + mode="constant", + constant_values=0, + ) + elif len(pos_ref_samples[0]) > (self.seq_length): + # Check for overflow and truncate. + pos_ref_samples[0] = pos_ref_samples[0][: (self.seq_length)] + if len(neg_ref_samples[0]) < (self.seq_length): + # Pad with 0s + neg_ref_samples[0] = np.pad( + neg_ref_samples[0], + (0, (self.seq_length) - len(neg_ref_samples[0])), + mode="constant", + constant_values=0, + ) + elif len(neg_ref_samples[0]) > (self.seq_length): + # Check for overflow and truncate. + neg_ref_samples[0] = neg_ref_samples[0][: (self.seq_length)] + if len(datasets) == 2: + # pos + if len(samples[0]) < (self.seq_length + 1): + # Pad with -100s so the masking function can ignore these. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] + # neg + if len(samples[1]) < (self.seq_length + 1): + # Pad with -100s so the masking function can ignore these. + samples[1] = np.pad( + samples[1], + (0, (self.seq_length + 1) - len(samples[1])), + mode="constant", + constant_values=-100, + ) + elif len(samples[1]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[1] = samples[1][: (self.seq_length + 1)] + ret = { + "pos": np.array(samples[0], dtype=np.int64), + "neg": np.array(samples[1], dtype=np.int64), + } + if self.pos_ref_dataset is not None: + ret["pos_ref"] = np.array(pos_ref_samples[0], dtype=np.float32) + ret["neg_ref"] = np.array(neg_ref_samples[0], dtype=np.float32) + return ret + else: + # pos + if len(samples[0]) < (self.seq_length + 1): + # Pad with 0s, can use any number since it's masked. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=0, + ) + # pad with -100s so we can mask it out + samples[2] = np.pad( + samples[2], + (0, (self.seq_length + 1) - len(samples[2])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] + samples[2] = samples[2][: (self.seq_length + 1)] + # neg + if len(samples[1]) < (self.seq_length + 1): + # Pad with 0s, can use any number since it's masked. + samples[1] = np.pad( + samples[1], + (0, (self.seq_length + 1) - len(samples[1])), + mode="constant", + constant_values=0, + ) + # pad with -100s so we can mask it out + samples[3] = np.pad( + samples[3], + (0, (self.seq_length + 1) - len(samples[3])), + mode="constant", + constant_values=-100, + ) + elif len(samples[1]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[1] = samples[1][: (self.seq_length + 1)] + samples[3] = samples[3][: (self.seq_length + 1)] + ret = { + "pos": np.array(samples[0], dtype=np.int64), + "neg": np.array(samples[1], dtype=np.int64), + "pos_label": np.array(samples[2], dtype=np.int64), + "neg_label": np.array(samples[3], dtype=np.int64), + } + if self.pos_ref_dataset is not None: + ret["pos_ref"] = np.array(pos_ref_samples[0], dtype=np.float32) + ret["neg_ref"] = np.array(neg_ref_samples[0], dtype=np.float32) + return ret + except IndexError: + new_idx = idx % len(self) + print( + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + ) + return self[new_idx] + + +def _build_index_mappings( + name, + pos_data_prefix, + documents, + pos_sizes, + neg_sizes, + pos_label_dataset, + neg_label_dataset, + num_samples, + seq_length, + seed, + packing_impl, + use_shared_fs=True, + allow_chopped=True, +): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, pos_sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = pos_data_prefix + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + _filename += "_{}pi".format(packing_impl) + doc_idx_filename = _filename + "_doc_idx.npy" + sample_idx_filename = _filename + "_sample_idx.npy" + shuffle_idx_filename = _filename + "_shuffle_idx.npy" + + if not use_shared_fs: + should_process_dataset = int(os.environ["LOCAL_RANK"]) == 0 + else: + should_process_dataset = torch.distributed.get_rank() == 0 + + # Build the indexed mapping if not exist. + if should_process_dataset: + if ( + (not os.path.isfile(doc_idx_filename)) + or (not os.path.isfile(sample_idx_filename)) + or (not os.path.isfile(shuffle_idx_filename)) + ): + print_rank_0( + " > WARNING: could not find index map files, building " + "the indices on rank 0 ..." + ) + # doc-idx. + start_time = time.time() + if packing_impl == "pack_until_overflow": + # Naively pack data until it overflows, then roll it over to a new one instead. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = [] + doc_idx = [] + # Iterate over files until we have enough samples. + temp_shuffle_idx = np.arange(len(documents)) + np_rng.shuffle(temp_shuffle_idx) + running_length = 0 + curr_shuffle_idx = 0 + while len(sample_idx) < num_samples: + # If not allow_chopped, skip this item if it's chopped. + if not allow_chopped: + if ( + pos_sizes[temp_shuffle_idx[curr_shuffle_idx]] + < seq_length + 1 + ): + curr_shuffle_idx += 1 + continue + if ( + neg_sizes[temp_shuffle_idx[curr_shuffle_idx]] + < seq_length + 1 + ): + curr_shuffle_idx += 1 + continue + # Then, check if we need to skip this item... + if pos_label_dataset is not None: + if np.all( + pos_label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + if np.all( + neg_label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + doc_length = max( + pos_sizes[temp_shuffle_idx[curr_shuffle_idx]], + neg_sizes[temp_shuffle_idx[curr_shuffle_idx]], + ) + if running_length == 0: + sample_idx.append(np.array([len(doc_idx), 0])) + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + running_length += doc_length + else: + if running_length + doc_length > (seq_length + 1): + running_length = doc_length + sample_idx.append(np.array([len(doc_idx), 0])) + else: + running_length += doc_length + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + curr_shuffle_idx += 1 + if curr_shuffle_idx == len(documents): + curr_shuffle_idx = 0 + np_rng.shuffle(temp_shuffle_idx) + sample_idx.append(np.array([len(doc_idx), 0])) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + elif packing_impl == "unpacked": + # Unpacked data, one sample per document. + shuffle_idx = np.array([i % len(documents) for i in range(num_samples)]) + np_rng.shuffle(shuffle_idx) + sample_idx = np.zeros((num_samples + 1, 2), dtype=np.int64) + sample_idx[:, 0] = np.array([i for i in range(num_samples + 1)]) + sample_idx[:, 1] = 0 + doc_idx = list() + doc_i = 0 + while len(doc_idx) <= num_samples: + # Check if we need to skip this item... + if not allow_chopped: + # +1 since we shift left/right by 1 + if pos_sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + if neg_sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + # In theory if we don't allow chopped we should be able to skip it, but the warm fuzzies I get + # from this are worth the extra bool check + if np.all(pos_label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + if np.all(neg_label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + doc_idx.append(doc_i) + doc_i = (doc_i + 1) % len(documents) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_io_parallel_group() + ) + + # Load mappings. + start_time = time.time() + print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading sample-idx mapping from {}".format(sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence length, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng): + """Build an array with length = number-of-epochs * number-of-documents. + Each index is mapped to a corresponding document.""" + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int64) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Beginning offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += remaining_seq_length + doc_length - 1 + remaining_seq_length = 0 + else: + # Otherwise, start from the beginning of the next document. + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(size, np_rng): + """Build the range [0, size) and shuffle.""" + dtype_ = np.uint32 + if size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx) + return shuffle_idx diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 6878c79eb..7b1a60d46 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -873,6 +873,42 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to validation label datasets (not shifted by 1 yet!). """ + pos_train_data_paths: list = None + neg_train_data_paths: list = None + """ + List of paths to positive and negative training datasets. + """ + + pos_train_label_data_paths: list = None + neg_train_label_data_paths: list = None + """ + List of paths to positive and negative training label datasets (not shifted by 1 yet!). + """ + + pos_valid_data_paths: list = None + neg_valid_data_paths: list = None + """ + List of paths to positive and negative validation datasets. + """ + + pos_valid_label_data_paths: list = None + neg_valid_label_data_paths: list = None + """ + List of paths to positive and negative validation label datasets (not shifted by 1 yet!). + """ + + pos_test_data_paths: list = None + neg_test_data_paths: list = None + """ + List of paths to positive and negative test datasets. + """ + + pos_test_label_data_paths: list = None + neg_test_label_data_paths: list = None + """ + List of paths to positive and negative test label datasets (not shifted by 1 yet!). + """ + train_data_weights: list = None """ List of 'weights' that decide how often to sample from each training dataset when blending datasets. If None, defaults to equal weighting. @@ -929,6 +965,26 @@ class NeoXArgsTraining(NeoXArgsTemplate): warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets """ + dataset_impl: Literal["gpt2", "pairwise"] = "gpt2" + """ + Dataset implementation, can be one of "gpt2" or "pairwise" + """ + + train_impl: Literal["normal", "dpo"] = "normal" + """ + Training implementation, can be one of "normal" or "dpo" + """ + + dpo_fp32: bool = True + """ + Whether to cast logits to fp32 for DPO loss calculation. + """ + + dpo_beta: float = 0.1 + """ + Beta value for DPO + """ + allow_chopped: bool = True """ WARNING: if your packing impl is packed, this is ignored. diff --git a/megatron/training.py b/megatron/training.py index 482b4154f..3c6a6b506 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -21,12 +21,14 @@ """Pretrain utilities.""" from datetime import datetime from functools import partial +from collections import defaultdict import math import sys from contextlib import nullcontext import torch +import torch.nn.functional as F import deepspeed from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler import numpy as np @@ -44,6 +46,7 @@ SoftEmbedding, get_params_for_weight_decay_optimization, ) +from megatron.mpu.mappings import gather_from_model_parallel_region from megatron.checkpointing import load_checkpoint, save_checkpoint from megatron.data.data_utils import build_train_valid_test_data_iterators from megatron.initialize import initialize_megatron @@ -136,7 +139,7 @@ def gen(): old_hidden_size = neox_args.hidden_size neox_args.hidden_size = hidden_size - model, optimizer, _ = setup_model_and_optimizer( + model, optimizer, _, _ = setup_model_and_optimizer( neox_args=neox_args, use_cache=False ) @@ -192,7 +195,7 @@ def pretrain(neox_args): # Model, optimizer, and learning rate. timers("model and optimizer").start() - model, optimizer, lr_scheduler = setup_model_and_optimizer( + model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer( neox_args=neox_args, use_cache=False, iteration=neox_args.iteration ) timers("model and optimizer").stop() @@ -230,6 +233,7 @@ def pretrain(neox_args): neox_args=neox_args, timers=timers, model=model, + reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, train_data_iterator=train_data_iterator, @@ -310,7 +314,14 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] + if neox_args.train_impl == "normal": + keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + elif neox_args.train_impl == "dpo": + keys = ( + [["pos", "pos_label"], ["neg", "neg_label"]] + if neox_args.pos_label_data_paths + else [["pos"], ["neg"]] + ) datatype = torch.int64 # Broadcast data. @@ -318,13 +329,33 @@ def get_batch(neox_args, data_iterator): data = next(data_iterator) else: data = None - return _get_batch( - neox_args=neox_args, - tokenizer=neox_args.tokenizer, - keys=keys, - data=data, - datatype=datatype, - ) + if neox_args.train_type == "normal": + return _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys, + data=data, + datatype=datatype, + ) + elif neox_args.train_type == "dpo": + pos_tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys[0], + data=data, + datatype=datatype, + ) + neg_tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys[1], + data=data, + datatype=datatype, + ) + return [ + torch.cat((pos_item, neg_item), dim=0) + for pos_item, neg_item in zip(pos_tup, neg_tup) + ] def get_batch_pipe(data, neox_args, curr_scheduler=None): @@ -418,8 +449,23 @@ def mb_moe_loss_func(args, loss_mask, output_tensor=None): return averaged_lbl, loss_dict +def get_pos_neg_logp(logits, labels, force_fp32=False): + if force_fp32: + logits = logits.float() + logp = logits.log_softmax(dim=-1) + per_token_logp = torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + # Split to pos/neg... + return torch.chunk(per_token_logp, 2, 0) + + def forward_step( - data_iterator, model, neox_args, timers, return_logits=False, is_train=False + data_iterator, + model, + neox_args, + timers, + return_logits=False, + is_train=False, + reference_model=None, ): """Forward step.""" if neox_args.is_pipe_parallel: @@ -441,38 +487,97 @@ def forward_step( if neox_args.memory_profiling: torch.cuda.nvtx.range_push(f"Forward pass") - # Sequential returns moe_losses, but this is not yet supported by pipe parallel - maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) - if type(maybe_tuple) is tuple: - outputs, moe_losses = maybe_tuple - else: - outputs = maybe_tuple - moe_losses = [] - if ( - is_train - and neox_args.curriculum_learning - and neox_args.curriculum_seqlen < neox_args.seq_length - ): - loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() - labels = labels[:, : neox_args.curriculum_seqlen].contiguous() - main_loss = cross_entropy( - outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy - ) - if neox_args.moe_num_experts > 1: - if neox_args.moe_type == "deepspeed": - moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) - elif neox_args.moe_type == "megablocks": - moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] + metrics = {} + if neox_args.train_impl == "normal": + # Sequential returns moe_losses, but this is not yet supported by pipe parallel + maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) + if type(maybe_tuple) is tuple: + outputs, moe_losses = maybe_tuple else: - raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") - else: - moe_loss = 0.0 - loss = main_loss + moe_loss + outputs = maybe_tuple + moe_losses = [] + if ( + is_train + and neox_args.curriculum_learning + and neox_args.curriculum_seqlen < neox_args.seq_length + ): + loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() + labels = labels[:, : neox_args.curriculum_seqlen].contiguous() + main_loss = cross_entropy( + outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy + ) + if neox_args.moe_num_experts > 1: + if neox_args.moe_type == "deepspeed": + moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) + elif neox_args.moe_type == "megablocks": + moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] + else: + raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") + else: + moe_loss = 0.0 + loss = main_loss + moe_loss + elif neox_args.train_type == "dpo": + # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 + with torch.no_grad(): + # So we can gather token logps... + token_logp_labels = labels.clone() + token_logp_labels[token_logp_labels == -100] = 0 + pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) + ref_maybe_tuple = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + ref_outputs, _ = ref_maybe_tuple + else: + ref_outputs = ref_maybe_tuple + # gather across tensor parallel group + ref_outputs = gather_from_model_parallel_region(ref_outputs) + ref_pos, ref_neg = get_pos_neg_logp( + ref_outputs, token_logp_labels, neox_args.dpo_fp32 + ) + ref_pos = (ref_pos * pos_loss_mask).sum(-1) + ref_neg = (ref_neg * neg_loss_mask).sum(-1) + chosen_maybe_tuple = model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(chosen_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + chosen_outputs, _ = chosen_maybe_tuple + else: + chosen_outputs = chosen_maybe_tuple + chosen_outputs = gather_from_model_parallel_region(chosen_outputs) + chosen_pos, chosen_neg = get_pos_neg_logp( + chosen_outputs, token_logp_labels, neox_args.dpo_fp32 + ) + chosen_pos = (chosen_pos * pos_loss_mask).sum(-1) + chosen_neg = (chosen_neg * neg_loss_mask).sum(-1) + with torch.no_grad(): + # Collect metrics... + metrics["ref_neg"] = ref_neg.clone().detach().mean() + metrics["ref_pos"] = ref_pos.clone().detach().mean() + metrics["chosen_neg"] = chosen_neg.clone().detach().mean() + metrics["chosen_pos"] = chosen_pos.clone().detach().mean() + chosen_rewards = neox_args.dpo_beta * ( + chosen_pos.clone().detach() - ref_pos.clone().detach() + ) + rejected_rewards = neox_args.dpo_beta * ( + chosen_neg.clone().detach() - ref_neg.clone().detach() + ) + reward_acc = (chosen_rewards > rejected_rewards).float() + metrics["reward_acc"] = reward_acc.mean() + metrics["chosen_rewards"] = chosen_rewards.mean() + metrics["rejected_rewards"] = rejected_rewards.mean() + metrics["margins"] = (chosen_rewards - rejected_rewards).mean() + pi_logrations = chosen_pos - chosen_neg + ref_logrations = ref_pos - ref_neg + logits = pi_logrations - ref_logrations + loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean() if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: - return loss, outputs - return loss + return loss, outputs, metrics + return loss, metrics def get_model(neox_args, use_cache=False): @@ -547,7 +652,7 @@ def get_model(neox_args, use_cache=False): raise ValueError("Must be using deepspeed to run neox") -def get_optimizer(model, neox_args): +def get_optimizer(model, neox_args, dummy=False): """Set up the optimizer.""" if neox_args.no_load_optim: return None, None @@ -583,8 +688,13 @@ def get_optimizer(model, neox_args): _param_groups = [] for param_group in param_groups: trainable_params = [p for p in param_group["params"] if p.requires_grad] + if dummy: + trainable_params = [trainable_params[0]] # just take the first one param_group["params"] = trainable_params _param_groups.append(param_group) + if dummy: + # Only need one. + break param_groups = _param_groups # If we're using mup, then the optimizer must be adam or sgd @@ -743,10 +853,24 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" + needs_reference_model = neox_args.train_type == "dpo" model = get_model(neox_args=neox_args, use_cache=use_cache) + if needs_reference_model: + reference_model = get_model(neox_args=neox_args, use_cache=use_cache) + else: + reference_model = None optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args) lr_scheduler = get_learning_rate_scheduler(optimizer=optimizer, neox_args=neox_args) - + if neox_args.deepspeed and needs_reference_model: + # Need an optimizer & lr_scheduler so make a very small one to keep deepspeed happy... + ref_optimizer, ref_param_groups = get_optimizer( + model=reference_model, neox_args=neox_args, dummy=True + ) + ref_lr_scheduler = get_learning_rate_scheduler( + optimizer=ref_optimizer, neox_args=neox_args + ) + else: + ref_optimizer, ref_param_groups, ref_lr_scheduler = None, None, None if neox_args.deepspeed: print_rank_0("DeepSpeed is enabled.") if neox_args.no_load_optim: @@ -768,6 +892,16 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + if needs_reference_model: + reference_model, _, _, _ = deepspeed.initialize( + model=reference_model, + optimizer=ref_optimizer, + args=neox_args, + lr_scheduler=ref_lr_scheduler, + dist_init_required=False, + model_parameters=ref_param_groups, + mpu=mpu if not neox_args.is_pipe_parallel else None, + ) if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. model.has_moe_layers = True @@ -799,10 +933,19 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): neox_args.iteration = load_checkpoint( neox_args=neox_args, model=model, + reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, iteration=iteration, ) + if needs_reference_model: + _ = load_checkpoint( + neox_args=neox_args, + model=reference_model, + optimizer=ref_optimizer, + lr_scheduler=ref_lr_scheduler, + iteration=iteration, + ) print_rank_0( f"Loading checkpoint and starting from iteration {neox_args.iteration}" ) @@ -814,7 +957,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): if lr_scheduler is not None: lr_scheduler.optimizer = model.optimizer - return model, optimizer, lr_scheduler + return model, optimizer, lr_scheduler, reference_model def backward_step(neox_args, timers, optimizer, model, loss): @@ -836,7 +979,15 @@ def backward_step(neox_args, timers, optimizer, model, loss): raise ValueError("Must be using deepspeed to run neox") -def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler): +def train_step( + neox_args, + timers, + data_iterator, + model, + optimizer, + lr_scheduler, + reference_model=None, +): """Single training step.""" # Pipeline parallelism schedules forward/backward/step @@ -844,6 +995,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) reduced_loss = train_step_pipe( neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator ) + reduced_metrics = {"lm_loss": reduced_loss} if ( neox_args.memory_profiling and neox_args.iteration >= neox_args.profile_step_start @@ -853,18 +1005,22 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) save_snapshot(neox_args) else: losses = [] + metric_dicts = defaultdict(list) for _ in range(neox_args.gradient_accumulation_steps): # Forward model for one step. timers("forward").start() - loss = forward_step( + loss, metric_dict = forward_step( neox_args=neox_args, timers=timers, data_iterator=data_iterator, model=model, is_train=True, + reference_model=reference_model, ) timers("forward").stop() losses.append(loss) + for key in metric_dict.keys(): + metric_dicts[key].append(metric_dict[key]) # Calculate gradients, reduce across processes, and clip. if ( neox_args.profile @@ -913,17 +1069,20 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) and torch.distributed.get_rank() == 0 ): save_snapshot(neox_args) - reduced_loss = { - "lm_loss": reduce_losses(losses).mean() - } # reduces losses across machines for logging + # reduces metrics across machines for logging + reduce_metrics = { + key: reduce_losses([metric_dicts[key]]).mean() + for key in metric_dicts.keys() + } + reduce_metrics["lm_loss"] = reduce_losses(losses).mean() if neox_args.precision == "fp16" and model.optimizer.overflow: skipped_iter = 1 else: skipped_iter = 0 - collect_loss_for_unit_test(reduced_loss["lm_loss"]) - return reduced_loss, skipped_iter + collect_loss_for_unit_test(reduce_metrics["lm_loss"]) + return reduce_metrics, skipped_iter def train_step_pipe(neox_args, timers, model, data_iterator): @@ -949,6 +1108,7 @@ def train( neox_args, timers, model, + reference_model, optimizer, lr_scheduler, train_data_iterator, @@ -1004,6 +1164,7 @@ def train( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, + reference_model=reference_model, ) if neox_args.profile and iteration == neox_args.profile_step_stop: torch.cuda.cudart().cudaProfilerStop() @@ -1094,6 +1255,7 @@ def evaluate( # Turn on evaluation mode which disables dropout. model.eval() losses = [] + metric_dicts = defaultdict(list) if neox_args.char_level_ppl: data_iterator = CharCounter(data_iterator, neox_args.tokenizer) @@ -1115,14 +1277,15 @@ def evaluate( else neox_args.gradient_accumulation_steps ): # Forward evaluation - loss = forward_step_fn( + loss, metric_dict = forward_step_fn( model=model, data_iterator=data_iterator, neox_args=neox_args, timers=timers, ) losses.append(loss) - + for key in metric_dict.keys(): + metric_dicts[key].append(metric_dict[key]) # When contiguous memory optimizations are enabled, the buffers # allocated by the optimizations are deallocated during backward pass # in the absence of backward pass the buffers should be reset after each @@ -1132,6 +1295,8 @@ def evaluate( # reduces losses across processes for logging & run eval harness tasks eval_results = {"lm_loss": reduce_losses(losses).mean().item()} + for key in metric_dicts.keys(): + eval_results[key] = reduce_losses(metric_dicts[key]).mean().item() eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"]) if neox_args.char_level_ppl: diff --git a/megatron/utils.py b/megatron/utils.py index 26b4439bd..a64a8ba6c 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -449,7 +449,7 @@ def setup_for_inference_or_eval(use_cache=True, overwrite_values=None, input_arg initialize_megatron(neox_args) # set up model and load checkpoint. - model, _, _ = setup_model_and_optimizer( + model, _, _, _ = setup_model_and_optimizer( neox_args=neox_args, use_cache=use_cache, iteration=neox_args.iteration,