Skip to content

Commit

Permalink
Merge branch 'boxiangw/mlperf-option-add-one-extra-token' into 'main'
Browse files Browse the repository at this point in the history
[MLPerf] GPT dataset features: drop last partial validation sequence, drop extra token, return sample with 1s loss mask, mock dataset testing

See merge request ADLR/megatron-lm!1223
  • Loading branch information
ericharper committed May 2, 2024
2 parents 2297178 + c90aa16 commit db3a3f7
Show file tree
Hide file tree
Showing 27 changed files with 543 additions and 414 deletions.
3 changes: 1 addition & 2 deletions examples/run_simple_mcore_train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def get_train_data_iterator():
config = GPTDatasetConfig(
random_seed = 0,
sequence_length = 64,
blend=[],
mock=True,
blend=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
Expand Down
7 changes: 3 additions & 4 deletions megatron/core/QuickStart.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,9 @@ from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset
def get_train_data_iterator():
config = GPTDatasetConfig(
random_seed = 0,
sequence_length = 64,
blend=[],
mock=True,
random_seed=0,
sequence_length=64,
blend=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
Expand Down
7 changes: 2 additions & 5 deletions megatron/core/datasets/bert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BERTMaskedWordPieceDataset(MaskedWordPieceDataset):
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (int): The number of samples to draw from the indexed dataset
num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
Expand All @@ -50,17 +50,14 @@ def __init__(
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: int,
num_samples: Optional[int],
index_split: Split,
config: BERTMaskedWordPieceDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)

def _finalize(self) -> None:
"""Abstract method implementation
"""
self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
# Account for the single <cls> and two <sep> token ids
self.sample_index = self._build_sample_index(
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/datasets/blended_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _build_indices(self) -> Tuple[numpy.ndarray, numpy.ndarray]:
log_single_rank(
logger,
logging.WARNING,
"Unable to save the blending indexes because path_to_cache is None",
f"Unable to save the {type(self).__name__} indexes because path_to_cache is None",
)

t_end = time.time()
Expand Down
50 changes: 22 additions & 28 deletions megatron/core/datasets/blended_megatron_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from megatron.core.datasets.blended_dataset import BlendedDataset
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset, MockDataset
from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset
from megatron.core.datasets.utils import Split, log_single_rank, normalize
from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank

Expand Down Expand Up @@ -51,13 +51,11 @@ def __init__(

log_single_rank(
logger,
logging.WARNING,
logging.INFO,
f"Building dataset splits with cls={cls.__name__}, sizes={self.sizes}, and config={self.config}",
)

if self.config.mock:
assert issubclass(self.cls, MockDataset)
else:
if not self.config.mock:
for split in Split:
size_is_none = self.sizes[split.value] is None
if self.config.blend_per_split is None:
Expand Down Expand Up @@ -151,7 +149,13 @@ def _build_blended_dataset_splits(self,) -> List[Optional[TopLevelDataset]]:
# Return fake "mock" datasets
##
if self.config.mock:
return self._build_megatron_dataset_splits(None, None, self.sizes)
split = self.config.split_matrix
try:
return self._build_megatron_dataset_splits(None, split, self.sizes)
except Exception as error:
raise Exception(
f"{self.cls.__name__} failed to build as a mock data generator"
) from error

##
# All splits come from the same distribution
Expand Down Expand Up @@ -282,7 +286,7 @@ def _build_megatron_dataset_splits(
"""Build each MidLevelDataset split from a single LowLevelDataset
Args:
dataset_path (Optional[str]): The path on disk which defines the underlying LowLevelDataset, e.g. the .bin and .idx file prefix when self.cls is of type IndexedMegatronDataset or None when self.cls is of type MockDataset
dataset_path (Optional[str]): The path on disk which defines the underlying LowLevelDataset, or None for mock dataset classes
split (List[Tuple[float, float]]): The dataset split matrix
Expand All @@ -292,33 +296,23 @@ def _build_megatron_dataset_splits(
List[Optional[MidLevelDataset]]: The MidLevelDataset (or None) per split
"""
# Build the low level dataset
if issubclass(self.cls, MockDataset):
low_level_dataset = None
elif issubclass(self.cls, MegatronDataset):
low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config)
else:
raise NotImplementedError
low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config)

# Build the split indices for the low level dataset
if low_level_dataset is not None:
num_elements = self.cls.numel_low_level_dataset(low_level_dataset)
split_indices = []
for i, _ in enumerate(Split):
if split[i] is not None:
beg = int(round(split[i][0] * float(num_elements)))
end = int(round(split[i][1] * float(num_elements)))
split_indices.append(
numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32)
)
else:
split_indices.append(None)
else:
split_indices = [None for _ in Split]
num_elements = self.cls.numel_low_level_dataset(low_level_dataset)
split_indices = []
for i, _ in enumerate(Split):
if split[i] is not None:
beg = int(round(split[i][0] * float(num_elements)))
end = int(round(split[i][1] * float(num_elements)))
split_indices.append(numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32))
else:
split_indices.append(None)

# Build the mid level dataset
mid_level_datasets = []
for i, _split in enumerate(Split):
if not self.config.mock and split[i] is None:
if split[i] is None:
mid_level_datasets.append(None)
else:
mid_level_datasets.append(
Expand Down
67 changes: 35 additions & 32 deletions megatron/core/datasets/blended_megatron_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

import torch

from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from megatron.core.datasets.utils import Split, log_single_rank, normalize

Expand Down Expand Up @@ -53,46 +51,51 @@ class BlendedMegatronDatasetConfig:
mmap_bin_files: bool = True
"""Whether to mmap the .bin files or use file pointers."""

mock: bool = False
"""Whether to bypass real data loading and validation in favor of mock data generation."""
mock: bool = field(init=False, default=False)
"""Whether to bypass real data loading and validation in favor of mock data generation.
Created automatically from 'blend' and 'blend_per_split'. Not to be passed in to the
constructor.
"""

tokenizer: Optional[MegatronTokenizer] = None
"""The MegatronTokenizer instance or None. Required for datasets which do online tokenization."""

def __post_init__(self) -> None:
"""Do asserts and set fields post init
"""
log_single_rank(logger, logging.INFO, f"mock = {self.mock}")

if not self.mock:
if self.blend_per_split is not None and any(self.blend_per_split):
assert self.blend is None, "blend and blend_per_split are incompatible"
assert self.split is None, "split and blend_per_split are incompatible"
assert len(self.blend_per_split) == len(
Split
), f"blend_per_split must contain {len(Split)} blends"
for split in Split:
if self.blend_per_split[split.value] is None:
log_single_rank(
logger, logging.INFO, f"blend not provided for {split.name} split"
)
else:
assert self.blend_per_split[split.value][1] is None or len(
self.blend_per_split[split.value][0]
) == len(
self.blend_per_split[split.value][1]
), "blend per split prefixes and weights must be equal in number"
else:
assert (
self.blend is not None
), "one of either blend or blend_per_split must be provided"
assert self.split is not None, "both blend and split must be provided"
if self.blend_per_split is not None and any(self.blend_per_split):
assert self.blend is None, "blend and blend_per_split are incompatible"
assert self.split is None, "split and blend_per_split are incompatible"
assert len(self.blend_per_split) == len(
Split
), f"blend_per_split must contain {len(Split)} blends"
for split in Split:
if self.blend_per_split[split.value] is None:
log_single_rank(
logger, logging.INFO, f"blend not provided for {split.name} split"
)
else:
assert self.blend_per_split[split.value][1] is None or len(
self.blend_per_split[split.value][0]
) == len(
self.blend_per_split[split.value][1]
), "blend per split prefixes and weights must be equal in number"
else:
assert self.split is not None, "split must be provided in absence of blend_per_split"
split_vector = parse_and_normalize_split(self.split)
self.split_matrix = convert_split_vector_to_split_matrix(split_vector)
log_single_rank(logger, logging.INFO, f"Let split_matrix = {self.split_matrix}")
if self.blend is not None:
assert self.blend[1] is None or len(self.blend[0]) == len(
self.blend[1]
), "blend prefixes and weights must be equal in number"
split_vector = parse_and_normalize_split(self.split)
self.split_matrix = convert_split_vector_to_split_matrix(split_vector)
log_single_rank(logger, logging.INFO, f"Let split_matrix = {self.split_matrix}")
else:
self.mock = True
log_single_rank(
logger,
logging.INFO,
f"Let mock = True, as both blend and blend_per_split are None",
)


def parse_and_normalize_split(split: str) -> List[float]:
Expand Down
Loading

0 comments on commit db3a3f7

Please sign in to comment.