Skip to content

Conversation

@le1nux
Copy link
Member

@le1nux le1nux commented Sep 24, 2024

What does this PR do?

This PR addresses issue #258 (inefficiencies in the dataloader) and additionally introduces a combined dataset, where a dataset can now comprise a list of datasets and iterate over them.
As part of fixing the dataloader inefficiencies, we now implement the sample skipping functionality not on the dataloader level anymore but in an adapted version of the PyTorch DistributedSampler. I reran a warm start and the learning is equivalent to a full, non-warmstarted run.

Screenshot 2024-09-27 at 10 36 19

General Changes

  • Introduced ResumableDistributedSampler which is a copy of the PyTorch DistributedSampler added with the feature to skip samples. This is from now on used for warmstarts instead of the skip_num_samples in the Dataloader. In case of skipping samples, the dataloader had to instantiate a ResumableBatchSampler which was internally iterating over all the dataset indices. For small datasets this was fine, but for larger datasets (in the trillion token range) this became a bottleneck at instantiation time:
    self.underlying_batch_sampler = underlying_batch_sampler
    # NOTE: we are only iterating ove the indices not the actual data
    # so this is relatively cheap
    self.indices = list(iter(self.underlying_batch_sampler))

    Skipping in the ResumableDistributedSampler is skipping in O(1) now. The ResumableBatchSampler was removed from the codebase.
  • Replaced the packed index generation routine (inefficient due to for loop)
    return [
    ((i * self.block_size - i) * self._token_size_in_bytes, self.block_size * self._token_size_in_bytes)
    for i in range(num_samples)
    ]

    with a vectorized version.
  • added new NumberConversion routine num_samples_from_num_tokens

Breaking Changes

  • Removed RepeatingDataloader, as a feature that was never actively used for running multiple epochs and had complex maintenance when refactoring the sampling. If needed we could reimpliment it.
  • In the settings, the training_progress section has now num_seen_samples instead of local_num_seen_batches , as skipping is now done on the Sampler level and not on the dataloader level anymore
  • batch_size and fast_forward_batch_id fields in the LLMDataLoader are not neede anymore and were removed.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

Copy link
Member

@flxst flxst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :) Left a few minor comments.

raw_data_path (Path): Path to a packed binary file (*.pbin).
Use `modalities data pack_encoded_data` to create one based on a JSONL-file.
sample_key (str): The key to access the sample in the BatchEncoding.
load_index (bool, optional): Flag indicating whether to load the index. Defaults to True.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more consistent if this defaulted to False like in PackedMemMapDatasetContinuous (see line 308)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PackedMemMapDatasetContinuous we would never load the index (apart for debugging purposes), that's why I made it defaulting to False. The "Continuuous" implementation does not need an index. The PackedMemMapDatasetBase, however, in it's default implementation would use the index for packing the data, which is why it defaults to True.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comments on the changelog:

  • The table in the beginning also needs to be updated
  • It might be useful to reverse the order of PRs so that the newest ones come first

le1nux and others added 8 commits October 24, 2024 00:00
Co-authored-by: Felix Stollenwerk <felix.stollenwerk@ai.se>
Co-authored-by: Felix Stollenwerk <felix.stollenwerk@ai.se>
Co-authored-by: Felix Stollenwerk <felix.stollenwerk@ai.se>
Co-authored-by: Felix Stollenwerk <felix.stollenwerk@ai.se>
Co-authored-by: Felix Stollenwerk <felix.stollenwerk@ai.se>
T_co = TypeVar("T_co", covariant=True)


class ResumableDistributedSampler(Sampler[T_co]):
Copy link
Member

@mali-git mali-git Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid copying PyTorch's Sampler and extend it instead? E.g..:

import torch
from torch.utils.data import Sampler

class CustomDistributedSampler(torch.utils.data.distributed.DistributedSampler):
    def __init__(self, dataset, start_idx=0, ...):
        super().__init__(dataset, ...)
        self.start_idx = start_idx

    def __iter__(self):
        # Get all indices assigned to this process
        indices = list(super().__iter__())

        # Filter indices to include only those starting from start_idx
        filtered_indices = [idx for idx in indices if idx >= self.start_idx]

        return iter(filtered_indices)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't remember why we decided for adapting DistributedSampler in pytorch instead of inheriting or using composition. From the start of modalities, wie basically used the adaptation. We could make this another issue to investigate if it makes sense to move to the original pytorch DistributedSampler.

@le1nux le1nux merged commit 94cf3f0 into main Nov 22, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants