Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpointing and automatic restart logic to VMC/TDVP drivers #1237

Open
PhilipVinc opened this issue Jun 10, 2022 · 6 comments
Open

Add checkpointing and automatic restart logic to VMC/TDVP drivers #1237

PhilipVinc opened this issue Jun 10, 2022 · 6 comments
Labels
contributor welcome We welcome contributions or PRs on this issue enhancement New feature or request

Comments

@PhilipVinc
Copy link
Member

PhilipVinc commented Jun 10, 2022

It would be great to add a flag to drivers, like ...checkpoint=True, which automatically serialised state of the optimiser, state and everything's that's needed while the simulation runs.

Moreover, if that flag is specified, and a simulation is run with the same parameters as a previous run, the driver would automatically load the last checkpoint and restart from there.

API-wise I think this flag should be specified to the constructor of a driver.

class AbstractVariationalDriver
  def __init__(self, .........., checkpoint:bool=False):

       if checkpoint:
            if exists(checkpoint_data):
              self.load_checkpoint(..)
            else:
              self.setup_checkpoint(..)

We should design a bit the API, decide how to handle reloading of loggers (like, to support resuming of Json logging, we would need to implement a deserialisation of our json outputs...).

But it would be a great addition.

@PhilipVinc PhilipVinc added enhancement New feature or request contributor welcome We welcome contributions or PRs on this issue labels Jun 10, 2022
@femtobit
Copy link
Collaborator

femtobit commented Jun 10, 2022

The checkpointing should work together with the logging in some way. Using the information from HDF5Log with saving of intermediate parameters enabled already provides like 80 % of the data required for a checkpoint (the other loggers do too, but in a more scattered fashion) and storing additional data such as the sampler state is supported by the logging interface too.

So saving the checkpoint as part of the log at certain intervals (and then adding code for resuming from that checkpoint) would be the best way to implement it in my view. Do you agree, @PhilipVinc?

@maxbortone
Copy link
Contributor

I've implemented something along these lines in terms of functionality in my scripts, so not within the driver class, but similar logic. If a checkpoint is available in the current working directory, load it and restore its content, otherwise create a new one after a certain amount of iterations. I've used the functions provided in the flax.training.checkpoints package to save and restore checkpoints.

In code, I do something like this before intializing the driver:

parameters = vs.parameters.unfreeze()
opt_state = op.init(parameters)
initial_step = 1
opt_state, parameters, initial_step = restore_checkpoint(workdir, (opt_state, parameters, initial_step))
vs.parameters = parameters
vmc = nk.driver.VMC(ha, op, variational_state=vs, preconditioner=sr)

One catch with using the functions provided by flax is that they depend on some tensorflow library, so that's probably not desired. However, as suggested by @femtobit, having a single-file logger with append mode already provides most of what is needed, it would only be necessary to write a function that takes a path as an argument, checks whether there is a logfile with checkpoints stored, and if yes, loads the last one.

For some cases (e.g.: optimizing with Adam), one might also need to save the state of the optimizer in the checkpoint in order to properly restart a run. Maybe it might be worth defining a Checkpoint dataclass, which stores step, model parameters and optimizer state and save that in the hdf5 logfile in a checkpoint group.

@PhilipVinc
Copy link
Member Author

One problem that would arise with checkpointing is how to handle MPI-training: our variational states right now store a per-rank sampler state with a per-rank key. If we are running under MPI, what shall we do? save a per-rank variational state (this will be wasteful, by the way, because it will save the parameters multiple times) ?

As for what @maxbortone said, I gave a look at flax.training.checkpoints and it's quite well written and would answer our needs. Actually, if we declared how flax can serialise/deserialize the driver then we'd be done already..

@maxbortone
Copy link
Contributor

@PhilipVinc the only problem with flax.training.checkpoints is that the package as of now depends on tensorflow.io.gfile, which is an I/O package that supports different filesystems. I've circumvented this by rewriting flax.training.checkpoints without this package, since I have no need to write to cloud storage or other filesystems. This flax PR has tackled this issue and provided a shim to remove the TF dependency, but has not yet been merged into the main branch. We could ask to contribute to it and finish it for the next release maybe.

@PhilipVinc
Copy link
Member Author

PhilipVinc commented Jun 15, 2022

Let's see if I can harass them into submission to the NetKet empire convince them to merge the PR...

@PhilipVinc
Copy link
Member Author

So, just by asking, the PR was approved and I guess in due time it will show up in a release.

So the next step for netket would be to actually make sure that everything works.

Most likely this means adding flax.serialization support to the drivers by doing something like this. Then, it should work out of the box, exception made for the loggers. Those... it might be a bit harder to get them to work.

In general those things work out nicely if you have an 'append by default' that works system. But most our loggers don't support appending. Maybe we should fix that as well.

Also, a complication is given by the fact that... how do we de-serialize the sampling state if one is running under MPI? We should serialise Nrank different states. But there's a lot of redundant information there...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor welcome We welcome contributions or PRs on this issue enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants