Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tarred dataset len when num shards is not divisible by workers #4553

Merged
merged 20 commits into from
Jul 26, 2022

Conversation

itzsimpl
Copy link
Contributor

@itzsimpl itzsimpl commented Jul 15, 2022

Signed-off-by: Iztok Lebar Bajec ilb@fri.uni-lj.si

What does this PR do ?

Fixes #4522. When using tarred datasets and the number of shards is not divisible by world_size, nemo logs a warning, but training continues. However, in this case the computation of DataLoader.len() becomes incorrect. This leads to invalid lr_scheduler.max_steps, and other issues down the road. One example being Validation check not running at end of epoch.

Collection: [nlp]

  • language_modelling
  • machine_translation
  • text_normalization
  • token_classification

Changelog

In case of shard_strategy='scatter' the len() is computed based on the amount of taken shards and assuming total num_baches are equally distributed over all shards, i.e. that each shard contains the same number of batches. In addition, following other implementations this PR adds shard_strategy support to NLP/PC tarred datasets.

There are two additional occasions, where tarred datasets are used

if shard_strategy == 'scatter':

which based on the implementation does not seem to be affected by this issue, and
if shard_strategy == 'scatter':
logging.info("All tarred dataset shards will be scattered evenly across all nodes.")
if len(text_tar_filepaths) % world_size != 0:
logging.warning(
f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible "
f"by number of distributed workers ({world_size})."
)
begin_idx = (len(text_tar_filepaths) // world_size) * global_rank
end_idx = begin_idx + (len(text_tar_filepaths) // world_size)
text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx]
logging.info(
"Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx
)

which might, but I chose to leave it as is, as the computation of len()
def __len__(self):
return (self.metadata['num_text'] - self.max_seq_length) // self.batch_step

greatly differs from others

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
note to @PeganovAnton for NLP/C

Additional Information

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>
Copy link
Contributor

@MaximumEntropy MaximumEntropy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find. Thanks for fixing this! Looks good to me for NMT and Language Modeling, will let @ekmb and @PeganovAnton review for Punctuation and Capitalization.

@MaximumEntropy MaximumEntropy requested a review from ekmb July 15, 2022 18:28
@PeganovAnton PeganovAnton added the bug Something isn't working label Jul 16, 2022
@PeganovAnton PeganovAnton self-requested a review July 16, 2022 08:19
@PeganovAnton
Copy link
Contributor

Great fix! Reviewed punctuation part.

@itzsimpl
Copy link
Contributor Author

itzsimpl commented Jul 16, 2022

While running NLP/PC on 8 GPUs I've noticed that the lr_scheduler.max_steps computation is still off. Indeed, the implementation of compute_max_steps assumes num_samples is the entire dataset size, not per worker, see L866.

def compute_max_steps(
max_epochs, accumulate_grad_batches, limit_train_batches, num_workers, num_samples, batch_size, drop_last
):
_round = math.floor if drop_last else math.ceil
sampler_num_samples = math.ceil(num_samples / max(1, num_workers))
if drop_last and num_workers > 1:
logging.warning(
"Please note that drop_last is broken in pytorch 1.6.0. We will fix when pytorch 1.7.0 is released"
)
# TODO: Master version, not in pytorch 1.6.0
# sampler_num_samples = math.ceil((num_samples - num_workers)/ num_workers)
steps_per_epoch = _round(sampler_num_samples / batch_size)
if isinstance(limit_train_batches, int) or limit_train_batches == 0.0:
steps_per_epoch = min(steps_per_epoch, int(limit_train_batches))
elif steps_per_epoch != float('inf'):
# limit_train_batches is a percentage of batches per epoch
steps_per_epoch = int(steps_per_epoch * limit_train_batches)
return math.ceil(steps_per_epoch / accumulate_grad_batches) * max_epochs

But, DataLoader.len() will in case of tarred datasets return the num_samples per worker, which will as per line

num_samples = len(train_dataloader.dataset)

lead to a faulty computation of max_steps, i.e. instead of max_epocs it will optimise for max_epocs/num_workers.

Since prior to this PR DataLoader.len() returned per worker, but the computation of it was faulty, I would like your advice on what is the correct solution.

Multiplying the num_samples per num_workers, will probably lead to faulty behaviour with non tarred datasets, I guess. I can make DataLoader.len() return the correct total number of samples (not per worker), but I do not know if anywhere in the code DataLoader.len() is expected to return the number of samples per worker.

@itzsimpl
Copy link
Contributor Author

@titu1994 I would like your thoughts on what DataLoader.len() should return in case of scattered tarred datasets.

@titu1994
Copy link
Collaborator

ASR tarred datasets always return the total number of samples in the manifest - https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/data/audio_to_text.py#L765

For tarred dataset we then limit the train dataset to have a fixed number of steps per epoch explicitly - https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_models.py#L420-L433

@itzsimpl
Copy link
Contributor Author

itzsimpl commented Jul 16, 2022

Thanks @titu1994.

Note though that when the number of shards is not divisible by world_size the scatter shard strategy will not use all data, some shards (the non divisible portion, i.e. at max the last world_size-1 shards, will be dropped), hence the actual number of samples will not be equal to the total number of samples in the manifest, i.e. num_shards_per_rank * shard_size * world_size != len(self.manifest_processor.collection), and the reported length will be invalid even in the case of ASR tarred datasets. This could lead to the same problems as I've observed in NLP/PC.

if shard_strategy == 'scatter':
logging.info("All tarred dataset shards will be scattered evenly across all nodes.")
if len(audio_tar_filepaths) % world_size != 0:
logging.warning(
f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
f"by number of distributed workers ({world_size})."
)
begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank
end_idx = begin_idx + len(audio_tar_filepaths) // world_size
audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
logging.info(
"Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx
)

Assume you have 146 tarred files, each containing 1000 samples, and a world_size of 8, then

  • len(self.manifest_processor.collection) = 146000
  • each rank works with floor(146/8) = 18 shards or 18000 samples, thus the total actual number of samples equals 18*8*1000 = 144000
  • assuming self._trainer.limit_train_batches=1.0 and train_data_config['batch_size']=1
    * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size'])
    will give you 18250, where it should be 18000.

@titu1994
Copy link
Collaborator

In practice it doesn't matter, we shuffle tarfiles before each draw of samples anyway.

@itzsimpl
Copy link
Contributor Author

@titu1994 Correct me if I'm wrong, but based on the way it is implemented even in ASR shuffling occurs after the WebDataset is built from scattered tar files, i.e. a from a subset of all files. I don’t see any shuffling of audio_tar_filepaths prior to expand_audio_filepaths, which does the scattering. 



Hence, if my understanding is correct, in the above case of 8GPUs, 146 tar files with 1000 samples each, only the first 144 of 146 tar files will ever be used. A larger number of samples in limit_train_batches (due to total number of samples / world_size) will result in oversampling the shuffled tar files that were assigned to a specific rank. The ‘replicate’ strategy will on the other hand take samples from all 146 tar files.



What I am seeing in NLP/PC is that in the above case:

  • Prior to this PR, validation at the end of epoch was not run, due to an invalid number of samples returned by DataLoader.len()
  • If DataLoader.len() returns the correct number of samples a rank has, the number of steps per epoch will be correct, but lr_scheduler will compute an invalid number of effective maximum steps; with num_epochs set, a simple linear lr_scheduler will end sooner than expected, at num_epocs/world_size, more complex schedulers may hide this problem.
  • If I return the actual total length of the dataset, lr_scheduler will compute the correct number of effective maximum steps, but the progress bars will have a wrong number of samples. Assuming only 1 validation per epoch, validation will be run at approximately 1/world_size and the epoch will end immediately afterwords.

I can:

  1. 
Make DataLoader.len() return the total number of samples and modify limit_train_batches in setup_training_data when the data is tarred, like ASR does, but this will in the case when the number of shards is non divisible by world_size and shard strategy is scattered, as mentioned earlier, lead to oversampling of a subset of all data. The only mention that not all data was used will be the shard strategy warning in the logs.
  2. Make DataLoader.len() return the actual total number of samples and modify limit_train_batches in setup_training_data when the data is tarred, like ASR does, which will avoid oversampling, but since in the case when the number of shards is non divisible by world_size and shard strategy is scattered not all data is used, the total number of steps will be correct, however, also different from the case when shard strategy is replicate or non tarred data is used. The fact that not all data was used will, beside the shard strategy warning in the logs, reflect also in the total number of steps.

Let me know which solution do you prefer. Personally, I would opt for the second, especially since some collections use tarred datasets also for validation, where exclusion of data may be acceptable, but oversampling may be less so.

PeganovAnton
PeganovAnton previously approved these changes Jul 17, 2022
@titu1994
Copy link
Collaborator

Hence, if my understanding is correct, in the above case of 8GPUs, 146 tar files with 1000 samples each, only the first 144 of 146 tar files will ever be used.

No, Webdataset shuffles tarfile list every time iter is called, which is after every epoch - https://github.com/webdataset/webdataset/blob/0.1.65/webdataset/dataset.py#L118

titu1994
titu1994 previously approved these changes Jul 18, 2022
@itzsimpl
Copy link
Contributor Author

@titu1994 The DCO required me to do a rebase, which I followed to the letter, hope this didn't mess things up.

@MaximumEntropy
Copy link
Contributor

DCO Rebase always messes things up. I think you need to git checkout upstream/main -- untouched_file.py to reset all the files that you didn't touch.

@itzsimpl
Copy link
Contributor Author

I'm a bit lost here, don't have much experience with rebasing.

What I'm seeing is that under the commits section here there are a bunch of commits from the main branch, that are now signed of also by me. I don't think doing checkouts of files will unsigne those commits. I assume this happened because while working on this PR at some point, or multiple, main was merged in to keep the PR updated, and doing git rebase HEAD~14 --signoff followed by git push --force-with-lease origin fix/tarred_dataset_len as instruced by the DCO, signed those commits.

I have a backup of the repo prior to doing the rebase, if this helps.

Please advise to what should I do, or let me know the status is acceptable the way it is.

@titu1994
Copy link
Collaborator

Skip dco for now, we will override it.

@titu1994
Copy link
Collaborator

To fix this, ignore dco bot and rebase with main branch using an ide in interactive mode, there shouldn't be any conflicts, and finally force push. Dco messed up here and made you sign commits you didn't touch.

@MaximumEntropy
Copy link
Contributor

The way it is right now is definitely not what its supposed to be before merging, reverting back to what you had before the DCO related commits might be the best thing to do and we can set DCO to pass.

@titu1994
Copy link
Collaborator

Agreed let's get the Pr back to original state and override dco

@itzsimpl
Copy link
Contributor Author

OK. This should be it then.

@PeganovAnton
Copy link
Contributor

@titu1994 should I wait until DCO will be resolved?

@MaximumEntropy
Copy link
Contributor

I just set DCO to pass. @PeganovAnton please take a final look.

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>
Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>
@itzsimpl
Copy link
Contributor Author

@MaximumEntropy you'll probably have to once again set DCO to pass.

@itzsimpl itzsimpl requested a review from titu1994 July 23, 2022 11:24
@itzsimpl
Copy link
Contributor Author

@titu1994 @MaximumEntropy I guess one of you will need to set DCO to pass. If I follow its instructions and rebase it will mess things up.

@PeganovAnton PeganovAnton merged commit 7890979 into NVIDIA:main Jul 26, 2022
Davood-M pushed a commit to Davood-M/NeMo that referenced this pull request Aug 9, 2022
…VIDIA#4553)

* fix tarred dataset len when num shards is not divisible by workers

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

* update error reporting on invalid `shard_strategy`

* update NLP/PC tarred dataset docstring

* add `shard_strategy` to NLP/PC `@dataclass`

* update NLP/PC tarred dataset docstring

* add `shard_strategy` to NLP/PC docs

* revert test with Dataloader retruning the actual data length

* make dataloader return actual num of samples, set `limit_train_baches` on `setup_*`

* update `shard_strategy` docstrings

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

* update `tarred_dataset` documentation

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

* fix style

* update documentation

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

* updated docstrings

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

Co-authored-by: PeganovAnton <peganoff2@mail.ru>
Signed-off-by: David Mosallanezhad <dmosallanezh@nvidia.com>
hainan-xv pushed a commit to hainan-xv/NeMo that referenced this pull request Nov 29, 2022
…VIDIA#4553)

* fix tarred dataset len when num shards is not divisible by workers

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

* update error reporting on invalid `shard_strategy`

* update NLP/PC tarred dataset docstring

* add `shard_strategy` to NLP/PC `@dataclass`

* update NLP/PC tarred dataset docstring

* add `shard_strategy` to NLP/PC docs

* revert test with Dataloader retruning the actual data length

* make dataloader return actual num of samples, set `limit_train_baches` on `setup_*`

* update `shard_strategy` docstrings

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

* update `tarred_dataset` documentation

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

* fix style

* update documentation

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

* updated docstrings

Signed-off-by: Iztok Lebar Bajec <ilb@fri.uni-lj.si>

Co-authored-by: PeganovAnton <peganoff2@mail.ru>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[NLP] punctuation_capitalization resume training fails
4 participants