Skip to content

Commit

Permalink
Test branch for PR 274
Browse files Browse the repository at this point in the history
  • Loading branch information
trunk-io[bot] committed Aug 3, 2023
2 parents 503ef70 + 2dc3b8d commit b948969
Show file tree
Hide file tree
Showing 12 changed files with 411 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install python dependecies
run: |
python -m pip install --upgrade pip
pip install dill "e3nn-jax<0.19.2" jaxlib jax-md jaxopt pytest matplotlib
pip install dill jaxlib "jax-md>=0.2.7" jaxopt pytest matplotlib
- name: Install pysages
run: pip install .
Expand Down Expand Up @@ -60,7 +60,7 @@ jobs:
- name: Install python dependecies
run: |
python -m pip install --upgrade pip
pip install dill "e3nn-jax<0.19.2" jaxlib jax-md jaxopt pytest pylint flake8
pip install dill jaxlib "jax-md>=0.2.7" jaxopt pytest pylint flake8
pip install -r docs/requirements.txt
- name: Install pysages
run: pip install .
Expand Down
3 changes: 2 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

- Pablo Zubieta
- Ludwig Schneider
- [Trung Nguyen](https://github.com/ndtrung81)

## Collective Variables

Expand Down Expand Up @@ -38,4 +39,4 @@ details. Specific contributions to this repository are listed below
## Other

For other contributions such as bugfixes or performance improvements, take a look at
https://github.com/SSAGESLabs/PySAGES/graphs/contributors
https://github.com/SSAGESLabs/PySAGES/graphs/contributors.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ RUN python -m pip install ase gsd matplotlib "pyparsing<3"

# Install JAX and JAX-MD
RUN python -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN python -m pip install --upgrade "e3nn-jax<0.19.2" jax-md jaxopt
RUN python -m pip install --upgrade "jax-md>=0.2.7" jaxopt

COPY . /PySAGES
RUN pip install /PySAGES/
21 changes: 21 additions & 0 deletions examples/lammps/unbiased/lj.lmp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# 3d Lennard-Jones melt

units lj
atom_style atomic
atom_modify map yes

lattice fcc 0.8442
region box block 0 20 0 20 0 20
create_box 1 box
create_atoms 1 box
mass 1 1.0

velocity all create 1.44 87287 loop geom

pair_style lj/cut 2.5
pair_coeff 1 1 1.0 1.0 2.5

neighbor 0.3 bin
neigh_modify delay 5 every 1

fix 1 all nve
93 changes: 93 additions & 0 deletions examples/lammps/unbiased/unbiased.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#!/usr/bin/env python3

"""
Example unbiased simulation with pysages and lammps.
For a list of possible options for running the script pass `-h` as argument from the
command line, or call `get_args(["-h"])` if the module was loaded interactively.
"""

# %%
import argparse
import sys

from lammps import lammps

import pysages
from pysages.backends import SamplingContext
from pysages.colvars import Component
from pysages.methods import Unbiased


# %%
def generate_context(args="", script="lj.lmp", store_freq=1):
"""
Returns a lammps simulation defined by the contents of `script` using `args` as
initialization arguments.
"""
context = lammps(cmdargs=args.split())
context.file(script)
# Allow for the retrieval of the unwrapped positions
context.command(f"fix unwrap all store/state {store_freq} xu yu zu")
return context


def get_args(argv):
"""Process the command-line arguments to this script."""

available_args = [
("time-steps", "t", int, 1e2, "Number of simulation steps"),
("kokkos", "k", bool, True, "Whether to use Kokkos acceleration"),
]
parser = argparse.ArgumentParser(description="Example script to run pysages with lammps")

for name, short, T, val, doc in available_args:
if T is bool:
action = "store_" + str(val).lower()
parser.add_argument("--" + name, "-" + short, action=action, help=doc)
else:
convert = (lambda x: int(float(x))) if T is int else T
parser.add_argument("--" + name, "-" + short, type=convert, default=T(val), help=doc)

return parser.parse_args(argv)


def main(argv):
"""Example simulation with pysages and lammps."""
args = get_args(argv)

context_args = {"store_freq": args.time_steps}
if args.kokkos:
# Passed to the lammps constructor as `cmdargs` when running the script
# with the --kokkos (or -k) option
context_args["args"] = "-k on g 1 -sf kk -pk kokkos newton on neigh half"

# Setting the collective variable, method, and running the simulation
cvs = [Component([0], i) for i in range(3)]
method = Unbiased(cvs)
sampling_context = SamplingContext(method, generate_context, context_args=context_args)
result = pysages.run(sampling_context, args.time_steps)

# Post-run analysis
# -----------------
context = sampling_context.context
nlocal = sampling_context.sampler.view.local_particle_number()
snapshot = result.snapshots[0]
state = result.states[0]

# Retrieve the pointer to the unwrapped positions,
ptr = context.extract_fix("unwrap", 1, 2)
# and make them available as a numpy ndarray
positions = context.numpy.darray(ptr, nlocal, dim=3)
# Get the map to sort the atoms since they can be reordered during the simulation
ids = context.numpy.extract_atom("id").argsort()

# The ids of the final snapshot in pysages and lammps should be the same
assert (snapshot.ids == ids).all()
# For our example, the last value of the CV should match
# the unwrapped position of the zeroth atom
assert (state.xi.flatten() == positions[ids[0]]).all()


if __name__ == "__main__":
main(sys.argv[1:])
11 changes: 7 additions & 4 deletions pysages/backends/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,12 @@ def __init__(
self._backend_name = "ase"
elif module_name.startswith("hoomd"):
self._backend_name = "hoomd"
elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"):
self._backend_name = "openmm"
elif isinstance(context, JaxMDContext):
self._backend_name = "jax-md"
elif module_name.startswith("lammps"):
self._backend_name = "lammps"
elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"):
self._backend_name = "openmm"

if self._backend_name is None:
backends = ", ".join(supported_backends())
Expand Down Expand Up @@ -113,14 +115,15 @@ def __enter__(self):
"""
if hasattr(self.context, "__enter__"):
return self.context.__enter__()
return self.context

def __exit__(self, exc_type, exc_value, exc_traceback):
"""
Trampoline 'with statements' to the wrapped context when the backend supports it.
"""
if hasattr(self.context, "__exit__"):
return self.context.__exit__(exc_type, exc_value, exc_traceback)
self.context.__exit__(exc_type, exc_value, exc_traceback)


def supported_backends():
return ("ase", "hoomd", "jax-md", "openmm")
return ("ase", "hoomd", "jax-md", "lammps", "openmm")
Loading

0 comments on commit b948969

Please sign in to comment.