New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add checkpointing and automatic restart logic to VMC/TDVP drivers #1237
Comments
The checkpointing should work together with the logging in some way. Using the information from So saving the checkpoint as part of the log at certain intervals (and then adding code for resuming from that checkpoint) would be the best way to implement it in my view. Do you agree, @PhilipVinc? |
I've implemented something along these lines in terms of functionality in my scripts, so not within the driver class, but similar logic. If a checkpoint is available in the current working directory, load it and restore its content, otherwise create a new one after a certain amount of iterations. I've used the functions provided in the In code, I do something like this before intializing the driver:
One catch with using the functions provided by flax is that they depend on some tensorflow library, so that's probably not desired. However, as suggested by @femtobit, having a single-file logger with append mode already provides most of what is needed, it would only be necessary to write a function that takes a path as an argument, checks whether there is a logfile with checkpoints stored, and if yes, loads the last one. For some cases (e.g.: optimizing with Adam), one might also need to save the state of the optimizer in the checkpoint in order to properly restart a run. Maybe it might be worth defining a |
One problem that would arise with checkpointing is how to handle MPI-training: our variational states right now store a per-rank sampler state with a per-rank key. If we are running under MPI, what shall we do? save a per-rank variational state (this will be wasteful, by the way, because it will save the parameters multiple times) ? As for what @maxbortone said, I gave a look at |
@PhilipVinc the only problem with |
Let's see if I can |
So, just by asking, the PR was approved and I guess in due time it will show up in a release. So the next step for netket would be to actually make sure that everything works. Most likely this means adding In general those things work out nicely if you have an 'append by default' that works system. But most our loggers don't support appending. Maybe we should fix that as well. Also, a complication is given by the fact that... how do we de-serialize the sampling state if one is running under MPI? We should serialise Nrank different states. But there's a lot of redundant information there... |
It would be great to add a flag to drivers, like
...checkpoint=True
, which automatically serialised state of the optimiser, state and everything's that's needed while the simulation runs.Moreover, if that flag is specified, and a simulation is run with the same parameters as a previous run, the driver would automatically load the last checkpoint and restart from there.
API-wise I think this flag should be specified to the constructor of a driver.
We should design a bit the API, decide how to handle reloading of loggers (like, to support resuming of Json logging, we would need to implement a deserialisation of our json outputs...).
But it would be a great addition.
The text was updated successfully, but these errors were encountered: