Skip to content

Commit

Permalink
Add functions for saving and loading simulation results (#304)
Browse files Browse the repository at this point in the history
This adds two new functions `pysages.save` and `pysages.load` that handle serialize+write to and read+deserialize from a file, respectively.

If a file was pickled before #292, `pysages.load` tries to workaround the fact that previous sampling method classes didn't contain a `ncalls` field.

Closes #298
  • Loading branch information
pabloferz committed Feb 28, 2024
1 parent ef26c2c commit d03cacd
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 35 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ FROM ssages/pysages-openmm
WORKDIR /

RUN python -m pip install --upgrade pip
RUN python -m pip install ase gsd matplotlib "pyparsing<3"
RUN python -m pip install ase dill gsd matplotlib "pyparsing<3"

# Install JAX and JAX-MD
# Install JAX and JAX-based libraries
RUN python -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN python -m pip install --upgrade "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" "jax-md>=0.2.7" jaxopt

Expand Down
11 changes: 3 additions & 8 deletions examples/hoomd3/restart/restart.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#!/usr/bin/env python3

import pickle

import hoomd
import hoomd.dlext
import hoomd.md
Expand Down Expand Up @@ -114,15 +112,12 @@ def main():
]
plot(hist_list, target_hist, (-Lmax / 2, Lmax / 2), 2)

# Dump the pickle file for restart. This is the standard way to
# save a system's information to perform a restart in a new run.
with open("restart.pickle", "wb") as f:
pickle.dump(state, f)
# Save the system's information to perform a restart in a new run.
pysages.save(state, "restart.pkl")

# Load the restart file. This is how to run a pysages run from a
# previously stored state.
with open("restart.pickle", "rb") as f:
state = pickle.load(f)
state = pysages.load("restart.pkl")

# When restarting, run the system using the same generate_context function!
state = pysages.run(state, generate_context, int(1e4))
Expand Down
9 changes: 4 additions & 5 deletions examples/openmm/string/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import argparse
import importlib
import os
import pickle
import shutil
import sys

Expand Down Expand Up @@ -81,7 +80,7 @@ def get_args(argv):
parser = argparse.ArgumentParser(
description="Example script to run the spline (improved) string method"
)
for (name, short, T, val, doc) in available_args:
for name, short, T, val, doc in available_args:
parser.add_argument("--" + name, "-" + short, type=T, default=T(val), help=doc)
parser.add_argument("--mpi", action="store_true", help="Use MPI executor")
args = parser.parse_args(argv)
Expand Down Expand Up @@ -121,11 +120,12 @@ def plot_energy(result):
s = np.linspace(0, 1, len(result["point_convergence"]))
free_energy = np.asarray(result["free_energy"])
offset = np.min(free_energy)
ax.plot(s, free_energy - offset, "o-", color="teal")

ax.plot(s, free_energy - offset, "o-", color="teal")
ax2.plot(s, result["point_convergence"], color="maroon")

fig.savefig("energy.pdf", transparent=True, bbox_inches="tight", pad_inches=0)
return fig


def plot_path(result):
Expand Down Expand Up @@ -178,8 +178,7 @@ def main(argv):
plot_path(result)
plot_energy(result)

with open("result.pkl", "wb") as file_handle:
pickle.dump(result, file_handle)
pysages.save(result, "result.pkl")


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[build-system]
requires = [
"setuptools>=41.0",
"wheel>=0.33",
"numpy>=1.16",
"cython>=0.29",
"numpy>=1.16",
"setuptools>=41.0",
"setuptools_scm[toml]>=6.0",
"wheel>=0.33",
]

[tool.pytest.ini_options]
Expand Down
25 changes: 11 additions & 14 deletions pysages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES

# flake8: noqa F401
# flake8: noqa E402,F401

"""
PySAGES: Python Suite for Advanced General Ensemble Simulations
Expand Down Expand Up @@ -58,18 +58,15 @@ def _config_jax():
_set_cuda_visible_devices()
_config_jax()


from . import backends, colvars, methods # noqa: E402, F401
from ._version import version as __version__ # noqa: E402, F401
from ._version import version_tuple as __version_tuple__ # noqa: E402, F401
from .backends import supported_backends # noqa: E402, F401
from .grids import Chebyshev, Grid # noqa: E402, F401
from .methods import ( # noqa: E402, F401
CVRestraints,
ReplicasConfiguration,
SerialExecutor,
)
from .utils import dispatch, dispatch_table # noqa: E402, F401
# pylint: disable=C0413
from . import backends, colvars, methods
from ._version import version as __version__
from ._version import version_tuple as __version_tuple__
from .backends import supported_backends
from .grids import Chebyshev, Grid
from .methods import CVRestraints, ReplicasConfiguration, SerialExecutor
from .serialization import load, save
from .utils import dispatch, dispatch_table

run = dispatch_table(dispatch)["run"]
analyze = dispatch_table(dispatch)["analyze"]
Expand All @@ -81,4 +78,4 @@ def _config_jax():
del os
del _config_jax
del _set_cuda_visible_devices
del _version
del _version # pylint: disable=E0602
132 changes: 132 additions & 0 deletions pysages/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: MIT
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES

"""
Utilities for saving and loading the results of `pysages` simulations.
This module provides two functions for managing the persistent storage of a `pysages`
simulation's state using pickle serialization (via the `dill` library).
* `load(filename)`: Attempts to load the simulation state from a file and return
the corresponding `Result` object.
* `save(result, filename)`: Saves the given `Result` object to a file.
**Note:**
These functions assume pickle's `DEFAULT_PROTOCOL` and data format. Use them with caution
if modifications have been made to the saved data structures.
"""

import dill as pickle

from pysages.methods import Metadynamics
from pysages.methods.core import GriddedSamplingMethod, Result
from pysages.typing import Callable
from pysages.utils import dispatch, identity


def load(filename) -> Result:
"""
Loads the state of an previously run `pysages` simulation from a file.
This function attempts to load the pickled data (via the `dill` library) from a file
with the given `filename` and return the corresponding `Result` object.
Parameters
----------
filename: str
The name of the file containing the pickled data.
**Notes:**
This function attempts to recover from deserialization errors related to missing
`ncalls` attributes.
"""
with open(filename, "rb") as io:
bytestring = io.read()

try:
return pickle.loads(bytestring)

except TypeError as e: # pylint: disable=W0718
if "ncalls" not in getattr(e, "message", repr(e)):
raise e

# We know that states preceed callbacks so we try to find all tuples of values
# corresponding to each state.
j = bytestring.find(b"\x8c\x06states\x94")
k = bytestring.find(b"\x8c\tcallbacks\x94")
boundary = b"t\x94\x81\x94"

marks = []
while True:
i = j
j = bytestring.find(boundary, i + 1, k)
if j == -1:
marks.append((i, len(bytestring)))
break
marks.append((i, j))

# We set `ncalls` as zero and adjust it later
first = marks[0]
last = marks.pop()
slices = [
bytestring[: first[0]],
*(bytestring[i:j] + b"K\x00" for (i, j) in marks),
bytestring[last[0] :], # noqa: E203
]
bytestring = b"".join(slices)

# Try to deserialize again
result = pickle.loads(bytestring)

# Update results with `ncalls` estimates for each state
update = _ncalls_estimator(result.method)
result.states = [update(state) for state in result.states]

return result


def save(result: Result, filename) -> None:
"""
Saves the result of a `pysages` simulation to a file.
This function saves the given `Result` object to a file with the specified `filename`
using pickle serialization (via the `dill` library).
Parameters
----------
result: Result
The `Result` object to be saved.
filename: str
The name of the file to save the data to.
"""
with open(filename, "wb") as io:
pickle.dump(result, io)


@dispatch
def _ncalls_estimator(_) -> Callable:
# Fallback case. We leave ncalls as zero.
return identity


@dispatch
def _ncalls_estimator(_: Metadynamics) -> Callable:
def update(state):
ncalls = state.idx # use the number of gaussians deposited as proxy
return state._replace(ncalls=ncalls)

return update


@dispatch
def _ncalls_estimator(_: GriddedSamplingMethod) -> Callable:
def update(state):
ncalls = state.hist.sum().item() # use the histograms total count as proxy
return state._replace(ncalls=ncalls)

return update
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ license = MIT/GPL-3.0
packages = find:
python_requires = >=3.6
install_requires =
jax >=0.3.5
cython
dill
jax >=0.3.5
plum-dispatch >=1.5.4, !=2.0.0, !=2.0.1
numba

Expand Down
16 changes: 14 additions & 2 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import inspect
import pathlib
import tempfile

import dill as pickle
Expand Down Expand Up @@ -168,9 +169,9 @@ def test_pickle_colvars():


def test_pickle_results():
with tempfile.NamedTemporaryFile() as tmp_pickle:
test_result = abf_example.run_simulation(10, write_output=False)
test_result = abf_example.run_simulation(10, write_output=False)

with tempfile.NamedTemporaryFile() as tmp_pickle:
pickle.dump(test_result, tmp_pickle)
tmp_pickle.flush()

Expand All @@ -180,3 +181,14 @@ def test_pickle_results():
assert np.all(test_result.states[0].bias == tmp_result.states[0].bias).item()
assert np.all(test_result.states[0].hist == tmp_result.states[0].hist).item()
assert np.all(test_result.states[0].Fsum == tmp_result.states[0].Fsum).item()

tmp_file = pathlib.Path(".tmp_test_pickle")
pysages.save(test_result, tmp_file)
tmp_result = pysages.load(tmp_file.name)

assert np.all(test_result.states[0].xi == tmp_result.states[0].xi).item()
assert np.all(test_result.states[0].bias == tmp_result.states[0].bias).item()
assert np.all(test_result.states[0].hist == tmp_result.states[0].hist).item()
assert np.all(test_result.states[0].Fsum == tmp_result.states[0].Fsum).item()

tmp_file.unlink()

0 comments on commit d03cacd

Please sign in to comment.