Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make iterable dataset more efficient #219

Merged
merged 7 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,11 +647,6 @@ class TrainConfig(BaseConfig):
to write out a final checkpoint.
"""

save_data_indices: bool = False
"""
If ``True``, write the indices of the examples in each batch for each rank to a tsv file in the save folder.
"""

@property
def autocast_precision(self) -> torch.dtype:
if self.precision == "amp_bf16":
Expand Down
20 changes: 15 additions & 5 deletions olmo/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler

from ..config import DataConfig, TrainConfig
from ..exceptions import OlmoConfigurationError
from ..util import global_rank
from ..util import barrier, get_global_rank, get_world_size
from .collator import DataCollator
from .iterable_dataset import IterableDataset
from .memmap_dataset import MemMapDataset
Expand Down Expand Up @@ -42,15 +42,15 @@ def build_eval_dataloader(
collator = DataCollator(pad_direction=data_config.pad_direction, pad_token_id=train_config.model.pad_token_id)
if data_config.drop_last:
# Make sure batch size is small enough.
samples_per_device = len(dataset) // dist.get_world_size()
samples_per_device = len(dataset) // get_world_size()
batch_size = min(batch_size, samples_per_device)
assert batch_size > 0, f"dataset for {data_config.paths} is too small"
sampler = DistributedSampler(
dataset,
drop_last=data_config.drop_last,
shuffle=shuffle,
num_replicas=dist.get_world_size(),
rank=global_rank(),
num_replicas=get_world_size(),
rank=get_global_rank(),
seed=train_config.seed,
)
return DataLoader(
Expand All @@ -72,13 +72,23 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
)
dataset = build_memmap_dataset(train_config, train_config.data)
work_dir = Path(train_config.save_folder) / "train_data"
if get_global_rank() == 0:
if work_dir.is_dir() and not train_config.save_overwrite:
raise OlmoConfigurationError(
"train data working directory already exists, use --save_overwrite to overwrite"
)
else:
work_dir.mkdir(exist_ok=True, parents=True)
barrier()
return DataLoader(
IterableDataset(
dataset, # type: ignore
seed=train_config.seed,
shuffle=True,
drop_last=train_config.data.drop_last,
max_examples=train_config.global_train_batch_size * train_config.max_duration,
work_dir=work_dir,
),
batch_size=train_config.device_train_batch_size,
drop_last=train_config.data.drop_last,
Expand Down
42 changes: 32 additions & 10 deletions olmo/data/iterable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union

import numpy as np
import torch
import torch.distributed as dist
import torch.utils.data

from ..util import global_rank
from ..aliases import PathOrStr
from ..util import barrier, get_global_rank, get_world_size

__all__ = ["IterableDataset"]

Expand All @@ -29,19 +31,16 @@ def __init__(
drop_last: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
work_dir: Optional[PathOrStr] = None,
):
self.dataset = dataset
self.seed = seed
self.start_index = start_index
self.max_examples = max_examples
self.shuffle = shuffle
self.drop_last = drop_last
self.rank = rank if rank is not None else global_rank()
self.world_size = (
world_size
if world_size is not None
else (dist.get_world_size() if (dist.is_available() and dist.is_initialized()) else 1)
)
self.rank = rank if rank is not None else get_global_rank()
self.world_size = world_size if world_size is not None else get_world_size()
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type]
Expand All @@ -53,8 +52,21 @@ def __init__(
else:
num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type]
self.total_size = num_samples * self.world_size
self.global_indices_file: Optional[Path] = None
if work_dir is not None:
self.global_indices_file = Path(work_dir) / "global_indices.npy"
if self.rank == 0:
self.global_indices_file.parent.mkdir(parents=True, exist_ok=True)
global_indices = self._build_global_indices()
global_indices_mmap = np.memmap(
self.global_indices_file, dtype=np.uint64, mode="w+", shape=(len(global_indices),)
)
global_indices_mmap[:] = global_indices
global_indices_mmap.flush()
del global_indices_mmap
barrier()

def __iter__(self) -> Iterator[Dict[str, Any]]:
def _build_global_indices(self) -> List[int]:
if self.shuffle:
# Deterministically shuffle based on epoch and seed
g = torch.Generator()
Expand All @@ -74,6 +86,16 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
# Remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
return indices

def get_global_indices(self) -> Sequence[int]:
if self.global_indices_file is not None:
return np.memmap(self.global_indices_file, mode="r", dtype=np.uint64) # type: ignore
else:
return self._build_global_indices()

def __iter__(self) -> Iterator[Dict[str, Any]]:
indices = self.get_global_indices()

# Truncate to max_examples.
if self.max_examples is not None:
Expand All @@ -93,7 +115,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
if worker_info is not None:
indices = indices[worker_info.id :: worker_info.num_workers]

return (self._get_dataset_item(idx) for idx in indices)
return (self._get_dataset_item(int(idx)) for idx in indices)

def _get_dataset_item(self, idx: int) -> Dict[str, Any]:
item = self.dataset[idx]
Expand Down
2 changes: 1 addition & 1 deletion olmo/data/memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MemMapDataset(Dataset[Dict[str, Any]]):
remainder of the tokens will be ignored.

No special tokens are added to the input IDs so it's assumed that if you want
EOS tokens between documents, for example, those will already by in the memory-mapped array.
EOS tokens between documents, for example, those will already be in the memory-mapped array.

:param paths: Paths to memory-mapped token arrays.
:param chunk_size: The number of tokens to chunk together into a single instance.
Expand Down
7 changes: 3 additions & 4 deletions olmo/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Dict, List, Union

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torchmetrics import MeanMetric, Metric

from ..config import EvaluatorConfig, EvaluatorType, TrainConfig
from ..exceptions import OlmoConfigurationError
from ..tokenizer import Tokenizer
from ..util import cycle_through_epochs, global_rank
from ..util import cycle_through_epochs, get_global_rank, get_world_size
from .downstream import ICLMetric, label_to_task_map
from .evaluator import Evaluator

Expand Down Expand Up @@ -39,8 +38,8 @@ def build_downstream_evaluator(
ds_eval_dataset,
drop_last=data_config.drop_last,
shuffle=False,
num_replicas=dist.get_world_size(),
rank=global_rank(),
num_replicas=get_world_size(),
rank=get_global_rank(),
seed=train_config.seed,
)
ds_eval_dataloader = DataLoader(
Expand Down
Loading