diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68ceb6c5..31829821 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: - name: Install python dependecies run: | python -m pip install --upgrade pip - pip install ase dill "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" "jax<=0.4.15" "jaxlib<=0.4.15" "jax-md>=0.2.7" jaxopt pytest matplotlib + pip install ase dill "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" "jax<=0.4.15" "jaxlib<=0.4.15" "jax-md>=0.2.7" jaxopt pytest matplotlib "scipy<1.13" - name: Install pysages run: pip install . diff --git a/docs/requirements.txt b/docs/requirements.txt index 2c94661a..8c2c8a9b 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,7 @@ furo jaxlib myst-parser +scipy<1.13 setuptools-scm sphinx sphinx-copybutton diff --git a/pysages/methods/core.py b/pysages/methods/core.py index 6c10fd68..13e004d1 100644 --- a/pysages/methods/core.py +++ b/pysages/methods/core.py @@ -16,7 +16,15 @@ from pysages.methods.restraints import canonicalize from pysages.methods.utils import ReplicasConfiguration from pysages.typing import Callable, Optional, Union -from pysages.utils import ToCPU, copy, dispatch, dispatch_table, has_method, identity +from pysages.utils import ( + ToCPU, + copy, + device_platform, + dispatch, + dispatch_table, + has_method, + identity, +) # Base Classes # ============ @@ -392,7 +400,7 @@ def _run( # noqa: F811 # pylint: disable=C0116,E0102 context_args["context"] = sampling_context.context sampler = sampling_context.sampler prev_snapshot = result.snapshots - if sampler.state.xi.device().platform == "cpu": + if device_platform(sampler.state.xi) == "cpu": prev_snapshot = copy(prev_snapshot, ToCPU()) sampler.restore(prev_snapshot) sampler.state = result.states diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index 00279a4c..3351d48e 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -10,6 +10,7 @@ from .compat import ( check_device_array, + device_platform, dispatch_table, has_method, is_generic_subclass, diff --git a/pysages/utils/compat.py b/pysages/utils/compat.py index d72094e0..1544d42a 100644 --- a/pysages/utils/compat.py +++ b/pysages/utils/compat.py @@ -36,6 +36,21 @@ def prod(iterable, start=1): return result +# Compatibility for jax >=0.4.27 + +# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0427-may-7-2024 + +if _jax_version_tuple < (0, 4, 27): + + def device_platform(array): + return array.device().platform + +else: + + def device_platform(array): + return next(iter(array.devices())).platform + + # Compatibility for jax >=0.4.1 # https://github.com/google/jax/releases/tag/jax-v0.4.1