Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

warnings: resuming before epoch end is absolutely normal for long trainings #18780

Open
stas00 opened this issue Oct 11, 2023 · 6 comments
Open
Labels
data handling Generic data-related topic feature Is an improvement or enhancement
Milestone

Comments

@stas00
Copy link
Contributor

stas00 commented Oct 11, 2023

Description & Motivation

forking from #18723 (comment) where we were discussing various warnings that don't necessarily apply to all.

This issue discusses this warnings:

[...]python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:151: UserWarning: You're 
resuming from a checkpoint that ended before the epoch ended. This can cause unreliable results if 
further training is done. Consider using an end-of-epoch checkpoint

Many shared SLURM environments have a relatively short time limit to each job so one can't do one epoch w/o restart and resuming. e.g. some have only 20h top.

In "can cause unreliable results" are you perhaps implying that there is no guarantee the DL will not continue from where it left off on saving the last checkpoint but will repeat the same data? Shouldn't PTL save the worker RNG state and correctly restore it on resume? Though with a custom DL there is no way PTL could easily do that.

But in general a 3 months training will take many restarts, not only because of a short SLURM job limit, but also because there will be divergences requiring rollbacks, which means restarts.

And yes the operator needs to be super-aware whether the resume breaks the unique flow of samples and leads to replacements.

cc @Borda @justusschock @awaelchli

@stas00 stas00 added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Oct 11, 2023
@awaelchli
Copy link
Member

awaelchli commented Oct 11, 2023

In "can cause unreliable results" are you perhaps implying that there is no guarantee the DL will not continue from where it left off on saving the last checkpoint but will repeat the same data? Shouldn't PTL save the worker RNG state and correctly restore it on resume? Though we custom DL there is no way PTL could easily do that.

Yes exactly, that's what the warning is trying to say. I often struggle to explain to users that resuming mid-epoch is highly non-trivial. It seems so obvious that Lightning should "support" this to many users, or users even assume that Lightning already does this without even questioning what might happen with the random state. The surprise could be that the results are skewed due to the network seeing some data more often than others due to the restart.

We spent quite some time (two entire releases) developing a fault-tolerant system but it never came out of the experimental state because of several challenges. Capturing the random state in workers was possible, but very costly at the same time and had a load of edge cases to handle. Lots of caveats around IterableDataset. Even with a limited scope, the complex situation became unmanageable. We ultimately decided to drop the effort of making dataloaders stateful and resumable, and instead only handle the loop state and trainer state. We hoped that eventually DataLoader2 / torchdata would put the necessary building blocks in place to make data pipes serializable, but now that they stopped the development this won't be possible. For now, we say that Lightning can guarantee that the trainer and loop state is managed in a fault-tolerant way, but the data is not and is up to the user.

Back to the warning: For this warning

You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable results if
further training is done. Consider using an end-of-epoch checkpoint

I was thinking that we already have an info message when resuming a checkpoint:

Restored all states from the checkpoint at checkpoints/epoch=0-step=500.ckpt

and we could possibly mention that the checkpoint is a mid-epoch checkpoint in that message, essentially combining the warning into it. For example:

Restored all states from the mid-epoch checkpoint at checkpoints/epoch=0-step=500.ckpt
DataLoader sampler state will be reset

(I don't know yet how to best word this in the message, this is just a quick draft)

cc @carmocca

@awaelchli awaelchli added data handling Generic data-related topic and removed needs triage Waiting to be triaged by maintainers labels Oct 11, 2023
@stas00
Copy link
Contributor Author

stas00 commented Oct 11, 2023

OK, so this is actually not just an advisory warning that is relevant to some, this is a problem that needs to be made users aware of by all means. In such case I'd raise an exception on resume and explain what happens and give the user an API to tell I-know-what-I-am-doing flag to remove the exception. The current warning will be ignored by many (most?) and unexpected problems will follow since most users will blindly believe that PL had it figured out for them. This is too important to not stop the presses, IMHO.

Somehow I thought we had it right in the HF Trainer wrt RNG restoration in DL, but it's been a long time since I worked on it, so my memory is foggy. Perhaps at that time we really only dealt with generic situations. But I totally hear you that in some cases it's easy whereas in other cases the user should be taking care of it.

Perhaps:

  1. the warning should point to the specific problem of what may cause the unreliable results - that is to say explicitly that the RNG state of the DL sampler isn't restored to its pre-checkpoint-saving state and data duplication will be happening before the epoch is ended.
  2. it then ideally should point the user to the doc of how they can fix it
  3. finally ideally it should have a way for a user who fixed it to tell - I fixed it and not get a no longer applicable warning - as a new team member will always bring it up and ask - is this a problem?

and as I suggested in the first para - I think this situation should warrant an exception raising.

@carmocca
Copy link
Contributor

These are all good suggestions. The warning is really old and should be updated to describe these limitations. AFAIK the main missing pieces would be to implement #17105 and implement loading/reloading of self.logged metrics (this was partially implemented before #16516)

@carmocca carmocca added this to the future milestone Oct 24, 2023
@RuABraun
Copy link

Does the checkpoint save the number of batches that were seen in the current epoch? thinking about how to resume from an inside epoch ckpt and think one could just iterate through all the batches until batch_index > saved_index

btw fairseq has this capability built in

@stas00
Copy link
Contributor Author

stas00 commented Nov 23, 2023

Yes, that's what HF Trainer does and if I remember correctly Megatron-LM does as well.

But this only works well if you have a simple DataSampler - ideally already preprocessed - if you use a complicated one that requires a lot of real time processing such fast-forwarding could be extremely slow. So probably need to disable any transformations for such an action.

Additionally if the dataset is remote and webdataset or alike DL is used this again isn't quite doable, since you will have to potentially re-download many chunks of data from remote storage.

In these complicated cases keeping track of RNG states and restoring those is a better solution. Albeit remote storage handling can still be a problem.

@albertz
Copy link

albertz commented Jan 27, 2024

In RETURNN, we also have this capability. More specifically, we operate on sub-epochs. The user specifies the number of random partitions of a dataset. E.g. for Librispeech, we use 20, so each sub-epoch covers around 100h of audio. Once a full epoch is finished, the partitioning is redone.

In our case, we shuffle the sequences in advance for the whole dataset for each whole epoch, and then partition it evenly into the sub-epochs. This approach might not scale well for very large corpora though, as you need to operate on the list of sequences after every epoch, which might be too large to handle. (For all our research corpora, it was not a problem so far.)

All our dataset logic, also including any on-the-fly processing (e.g. augmentation etc) use a shared RNG, which is seeded at the beginning of every sub-epoch. This assures that we can safely recover after a sub-epoch.

Shuffling the sequences can also be done on-the-fly, so I think this approach can still scale to even much larger corpora.

Maybe such an approach could be interesting here as well.

But if you can properly serialize the RNG state of any data loader iterator and any other data sampler in between, then you can also recover the state after every sequence or mini batch. The approach in RETURNN does not need to serialize the RNG state, though, so it's a bit simpler to implement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data handling Generic data-related topic feature Is an improvement or enhancement
Projects
None yet
Development

No branches or pull requests

5 participants