Skip to content

Commit

Permalink
Compat option for jax devices (#320)
Browse files Browse the repository at this point in the history
Closes #318.
  • Loading branch information
pabloferz committed May 25, 2024
1 parent 2da7c79 commit fbe906d
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install python dependecies
run: |
python -m pip install --upgrade pip
pip install ase dill "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" "jax<=0.4.15" "jaxlib<=0.4.15" "jax-md>=0.2.7" jaxopt pytest matplotlib
pip install ase dill "dm-haiku<0.0.11" "e3nn-jax!=0.20.4" "jax<=0.4.15" "jaxlib<=0.4.15" "jax-md>=0.2.7" jaxopt pytest matplotlib "scipy<1.13"
- name: Install pysages
run: pip install .
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
furo
jaxlib
myst-parser
scipy<1.13
setuptools-scm
sphinx
sphinx-copybutton
Expand Down
12 changes: 10 additions & 2 deletions pysages/methods/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
from pysages.methods.restraints import canonicalize
from pysages.methods.utils import ReplicasConfiguration
from pysages.typing import Callable, Optional, Union
from pysages.utils import ToCPU, copy, dispatch, dispatch_table, has_method, identity
from pysages.utils import (
ToCPU,
copy,
device_platform,
dispatch,
dispatch_table,
has_method,
identity,
)

# Base Classes
# ============
Expand Down Expand Up @@ -392,7 +400,7 @@ def _run( # noqa: F811 # pylint: disable=C0116,E0102
context_args["context"] = sampling_context.context
sampler = sampling_context.sampler
prev_snapshot = result.snapshots
if sampler.state.xi.device().platform == "cpu":
if device_platform(sampler.state.xi) == "cpu":
prev_snapshot = copy(prev_snapshot, ToCPU())
sampler.restore(prev_snapshot)
sampler.state = result.states
Expand Down
1 change: 1 addition & 0 deletions pysages/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .compat import (
check_device_array,
device_platform,
dispatch_table,
has_method,
is_generic_subclass,
Expand Down
15 changes: 15 additions & 0 deletions pysages/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ def prod(iterable, start=1):
return result


# Compatibility for jax >=0.4.27

# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0427-may-7-2024

if _jax_version_tuple < (0, 4, 27):

def device_platform(array):
return array.device().platform

else:

def device_platform(array):
return next(iter(array.devices())).platform


# Compatibility for jax >=0.4.1

# https://github.com/google/jax/releases/tag/jax-v0.4.1
Expand Down

0 comments on commit fbe906d

Please sign in to comment.